From 4fde66e6094b57201d208b8abd3d7715341cd5db Mon Sep 17 00:00:00 2001 From: Fernando Sahmkow Date: Thu, 27 Jun 2019 18:57:47 -0400 Subject: [PATCH] shader_ir: Add basic goto elimination --- src/video_core/shader/ast.cpp | 348 ++++++++++++++++++++++++++++++++-- src/video_core/shader/ast.h | 178 ++++++++++++++--- 2 files changed, 486 insertions(+), 40 deletions(-) diff --git a/src/video_core/shader/ast.cpp b/src/video_core/shader/ast.cpp index 5d0e85f427..56a1b29f32 100644 --- a/src/video_core/shader/ast.cpp +++ b/src/video_core/shader/ast.cpp @@ -11,6 +11,147 @@ namespace VideoCommon::Shader { +ASTZipper::ASTZipper() = default; +ASTZipper::ASTZipper(ASTNode new_first) : first{}, last{} { + first = new_first; + last = new_first; + ASTNode current = first; + while (current) { + current->manager = this; + last = current; + current = current->next; + } +} + +void ASTZipper::PushBack(ASTNode new_node) { + new_node->previous = last; + if (last) { + last->next = new_node; + } + new_node->next.reset(); + last = new_node; + if (!first) { + first = new_node; + } + new_node->manager = this; +} + +void ASTZipper::PushFront(ASTNode new_node) { + new_node->previous.reset(); + new_node->next = first; + if (first) { + first->previous = first; + } + first = new_node; + if (!last) { + last = new_node; + } + new_node->manager = this; +} + +void ASTZipper::InsertAfter(ASTNode new_node, ASTNode at_node) { + if (!at_node) { + PushFront(new_node); + return; + } + new_node->previous = at_node; + if (at_node == last) { + last = new_node; + } + new_node->next = at_node->next; + at_node->next = new_node; + new_node->manager = this; +} + +void ASTZipper::SetParent(ASTNode new_parent) { + ASTNode current = first; + while (current) { + current->parent = new_parent; + current = current->next; + } +} + +void ASTZipper::DetachTail(ASTNode node) { + ASSERT(node->manager == this); + if (node == first) { + first.reset(); + last.reset(); + return; + } + + last = node->previous; + node->previous.reset(); +} + +void ASTZipper::DetachSegment(ASTNode start, ASTNode end) { + ASSERT(start->manager == this && end->manager == this); + ASTNode prev = start->previous; + ASTNode post = end->next; + if (!prev) { + first = post; + } else { + prev->next = post; + } + if (!post) { + last = prev; + } else { + post->previous = prev; + } + start->previous.reset(); + end->next.reset(); + ASTNode current = start; + bool found = false; + while (current) { + current->manager = nullptr; + current->parent.reset(); + found |= current == end; + current = current->next; + } + ASSERT(found); +} + +void ASTZipper::DetachSingle(ASTNode node) { + ASSERT(node->manager == this); + ASTNode prev = node->previous; + ASTNode post = node->next; + node->previous.reset(); + node->next.reset(); + if (!prev) { + first = post; + } else { + prev->next = post; + } + if (!post) { + last = prev; + } else { + post->previous = prev; + } + + node->manager = nullptr; + node->parent.reset(); +} + + +void ASTZipper::Remove(ASTNode node) { + ASSERT(node->manager == this); + ASTNode next = node->next; + ASTNode previous = node->previous; + if (previous) { + previous->next = next; + } + if (next) { + next->previous = previous; + } + node->parent.reset(); + node->manager = nullptr; + if (node == last) { + last = previous; + } + if (node == first) { + first = next; + } +} + class ExprPrinter final { public: ExprPrinter() = default; @@ -72,32 +213,39 @@ public: void operator()(ASTProgram& ast) { scope++; inner += "program {\n"; - for (ASTNode& node : ast.nodes) { - Visit(node); + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); } inner += "}\n"; scope--; } - void operator()(ASTIf& ast) { + void operator()(ASTIfThen& ast) { ExprPrinter expr_parser{}; std::visit(expr_parser, *ast.condition); inner += Ident() + "if (" + expr_parser.GetResult() + ") {\n"; scope++; - for (auto& node : ast.then_nodes) { - Visit(node); + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); } scope--; - if (ast.else_nodes.size() > 0) { - inner += Ident() + "} else {\n"; - scope++; - for (auto& node : ast.else_nodes) { - Visit(node); - } - scope--; - } else { - inner += Ident() + "}\n"; + inner += Ident() + "}\n"; + } + + void operator()(ASTIfElse& ast) { + inner += Ident() + "else {\n"; + scope++; + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); } + scope--; + inner += Ident() + "}\n"; } void operator()(ASTBlockEncoded& ast) { @@ -128,8 +276,10 @@ public: std::visit(expr_parser, *ast.condition); inner += Ident() + "do {\n"; scope++; - for (auto& node : ast.loop_nodes) { - Visit(node); + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); } scope--; inner += Ident() + "} while (" + expr_parser.GetResult() + ")\n"; @@ -142,6 +292,12 @@ public: (ast.kills ? "discard" : "exit") + ";\n"; } + void operator()(ASTBreak& ast) { + ExprPrinter expr_parser{}; + std::visit(expr_parser, *ast.condition); + inner += Ident() + "(" + expr_parser.GetResult() + ") -> break;\n"; + } + std::string& Ident() { if (memo_scope == scope) { return tabs_memo; @@ -177,4 +333,164 @@ std::string ASTManager::Print() { return printer.GetResult(); } +#pragma optimize("", off) + +void ASTManager::Decompile() { + auto it = gotos.begin(); + while (it != gotos.end()) { + ASTNode goto_node = *it; + u32 label_index = goto_node->GetGotoLabel(); + ASTNode label = labels[label_index]; + if (IndirectlyRelated(goto_node, label)) { + while (!DirectlyRelated(goto_node, label)) { + MoveOutward(goto_node); + } + } + if (DirectlyRelated(goto_node, label)) { + u32 goto_level = goto_node->GetLevel(); + u32 label_level = goto_node->GetLevel(); + while (label_level > goto_level) { + MoveOutward(goto_node); + goto_level++; + } + } + if (label->GetParent() == goto_node->GetParent()) { + bool is_loop = false; + ASTNode current = goto_node->GetPrevious(); + while (current) { + if (current == label) { + is_loop = true; + break; + } + current = current->GetPrevious(); + } + + if (is_loop) { + EncloseDoWhile(goto_node, label); + } else { + EncloseIfThen(goto_node, label); + } + it = gotos.erase(it); + continue; + } + it++; + } + /* + for (ASTNode label : labels) { + auto& manager = label->GetManager(); + manager.Remove(label); + } + labels.clear(); + */ +} + +bool ASTManager::IndirectlyRelated(ASTNode first, ASTNode second) { + return !(first->GetParent() == second->GetParent() || DirectlyRelated(first, second)); +} + +bool ASTManager::DirectlyRelated(ASTNode first, ASTNode second) { + if (first->GetParent() == second->GetParent()) { + return false; + } + u32 first_level = first->GetLevel(); + u32 second_level = second->GetLevel(); + u32 min_level; + u32 max_level; + ASTNode max; + ASTNode min; + if (first_level > second_level) { + min_level = second_level; + min = second; + max_level = first_level; + max = first; + } else { + min_level = first_level; + min = first; + max_level = second_level; + max = second; + } + + while (min_level < max_level) { + min_level++; + min = min->GetParent(); + } + + return (min->GetParent() == max->GetParent()); +} + +void ASTManager::EncloseDoWhile(ASTNode goto_node, ASTNode label) { + ASTZipper& zipper = goto_node->GetManager(); + ASTNode loop_start = label->GetNext(); + if (loop_start == goto_node) { + zipper.Remove(goto_node); + return; + } + ASTNode parent = label->GetParent(); + Expr condition = goto_node->GetGotoCondition(); + zipper.DetachSegment(loop_start, goto_node); + ASTNode do_while_node = ASTBase::Make(parent, condition, ASTZipper(loop_start)); + zipper.InsertAfter(do_while_node, label); + ASTZipper* sub_zipper = do_while_node->GetSubNodes(); + sub_zipper->SetParent(do_while_node); + sub_zipper->Remove(goto_node); +} + +void ASTManager::EncloseIfThen(ASTNode goto_node, ASTNode label) { + ASTZipper& zipper = goto_node->GetManager(); + ASTNode if_end = label->GetPrevious(); + if (if_end == goto_node) { + zipper.Remove(goto_node); + return; + } + ASTNode prev = goto_node->GetPrevious(); + ASTNode parent = label->GetParent(); + Expr condition = goto_node->GetGotoCondition(); + Expr neg_condition = MakeExpr(condition); + zipper.DetachSegment(goto_node, if_end); + ASTNode if_node = ASTBase::Make(parent, condition, ASTZipper(goto_node)); + zipper.InsertAfter(if_node, prev); + ASTZipper* sub_zipper = if_node->GetSubNodes(); + sub_zipper->SetParent(if_node); + sub_zipper->Remove(goto_node); +} + +void ASTManager::MoveOutward(ASTNode goto_node) { + ASTZipper& zipper = goto_node->GetManager(); + ASTNode parent = goto_node->GetParent(); + bool is_loop = parent->IsLoop(); + bool is_if = parent->IsIfThen() || parent->IsIfElse(); + + ASTNode prev = goto_node->GetPrevious(); + + Expr condition = goto_node->GetGotoCondition(); + u32 var_index = NewVariable(); + Expr var_condition = MakeExpr(var_index); + ASTNode var_node = ASTBase::Make(parent, var_index, condition); + zipper.DetachSingle(goto_node); + zipper.InsertAfter(var_node, prev); + goto_node->SetGotoCondition(var_condition); + if (is_loop) { + ASTNode break_node = ASTBase::Make(parent, var_condition); + zipper.InsertAfter(break_node, var_node); + } else if (is_if) { + ASTNode post = var_node->GetNext(); + if (post) { + zipper.DetachTail(post); + ASTNode if_node = ASTBase::Make(parent, var_condition, ASTZipper(post)); + zipper.InsertAfter(if_node, var_node); + ASTZipper* sub_zipper = if_node->GetSubNodes(); + sub_zipper->SetParent(if_node); + } + } else { + UNREACHABLE(); + } + ASTZipper& zipper2 = parent->GetManager(); + ASTNode next = parent->GetNext(); + if (is_if && next && next->IsIfElse()) { + zipper2.InsertAfter(goto_node, next); + return; + } + zipper2.InsertAfter(goto_node, parent); +} + } // namespace VideoCommon::Shader diff --git a/src/video_core/shader/ast.h b/src/video_core/shader/ast.h index ca71543fba..22ac8884cc 100644 --- a/src/video_core/shader/ast.h +++ b/src/video_core/shader/ast.h @@ -18,32 +18,71 @@ namespace VideoCommon::Shader { class ASTBase; class ASTProgram; -class ASTIf; +class ASTIfThen; +class ASTIfElse; class ASTBlockEncoded; class ASTVarSet; class ASTGoto; class ASTLabel; class ASTDoWhile; class ASTReturn; +class ASTBreak; -using ASTData = std::variant; +using ASTData = std::variant; using ASTNode = std::shared_ptr; -class ASTProgram { -public: - ASTProgram() = default; - std::list nodes; +enum class ASTZipperType : u32 { + Program, + IfThen, + IfElse, + Loop, }; -class ASTIf { +class ASTZipper final { public: - ASTIf(Expr condition, std::list then_nodes, std::list else_nodes) - : condition(condition), then_nodes{then_nodes}, else_nodes{then_nodes} {} + ASTZipper(); + ASTZipper(ASTNode first); + + ASTNode GetFirst() { + return first; + } + + ASTNode GetLast() { + return last; + } + + void PushBack(ASTNode new_node); + void PushFront(ASTNode new_node); + void InsertAfter(ASTNode new_node, ASTNode at_node); + void SetParent(ASTNode new_parent); + void DetachTail(ASTNode node); + void DetachSingle(ASTNode node); + void DetachSegment(ASTNode start, ASTNode end); + void Remove(ASTNode node); + + ASTNode first{}; + ASTNode last{}; +}; + +class ASTProgram { +public: + ASTProgram() : nodes{} {}; + ASTZipper nodes; +}; + +class ASTIfThen { +public: + ASTIfThen(Expr condition, ASTZipper nodes) : condition(condition), nodes{nodes} {} Expr condition; - std::list then_nodes; - std::list else_nodes; + ASTZipper nodes; +}; + +class ASTIfElse { +public: + ASTIfElse(ASTZipper nodes) : nodes{nodes} {} + ASTZipper nodes; }; class ASTBlockEncoded { @@ -75,10 +114,9 @@ public: class ASTDoWhile { public: - ASTDoWhile(Expr condition, std::list loop_nodes) - : condition(condition), loop_nodes{loop_nodes} {} + ASTDoWhile(Expr condition, ASTZipper nodes) : condition(condition), nodes{nodes} {} Expr condition; - std::list loop_nodes; + ASTZipper nodes; }; class ASTReturn { @@ -88,6 +126,12 @@ public: bool kills; }; +class ASTBreak { +public: + ASTBreak(Expr condition) : condition{condition} {} + Expr condition; +}; + class ASTBase { public: explicit ASTBase(ASTNode parent, ASTData data) : parent{parent}, data{data} {} @@ -111,9 +155,9 @@ public: u32 GetLevel() const { u32 level = 0; - auto next = parent; - while (next) { - next = next->GetParent(); + auto next_parent = parent; + while (next_parent) { + next_parent = next_parent->GetParent(); level++; } return level; @@ -123,15 +167,83 @@ public: return &data; } + ASTNode GetNext() { + return next; + } + + ASTNode GetPrevious() { + return previous; + } + + ASTZipper& GetManager() { + return *manager; + } + + u32 GetGotoLabel() const { + auto inner = std::get_if(&data); + if (inner) { + return inner->label; + } + return -1; + } + + Expr GetGotoCondition() const { + auto inner = std::get_if(&data); + if (inner) { + return inner->condition; + } + return nullptr; + } + + void SetGotoCondition(Expr new_condition) { + auto inner = std::get_if(&data); + if (inner) { + inner->condition = new_condition; + } + } + + bool IsIfThen() const { + return std::holds_alternative(data); + } + + bool IsIfElse() const { + return std::holds_alternative(data); + } + + bool IsLoop() const { + return std::holds_alternative(data); + } + + ASTZipper* GetSubNodes() { + if (std::holds_alternative(data)) { + return &std::get_if(&data)->nodes; + } + if (std::holds_alternative(data)) { + return &std::get_if(&data)->nodes; + } + if (std::holds_alternative(data)) { + return &std::get_if(&data)->nodes; + } + if (std::holds_alternative(data)) { + return &std::get_if(&data)->nodes; + } + return nullptr; + } + private: + friend class ASTZipper; + ASTData data; ASTNode parent; + ASTNode next{}; + ASTNode previous{}; + ASTZipper* manager{}; }; class ASTManager final { public: explicit ASTManager() { - main_node = ASTBase::Make(nullptr); + main_node = ASTBase::Make(ASTNode{}); program = std::get_if(main_node->GetInnerData()); } @@ -147,31 +259,49 @@ public: u32 index = labels_map[address]; ASTNode label = ASTBase::Make(main_node, index); labels[index] = label; - program->nodes.push_back(label); + program->nodes.PushBack(label); } void InsertGoto(Expr condition, u32 address) { u32 index = labels_map[address]; ASTNode goto_node = ASTBase::Make(main_node, condition, index); gotos.push_back(goto_node); - program->nodes.push_back(goto_node); + program->nodes.PushBack(goto_node); } void InsertBlock(u32 start_address, u32 end_address) { ASTNode block = ASTBase::Make(main_node, start_address, end_address); - program->nodes.push_back(block); + program->nodes.PushBack(block); } void InsertReturn(Expr condition, bool kills) { ASTNode node = ASTBase::Make(main_node, condition, kills); - program->nodes.push_back(node); + program->nodes.PushBack(node); } std::string Print(); - void Decompile() {} + void Decompile(); + + private: + bool IndirectlyRelated(ASTNode first, ASTNode second); + + bool DirectlyRelated(ASTNode first, ASTNode second); + + void EncloseDoWhile(ASTNode goto_node, ASTNode label); + + void EncloseIfThen(ASTNode goto_node, ASTNode label); + + void MoveOutward(ASTNode goto_node) ; + + u32 NewVariable() { + u32 new_var = variables; + variables++; + return new_var; + } + std::unordered_map labels_map{}; u32 labels_count{}; std::vector labels{};