diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp index 119073776a..e19d502bc1 100644 --- a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp +++ b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp @@ -254,10 +254,6 @@ public: } private: - using OperationDecompilerFn = std::string (GLSLDecompiler::*)(Operation); - using OperationDecompilersArray = - std::array(OperationCode::Amount)>; - void DeclareVertex() { if (stage != ShaderStage::Vertex) return; @@ -1400,14 +1396,10 @@ private: return fmt::format("{}[{}]", pair, VisitOperand(operation, 1, Type::Uint)); } - std::string LogicalAll2(Operation operation) { + std::string LogicalAnd2(Operation operation) { return GenerateUnary(operation, "all", Type::Bool, Type::Bool2); } - std::string LogicalAny2(Operation operation) { - return GenerateUnary(operation, "any", Type::Bool, Type::Bool2); - } - template std::string GenerateHalfComparison(Operation operation, const std::string& compare_op) { const std::string comparison{GenerateBinaryCall(operation, compare_op, Type::Bool2, @@ -1714,7 +1706,7 @@ private: return "utof(gl_WorkGroupID"s + GetSwizzle(element) + ')'; } - static constexpr OperationDecompilersArray operation_decompilers = { + static constexpr std::array operation_decompilers = { &GLSLDecompiler::Assign, &GLSLDecompiler::Select, @@ -1798,8 +1790,7 @@ private: &GLSLDecompiler::LogicalXor, &GLSLDecompiler::LogicalNegate, &GLSLDecompiler::LogicalPick2, - &GLSLDecompiler::LogicalAll2, - &GLSLDecompiler::LogicalAny2, + &GLSLDecompiler::LogicalAnd2, &GLSLDecompiler::LogicalLessThan, &GLSLDecompiler::LogicalEqual, @@ -1863,6 +1854,7 @@ private: &GLSLDecompiler::WorkGroupId<1>, &GLSLDecompiler::WorkGroupId<2>, }; + static_assert(operation_decompilers.size() == static_cast(OperationCode::Amount)); std::string GetRegister(u32 index) const { return GetDeclarationWithSuffix(index, "gpr"); diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp index 9b2d8e987d..d267712c9f 100644 --- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp +++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp @@ -205,10 +205,6 @@ public: } private: - using OperationDecompilerFn = Id (SPIRVDecompiler::*)(Operation); - using OperationDecompilersArray = - std::array(OperationCode::Amount)>; - static constexpr auto INTERNAL_FLAGS_COUNT = static_cast(InternalFlag::Amount); void AllocateBindings() { @@ -804,12 +800,7 @@ private: return {}; } - Id LogicalAll2(Operation operation) { - UNIMPLEMENTED(); - return {}; - } - - Id LogicalAny2(Operation operation) { + Id LogicalAnd2(Operation operation) { UNIMPLEMENTED(); return {}; } @@ -1206,7 +1197,7 @@ private: return {}; } - static constexpr OperationDecompilersArray operation_decompilers = { + static constexpr std::array operation_decompilers = { &SPIRVDecompiler::Assign, &SPIRVDecompiler::Ternary<&Module::OpSelect, Type::Float, Type::Bool, Type::Float, @@ -1291,8 +1282,7 @@ private: &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>, &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>, &SPIRVDecompiler::LogicalPick2, - &SPIRVDecompiler::LogicalAll2, - &SPIRVDecompiler::LogicalAny2, + &SPIRVDecompiler::LogicalAnd2, &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>, &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>, @@ -1357,6 +1347,7 @@ private: &SPIRVDecompiler::WorkGroupId<1>, &SPIRVDecompiler::WorkGroupId<2>, }; + static_assert(operation_decompilers.size() == static_cast(OperationCode::Amount)); const VKDevice& device; const ShaderIR& ir; diff --git a/src/video_core/shader/decode/half_set_predicate.cpp b/src/video_core/shader/decode/half_set_predicate.cpp index ff41fb2b5b..ad180d6dfd 100644 --- a/src/video_core/shader/decode/half_set_predicate.cpp +++ b/src/video_core/shader/decode/half_set_predicate.cpp @@ -51,26 +51,23 @@ u32 ShaderIR::DecodeHalfSetPredicate(NodeBlock& bb, u32 pc) { op_b = Immediate(0); } - // We can't use the constant predicate as destination. - ASSERT(instr.hsetp2.pred3 != static_cast(Pred::UnusedIndex)); - - const Node second_pred = GetPredicate(instr.hsetp2.pred39, instr.hsetp2.neg_pred != 0); - const OperationCode combiner = GetPredicateCombiner(instr.hsetp2.op); - const OperationCode pair_combiner = - h_and ? OperationCode::LogicalAll2 : OperationCode::LogicalAny2; + const Node pred39 = GetPredicate(instr.hsetp2.pred39, instr.hsetp2.neg_pred); + + const auto Write = [&](u64 dest, Node src) { + SetPredicate(bb, dest, Operation(combiner, std::move(src), pred39)); + }; const Node comparison = GetPredicateComparisonHalf(cond, op_a, op_b); - const Node first_pred = Operation(pair_combiner, comparison); - - // Set the primary predicate to the result of Predicate OP SecondPredicate - const Node value = Operation(combiner, first_pred, second_pred); - SetPredicate(bb, instr.hsetp2.pred3, value); - - if (instr.hsetp2.pred0 != static_cast(Pred::UnusedIndex)) { - // Set the secondary predicate to the result of !Predicate OP SecondPredicate, if enabled - const Node negated_pred = Operation(OperationCode::LogicalNegate, first_pred); - SetPredicate(bb, instr.hsetp2.pred0, Operation(combiner, negated_pred, second_pred)); + const u64 first = instr.hsetp2.pred0; + const u64 second = instr.hsetp2.pred3; + if (h_and) { + const Node joined = Operation(OperationCode::LogicalAnd2, comparison); + Write(first, joined); + Write(second, Operation(OperationCode::LogicalNegate, joined)); + } else { + Write(first, Operation(OperationCode::LogicalPick2, comparison, Immediate(0u))); + Write(second, Operation(OperationCode::LogicalPick2, comparison, Immediate(1u))); } return pc; diff --git a/src/video_core/shader/node.h b/src/video_core/shader/node.h index 7427ed896d..715184d67a 100644 --- a/src/video_core/shader/node.h +++ b/src/video_core/shader/node.h @@ -101,8 +101,7 @@ enum class OperationCode { LogicalXor, /// (bool a, bool b) -> bool LogicalNegate, /// (bool a) -> bool LogicalPick2, /// (bool2 pair, uint index) -> bool - LogicalAll2, /// (bool2 a) -> bool - LogicalAny2, /// (bool2 a) -> bool + LogicalAnd2, /// (bool2 a) -> bool LogicalFLessThan, /// (float a, float b) -> bool LogicalFEqual, /// (float a, float b) -> bool