From 9170200a11715d131645d1ffb92e86e6ef0d7e88 Mon Sep 17 00:00:00 2001 From: ReinUsesLisp Date: Thu, 11 Feb 2021 16:39:06 -0300 Subject: [PATCH] shader: Initial implementation of an AST --- externals/sirit | 2 +- src/shader_recompiler/CMakeLists.txt | 4 +- .../backend/spirv/emit_spirv.cpp | 45 +- .../backend/spirv/emit_spirv.h | 18 +- .../spirv/emit_spirv_context_get_set.cpp | 8 + .../backend/spirv/emit_spirv_control_flow.cpp | 25 - .../backend/spirv/emit_spirv_undefined.cpp | 16 +- .../frontend/ir/basic_block.cpp | 64 +- .../frontend/ir/basic_block.h | 40 +- .../frontend/ir/condition.cpp | 14 +- src/shader_recompiler/frontend/ir/condition.h | 2 +- src/shader_recompiler/frontend/ir/function.h | 2 +- .../frontend/ir/ir_emitter.cpp | 43 +- .../frontend/ir/ir_emitter.h | 23 +- .../frontend/ir/microinstruction.cpp | 4 +- src/shader_recompiler/frontend/ir/opcodes.inc | 16 +- .../frontend/ir/structured_control_flow.cpp | 742 ++++++++++++++++++ .../frontend/ir/structured_control_flow.h | 22 + .../frontend/maxwell/control_flow.cpp | 424 ++++------ .../frontend/maxwell/control_flow.h | 77 +- .../frontend/maxwell/location.h | 12 +- .../frontend/maxwell/program.cpp | 69 +- .../frontend/maxwell/program.h | 2 +- .../frontend/maxwell/termination_code.cpp | 86 -- .../frontend/maxwell/termination_code.h | 17 - .../translate/impl/integer_shift_left.cpp | 2 +- .../frontend/maxwell/translate/translate.cpp | 17 +- .../frontend/maxwell/translate/translate.h | 7 +- .../ir_opt/constant_propagation_pass.cpp | 50 ++ .../ir_opt/ssa_rewrite_pass.cpp | 24 +- .../ir_opt/verification_pass.cpp | 4 + src/shader_recompiler/main.cpp | 29 +- src/shader_recompiler/shader_info.h | 28 + 33 files changed, 1347 insertions(+), 591 deletions(-) create mode 100644 src/shader_recompiler/frontend/ir/structured_control_flow.cpp create mode 100644 src/shader_recompiler/frontend/ir/structured_control_flow.h delete mode 100644 src/shader_recompiler/frontend/maxwell/termination_code.cpp delete mode 100644 src/shader_recompiler/frontend/maxwell/termination_code.h create mode 100644 src/shader_recompiler/shader_info.h diff --git a/externals/sirit b/externals/sirit index 1f7b70730..c374bfd9f 160000 --- a/externals/sirit +++ b/externals/sirit @@ -1 +1 @@ -Subproject commit 1f7b70730d610cfbd5099ab93dd38ec8a78e7e35 +Subproject commit c374bfd9fdff02a0cff85d005488967b1b0f675e diff --git a/src/shader_recompiler/CMakeLists.txt b/src/shader_recompiler/CMakeLists.txt index 12fbcb37c..27fc79e21 100644 --- a/src/shader_recompiler/CMakeLists.txt +++ b/src/shader_recompiler/CMakeLists.txt @@ -36,6 +36,8 @@ add_executable(shader_recompiler frontend/ir/program.cpp frontend/ir/program.h frontend/ir/reg.h + frontend/ir/structured_control_flow.cpp + frontend/ir/structured_control_flow.h frontend/ir/type.cpp frontend/ir/type.h frontend/ir/value.cpp @@ -51,8 +53,6 @@ add_executable(shader_recompiler frontend/maxwell/opcodes.h frontend/maxwell/program.cpp frontend/maxwell/program.h - frontend/maxwell/termination_code.cpp - frontend/maxwell/termination_code.h frontend/maxwell/translate/impl/common_encoding.h frontend/maxwell/translate/impl/floating_point_add.cpp frontend/maxwell/translate/impl/floating_point_conversion_integer.cpp diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp index 7c4269fad..5022b5159 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp @@ -105,8 +105,26 @@ void EmitSPIRV::EmitInst(EmitContext& ctx, IR::Inst* inst) { throw LogicError("Invalid opcode {}", inst->Opcode()); } -void EmitSPIRV::EmitPhi(EmitContext&) { - throw NotImplementedException("SPIR-V Instruction"); +static Id TypeId(const EmitContext& ctx, IR::Type type) { + switch (type) { + case IR::Type::U1: + return ctx.u1; + default: + throw NotImplementedException("Phi node type {}", type); + } +} + +Id EmitSPIRV::EmitPhi(EmitContext& ctx, IR::Inst* inst) { + const size_t num_args{inst->NumArgs()}; + boost::container::small_vector operands; + operands.reserve(num_args * 2); + for (size_t index = 0; index < num_args; ++index) { + IR::Block* const phi_block{inst->PhiBlock(index)}; + operands.push_back(ctx.Def(inst->Arg(index))); + operands.push_back(ctx.BlockLabel(phi_block)); + } + const Id result_type{TypeId(ctx, inst->Arg(0).Type())}; + return ctx.OpPhi(result_type, std::span(operands.data(), operands.size())); } void EmitSPIRV::EmitVoid(EmitContext&) {} @@ -115,6 +133,29 @@ void EmitSPIRV::EmitIdentity(EmitContext&) { throw NotImplementedException("SPIR-V Instruction"); } +// FIXME: Move to its own file +void EmitSPIRV::EmitBranch(EmitContext& ctx, IR::Inst* inst) { + ctx.OpBranch(ctx.BlockLabel(inst->Arg(0).Label())); +} + +void EmitSPIRV::EmitBranchConditional(EmitContext& ctx, IR::Inst* inst) { + ctx.OpBranchConditional(ctx.Def(inst->Arg(0)), ctx.BlockLabel(inst->Arg(1).Label()), + ctx.BlockLabel(inst->Arg(2).Label())); +} + +void EmitSPIRV::EmitLoopMerge(EmitContext& ctx, IR::Inst* inst) { + ctx.OpLoopMerge(ctx.BlockLabel(inst->Arg(0).Label()), ctx.BlockLabel(inst->Arg(1).Label()), + spv::LoopControlMask::MaskNone); +} + +void EmitSPIRV::EmitSelectionMerge(EmitContext& ctx, IR::Inst* inst) { + ctx.OpSelectionMerge(ctx.BlockLabel(inst->Arg(0).Label()), spv::SelectionControlMask::MaskNone); +} + +void EmitSPIRV::EmitReturn(EmitContext& ctx) { + ctx.OpReturn(); +} + void EmitSPIRV::EmitGetZeroFromOp(EmitContext&) { throw LogicError("Unreachable instruction"); } diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h index 3f4b68a7d..9aa83b5de 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv.h +++ b/src/shader_recompiler/backend/spirv/emit_spirv.h @@ -124,18 +124,20 @@ private: void EmitInst(EmitContext& ctx, IR::Inst* inst); // Microinstruction emitters - void EmitPhi(EmitContext& ctx); + Id EmitPhi(EmitContext& ctx, IR::Inst* inst); void EmitVoid(EmitContext& ctx); void EmitIdentity(EmitContext& ctx); void EmitBranch(EmitContext& ctx, IR::Inst* inst); void EmitBranchConditional(EmitContext& ctx, IR::Inst* inst); - void EmitExit(EmitContext& ctx); + void EmitLoopMerge(EmitContext& ctx, IR::Inst* inst); + void EmitSelectionMerge(EmitContext& ctx, IR::Inst* inst); void EmitReturn(EmitContext& ctx); - void EmitUnreachable(EmitContext& ctx); void EmitGetRegister(EmitContext& ctx); void EmitSetRegister(EmitContext& ctx); void EmitGetPred(EmitContext& ctx); void EmitSetPred(EmitContext& ctx); + void EmitSetGotoVariable(EmitContext& ctx); + void EmitGetGotoVariable(EmitContext& ctx); Id EmitGetCbuf(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset); void EmitGetAttribute(EmitContext& ctx); void EmitSetAttribute(EmitContext& ctx); @@ -151,11 +153,11 @@ private: void EmitSetOFlag(EmitContext& ctx); Id EmitWorkgroupId(EmitContext& ctx); Id EmitLocalInvocationId(EmitContext& ctx); - void EmitUndef1(EmitContext& ctx); - void EmitUndef8(EmitContext& ctx); - void EmitUndef16(EmitContext& ctx); - void EmitUndef32(EmitContext& ctx); - void EmitUndef64(EmitContext& ctx); + Id EmitUndefU1(EmitContext& ctx); + void EmitUndefU8(EmitContext& ctx); + void EmitUndefU16(EmitContext& ctx); + void EmitUndefU32(EmitContext& ctx); + void EmitUndefU64(EmitContext& ctx); void EmitLoadGlobalU8(EmitContext& ctx); void EmitLoadGlobalS8(EmitContext& ctx); void EmitLoadGlobalU16(EmitContext& ctx); diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp index b121305ea..1eab739ed 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp @@ -22,6 +22,14 @@ void EmitSPIRV::EmitSetPred(EmitContext&) { throw NotImplementedException("SPIR-V Instruction"); } +void EmitSPIRV::EmitSetGotoVariable(EmitContext&) { + throw NotImplementedException("SPIR-V Instruction"); +} + +void EmitSPIRV::EmitGetGotoVariable(EmitContext&) { + throw NotImplementedException("SPIR-V Instruction"); +} + Id EmitSPIRV::EmitGetCbuf(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) { if (!binding.IsImmediate()) { throw NotImplementedException("Constant buffer indexing"); diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp index 770fe113c..66ce6c8c5 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp @@ -3,28 +3,3 @@ // Refer to the license.txt file included. #include "shader_recompiler/backend/spirv/emit_spirv.h" - -namespace Shader::Backend::SPIRV { - -void EmitSPIRV::EmitBranch(EmitContext& ctx, IR::Inst* inst) { - ctx.OpBranch(ctx.BlockLabel(inst->Arg(0).Label())); -} - -void EmitSPIRV::EmitBranchConditional(EmitContext& ctx, IR::Inst* inst) { - ctx.OpBranchConditional(ctx.Def(inst->Arg(0)), ctx.BlockLabel(inst->Arg(1).Label()), - ctx.BlockLabel(inst->Arg(2).Label())); -} - -void EmitSPIRV::EmitExit(EmitContext& ctx) { - ctx.OpReturn(); -} - -void EmitSPIRV::EmitReturn(EmitContext&) { - throw NotImplementedException("SPIR-V Instruction"); -} - -void EmitSPIRV::EmitUnreachable(EmitContext&) { - throw NotImplementedException("SPIR-V Instruction"); -} - -} // namespace Shader::Backend::SPIRV diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp index 3850b072c..859b60a95 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp @@ -6,23 +6,23 @@ namespace Shader::Backend::SPIRV { -void EmitSPIRV::EmitUndef1(EmitContext&) { +Id EmitSPIRV::EmitUndefU1(EmitContext& ctx) { + return ctx.OpUndef(ctx.u1); +} + +void EmitSPIRV::EmitUndefU8(EmitContext&) { throw NotImplementedException("SPIR-V Instruction"); } -void EmitSPIRV::EmitUndef8(EmitContext&) { +void EmitSPIRV::EmitUndefU16(EmitContext&) { throw NotImplementedException("SPIR-V Instruction"); } -void EmitSPIRV::EmitUndef16(EmitContext&) { +void EmitSPIRV::EmitUndefU32(EmitContext&) { throw NotImplementedException("SPIR-V Instruction"); } -void EmitSPIRV::EmitUndef32(EmitContext&) { - throw NotImplementedException("SPIR-V Instruction"); -} - -void EmitSPIRV::EmitUndef64(EmitContext&) { +void EmitSPIRV::EmitUndefU64(EmitContext&) { throw NotImplementedException("SPIR-V Instruction"); } diff --git a/src/shader_recompiler/frontend/ir/basic_block.cpp b/src/shader_recompiler/frontend/ir/basic_block.cpp index da33ff6f1..b5616f394 100644 --- a/src/shader_recompiler/frontend/ir/basic_block.cpp +++ b/src/shader_recompiler/frontend/ir/basic_block.cpp @@ -17,6 +17,8 @@ namespace Shader::IR { Block::Block(ObjectPool& inst_pool_, u32 begin, u32 end) : inst_pool{&inst_pool_}, location_begin{begin}, location_end{end} {} +Block::Block(ObjectPool& inst_pool_) : Block{inst_pool_, 0, 0} {} + Block::~Block() = default; void Block::AppendNewInst(Opcode op, std::initializer_list args) { @@ -38,8 +40,25 @@ Block::iterator Block::PrependNewInst(iterator insertion_point, Opcode op, return result_it; } -void Block::AddImmediatePredecessor(IR::Block* immediate_predecessor) { - imm_predecessors.push_back(immediate_predecessor); +void Block::SetBranches(Condition cond, Block* branch_true_, Block* branch_false_) { + branch_cond = cond; + branch_true = branch_true_; + branch_false = branch_false_; +} + +void Block::SetBranch(Block* branch) { + branch_cond = Condition{true}; + branch_true = branch; +} + +void Block::SetReturn() { + branch_cond = Condition{true}; + branch_true = nullptr; + branch_false = nullptr; +} + +bool Block::IsVirtual() const noexcept { + return location_begin == location_end; } u32 Block::LocationBegin() const noexcept { @@ -58,6 +77,12 @@ const Block::InstructionList& Block::Instructions() const noexcept { return instructions; } +void Block::AddImmediatePredecessor(Block* block) { + if (std::ranges::find(imm_predecessors, block) == imm_predecessors.end()) { + imm_predecessors.push_back(block); + } +} + std::span Block::ImmediatePredecessors() const noexcept { return imm_predecessors; } @@ -70,8 +95,17 @@ static std::string BlockToIndex(const std::map& block_to_i return fmt::format("$", reinterpret_cast(block)); } +static size_t InstIndex(std::map& inst_to_index, size_t& inst_index, + const Inst* inst) { + const auto [it, is_inserted]{inst_to_index.emplace(inst, inst_index + 1)}; + if (is_inserted) { + ++inst_index; + } + return it->second; +} + static std::string ArgToIndex(const std::map& block_to_index, - const std::map& inst_to_index, + std::map& inst_to_index, size_t& inst_index, const Value& arg) { if (arg.IsEmpty()) { return ""; @@ -80,10 +114,7 @@ static std::string ArgToIndex(const std::map& block_to_ind return BlockToIndex(block_to_index, arg.Label()); } if (!arg.IsImmediate()) { - if (const auto it{inst_to_index.find(arg.Inst())}; it != inst_to_index.end()) { - return fmt::format("%{}", it->second); - } - return fmt::format("%", reinterpret_cast(arg.Inst())); + return fmt::format("%{}", InstIndex(inst_to_index, inst_index, arg.Inst())); } switch (arg.Type()) { case Type::U1: @@ -125,14 +156,14 @@ std::string DumpBlock(const Block& block, const std::map& const Opcode op{inst.Opcode()}; ret += fmt::format("[{:016x}] ", reinterpret_cast(&inst)); if (TypeOf(op) != Type::Void) { - ret += fmt::format("%{:<5} = {}", inst_index, op); + ret += fmt::format("%{:<5} = {}", InstIndex(inst_to_index, inst_index, &inst), op); } else { ret += fmt::format(" {}", op); // '%00000 = ' -> 1 + 5 + 3 = 9 spaces } - const size_t arg_count{NumArgsOf(op)}; + const size_t arg_count{inst.NumArgs()}; for (size_t arg_index = 0; arg_index < arg_count; ++arg_index) { const Value arg{inst.Arg(arg_index)}; - const std::string arg_str{ArgToIndex(block_to_index, inst_to_index, arg)}; + const std::string arg_str{ArgToIndex(block_to_index, inst_to_index, inst_index, arg)}; ret += arg_index != 0 ? ", " : " "; if (op == Opcode::Phi) { ret += fmt::format("[ {}, {} ]", arg_index, @@ -140,10 +171,12 @@ std::string DumpBlock(const Block& block, const std::map& } else { ret += arg_str; } - const Type actual_type{arg.Type()}; - const Type expected_type{ArgTypeOf(op, arg_index)}; - if (!AreTypesCompatible(actual_type, expected_type)) { - ret += fmt::format("", actual_type, expected_type); + if (op != Opcode::Phi) { + const Type actual_type{arg.Type()}; + const Type expected_type{ArgTypeOf(op, arg_index)}; + if (!AreTypesCompatible(actual_type, expected_type)) { + ret += fmt::format("", actual_type, expected_type); + } } } if (TypeOf(op) != Type::Void) { @@ -151,9 +184,6 @@ std::string DumpBlock(const Block& block, const std::map& } else { ret += '\n'; } - - inst_to_index.emplace(&inst, inst_index); - ++inst_index; } return ret; } diff --git a/src/shader_recompiler/frontend/ir/basic_block.h b/src/shader_recompiler/frontend/ir/basic_block.h index ec3ad6263..3205705e7 100644 --- a/src/shader_recompiler/frontend/ir/basic_block.h +++ b/src/shader_recompiler/frontend/ir/basic_block.h @@ -11,7 +11,9 @@ #include +#include "shader_recompiler/frontend/ir/condition.h" #include "shader_recompiler/frontend/ir/microinstruction.h" +#include "shader_recompiler/frontend/ir/value.h" #include "shader_recompiler/object_pool.h" namespace Shader::IR { @@ -26,6 +28,7 @@ public: using const_reverse_iterator = InstructionList::const_reverse_iterator; explicit Block(ObjectPool& inst_pool_, u32 begin, u32 end); + explicit Block(ObjectPool& inst_pool_); ~Block(); Block(const Block&) = delete; @@ -41,9 +44,15 @@ public: iterator PrependNewInst(iterator insertion_point, Opcode op, std::initializer_list args = {}, u64 flags = 0); - /// Adds a new immediate predecessor to the basic block. - void AddImmediatePredecessor(IR::Block* immediate_predecessor); + /// Set the branches to jump to when all instructions have executed. + void SetBranches(Condition cond, Block* branch_true, Block* branch_false); + /// Set the branch to unconditionally jump to when all instructions have executed. + void SetBranch(Block* branch); + /// Mark the block as a return block. + void SetReturn(); + /// Returns true when the block does not implement any guest instructions directly. + [[nodiscard]] bool IsVirtual() const noexcept; /// Gets the starting location of this basic block. [[nodiscard]] u32 LocationBegin() const noexcept; /// Gets the end location for this basic block. @@ -54,8 +63,23 @@ public: /// Gets an immutable reference to the instruction list for this basic block. [[nodiscard]] const InstructionList& Instructions() const noexcept; + /// Adds a new immediate predecessor to this basic block. + void AddImmediatePredecessor(Block* block); /// Gets an immutable span to the immediate predecessors. - [[nodiscard]] std::span ImmediatePredecessors() const noexcept; + [[nodiscard]] std::span ImmediatePredecessors() const noexcept; + + [[nodiscard]] Condition BranchCondition() const noexcept { + return branch_cond; + } + [[nodiscard]] bool IsTerminationBlock() const noexcept { + return !branch_true && !branch_false; + } + [[nodiscard]] Block* TrueBranch() const noexcept { + return branch_true; + } + [[nodiscard]] Block* FalseBranch() const noexcept { + return branch_false; + } [[nodiscard]] bool empty() const { return instructions.empty(); @@ -129,10 +153,18 @@ private: /// List of instructions in this block InstructionList instructions; + /// Condition to choose the branch to take + Condition branch_cond{true}; + /// Block to jump into when the branch condition evaluates as true + Block* branch_true{nullptr}; + /// Block to jump into when the branch condition evaluates as false + Block* branch_false{nullptr}; /// Block immediate predecessors - std::vector imm_predecessors; + std::vector imm_predecessors; }; +using BlockList = std::vector; + [[nodiscard]] std::string DumpBlock(const Block& block); [[nodiscard]] std::string DumpBlock(const Block& block, diff --git a/src/shader_recompiler/frontend/ir/condition.cpp b/src/shader_recompiler/frontend/ir/condition.cpp index edff35dc7..ec1659e2b 100644 --- a/src/shader_recompiler/frontend/ir/condition.cpp +++ b/src/shader_recompiler/frontend/ir/condition.cpp @@ -16,15 +16,13 @@ std::string NameOf(Condition condition) { ret = fmt::to_string(condition.FlowTest()); } const auto [pred, negated]{condition.Pred()}; - if (pred != Pred::PT || negated) { - if (!ret.empty()) { - ret += '&'; - } - if (negated) { - ret += '!'; - } - ret += fmt::to_string(pred); + if (!ret.empty()) { + ret += '&'; } + if (negated) { + ret += '!'; + } + ret += fmt::to_string(pred); return ret; } diff --git a/src/shader_recompiler/frontend/ir/condition.h b/src/shader_recompiler/frontend/ir/condition.h index 52737025c..16b4ae888 100644 --- a/src/shader_recompiler/frontend/ir/condition.h +++ b/src/shader_recompiler/frontend/ir/condition.h @@ -26,7 +26,7 @@ public: explicit Condition(Pred pred_, bool pred_negated_ = false) noexcept : Condition(FlowTest::T, pred_, pred_negated_) {} - Condition(bool value) : Condition(Pred::PT, !value) {} + explicit Condition(bool value) : Condition(Pred::PT, !value) {} auto operator<=>(const Condition&) const noexcept = default; diff --git a/src/shader_recompiler/frontend/ir/function.h b/src/shader_recompiler/frontend/ir/function.h index bba7d1d39..fd7d56419 100644 --- a/src/shader_recompiler/frontend/ir/function.h +++ b/src/shader_recompiler/frontend/ir/function.h @@ -11,7 +11,7 @@ namespace Shader::IR { struct Function { - boost::container::small_vector blocks; + BlockList blocks; }; } // namespace Shader::IR diff --git a/src/shader_recompiler/frontend/ir/ir_emitter.cpp b/src/shader_recompiler/frontend/ir/ir_emitter.cpp index ada0be834..30932043f 100644 --- a/src/shader_recompiler/frontend/ir/ir_emitter.cpp +++ b/src/shader_recompiler/frontend/ir/ir_emitter.cpp @@ -44,26 +44,29 @@ F64 IREmitter::Imm64(f64 value) const { return F64{Value{value}}; } -void IREmitter::Branch(IR::Block* label) { +void IREmitter::Branch(Block* label) { + label->AddImmediatePredecessor(block); Inst(Opcode::Branch, label); } -void IREmitter::BranchConditional(const U1& cond, IR::Block* true_label, IR::Block* false_label) { - Inst(Opcode::BranchConditional, cond, true_label, false_label); +void IREmitter::BranchConditional(const U1& condition, Block* true_label, Block* false_label) { + true_label->AddImmediatePredecessor(block); + false_label->AddImmediatePredecessor(block); + Inst(Opcode::BranchConditional, condition, true_label, false_label); } -void IREmitter::Exit() { - Inst(Opcode::Exit); +void IREmitter::LoopMerge(Block* merge_block, Block* continue_target) { + Inst(Opcode::LoopMerge, merge_block, continue_target); +} + +void IREmitter::SelectionMerge(Block* merge_block) { + Inst(Opcode::SelectionMerge, merge_block); } void IREmitter::Return() { Inst(Opcode::Return); } -void IREmitter::Unreachable() { - Inst(Opcode::Unreachable); -} - U32 IREmitter::GetReg(IR::Reg reg) { return Inst(Opcode::GetRegister, reg); } @@ -81,6 +84,14 @@ U1 IREmitter::GetPred(IR::Pred pred, bool is_negated) { } } +U1 IREmitter::GetGotoVariable(u32 id) { + return Inst(Opcode::GetGotoVariable, id); +} + +void IREmitter::SetGotoVariable(u32 id, const U1& value) { + Inst(Opcode::SetGotoVariable, id, value); +} + void IREmitter::SetPred(IR::Pred pred, const U1& value) { Inst(Opcode::SetPred, pred, value); } @@ -121,6 +132,20 @@ void IREmitter::SetOFlag(const U1& value) { Inst(Opcode::SetOFlag, value); } +U1 IREmitter::Condition(IR::Condition cond) { + if (cond == IR::Condition{true}) { + return Imm1(true); + } else if (cond == IR::Condition{false}) { + return Imm1(false); + } + const FlowTest flow_test{cond.FlowTest()}; + const auto [pred, is_negated]{cond.Pred()}; + if (flow_test == FlowTest::T) { + return GetPred(pred, is_negated); + } + throw NotImplementedException("Condition {}", cond); +} + F32 IREmitter::GetAttribute(IR::Attribute attribute) { return Inst(Opcode::GetAttribute, attribute); } diff --git a/src/shader_recompiler/frontend/ir/ir_emitter.h b/src/shader_recompiler/frontend/ir/ir_emitter.h index bfd9916cc..4decb46bc 100644 --- a/src/shader_recompiler/frontend/ir/ir_emitter.h +++ b/src/shader_recompiler/frontend/ir/ir_emitter.h @@ -16,11 +16,11 @@ namespace Shader::IR { class IREmitter { public: - explicit IREmitter(Block& block_) : block{block_}, insertion_point{block.end()} {} + explicit IREmitter(Block& block_) : block{&block_}, insertion_point{block->end()} {} explicit IREmitter(Block& block_, Block::iterator insertion_point_) - : block{block_}, insertion_point{insertion_point_} {} + : block{&block_}, insertion_point{insertion_point_} {} - Block& block; + Block* block; [[nodiscard]] U1 Imm1(bool value) const; [[nodiscard]] U8 Imm8(u8 value) const; @@ -31,11 +31,11 @@ public: [[nodiscard]] U64 Imm64(u64 value) const; [[nodiscard]] F64 Imm64(f64 value) const; - void Branch(IR::Block* label); - void BranchConditional(const U1& cond, IR::Block* true_label, IR::Block* false_label); - void Exit(); + void Branch(Block* label); + void BranchConditional(const U1& condition, Block* true_label, Block* false_label); + void LoopMerge(Block* merge_block, Block* continue_target); + void SelectionMerge(Block* merge_block); void Return(); - void Unreachable(); [[nodiscard]] U32 GetReg(IR::Reg reg); void SetReg(IR::Reg reg, const U32& value); @@ -43,6 +43,9 @@ public: [[nodiscard]] U1 GetPred(IR::Pred pred, bool is_negated = false); void SetPred(IR::Pred pred, const U1& value); + [[nodiscard]] U1 GetGotoVariable(u32 id); + void SetGotoVariable(u32 id, const U1& value); + [[nodiscard]] U32 GetCbuf(const U32& binding, const U32& byte_offset); [[nodiscard]] U1 GetZFlag(); @@ -55,6 +58,8 @@ public: void SetCFlag(const U1& value); void SetOFlag(const U1& value); + [[nodiscard]] U1 Condition(IR::Condition cond); + [[nodiscard]] F32 GetAttribute(IR::Attribute attribute); void SetAttribute(IR::Attribute attribute, const F32& value); @@ -168,7 +173,7 @@ private: template T Inst(Opcode op, Args... args) { - auto it{block.PrependNewInst(insertion_point, op, {Value{args}...})}; + auto it{block->PrependNewInst(insertion_point, op, {Value{args}...})}; return T{Value{&*it}}; } @@ -184,7 +189,7 @@ private: T Inst(Opcode op, Flags flags, Args... args) { u64 raw_flags{}; std::memcpy(&raw_flags, &flags.proxy, sizeof(flags.proxy)); - auto it{block.PrependNewInst(insertion_point, op, {Value{args}...}, raw_flags)}; + auto it{block->PrependNewInst(insertion_point, op, {Value{args}...}, raw_flags)}; return T{Value{&*it}}; } }; diff --git a/src/shader_recompiler/frontend/ir/microinstruction.cpp b/src/shader_recompiler/frontend/ir/microinstruction.cpp index e7ca92039..b4ae371bd 100644 --- a/src/shader_recompiler/frontend/ir/microinstruction.cpp +++ b/src/shader_recompiler/frontend/ir/microinstruction.cpp @@ -51,9 +51,9 @@ bool Inst::MayHaveSideEffects() const noexcept { switch (op) { case Opcode::Branch: case Opcode::BranchConditional: - case Opcode::Exit: + case Opcode::LoopMerge: + case Opcode::SelectionMerge: case Opcode::Return: - case Opcode::Unreachable: case Opcode::SetAttribute: case Opcode::SetAttributeIndexed: case Opcode::WriteGlobalU8: diff --git a/src/shader_recompiler/frontend/ir/opcodes.inc b/src/shader_recompiler/frontend/ir/opcodes.inc index 5dc65f2df..ede5e20c2 100644 --- a/src/shader_recompiler/frontend/ir/opcodes.inc +++ b/src/shader_recompiler/frontend/ir/opcodes.inc @@ -10,15 +10,17 @@ OPCODE(Identity, Opaque, Opaq // Control flow OPCODE(Branch, Void, Label, ) OPCODE(BranchConditional, Void, U1, Label, Label, ) -OPCODE(Exit, Void, ) +OPCODE(LoopMerge, Void, Label, Label, ) +OPCODE(SelectionMerge, Void, Label, ) OPCODE(Return, Void, ) -OPCODE(Unreachable, Void, ) // Context getters/setters OPCODE(GetRegister, U32, Reg, ) OPCODE(SetRegister, Void, Reg, U32, ) OPCODE(GetPred, U1, Pred, ) OPCODE(SetPred, Void, Pred, U1, ) +OPCODE(GetGotoVariable, U1, U32, ) +OPCODE(SetGotoVariable, Void, U32, U1, ) OPCODE(GetCbuf, U32, U32, U32, ) OPCODE(GetAttribute, U32, Attribute, ) OPCODE(SetAttribute, Void, Attribute, U32, ) @@ -36,11 +38,11 @@ OPCODE(WorkgroupId, U32x3, OPCODE(LocalInvocationId, U32x3, ) // Undefined -OPCODE(Undef1, U1, ) -OPCODE(Undef8, U8, ) -OPCODE(Undef16, U16, ) -OPCODE(Undef32, U32, ) -OPCODE(Undef64, U64, ) +OPCODE(UndefU1, U1, ) +OPCODE(UndefU8, U8, ) +OPCODE(UndefU16, U16, ) +OPCODE(UndefU32, U32, ) +OPCODE(UndefU64, U64, ) // Memory operations OPCODE(LoadGlobalU8, U32, U64, ) diff --git a/src/shader_recompiler/frontend/ir/structured_control_flow.cpp b/src/shader_recompiler/frontend/ir/structured_control_flow.cpp new file mode 100644 index 000000000..2e9ce2525 --- /dev/null +++ b/src/shader_recompiler/frontend/ir/structured_control_flow.cpp @@ -0,0 +1,742 @@ +// Copyright 2021 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include "shader_recompiler/frontend/ir/basic_block.h" +#include "shader_recompiler/frontend/ir/ir_emitter.h" +#include "shader_recompiler/object_pool.h" + +namespace Shader::IR { +namespace { +struct Statement; + +// Use normal_link because we are not guaranteed to destroy the tree in order +using ListBaseHook = + boost::intrusive::list_base_hook>; + +using Tree = boost::intrusive::list, + // Avoid linear complexity on splice, size is never called + boost::intrusive::constant_time_size>; +using Node = Tree::iterator; +using ConstNode = Tree::const_iterator; + +enum class StatementType { + Code, + Goto, + Label, + If, + Loop, + Break, + Return, + Function, + Identity, + Not, + Or, + SetVariable, + Variable, +}; + +bool HasChildren(StatementType type) { + switch (type) { + case StatementType::If: + case StatementType::Loop: + case StatementType::Function: + return true; + default: + return false; + } +} + +struct Goto {}; +struct Label {}; +struct If {}; +struct Loop {}; +struct Break {}; +struct Return {}; +struct FunctionTag {}; +struct Identity {}; +struct Not {}; +struct Or {}; +struct SetVariable {}; +struct Variable {}; + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 26495) // Always initialize a member variable, expected in Statement +#endif +struct Statement : ListBaseHook { + Statement(Block* code_, Statement* up_) : code{code_}, up{up_}, type{StatementType::Code} {} + Statement(Goto, Statement* cond_, Node label_, Statement* up_) + : label{label_}, cond{cond_}, up{up_}, type{StatementType::Goto} {} + Statement(Label, u32 id_, Statement* up_) : id{id_}, up{up_}, type{StatementType::Label} {} + Statement(If, Statement* cond_, Tree&& children_, Statement* up_) + : children{std::move(children_)}, cond{cond_}, up{up_}, type{StatementType::If} {} + Statement(Loop, Statement* cond_, Tree&& children_, Statement* up_) + : children{std::move(children_)}, cond{cond_}, up{up_}, type{StatementType::Loop} {} + Statement(Break, Statement* cond_, Statement* up_) + : cond{cond_}, up{up_}, type{StatementType::Break} {} + Statement(Return) : type{StatementType::Return} {} + Statement(FunctionTag) : children{}, type{StatementType::Function} {} + Statement(Identity, Condition cond_) : guest_cond{cond_}, type{StatementType::Identity} {} + Statement(Not, Statement* op_) : op{op_}, type{StatementType::Not} {} + Statement(Or, Statement* op_a_, Statement* op_b_) + : op_a{op_a_}, op_b{op_b_}, type{StatementType::Or} {} + Statement(SetVariable, u32 id_, Statement* op_, Statement* up_) + : op{op_}, id{id_}, up{up_}, type{StatementType::SetVariable} {} + Statement(Variable, u32 id_) : id{id_}, type{StatementType::Variable} {} + + ~Statement() { + if (HasChildren(type)) { + std::destroy_at(&children); + } + } + + union { + Block* code; + Node label; + Tree children; + Condition guest_cond; + Statement* op; + Statement* op_a; + }; + union { + Statement* cond; + Statement* op_b; + u32 id; + }; + Statement* up{}; + StatementType type; +}; +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +std::string DumpExpr(const Statement* stmt) { + switch (stmt->type) { + case StatementType::Identity: + return fmt::format("{}", stmt->guest_cond); + case StatementType::Not: + return fmt::format("!{}", DumpExpr(stmt->op)); + case StatementType::Or: + return fmt::format("{} || {}", DumpExpr(stmt->op_a), DumpExpr(stmt->op_b)); + case StatementType::Variable: + return fmt::format("goto_L{}", stmt->id); + default: + return ""; + } +} + +std::string DumpTree(const Tree& tree, u32 indentation = 0) { + std::string ret; + std::string indent(indentation, ' '); + for (auto stmt = tree.begin(); stmt != tree.end(); ++stmt) { + switch (stmt->type) { + case StatementType::Code: + ret += fmt::format("{} Block {:04x};\n", indent, stmt->code->LocationBegin()); + break; + case StatementType::Goto: + ret += fmt::format("{} if ({}) goto L{};\n", indent, DumpExpr(stmt->cond), + stmt->label->id); + break; + case StatementType::Label: + ret += fmt::format("{}L{}:\n", indent, stmt->id); + break; + case StatementType::If: + ret += fmt::format("{} if ({}) {{\n", indent, DumpExpr(stmt->cond)); + ret += DumpTree(stmt->children, indentation + 4); + ret += fmt::format("{} }}\n", indent); + break; + case StatementType::Loop: + ret += fmt::format("{} do {{\n", indent); + ret += DumpTree(stmt->children, indentation + 4); + ret += fmt::format("{} }} while ({});\n", indent, DumpExpr(stmt->cond)); + break; + case StatementType::Break: + ret += fmt::format("{} if ({}) break;\n", indent, DumpExpr(stmt->cond)); + break; + case StatementType::Return: + ret += fmt::format("{} return;\n", indent); + break; + case StatementType::SetVariable: + ret += fmt::format("{} goto_L{} = {};\n", indent, stmt->id, DumpExpr(stmt->op)); + break; + case StatementType::Function: + case StatementType::Identity: + case StatementType::Not: + case StatementType::Or: + case StatementType::Variable: + throw LogicError("Statement can't be printed"); + } + } + return ret; +} + +bool HasNode(const Tree& tree, ConstNode stmt) { + const auto end{tree.end()}; + for (auto it = tree.begin(); it != end; ++it) { + if (it == stmt || (HasChildren(it->type) && HasNode(it->children, stmt))) { + return true; + } + } + return false; +} + +Node FindStatementWithLabel(Tree& tree, ConstNode goto_stmt) { + const ConstNode label_stmt{goto_stmt->label}; + const ConstNode end{tree.end()}; + for (auto it = tree.begin(); it != end; ++it) { + if (it == label_stmt || (HasChildren(it->type) && HasNode(it->children, label_stmt))) { + return it; + } + } + throw LogicError("Lift label not in tree"); +} + +void SanitizeNoBreaks(const Tree& tree) { + if (std::ranges::find(tree, StatementType::Break, &Statement::type) != tree.end()) { + throw NotImplementedException("Capturing statement with break nodes"); + } +} + +size_t Level(Node stmt) { + size_t level{0}; + Statement* node{stmt->up}; + while (node) { + ++level; + node = node->up; + } + return level; +} + +bool IsDirectlyRelated(Node goto_stmt, Node label_stmt) { + const size_t goto_level{Level(goto_stmt)}; + const size_t label_level{Level(label_stmt)}; + size_t min_level; + size_t max_level; + Node min; + Node max; + if (label_level < goto_level) { + min_level = label_level; + max_level = goto_level; + min = label_stmt; + max = goto_stmt; + } else { // goto_level < label_level + min_level = goto_level; + max_level = label_level; + min = goto_stmt; + max = label_stmt; + } + while (max_level > min_level) { + --max_level; + max = max->up; + } + return min->up == max->up; +} + +bool IsIndirectlyRelated(Node goto_stmt, Node label_stmt) { + return goto_stmt->up != label_stmt->up && !IsDirectlyRelated(goto_stmt, label_stmt); +} + +bool SearchNode(const Tree& tree, ConstNode stmt, size_t& offset) { + ++offset; + + const auto end = tree.end(); + for (ConstNode it = tree.begin(); it != end; ++it) { + ++offset; + if (stmt == it) { + return true; + } + if (HasChildren(it->type) && SearchNode(it->children, stmt, offset)) { + return true; + } + } + return false; +} + +class GotoPass { +public: + explicit GotoPass(std::span blocks, ObjectPool& stmt_pool) + : pool{stmt_pool} { + std::vector gotos{BuildUnorderedTreeGetGotos(blocks)}; + fmt::print(stdout, "BEFORE\n{}\n", DumpTree(root_stmt.children)); + for (const Node& goto_stmt : gotos | std::views::reverse) { + RemoveGoto(goto_stmt); + } + fmt::print(stdout, "AFTER\n{}\n", DumpTree(root_stmt.children)); + } + + Statement& RootStatement() noexcept { + return root_stmt; + } + +private: + void RemoveGoto(Node goto_stmt) { + // Force goto_stmt and label_stmt to be directly related + const Node label_stmt{goto_stmt->label}; + if (IsIndirectlyRelated(goto_stmt, label_stmt)) { + // Move goto_stmt out using outward-movement transformation until it becomes + // directly related to label_stmt + while (!IsDirectlyRelated(goto_stmt, label_stmt)) { + goto_stmt = MoveOutward(goto_stmt); + } + } + // Force goto_stmt and label_stmt to be siblings + if (IsDirectlyRelated(goto_stmt, label_stmt)) { + const size_t label_level{Level(label_stmt)}; + size_t goto_level{Level(goto_stmt)}; + if (goto_level > label_level) { + // Move goto_stmt out of its level using outward-movement transformations + while (goto_level > label_level) { + goto_stmt = MoveOutward(goto_stmt); + --goto_level; + } + } else { // Level(goto_stmt) < Level(label_stmt) + if (Offset(goto_stmt) > Offset(label_stmt)) { + // Lift goto_stmt to above stmt containing label_stmt using goto-lifting + // transformations + goto_stmt = Lift(goto_stmt); + } + // Move goto_stmt into label_stmt's level using inward-movement transformation + while (goto_level < label_level) { + goto_stmt = MoveInward(goto_stmt); + ++goto_level; + } + } + } + // TODO: Remove this + Node it{goto_stmt}; + bool sibling{false}; + do { + sibling |= it == label_stmt; + --it; + } while (it != goto_stmt->up->children.begin()); + while (it != goto_stmt->up->children.end()) { + sibling |= it == label_stmt; + ++it; + } + if (!sibling) { + throw LogicError("Not siblings"); + } + + // goto_stmt and label_stmt are guaranteed to be siblings, eliminate + if (std::next(goto_stmt) == label_stmt) { + // Simply eliminate the goto if the label is next to it + goto_stmt->up->children.erase(goto_stmt); + } else if (Offset(goto_stmt) < Offset(label_stmt)) { + // Eliminate goto_stmt with a conditional + EliminateAsConditional(goto_stmt, label_stmt); + } else { + // Eliminate goto_stmt with a loop + EliminateAsLoop(goto_stmt, label_stmt); + } + } + + std::vector BuildUnorderedTreeGetGotos(std::span blocks) { + // Assume all blocks have two branches + std::vector gotos; + gotos.reserve(blocks.size() * 2); + + const std::unordered_map labels_map{BuildLabels(blocks)}; + Tree& root{root_stmt.children}; + auto insert_point{root.begin()}; + for (Block* const block : blocks) { + ++insert_point; // Skip label + ++insert_point; // Skip set variable + root.insert(insert_point, *pool.Create(block, &root_stmt)); + + if (block->IsTerminationBlock()) { + root.insert(insert_point, *pool.Create(Return{})); + continue; + } + const Condition cond{block->BranchCondition()}; + Statement* const true_cond{pool.Create(Identity{}, Condition{true})}; + if (cond == Condition{true} || cond == Condition{false}) { + const bool is_true{cond == Condition{true}}; + const Block* const branch{is_true ? block->TrueBranch() : block->FalseBranch()}; + const Node label{labels_map.at(branch)}; + Statement* const goto_stmt{pool.Create(Goto{}, true_cond, label, &root_stmt)}; + gotos.push_back(root.insert(insert_point, *goto_stmt)); + } else { + Statement* const ident_cond{pool.Create(Identity{}, cond)}; + const Node true_label{labels_map.at(block->TrueBranch())}; + const Node false_label{labels_map.at(block->FalseBranch())}; + Statement* goto_true{pool.Create(Goto{}, ident_cond, true_label, &root_stmt)}; + Statement* goto_false{pool.Create(Goto{}, true_cond, false_label, &root_stmt)}; + gotos.push_back(root.insert(insert_point, *goto_true)); + gotos.push_back(root.insert(insert_point, *goto_false)); + } + } + return gotos; + } + + std::unordered_map BuildLabels(std::span blocks) { + // TODO: Consider storing labels intrusively inside the block + std::unordered_map labels_map; + Tree& root{root_stmt.children}; + u32 label_id{0}; + for (const Block* const block : blocks) { + Statement* const label{pool.Create(Label{}, label_id, &root_stmt)}; + labels_map.emplace(block, root.insert(root.end(), *label)); + Statement* const false_stmt{pool.Create(Identity{}, Condition{false})}; + root.push_back(*pool.Create(SetVariable{}, label_id, false_stmt, &root_stmt)); + ++label_id; + } + return labels_map; + } + + void UpdateTreeUp(Statement* tree) { + for (Statement& stmt : tree->children) { + stmt.up = tree; + } + } + + void EliminateAsConditional(Node goto_stmt, Node label_stmt) { + Tree& body{goto_stmt->up->children}; + Tree if_body; + if_body.splice(if_body.begin(), body, std::next(goto_stmt), label_stmt); + Statement* const cond{pool.Create(Not{}, goto_stmt->cond)}; + Statement* const if_stmt{pool.Create(If{}, cond, std::move(if_body), goto_stmt->up)}; + UpdateTreeUp(if_stmt); + body.insert(goto_stmt, *if_stmt); + body.erase(goto_stmt); + } + + void EliminateAsLoop(Node goto_stmt, Node label_stmt) { + Tree& body{goto_stmt->up->children}; + Tree loop_body; + loop_body.splice(loop_body.begin(), body, label_stmt, goto_stmt); + Statement* const cond{goto_stmt->cond}; + Statement* const loop{pool.Create(Loop{}, cond, std::move(loop_body), goto_stmt->up)}; + UpdateTreeUp(loop); + body.insert(goto_stmt, *loop); + body.erase(goto_stmt); + } + + [[nodiscard]] Node MoveOutward(Node goto_stmt) { + switch (goto_stmt->up->type) { + case StatementType::If: + return MoveOutwardIf(goto_stmt); + case StatementType::Loop: + return MoveOutwardLoop(goto_stmt); + default: + throw LogicError("Invalid outward movement"); + } + } + + [[nodiscard]] Node MoveInward(Node goto_stmt) { + Statement* const parent{goto_stmt->up}; + Tree& body{parent->children}; + const Node label_nested_stmt{FindStatementWithLabel(body, goto_stmt)}; + const Node label{goto_stmt->label}; + const u32 label_id{label->id}; + + Statement* const goto_cond{goto_stmt->cond}; + Statement* const set_var{pool.Create(SetVariable{}, label_id, goto_cond, parent)}; + body.insert(goto_stmt, *set_var); + + Tree if_body; + if_body.splice(if_body.begin(), body, std::next(goto_stmt), label_nested_stmt); + Statement* const variable{pool.Create(Variable{}, label_id)}; + Statement* const neg_var{pool.Create(Not{}, variable)}; + if (!if_body.empty()) { + Statement* const if_stmt{pool.Create(If{}, neg_var, std::move(if_body), parent)}; + UpdateTreeUp(if_stmt); + body.insert(goto_stmt, *if_stmt); + } + body.erase(goto_stmt); + + // Update nested if condition + switch (label_nested_stmt->type) { + case StatementType::If: + label_nested_stmt->cond = pool.Create(Or{}, neg_var, label_nested_stmt->cond); + break; + case StatementType::Loop: + break; + default: + throw LogicError("Invalid inward movement"); + } + Tree& nested_tree{label_nested_stmt->children}; + Statement* const new_goto{pool.Create(Goto{}, variable, label, &*label_nested_stmt)}; + return nested_tree.insert(nested_tree.begin(), *new_goto); + } + + [[nodiscard]] Node Lift(Node goto_stmt) { + Statement* const parent{goto_stmt->up}; + Tree& body{parent->children}; + const Node label{goto_stmt->label}; + const u32 label_id{label->id}; + const Node label_nested_stmt{FindStatementWithLabel(body, goto_stmt)}; + const auto type{label_nested_stmt->type}; + + Tree loop_body; + loop_body.splice(loop_body.begin(), body, label_nested_stmt, goto_stmt); + SanitizeNoBreaks(loop_body); + Statement* const variable{pool.Create(Variable{}, label_id)}; + Statement* const loop_stmt{pool.Create(Loop{}, variable, std::move(loop_body), parent)}; + UpdateTreeUp(loop_stmt); + const Node loop_node{body.insert(goto_stmt, *loop_stmt)}; + + Statement* const new_goto{pool.Create(Goto{}, variable, label, loop_stmt)}; + loop_stmt->children.push_front(*new_goto); + const Node new_goto_node{loop_stmt->children.begin()}; + + Statement* const set_var{pool.Create(SetVariable{}, label_id, goto_stmt->cond, loop_stmt)}; + loop_stmt->children.push_back(*set_var); + + body.erase(goto_stmt); + return new_goto_node; + } + + Node MoveOutwardIf(Node goto_stmt) { + const Node parent{Tree::s_iterator_to(*goto_stmt->up)}; + Tree& body{parent->children}; + const u32 label_id{goto_stmt->label->id}; + Statement* const goto_cond{goto_stmt->cond}; + Statement* const set_goto_var{pool.Create(SetVariable{}, label_id, goto_cond, &*parent)}; + body.insert(goto_stmt, *set_goto_var); + + Tree if_body; + if_body.splice(if_body.begin(), body, std::next(goto_stmt), body.end()); + if_body.pop_front(); + Statement* const cond{pool.Create(Variable{}, label_id)}; + Statement* const neg_cond{pool.Create(Not{}, cond)}; + Statement* const if_stmt{pool.Create(If{}, neg_cond, std::move(if_body), &*parent)}; + UpdateTreeUp(if_stmt); + body.insert(goto_stmt, *if_stmt); + + body.erase(goto_stmt); + + Statement* const new_cond{pool.Create(Variable{}, label_id)}; + Statement* const new_goto{pool.Create(Goto{}, new_cond, goto_stmt->label, parent->up)}; + Tree& parent_tree{parent->up->children}; + return parent_tree.insert(std::next(parent), *new_goto); + } + + Node MoveOutwardLoop(Node goto_stmt) { + Statement* const parent{goto_stmt->up}; + Tree& body{parent->children}; + const u32 label_id{goto_stmt->label->id}; + Statement* const goto_cond{goto_stmt->cond}; + Statement* const set_goto_var{pool.Create(SetVariable{}, label_id, goto_cond, parent)}; + Statement* const cond{pool.Create(Variable{}, label_id)}; + Statement* const break_stmt{pool.Create(Break{}, cond, parent)}; + body.insert(goto_stmt, *set_goto_var); + body.insert(goto_stmt, *break_stmt); + body.erase(goto_stmt); + + const Node loop{Tree::s_iterator_to(*goto_stmt->up)}; + Statement* const new_goto_cond{pool.Create(Variable{}, label_id)}; + Statement* const new_goto{pool.Create(Goto{}, new_goto_cond, goto_stmt->label, loop->up)}; + Tree& parent_tree{loop->up->children}; + return parent_tree.insert(std::next(loop), *new_goto); + } + + size_t Offset(ConstNode stmt) const { + size_t offset{0}; + if (!SearchNode(root_stmt.children, stmt, offset)) { + fmt::print(stdout, "{}\n", DumpTree(root_stmt.children)); + throw LogicError("Node not found in tree"); + } + return offset; + } + + ObjectPool& pool; + Statement root_stmt{FunctionTag{}}; +}; + +Block* TryFindForwardBlock(const Statement& stmt) { + const Tree& tree{stmt.up->children}; + const ConstNode end{tree.cend()}; + ConstNode forward_node{std::next(Tree::s_iterator_to(stmt))}; + while (forward_node != end && !HasChildren(forward_node->type)) { + if (forward_node->type == StatementType::Code) { + return forward_node->code; + } + ++forward_node; + } + return nullptr; +} + +[[nodiscard]] U1 VisitExpr(IREmitter& ir, const Statement& stmt) { + switch (stmt.type) { + case StatementType::Identity: + return ir.Condition(stmt.guest_cond); + case StatementType::Not: + return ir.LogicalNot(U1{VisitExpr(ir, *stmt.op)}); + case StatementType::Or: + return ir.LogicalOr(VisitExpr(ir, *stmt.op_a), VisitExpr(ir, *stmt.op_b)); + case StatementType::Variable: + return ir.GetGotoVariable(stmt.id); + default: + throw NotImplementedException("Statement type {}", stmt.type); + } +} + +class TranslatePass { +public: + TranslatePass(ObjectPool& inst_pool_, ObjectPool& block_pool_, + ObjectPool& stmt_pool_, Statement& root_stmt, + const std::function& func_, BlockList& block_list_) + : stmt_pool{stmt_pool_}, inst_pool{inst_pool_}, block_pool{block_pool_}, func{func_}, + block_list{block_list_} { + Visit(root_stmt, nullptr, nullptr); + } + +private: + void Visit(Statement& parent, Block* continue_block, Block* break_block) { + Tree& tree{parent.children}; + Block* current_block{nullptr}; + + for (auto it = tree.begin(); it != tree.end(); ++it) { + Statement& stmt{*it}; + switch (stmt.type) { + case StatementType::Label: + // Labels can be ignored + break; + case StatementType::Code: { + if (current_block && current_block != stmt.code) { + IREmitter ir{*current_block}; + ir.Branch(stmt.code); + } + current_block = stmt.code; + func(stmt.code); + block_list.push_back(stmt.code); + break; + } + case StatementType::SetVariable: { + if (!current_block) { + current_block = MergeBlock(parent, stmt); + } + IREmitter ir{*current_block}; + ir.SetGotoVariable(stmt.id, VisitExpr(ir, *stmt.op)); + break; + } + case StatementType::If: { + if (!current_block) { + current_block = block_pool.Create(inst_pool); + block_list.push_back(current_block); + } + Block* const merge_block{MergeBlock(parent, stmt)}; + + // Visit children + const size_t first_block_index{block_list.size()}; + Visit(stmt, merge_block, break_block); + + // Implement if header block + Block* const first_if_block{block_list.at(first_block_index)}; + IREmitter ir{*current_block}; + const U1 cond{VisitExpr(ir, *stmt.cond)}; + ir.SelectionMerge(merge_block); + ir.BranchConditional(cond, first_if_block, merge_block); + + current_block = merge_block; + break; + } + case StatementType::Loop: { + Block* const loop_header_block{block_pool.Create(inst_pool)}; + if (current_block) { + IREmitter{*current_block}.Branch(loop_header_block); + } + block_list.push_back(loop_header_block); + + Block* const new_continue_block{block_pool.Create(inst_pool)}; + Block* const merge_block{MergeBlock(parent, stmt)}; + + // Visit children + const size_t first_block_index{block_list.size()}; + Visit(stmt, new_continue_block, merge_block); + + // The continue block is located at the end of the loop + block_list.push_back(new_continue_block); + + // Implement loop header block + Block* const first_loop_block{block_list.at(first_block_index)}; + IREmitter ir{*loop_header_block}; + ir.LoopMerge(merge_block, new_continue_block); + ir.Branch(first_loop_block); + + // Implement continue block + IREmitter continue_ir{*new_continue_block}; + const U1 continue_cond{VisitExpr(continue_ir, *stmt.cond)}; + continue_ir.BranchConditional(continue_cond, ir.block, merge_block); + + current_block = merge_block; + break; + } + case StatementType::Break: { + if (!current_block) { + current_block = block_pool.Create(inst_pool); + block_list.push_back(current_block); + } + Block* const skip_block{MergeBlock(parent, stmt)}; + + IREmitter ir{*current_block}; + ir.BranchConditional(VisitExpr(ir, *stmt.cond), break_block, skip_block); + + current_block = skip_block; + break; + } + case StatementType::Return: { + if (!current_block) { + current_block = block_pool.Create(inst_pool); + block_list.push_back(current_block); + } + IREmitter{*current_block}.Return(); + current_block = nullptr; + break; + } + default: + throw NotImplementedException("Statement type {}", stmt.type); + } + } + if (current_block && continue_block) { + IREmitter ir{*current_block}; + ir.Branch(continue_block); + } + } + + Block* MergeBlock(Statement& parent, Statement& stmt) { + if (Block* const block{TryFindForwardBlock(stmt)}) { + return block; + } + // Create a merge block we can visit later + Block* const block{block_pool.Create(inst_pool)}; + Statement* const merge_stmt{stmt_pool.Create(block, &parent)}; + parent.children.insert(std::next(Tree::s_iterator_to(stmt)), *merge_stmt); + return block; + } + + ObjectPool& stmt_pool; + ObjectPool& inst_pool; + ObjectPool& block_pool; + const std::function& func; + BlockList& block_list; +}; +} // Anonymous namespace + +BlockList VisitAST(ObjectPool& inst_pool, ObjectPool& block_pool, + std::span unordered_blocks, + const std::function& func) { + ObjectPool stmt_pool; + GotoPass goto_pass{unordered_blocks, stmt_pool}; + BlockList block_list; + TranslatePass translate_pass{inst_pool, block_pool, stmt_pool, goto_pass.RootStatement(), + func, block_list}; + return block_list; +} + +} // namespace Shader::IR diff --git a/src/shader_recompiler/frontend/ir/structured_control_flow.h b/src/shader_recompiler/frontend/ir/structured_control_flow.h new file mode 100644 index 000000000..a574c24f7 --- /dev/null +++ b/src/shader_recompiler/frontend/ir/structured_control_flow.h @@ -0,0 +1,22 @@ +// Copyright 2021 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include + +#include + +#include "shader_recompiler/frontend/ir/basic_block.h" +#include "shader_recompiler/frontend/ir/microinstruction.h" +#include "shader_recompiler/object_pool.h" + +namespace Shader::IR { + +[[nodiscard]] BlockList VisitAST(ObjectPool& inst_pool, ObjectPool& block_pool, + std::span unordered_blocks, + const std::function& func); + +} // namespace Shader::IR diff --git a/src/shader_recompiler/frontend/maxwell/control_flow.cpp b/src/shader_recompiler/frontend/maxwell/control_flow.cpp index 21ee98137..e766b555b 100644 --- a/src/shader_recompiler/frontend/maxwell/control_flow.cpp +++ b/src/shader_recompiler/frontend/maxwell/control_flow.cpp @@ -17,38 +17,49 @@ #include "shader_recompiler/frontend/maxwell/location.h" namespace Shader::Maxwell::Flow { +namespace { +struct Compare { + bool operator()(const Block& lhs, Location rhs) const noexcept { + return lhs.begin < rhs; + } + + bool operator()(Location lhs, const Block& rhs) const noexcept { + return lhs < rhs.begin; + } + + bool operator()(const Block& lhs, const Block& rhs) const noexcept { + return lhs.begin < rhs.begin; + } +}; +} // Anonymous namespace static u32 BranchOffset(Location pc, Instruction inst) { return pc.Offset() + inst.branch.Offset() + 8; } -static std::array Split(Block&& block, Location pc, BlockId new_id) { - if (pc <= block.begin || pc >= block.end) { +static void Split(Block* old_block, Block* new_block, Location pc) { + if (pc <= old_block->begin || pc >= old_block->end) { throw InvalidArgument("Invalid address to split={}", pc); } - return { - Block{ - .begin{block.begin}, - .end{pc}, - .end_class{EndClass::Branch}, - .id{block.id}, - .stack{block.stack}, - .cond{true}, - .branch_true{new_id}, - .branch_false{UNREACHABLE_BLOCK_ID}, - .imm_predecessors{}, - }, - Block{ - .begin{pc}, - .end{block.end}, - .end_class{block.end_class}, - .id{new_id}, - .stack{std::move(block.stack)}, - .cond{block.cond}, - .branch_true{block.branch_true}, - .branch_false{block.branch_false}, - .imm_predecessors{}, - }, + *new_block = Block{ + .begin{pc}, + .end{old_block->end}, + .end_class{old_block->end_class}, + .stack{old_block->stack}, + .cond{old_block->cond}, + .branch_true{old_block->branch_true}, + .branch_false{old_block->branch_false}, + .ir{nullptr}, + }; + *old_block = Block{ + .begin{old_block->begin}, + .end{pc}, + .end_class{EndClass::Branch}, + .stack{std::move(old_block->stack)}, + .cond{IR::Condition{true}}, + .branch_true{new_block}, + .branch_false{nullptr}, + .ir{nullptr}, }; } @@ -112,7 +123,7 @@ static bool HasFlowTest(Opcode opcode) { static std::string NameOf(const Block& block) { if (block.begin.IsVirtual()) { - return fmt::format("\"Virtual {}\"", block.id); + return fmt::format("\"Virtual {}\"", block.begin); } else { return fmt::format("\"{}\"", block.begin); } @@ -158,126 +169,23 @@ bool Block::Contains(Location pc) const noexcept { Function::Function(Location start_address) : entrypoint{start_address}, labels{{ .address{start_address}, - .block_id{0}, + .block{nullptr}, .stack{}, }} {} -void Function::BuildBlocksMap() { - const size_t num_blocks{NumBlocks()}; - blocks_map.resize(num_blocks); - for (size_t block_index = 0; block_index < num_blocks; ++block_index) { - Block& block{blocks_data[block_index]}; - blocks_map[block.id] = █ - } -} - -void Function::BuildImmediatePredecessors() { - for (const Block& block : blocks_data) { - if (block.branch_true != UNREACHABLE_BLOCK_ID) { - blocks_map[block.branch_true]->imm_predecessors.push_back(block.id); - } - if (block.branch_false != UNREACHABLE_BLOCK_ID) { - blocks_map[block.branch_false]->imm_predecessors.push_back(block.id); - } - } -} - -void Function::BuildPostOrder() { - boost::container::small_vector block_stack; - post_order_map.resize(NumBlocks()); - - Block& first_block{blocks_data[blocks.front()]}; - first_block.post_order_visited = true; - block_stack.push_back(first_block.id); - - const auto visit_branch = [&](BlockId block_id, BlockId branch_id) { - if (branch_id == UNREACHABLE_BLOCK_ID) { - return false; - } - if (blocks_map[branch_id]->post_order_visited) { - return false; - } - blocks_map[branch_id]->post_order_visited = true; - - // Calling push_back twice is faster than insert on msvc - block_stack.push_back(block_id); - block_stack.push_back(branch_id); - return true; - }; - while (!block_stack.empty()) { - const Block* const block{blocks_map[block_stack.back()]}; - block_stack.pop_back(); - - if (!visit_branch(block->id, block->branch_true) && - !visit_branch(block->id, block->branch_false)) { - post_order_map[block->id] = static_cast(post_order_blocks.size()); - post_order_blocks.push_back(block->id); - } - } -} - -void Function::BuildImmediateDominators() { - auto transform_block_id{std::views::transform([this](BlockId id) { return blocks_map[id]; })}; - auto reverse_order_but_first{std::views::reverse | std::views::drop(1) | transform_block_id}; - auto has_idom{std::views::filter([](Block* block) { return block->imm_dominator; })}; - auto intersect{[this](Block* finger1, Block* finger2) { - while (finger1 != finger2) { - while (post_order_map[finger1->id] < post_order_map[finger2->id]) { - finger1 = finger1->imm_dominator; - } - while (post_order_map[finger2->id] < post_order_map[finger1->id]) { - finger2 = finger2->imm_dominator; - } - } - return finger1; - }}; - for (Block& block : blocks_data) { - block.imm_dominator = nullptr; - } - Block* const start_block{&blocks_data[blocks.front()]}; - start_block->imm_dominator = start_block; - - bool changed{true}; - while (changed) { - changed = false; - for (Block* const block : post_order_blocks | reverse_order_but_first) { - Block* new_idom{}; - for (Block* predecessor : block->imm_predecessors | transform_block_id | has_idom) { - new_idom = new_idom ? intersect(predecessor, new_idom) : predecessor; - } - changed |= block->imm_dominator != new_idom; - block->imm_dominator = new_idom; - } - } -} - -void Function::BuildDominanceFrontier() { - auto transform_block_id{std::views::transform([this](BlockId id) { return blocks_map[id]; })}; - auto has_enough_predecessors{[](Block& block) { return block.imm_predecessors.size() >= 2; }}; - for (Block& block : blocks_data | std::views::filter(has_enough_predecessors)) { - for (Block* current : block.imm_predecessors | transform_block_id) { - while (current != block.imm_dominator) { - current->dominance_frontiers.push_back(current->id); - current = current->imm_dominator; - } - } - } -} - -CFG::CFG(Environment& env_, Location start_address) : env{env_} { - VisitFunctions(start_address); - - for (Function& function : functions) { - function.BuildBlocksMap(); - function.BuildImmediatePredecessors(); - function.BuildPostOrder(); - function.BuildImmediateDominators(); - function.BuildDominanceFrontier(); - } -} - -void CFG::VisitFunctions(Location start_address) { +CFG::CFG(Environment& env_, ObjectPool& block_pool_, Location start_address) + : env{env_}, block_pool{block_pool_} { functions.emplace_back(start_address); + functions.back().labels.back().block = block_pool.Create(Block{ + .begin{start_address}, + .end{start_address}, + .end_class{EndClass::Branch}, + .stack{}, + .cond{IR::Condition{true}}, + .branch_true{nullptr}, + .branch_false{nullptr}, + .ir{nullptr}, + }); for (FunctionId function_id = 0; function_id < functions.size(); ++function_id) { while (!functions[function_id].labels.empty()) { Function& function{functions[function_id]}; @@ -294,35 +202,16 @@ void CFG::AnalyzeLabel(FunctionId function_id, Label& label) { return; } // Try to find the next block - Function* function{&functions[function_id]}; + Function* const function{&functions[function_id]}; Location pc{label.address}; - const auto next{std::upper_bound(function->blocks.begin(), function->blocks.end(), pc, - [function](Location pc, u32 block_index) { - return pc < function->blocks_data[block_index].begin; - })}; - const auto next_index{std::distance(function->blocks.begin(), next)}; - const bool is_last{next == function->blocks.end()}; - Location next_pc; - BlockId next_id{UNREACHABLE_BLOCK_ID}; - if (!is_last) { - next_pc = function->blocks_data[*next].begin; - next_id = function->blocks_data[*next].id; - } + const auto next_it{function->blocks.upper_bound(pc, Compare{})}; + const bool is_last{next_it == function->blocks.end()}; + Block* const next{is_last ? nullptr : &*next_it}; // Insert before the next block - Block block{ - .begin{pc}, - .end{pc}, - .end_class{EndClass::Branch}, - .id{label.block_id}, - .stack{std::move(label.stack)}, - .cond{true}, - .branch_true{UNREACHABLE_BLOCK_ID}, - .branch_false{UNREACHABLE_BLOCK_ID}, - .imm_predecessors{}, - }; + Block* const block{label.block}; // Analyze instructions until it reaches an already visited block or there's a branch bool is_branch{false}; - while (is_last || pc < next_pc) { + while (!next || pc < next->begin) { is_branch = AnalyzeInst(block, function_id, pc) == AnalysisState::Branch; if (is_branch) { break; @@ -332,43 +221,36 @@ void CFG::AnalyzeLabel(FunctionId function_id, Label& label) { if (!is_branch) { // If the block finished without a branch, // it means that the next instruction is already visited, jump to it - block.end = pc; - block.cond = true; - block.branch_true = next_id; - block.branch_false = UNREACHABLE_BLOCK_ID; + block->end = pc; + block->cond = IR::Condition{true}; + block->branch_true = next; + block->branch_false = nullptr; } // Function's pointer might be invalid, resolve it again - function = &functions[function_id]; - const u32 new_block_index = static_cast(function->blocks_data.size()); - function->blocks.insert(function->blocks.begin() + next_index, new_block_index); - function->blocks_data.push_back(std::move(block)); + // Insert the new block + functions[function_id].blocks.insert(*block); } bool CFG::InspectVisitedBlocks(FunctionId function_id, const Label& label) { const Location pc{label.address}; Function& function{functions[function_id]}; - const auto it{std::ranges::find_if(function.blocks, [&function, pc](u32 block_index) { - return function.blocks_data[block_index].Contains(pc); - })}; + const auto it{ + std::ranges::find_if(function.blocks, [pc](auto& block) { return block.Contains(pc); })}; if (it == function.blocks.end()) { // Address has not been visited return false; } - Block& block{function.blocks_data[*it]}; - if (block.begin == pc) { - throw LogicError("Dangling branch"); + Block* const visited_block{&*it}; + if (visited_block->begin == pc) { + throw LogicError("Dangling block"); } - const u32 first_index{*it}; - const u32 second_index{static_cast(function.blocks_data.size())}; - const std::array new_indices{first_index, second_index}; - std::array split_blocks{Split(std::move(block), pc, label.block_id)}; - function.blocks_data[*it] = std::move(split_blocks[0]); - function.blocks_data.push_back(std::move(split_blocks[1])); - function.blocks.insert(function.blocks.erase(it), new_indices.begin(), new_indices.end()); + Block* const new_block{label.block}; + Split(visited_block, new_block, pc); + function.blocks.insert(it, *new_block); return true; } -CFG::AnalysisState CFG::AnalyzeInst(Block& block, FunctionId function_id, Location pc) { +CFG::AnalysisState CFG::AnalyzeInst(Block* block, FunctionId function_id, Location pc) { const Instruction inst{env.ReadInstruction(pc.Offset())}; const Opcode opcode{Decode(inst.raw)}; switch (opcode) { @@ -390,12 +272,12 @@ CFG::AnalysisState CFG::AnalyzeInst(Block& block, FunctionId function_id, Locati AnalyzeBRX(block, pc, inst, IsAbsoluteJump(opcode)); break; case Opcode::RET: - block.end_class = EndClass::Return; + block->end_class = EndClass::Return; break; default: break; } - block.end = pc; + block->end = pc; return AnalysisState::Branch; case Opcode::BRK: case Opcode::CONT: @@ -404,9 +286,9 @@ CFG::AnalysisState CFG::AnalyzeInst(Block& block, FunctionId function_id, Locati if (!AnalyzeBranch(block, function_id, pc, inst, opcode)) { return AnalysisState::Continue; } - const auto [stack_pc, new_stack]{block.stack.Pop(OpcodeToken(opcode))}; - block.branch_true = AddLabel(block, new_stack, stack_pc, function_id); - block.end = pc; + const auto [stack_pc, new_stack]{block->stack.Pop(OpcodeToken(opcode))}; + block->branch_true = AddLabel(block, new_stack, stack_pc, function_id); + block->end = pc; return AnalysisState::Branch; } case Opcode::PBK: @@ -414,7 +296,7 @@ CFG::AnalysisState CFG::AnalyzeInst(Block& block, FunctionId function_id, Locati case Opcode::PEXIT: case Opcode::PLONGJMP: case Opcode::SSY: - block.stack.Push(OpcodeToken(opcode), BranchOffset(pc, inst)); + block->stack.Push(OpcodeToken(opcode), BranchOffset(pc, inst)); return AnalysisState::Continue; case Opcode::EXIT: return AnalyzeEXIT(block, function_id, pc, inst); @@ -444,51 +326,51 @@ CFG::AnalysisState CFG::AnalyzeInst(Block& block, FunctionId function_id, Locati return AnalysisState::Branch; } -void CFG::AnalyzeCondInst(Block& block, FunctionId function_id, Location pc, +void CFG::AnalyzeCondInst(Block* block, FunctionId function_id, Location pc, EndClass insn_end_class, IR::Condition cond) { - if (block.begin != pc) { + if (block->begin != pc) { // If the block doesn't start in the conditional instruction // mark it as a label to visit it later - block.end = pc; - block.cond = true; - block.branch_true = AddLabel(block, block.stack, pc, function_id); - block.branch_false = UNREACHABLE_BLOCK_ID; + block->end = pc; + block->cond = IR::Condition{true}; + block->branch_true = AddLabel(block, block->stack, pc, function_id); + block->branch_false = nullptr; return; } - // Impersonate the visited block with a virtual block - // Jump from this virtual to the real conditional instruction and the next instruction - Function& function{functions[function_id]}; - const BlockId conditional_block_id{++function.current_block_id}; - function.blocks.push_back(static_cast(function.blocks_data.size())); - Block& virtual_block{function.blocks_data.emplace_back(Block{ - .begin{}, // Virtual block - .end{}, + // Create a virtual block and a conditional block + Block* const conditional_block{block_pool.Create()}; + Block virtual_block{ + .begin{block->begin.Virtual()}, + .end{block->begin.Virtual()}, .end_class{EndClass::Branch}, - .id{block.id}, // Impersonating - .stack{block.stack}, + .stack{block->stack}, .cond{cond}, - .branch_true{conditional_block_id}, - .branch_false{UNREACHABLE_BLOCK_ID}, - .imm_predecessors{}, - })}; - // Set the end properties of the conditional instruction and give it a new identity - Block& conditional_block{block}; - conditional_block.end = pc; - conditional_block.end_class = insn_end_class; - conditional_block.id = conditional_block_id; + .branch_true{conditional_block}, + .branch_false{nullptr}, + .ir{nullptr}, + }; + // Save the contents of the visited block in the conditional block + *conditional_block = std::move(*block); + // Impersonate the visited block with a virtual block + *block = std::move(virtual_block); + // Set the end properties of the conditional instruction + conditional_block->end = pc; + conditional_block->end_class = insn_end_class; // Add a label to the instruction after the conditional instruction - const BlockId endif_block_id{AddLabel(conditional_block, block.stack, pc + 1, function_id)}; + Block* const endif_block{AddLabel(conditional_block, block->stack, pc + 1, function_id)}; // Branch to the next instruction from the virtual block - virtual_block.branch_false = endif_block_id; + block->branch_false = endif_block; // And branch to it from the conditional instruction if it is a branch if (insn_end_class == EndClass::Branch) { - conditional_block.cond = true; - conditional_block.branch_true = endif_block_id; - conditional_block.branch_false = UNREACHABLE_BLOCK_ID; + conditional_block->cond = IR::Condition{true}; + conditional_block->branch_true = endif_block; + conditional_block->branch_false = nullptr; } + // Finally insert the condition block into the list of blocks + functions[function_id].blocks.insert(*conditional_block); } -bool CFG::AnalyzeBranch(Block& block, FunctionId function_id, Location pc, Instruction inst, +bool CFG::AnalyzeBranch(Block* block, FunctionId function_id, Location pc, Instruction inst, Opcode opcode) { if (inst.branch.is_cbuf) { throw NotImplementedException("Branch with constant buffer offset"); @@ -500,21 +382,21 @@ bool CFG::AnalyzeBranch(Block& block, FunctionId function_id, Location pc, Instr const bool has_flow_test{HasFlowTest(opcode)}; const IR::FlowTest flow_test{has_flow_test ? inst.branch.flow_test.Value() : IR::FlowTest::T}; if (pred != Predicate{true} || flow_test != IR::FlowTest::T) { - block.cond = IR::Condition(flow_test, static_cast(pred.index), pred.negated); - block.branch_false = AddLabel(block, block.stack, pc + 1, function_id); + block->cond = IR::Condition(flow_test, static_cast(pred.index), pred.negated); + block->branch_false = AddLabel(block, block->stack, pc + 1, function_id); } else { - block.cond = true; + block->cond = IR::Condition{true}; } return true; } -void CFG::AnalyzeBRA(Block& block, FunctionId function_id, Location pc, Instruction inst, +void CFG::AnalyzeBRA(Block* block, FunctionId function_id, Location pc, Instruction inst, bool is_absolute) { const Location bra_pc{is_absolute ? inst.branch.Absolute() : BranchOffset(pc, inst)}; - block.branch_true = AddLabel(block, block.stack, bra_pc, function_id); + block->branch_true = AddLabel(block, block->stack, bra_pc, function_id); } -void CFG::AnalyzeBRX(Block&, Location, Instruction, bool is_absolute) { +void CFG::AnalyzeBRX(Block*, Location, Instruction, bool is_absolute) { throw NotImplementedException("{}", is_absolute ? "JMX" : "BRX"); } @@ -528,7 +410,7 @@ void CFG::AnalyzeCAL(Location pc, Instruction inst, bool is_absolute) { } } -CFG::AnalysisState CFG::AnalyzeEXIT(Block& block, FunctionId function_id, Location pc, +CFG::AnalysisState CFG::AnalyzeEXIT(Block* block, FunctionId function_id, Location pc, Instruction inst) { const IR::FlowTest flow_test{inst.branch.flow_test}; const Predicate pred{inst.Pred()}; @@ -537,41 +419,52 @@ CFG::AnalysisState CFG::AnalyzeEXIT(Block& block, FunctionId function_id, Locati return AnalysisState::Continue; } if (pred != Predicate{true} || flow_test != IR::FlowTest::T) { - if (block.stack.Peek(Token::PEXIT).has_value()) { + if (block->stack.Peek(Token::PEXIT).has_value()) { throw NotImplementedException("Conditional EXIT with PEXIT token"); } const IR::Condition cond{flow_test, static_cast(pred.index), pred.negated}; AnalyzeCondInst(block, function_id, pc, EndClass::Exit, cond); return AnalysisState::Branch; } - if (const std::optional exit_pc{block.stack.Peek(Token::PEXIT)}) { - const Stack popped_stack{block.stack.Remove(Token::PEXIT)}; - block.cond = true; - block.branch_true = AddLabel(block, popped_stack, *exit_pc, function_id); - block.branch_false = UNREACHABLE_BLOCK_ID; + if (const std::optional exit_pc{block->stack.Peek(Token::PEXIT)}) { + const Stack popped_stack{block->stack.Remove(Token::PEXIT)}; + block->cond = IR::Condition{true}; + block->branch_true = AddLabel(block, popped_stack, *exit_pc, function_id); + block->branch_false = nullptr; return AnalysisState::Branch; } - block.end = pc; - block.end_class = EndClass::Exit; + block->end = pc; + block->end_class = EndClass::Exit; return AnalysisState::Branch; } -BlockId CFG::AddLabel(const Block& block, Stack stack, Location pc, FunctionId function_id) { +Block* CFG::AddLabel(Block* block, Stack stack, Location pc, FunctionId function_id) { Function& function{functions[function_id]}; - if (block.begin == pc) { - return block.id; + if (block->begin == pc) { + // Jumps to itself + return block; } - const auto target{std::ranges::find(function.blocks_data, pc, &Block::begin)}; - if (target != function.blocks_data.end()) { - return target->id; + if (const auto it{function.blocks.find(pc, Compare{})}; it != function.blocks.end()) { + // Block already exists and it has been visited + return &*it; } - const BlockId block_id{++function.current_block_id}; + // TODO: FIX DANGLING BLOCKS + Block* const new_block{block_pool.Create(Block{ + .begin{pc}, + .end{pc}, + .end_class{EndClass::Branch}, + .stack{stack}, + .cond{IR::Condition{true}}, + .branch_true{nullptr}, + .branch_false{nullptr}, + .ir{nullptr}, + })}; function.labels.push_back(Label{ .address{pc}, - .block_id{block_id}, + .block{new_block}, .stack{std::move(stack)}, }); - return block_id; + return new_block; } std::string CFG::Dot() const { @@ -581,18 +474,12 @@ std::string CFG::Dot() const { for (const Function& function : functions) { dot += fmt::format("\tsubgraph cluster_{} {{\n", function.entrypoint); dot += fmt::format("\t\tnode [style=filled];\n"); - for (const u32 block_index : function.blocks) { - const Block& block{function.blocks_data[block_index]}; + for (const Block& block : function.blocks) { const std::string name{NameOf(block)}; - const auto add_branch = [&](BlockId branch_id, bool add_label) { - const auto it{std::ranges::find(function.blocks_data, branch_id, &Block::id)}; - dot += fmt::format("\t\t{}->", name); - if (it == function.blocks_data.end()) { - dot += fmt::format("\"Unknown label {}\"", branch_id); - } else { - dot += NameOf(*it); - }; - if (add_label && block.cond != true && block.cond != false) { + const auto add_branch = [&](Block* branch, bool add_label) { + dot += fmt::format("\t\t{}->{}", name, NameOf(*branch)); + if (add_label && block.cond != IR::Condition{true} && + block.cond != IR::Condition{false}) { dot += fmt::format(" [label=\"{}\"]", block.cond); } dot += '\n'; @@ -600,10 +487,10 @@ std::string CFG::Dot() const { dot += fmt::format("\t\t{};\n", name); switch (block.end_class) { case EndClass::Branch: - if (block.cond != false) { + if (block.cond != IR::Condition{false}) { add_branch(block.branch_true, true); } - if (block.cond != true) { + if (block.cond != IR::Condition{true}) { add_branch(block.branch_false, false); } break; @@ -619,12 +506,6 @@ std::string CFG::Dot() const { node_uid); ++node_uid; break; - case EndClass::Unreachable: - dot += fmt::format("\t\t{}->N{};\n", name, node_uid); - dot += fmt::format( - "\t\tN{} [label=\"Unreachable\"][shape=square][style=stripped];\n", node_uid); - ++node_uid; - break; } } if (function.entrypoint == 8) { @@ -635,10 +516,11 @@ std::string CFG::Dot() const { dot += "\t}\n"; } if (!functions.empty()) { - if (functions.front().blocks.empty()) { + auto& function{functions.front()}; + if (function.blocks.empty()) { dot += "Start;\n"; } else { - dot += fmt::format("\tStart -> {};\n", NameOf(functions.front().blocks_data.front())); + dot += fmt::format("\tStart -> {};\n", NameOf(*function.blocks.begin())); } dot += fmt::format("\tStart [shape=diamond];\n"); } diff --git a/src/shader_recompiler/frontend/maxwell/control_flow.h b/src/shader_recompiler/frontend/maxwell/control_flow.h index 49b369282..8179787b8 100644 --- a/src/shader_recompiler/frontend/maxwell/control_flow.h +++ b/src/shader_recompiler/frontend/maxwell/control_flow.h @@ -11,25 +11,27 @@ #include #include +#include #include "shader_recompiler/environment.h" #include "shader_recompiler/frontend/ir/condition.h" #include "shader_recompiler/frontend/maxwell/instruction.h" #include "shader_recompiler/frontend/maxwell/location.h" #include "shader_recompiler/frontend/maxwell/opcodes.h" +#include "shader_recompiler/object_pool.h" + +namespace Shader::IR { +class Block; +} namespace Shader::Maxwell::Flow { -using BlockId = u32; using FunctionId = size_t; -constexpr BlockId UNREACHABLE_BLOCK_ID{static_cast(-1)}; - enum class EndClass { Branch, Exit, Return, - Unreachable, }; enum class Token { @@ -59,58 +61,37 @@ private: boost::container::small_vector entries; }; -struct Block { +struct Block : boost::intrusive::set_base_hook< + // Normal link is ~2.5% faster compared to safe link + boost::intrusive::link_mode> { [[nodiscard]] bool Contains(Location pc) const noexcept; + bool operator<(const Block& rhs) const noexcept { + return begin < rhs.begin; + } + Location begin; Location end; EndClass end_class; - BlockId id; Stack stack; IR::Condition cond; - BlockId branch_true; - BlockId branch_false; - boost::container::small_vector imm_predecessors; - boost::container::small_vector dominance_frontiers; - union { - bool post_order_visited{false}; - Block* imm_dominator; - }; + Block* branch_true; + Block* branch_false; + IR::Block* ir; }; struct Label { Location address; - BlockId block_id; + Block* block; Stack stack; }; struct Function { Function(Location start_address); - void BuildBlocksMap(); - - void BuildImmediatePredecessors(); - - void BuildPostOrder(); - - void BuildImmediateDominators(); - - void BuildDominanceFrontier(); - - [[nodiscard]] size_t NumBlocks() const noexcept { - return static_cast(current_block_id) + 1; - } - Location entrypoint; - BlockId current_block_id{0}; boost::container::small_vector labels; - boost::container::small_vector blocks; - boost::container::small_vector blocks_data; - // Translates from BlockId to block index - boost::container::small_vector blocks_map; - - boost::container::small_vector post_order_blocks; - boost::container::small_vector post_order_map; + boost::intrusive::set blocks; }; class CFG { @@ -120,7 +101,7 @@ class CFG { }; public: - explicit CFG(Environment& env, Location start_address); + explicit CFG(Environment& env, ObjectPool& block_pool, Location start_address); CFG& operator=(const CFG&) = delete; CFG(const CFG&) = delete; @@ -133,35 +114,37 @@ public: [[nodiscard]] std::span Functions() const noexcept { return std::span(functions.data(), functions.size()); } + [[nodiscard]] std::span Functions() noexcept { + return std::span(functions.data(), functions.size()); + } private: - void VisitFunctions(Location start_address); - void AnalyzeLabel(FunctionId function_id, Label& label); /// Inspect already visited blocks. /// Return true when the block has already been visited bool InspectVisitedBlocks(FunctionId function_id, const Label& label); - AnalysisState AnalyzeInst(Block& block, FunctionId function_id, Location pc); + AnalysisState AnalyzeInst(Block* block, FunctionId function_id, Location pc); - void AnalyzeCondInst(Block& block, FunctionId function_id, Location pc, EndClass insn_end_class, + void AnalyzeCondInst(Block* block, FunctionId function_id, Location pc, EndClass insn_end_class, IR::Condition cond); /// Return true when the branch instruction is confirmed to be a branch - bool AnalyzeBranch(Block& block, FunctionId function_id, Location pc, Instruction inst, + bool AnalyzeBranch(Block* block, FunctionId function_id, Location pc, Instruction inst, Opcode opcode); - void AnalyzeBRA(Block& block, FunctionId function_id, Location pc, Instruction inst, + void AnalyzeBRA(Block* block, FunctionId function_id, Location pc, Instruction inst, bool is_absolute); - void AnalyzeBRX(Block& block, Location pc, Instruction inst, bool is_absolute); + void AnalyzeBRX(Block* block, Location pc, Instruction inst, bool is_absolute); void AnalyzeCAL(Location pc, Instruction inst, bool is_absolute); - AnalysisState AnalyzeEXIT(Block& block, FunctionId function_id, Location pc, Instruction inst); + AnalysisState AnalyzeEXIT(Block* block, FunctionId function_id, Location pc, Instruction inst); /// Return the branch target block id - BlockId AddLabel(const Block& block, Stack stack, Location pc, FunctionId function_id); + Block* AddLabel(Block* block, Stack stack, Location pc, FunctionId function_id); Environment& env; + ObjectPool& block_pool; boost::container::small_vector functions; FunctionId current_function_id{0}; }; diff --git a/src/shader_recompiler/frontend/maxwell/location.h b/src/shader_recompiler/frontend/maxwell/location.h index 66b51a19e..26d29eae2 100644 --- a/src/shader_recompiler/frontend/maxwell/location.h +++ b/src/shader_recompiler/frontend/maxwell/location.h @@ -15,7 +15,7 @@ namespace Shader::Maxwell { class Location { - static constexpr u32 VIRTUAL_OFFSET{std::numeric_limits::max()}; + static constexpr u32 VIRTUAL_BIAS{4}; public: constexpr Location() = default; @@ -27,12 +27,18 @@ public: Align(); } + constexpr Location Virtual() const noexcept { + Location virtual_location; + virtual_location.offset = offset - VIRTUAL_BIAS; + return virtual_location; + } + [[nodiscard]] constexpr u32 Offset() const noexcept { return offset; } [[nodiscard]] constexpr bool IsVirtual() const { - return offset == VIRTUAL_OFFSET; + return offset % 8 == VIRTUAL_BIAS; } constexpr auto operator<=>(const Location&) const noexcept = default; @@ -89,7 +95,7 @@ private: offset -= 8 + (offset % 32 == 8 ? 8 : 0); } - u32 offset{VIRTUAL_OFFSET}; + u32 offset{0xcccccccc}; }; } // namespace Shader::Maxwell diff --git a/src/shader_recompiler/frontend/maxwell/program.cpp b/src/shader_recompiler/frontend/maxwell/program.cpp index 8cdd20804..9fa912ed8 100644 --- a/src/shader_recompiler/frontend/maxwell/program.cpp +++ b/src/shader_recompiler/frontend/maxwell/program.cpp @@ -4,57 +4,58 @@ #include #include +#include #include "shader_recompiler/frontend/ir/basic_block.h" +#include "shader_recompiler/frontend/ir/structured_control_flow.h" #include "shader_recompiler/frontend/maxwell/program.h" -#include "shader_recompiler/frontend/maxwell/termination_code.h" #include "shader_recompiler/frontend/maxwell/translate/translate.h" #include "shader_recompiler/ir_opt/passes.h" namespace Shader::Maxwell { namespace { -void TranslateCode(ObjectPool& inst_pool, ObjectPool& block_pool, - Environment& env, const Flow::Function& cfg_function, IR::Function& function, - std::span block_map) { +IR::BlockList TranslateCode(ObjectPool& inst_pool, ObjectPool& block_pool, + Environment& env, Flow::Function& cfg_function) { const size_t num_blocks{cfg_function.blocks.size()}; - function.blocks.reserve(num_blocks); - - for (const Flow::BlockId block_id : cfg_function.blocks) { - const Flow::Block& flow_block{cfg_function.blocks_data[block_id]}; - - IR::Block* const ir_block{block_pool.Create(Translate(inst_pool, env, flow_block))}; - block_map[flow_block.id] = ir_block; - function.blocks.emplace_back(ir_block); - } -} - -void EmitTerminationInsts(const Flow::Function& cfg_function, - std::span block_map) { - for (const Flow::BlockId block_id : cfg_function.blocks) { - const Flow::Block& flow_block{cfg_function.blocks_data[block_id]}; - EmitTerminationCode(flow_block, block_map); - } -} - -void TranslateFunction(ObjectPool& inst_pool, ObjectPool& block_pool, - Environment& env, const Flow::Function& cfg_function, - IR::Function& function) { - std::vector block_map; - block_map.resize(cfg_function.blocks_data.size()); - - TranslateCode(inst_pool, block_pool, env, cfg_function, function, block_map); - EmitTerminationInsts(cfg_function, block_map); + std::vector blocks(cfg_function.blocks.size()); + std::ranges::for_each(cfg_function.blocks, [&, i = size_t{0}](auto& cfg_block) mutable { + const u32 begin{cfg_block.begin.Offset()}; + const u32 end{cfg_block.end.Offset()}; + blocks[i] = block_pool.Create(inst_pool, begin, end); + cfg_block.ir = blocks[i]; + ++i; + }); + std::ranges::for_each(cfg_function.blocks, [&, i = size_t{0}](auto& cfg_block) mutable { + IR::Block* const block{blocks[i]}; + ++i; + if (cfg_block.end_class != Flow::EndClass::Branch) { + block->SetReturn(); + } else if (cfg_block.cond == IR::Condition{true}) { + block->SetBranch(cfg_block.branch_true->ir); + } else if (cfg_block.cond == IR::Condition{false}) { + block->SetBranch(cfg_block.branch_false->ir); + } else { + block->SetBranches(cfg_block.cond, cfg_block.branch_true->ir, + cfg_block.branch_false->ir); + } + }); + return IR::VisitAST(inst_pool, block_pool, blocks, + [&](IR::Block* block) { Translate(env, block); }); } } // Anonymous namespace IR::Program TranslateProgram(ObjectPool& inst_pool, ObjectPool& block_pool, - Environment& env, const Flow::CFG& cfg) { + Environment& env, Flow::CFG& cfg) { IR::Program program; auto& functions{program.functions}; functions.reserve(cfg.Functions().size()); - for (const Flow::Function& cfg_function : cfg.Functions()) { - TranslateFunction(inst_pool, block_pool, env, cfg_function, functions.emplace_back()); + for (Flow::Function& cfg_function : cfg.Functions()) { + functions.push_back(IR::Function{ + .blocks{TranslateCode(inst_pool, block_pool, env, cfg_function)}, + }); } + + fmt::print(stdout, "No optimizations: {}", IR::DumpProgram(program)); std::ranges::for_each(functions, Optimization::SsaRewritePass); for (IR::Function& function : functions) { Optimization::Invoke(Optimization::GlobalMemoryToStorageBufferPass, function); diff --git a/src/shader_recompiler/frontend/maxwell/program.h b/src/shader_recompiler/frontend/maxwell/program.h index 3355ab129..542621a1d 100644 --- a/src/shader_recompiler/frontend/maxwell/program.h +++ b/src/shader_recompiler/frontend/maxwell/program.h @@ -19,6 +19,6 @@ namespace Shader::Maxwell { [[nodiscard]] IR::Program TranslateProgram(ObjectPool& inst_pool, ObjectPool& block_pool, Environment& env, - const Flow::CFG& cfg); + Flow::CFG& cfg); } // namespace Shader::Maxwell diff --git a/src/shader_recompiler/frontend/maxwell/termination_code.cpp b/src/shader_recompiler/frontend/maxwell/termination_code.cpp deleted file mode 100644 index ed5137f20..000000000 --- a/src/shader_recompiler/frontend/maxwell/termination_code.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2021 yuzu Emulator Project -// Licensed under GPLv2 or any later version -// Refer to the license.txt file included. - -#include - -#include "shader_recompiler/exception.h" -#include "shader_recompiler/frontend/ir/basic_block.h" -#include "shader_recompiler/frontend/ir/ir_emitter.h" -#include "shader_recompiler/frontend/maxwell/control_flow.h" -#include "shader_recompiler/frontend/maxwell/termination_code.h" - -namespace Shader::Maxwell { - -static void EmitExit(IR::IREmitter& ir) { - ir.Exit(); -} - -static IR::U1 GetFlowTest(IR::FlowTest flow_test, IR::IREmitter& ir) { - switch (flow_test) { - case IR::FlowTest::T: - return ir.Imm1(true); - case IR::FlowTest::F: - return ir.Imm1(false); - case IR::FlowTest::NE: - // FIXME: Verify this - return ir.LogicalNot(ir.GetZFlag()); - case IR::FlowTest::NaN: - // FIXME: Verify this - return ir.LogicalAnd(ir.GetSFlag(), ir.GetZFlag()); - default: - throw NotImplementedException("Flow test {}", flow_test); - } -} - -static IR::U1 GetCond(IR::Condition cond, IR::IREmitter& ir) { - const IR::FlowTest flow_test{cond.FlowTest()}; - const auto [pred, pred_negated]{cond.Pred()}; - if (pred == IR::Pred::PT && !pred_negated) { - return GetFlowTest(flow_test, ir); - } - if (flow_test == IR::FlowTest::T) { - return ir.GetPred(pred, pred_negated); - } - return ir.LogicalAnd(ir.GetPred(pred, pred_negated), GetFlowTest(flow_test, ir)); -} - -static void EmitBranch(const Flow::Block& flow_block, std::span block_map, - IR::IREmitter& ir) { - const auto add_immediate_predecessor = [&](Flow::BlockId label) { - block_map[label]->AddImmediatePredecessor(&ir.block); - }; - if (flow_block.cond == true) { - add_immediate_predecessor(flow_block.branch_true); - return ir.Branch(block_map[flow_block.branch_true]); - } - if (flow_block.cond == false) { - add_immediate_predecessor(flow_block.branch_false); - return ir.Branch(block_map[flow_block.branch_false]); - } - add_immediate_predecessor(flow_block.branch_true); - add_immediate_predecessor(flow_block.branch_false); - return ir.BranchConditional(GetCond(flow_block.cond, ir), block_map[flow_block.branch_true], - block_map[flow_block.branch_false]); -} - -void EmitTerminationCode(const Flow::Block& flow_block, std::span block_map) { - IR::Block* const block{block_map[flow_block.id]}; - IR::IREmitter ir(*block); - switch (flow_block.end_class) { - case Flow::EndClass::Branch: - EmitBranch(flow_block, block_map, ir); - break; - case Flow::EndClass::Exit: - EmitExit(ir); - break; - case Flow::EndClass::Return: - ir.Return(); - break; - case Flow::EndClass::Unreachable: - ir.Unreachable(); - break; - } -} - -} // namespace Shader::Maxwell diff --git a/src/shader_recompiler/frontend/maxwell/termination_code.h b/src/shader_recompiler/frontend/maxwell/termination_code.h deleted file mode 100644 index 04e044534..000000000 --- a/src/shader_recompiler/frontend/maxwell/termination_code.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2021 yuzu Emulator Project -// Licensed under GPLv2 or any later version -// Refer to the license.txt file included. - -#pragma once - -#include - -#include "shader_recompiler/frontend/ir/basic_block.h" -#include "shader_recompiler/frontend/maxwell/control_flow.h" - -namespace Shader::Maxwell { - -/// Emit termination instructions and collect immediate predecessors -void EmitTerminationCode(const Flow::Block& flow_block, std::span block_map); - -} // namespace Shader::Maxwell diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/integer_shift_left.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/integer_shift_left.cpp index d4b417d14..b752785d4 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/impl/integer_shift_left.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/impl/integer_shift_left.cpp @@ -28,7 +28,7 @@ void SHL(TranslatorVisitor& v, u64 insn, const IR::U32& unsafe_shift) { IR::U32 result; if (shl.w != 0) { // When .W is set, the shift value is wrapped - // To emulate this we just have to clamp it ourselves. + // To emulate this we just have to wrap it ourselves. const IR::U32 shift{v.ir.BitwiseAnd(unsafe_shift, v.ir.Imm32(31))}; result = v.ir.ShiftLeftLogical(base, shift); } else { diff --git a/src/shader_recompiler/frontend/maxwell/translate/translate.cpp b/src/shader_recompiler/frontend/maxwell/translate/translate.cpp index 7e6bb07a2..f1230f58f 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/translate.cpp +++ b/src/shader_recompiler/frontend/maxwell/translate/translate.cpp @@ -23,14 +23,13 @@ static void Invoke(TranslatorVisitor& visitor, Location pc, u64 insn) { } } -IR::Block Translate(ObjectPool& inst_pool, Environment& env, - const Flow::Block& flow_block) { - IR::Block block{inst_pool, flow_block.begin.Offset(), flow_block.end.Offset()}; - TranslatorVisitor visitor{env, block}; - - const Location pc_end{flow_block.end}; - Location pc{flow_block.begin}; - while (pc != pc_end) { +void Translate(Environment& env, IR::Block* block) { + if (block->IsVirtual()) { + return; + } + TranslatorVisitor visitor{env, *block}; + const Location pc_end{block->LocationEnd()}; + for (Location pc = block->LocationBegin(); pc != pc_end; ++pc) { const u64 insn{env.ReadInstruction(pc.Offset())}; const Opcode opcode{Decode(insn)}; switch (opcode) { @@ -43,9 +42,7 @@ IR::Block Translate(ObjectPool& inst_pool, Environment& env, default: throw LogicError("Invalid opcode {}", opcode); } - ++pc; } - return block; } } // namespace Shader::Maxwell diff --git a/src/shader_recompiler/frontend/maxwell/translate/translate.h b/src/shader_recompiler/frontend/maxwell/translate/translate.h index c1c21b278..e1aa2e0f4 100644 --- a/src/shader_recompiler/frontend/maxwell/translate/translate.h +++ b/src/shader_recompiler/frontend/maxwell/translate/translate.h @@ -6,14 +6,9 @@ #include "shader_recompiler/environment.h" #include "shader_recompiler/frontend/ir/basic_block.h" -#include "shader_recompiler/frontend/ir/microinstruction.h" -#include "shader_recompiler/frontend/maxwell/control_flow.h" -#include "shader_recompiler/frontend/maxwell/location.h" -#include "shader_recompiler/object_pool.h" namespace Shader::Maxwell { -[[nodiscard]] IR::Block Translate(ObjectPool& inst_pool, Environment& env, - const Flow::Block& flow_block); +void Translate(Environment& env, IR::Block* block); } // namespace Shader::Maxwell diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index f1170c61e..9fba6ac23 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp @@ -132,6 +132,32 @@ void FoldLogicalAnd(IR::Inst& inst) { } } +void FoldLogicalOr(IR::Inst& inst) { + if (!FoldCommutative(inst, [](bool a, bool b) { return a || b; })) { + return; + } + const IR::Value rhs{inst.Arg(1)}; + if (rhs.IsImmediate()) { + if (rhs.U1()) { + inst.ReplaceUsesWith(IR::Value{true}); + } else { + inst.ReplaceUsesWith(inst.Arg(0)); + } + } +} + +void FoldLogicalNot(IR::Inst& inst) { + const IR::U1 value{inst.Arg(0)}; + if (value.IsImmediate()) { + inst.ReplaceUsesWith(IR::Value{!value.U1()}); + return; + } + IR::Inst* const arg{value.InstRecursive()}; + if (arg->Opcode() == IR::Opcode::LogicalNot) { + inst.ReplaceUsesWith(arg->Arg(0)); + } +} + template void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) { const IR::Value value{inst.Arg(0)}; @@ -160,6 +186,24 @@ void FoldWhenAllImmediates(IR::Inst& inst, Func&& func) { inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{})); } +void FoldBranchConditional(IR::Inst& inst) { + const IR::U1 cond{inst.Arg(0)}; + if (cond.IsImmediate()) { + // TODO: Convert to Branch + return; + } + const IR::Inst* cond_inst{cond.InstRecursive()}; + if (cond_inst->Opcode() == IR::Opcode::LogicalNot) { + const IR::Value true_label{inst.Arg(1)}; + const IR::Value false_label{inst.Arg(2)}; + // Remove negation on the conditional (take the parameter out of LogicalNot) and swap + // the branches + inst.SetArg(0, cond_inst->Arg(0)); + inst.SetArg(1, false_label); + inst.SetArg(2, true_label); + } +} + void ConstantPropagation(IR::Inst& inst) { switch (inst.Opcode()) { case IR::Opcode::GetRegister: @@ -178,6 +222,10 @@ void ConstantPropagation(IR::Inst& inst) { return FoldSelect(inst); case IR::Opcode::LogicalAnd: return FoldLogicalAnd(inst); + case IR::Opcode::LogicalOr: + return FoldLogicalOr(inst); + case IR::Opcode::LogicalNot: + return FoldLogicalNot(inst); case IR::Opcode::ULessThan: return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); case IR::Opcode::BitFieldUExtract: @@ -188,6 +236,8 @@ void ConstantPropagation(IR::Inst& inst) { } return (base >> shift) & ((1U << count) - 1); }); + case IR::Opcode::BranchConditional: + return FoldBranchConditional(inst); default: break; } diff --git a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp index 15a9db90a..8ca996e93 100644 --- a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp +++ b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp @@ -34,6 +34,13 @@ struct SignFlagTag : FlagTag {}; struct CarryFlagTag : FlagTag {}; struct OverflowFlagTag : FlagTag {}; +struct GotoVariable : FlagTag { + GotoVariable() = default; + explicit GotoVariable(u32 index_) : index{index_} {} + + u32 index; +}; + struct DefTable { [[nodiscard]] ValueMap& operator[](IR::Reg variable) noexcept { return regs[IR::RegIndex(variable)]; @@ -43,6 +50,10 @@ struct DefTable { return preds[IR::PredIndex(variable)]; } + [[nodiscard]] ValueMap& operator[](GotoVariable goto_variable) { + return goto_vars[goto_variable.index]; + } + [[nodiscard]] ValueMap& operator[](ZeroFlagTag) noexcept { return zero_flag; } @@ -61,6 +72,7 @@ struct DefTable { std::array regs; std::array preds; + boost::container::flat_map goto_vars; ValueMap zero_flag; ValueMap sign_flag; ValueMap carry_flag; @@ -68,15 +80,15 @@ struct DefTable { }; IR::Opcode UndefOpcode(IR::Reg) noexcept { - return IR::Opcode::Undef32; + return IR::Opcode::UndefU32; } IR::Opcode UndefOpcode(IR::Pred) noexcept { - return IR::Opcode::Undef1; + return IR::Opcode::UndefU1; } IR::Opcode UndefOpcode(const FlagTag&) noexcept { - return IR::Opcode::Undef1; + return IR::Opcode::UndefU1; } [[nodiscard]] bool IsPhi(const IR::Inst& inst) noexcept { @@ -165,6 +177,9 @@ void SsaRewritePass(IR::Function& function) { pass.WriteVariable(pred, block, inst.Arg(1)); } break; + case IR::Opcode::SetGotoVariable: + pass.WriteVariable(GotoVariable{inst.Arg(0).U32()}, block, inst.Arg(1)); + break; case IR::Opcode::SetZFlag: pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0)); break; @@ -187,6 +202,9 @@ void SsaRewritePass(IR::Function& function) { inst.ReplaceUsesWith(pass.ReadVariable(pred, block)); } break; + case IR::Opcode::GetGotoVariable: + inst.ReplaceUsesWith(pass.ReadVariable(GotoVariable{inst.Arg(0).U32()}, block)); + break; case IR::Opcode::GetZFlag: inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block)); break; diff --git a/src/shader_recompiler/ir_opt/verification_pass.cpp b/src/shader_recompiler/ir_opt/verification_pass.cpp index 8a5adf5a2..32b56eb57 100644 --- a/src/shader_recompiler/ir_opt/verification_pass.cpp +++ b/src/shader_recompiler/ir_opt/verification_pass.cpp @@ -14,6 +14,10 @@ namespace Shader::Optimization { static void ValidateTypes(const IR::Function& function) { for (const auto& block : function.blocks) { for (const IR::Inst& inst : *block) { + if (inst.Opcode() == IR::Opcode::Phi) { + // Skip validation on phi nodes + continue; + } const size_t num_args{inst.NumArgs()}; for (size_t i = 0; i < num_args; ++i) { const IR::Type t1{inst.Arg(i).Type()}; diff --git a/src/shader_recompiler/main.cpp b/src/shader_recompiler/main.cpp index 9887e066d..3ca1677c4 100644 --- a/src/shader_recompiler/main.cpp +++ b/src/shader_recompiler/main.cpp @@ -2,6 +2,7 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#include #include #include @@ -36,34 +37,46 @@ void RunDatabase() { ForEachFile("D:\\Shaders\\Database", [&](const std::filesystem::path& path) { map.emplace_back(std::make_unique(path.string().c_str())); }); - for (int i = 0; i < 300; ++i) { + auto block_pool{std::make_unique>()}; + auto t0 = std::chrono::high_resolution_clock::now(); + int N = 1; + int n = 0; + for (int i = 0; i < N; ++i) { for (auto& env : map) { + ++n; // fmt::print(stdout, "Decoding {}\n", path.string()); + const Location start_address{0}; - auto cfg{std::make_unique(*env, start_address)}; + block_pool->ReleaseContents(); + Flow::CFG cfg{*env, *block_pool, start_address}; // fmt::print(stdout, "{}\n", cfg->Dot()); // IR::Program program{env, cfg}; // Optimize(program); // const std::string code{EmitGLASM(program)}; } } + auto t = std::chrono::high_resolution_clock::now(); + fmt::print(stdout, "{} ms", + std::chrono::duration_cast(t - t0).count() / double(N)); } int main() { // RunDatabase(); + auto flow_block_pool{std::make_unique>()}; auto inst_pool{std::make_unique>()}; auto block_pool{std::make_unique>()}; - // FileEnvironment env{"D:\\Shaders\\Database\\test.bin"}; - FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS15C2FB1F0B965767.bin"}; + FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"}; + // FileEnvironment env{"D:\\Shaders\\shader.bin"}; for (int i = 0; i < 1; ++i) { block_pool->ReleaseContents(); inst_pool->ReleaseContents(); - auto cfg{std::make_unique(env, 0)}; - // fmt::print(stdout, "{}\n", cfg->Dot()); - IR::Program program{TranslateProgram(*inst_pool, *block_pool, env, *cfg)}; - // fmt::print(stdout, "{}\n", IR::DumpProgram(program)); + flow_block_pool->ReleaseContents(); + Flow::CFG cfg{env, *flow_block_pool, 0}; + fmt::print(stdout, "{}\n", cfg.Dot()); + IR::Program program{TranslateProgram(*inst_pool, *block_pool, env, cfg)}; + fmt::print(stdout, "{}\n", IR::DumpProgram(program)); Backend::SPIRV::EmitSPIRV spirv{program}; } } diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h new file mode 100644 index 000000000..1760bf4a9 --- /dev/null +++ b/src/shader_recompiler/shader_info.h @@ -0,0 +1,28 @@ +// Copyright 2021 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include + +#include + +namespace Shader { + +struct Info { + struct ConstantBuffer { + + }; + + struct { + bool workgroup_id{}; + bool local_invocation_id{}; + bool fp16{}; + bool fp64{}; + } uses; + + std::array<18 +}; + +} // namespace Shader