mirror of
https://git.suyu.dev/suyu/suyu.git
synced 2024-12-23 17:00:57 +01:00
vk_shader_compiler: Implement the decompiler in SPIR-V
This commit is contained in:
parent
0366c18d87
commit
ca9901867e
3 changed files with 301 additions and 23 deletions
|
@ -88,6 +88,9 @@ bool IsPrecise(Operation operand) {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
class ASTDecompiler;
|
||||||
|
class ExprDecompiler;
|
||||||
|
|
||||||
class SPIRVDecompiler : public Sirit::Module {
|
class SPIRVDecompiler : public Sirit::Module {
|
||||||
public:
|
public:
|
||||||
explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage)
|
explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage)
|
||||||
|
@ -97,27 +100,7 @@ public:
|
||||||
AddExtension("SPV_KHR_variable_pointers");
|
AddExtension("SPV_KHR_variable_pointers");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Decompile() {
|
void DecompileBranchMode() {
|
||||||
AllocateBindings();
|
|
||||||
AllocateLabels();
|
|
||||||
|
|
||||||
DeclareVertex();
|
|
||||||
DeclareGeometry();
|
|
||||||
DeclareFragment();
|
|
||||||
DeclareRegisters();
|
|
||||||
DeclarePredicates();
|
|
||||||
DeclareLocalMemory();
|
|
||||||
DeclareInternalFlags();
|
|
||||||
DeclareInputAttributes();
|
|
||||||
DeclareOutputAttributes();
|
|
||||||
DeclareConstantBuffers();
|
|
||||||
DeclareGlobalBuffers();
|
|
||||||
DeclareSamplers();
|
|
||||||
|
|
||||||
execute_function =
|
|
||||||
Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
|
|
||||||
Emit(OpLabel());
|
|
||||||
|
|
||||||
const u32 first_address = ir.GetBasicBlocks().begin()->first;
|
const u32 first_address = ir.GetBasicBlocks().begin()->first;
|
||||||
const Id loop_label = OpLabel("loop");
|
const Id loop_label = OpLabel("loop");
|
||||||
const Id merge_label = OpLabel("merge");
|
const Id merge_label = OpLabel("merge");
|
||||||
|
@ -174,6 +157,43 @@ public:
|
||||||
Emit(continue_label);
|
Emit(continue_label);
|
||||||
Emit(OpBranch(loop_label));
|
Emit(OpBranch(loop_label));
|
||||||
Emit(merge_label);
|
Emit(merge_label);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecompileAST();
|
||||||
|
|
||||||
|
void Decompile() {
|
||||||
|
const bool is_fully_decompiled = ir.IsDecompiled();
|
||||||
|
AllocateBindings();
|
||||||
|
if (!is_fully_decompiled) {
|
||||||
|
AllocateLabels();
|
||||||
|
}
|
||||||
|
|
||||||
|
DeclareVertex();
|
||||||
|
DeclareGeometry();
|
||||||
|
DeclareFragment();
|
||||||
|
DeclareRegisters();
|
||||||
|
DeclarePredicates();
|
||||||
|
if (is_fully_decompiled) {
|
||||||
|
DeclareFlowVariables();
|
||||||
|
}
|
||||||
|
DeclareLocalMemory();
|
||||||
|
DeclareInternalFlags();
|
||||||
|
DeclareInputAttributes();
|
||||||
|
DeclareOutputAttributes();
|
||||||
|
DeclareConstantBuffers();
|
||||||
|
DeclareGlobalBuffers();
|
||||||
|
DeclareSamplers();
|
||||||
|
|
||||||
|
execute_function =
|
||||||
|
Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
|
||||||
|
Emit(OpLabel());
|
||||||
|
|
||||||
|
if (is_fully_decompiled) {
|
||||||
|
DecompileAST();
|
||||||
|
} else {
|
||||||
|
DecompileBranchMode();
|
||||||
|
}
|
||||||
|
|
||||||
Emit(OpReturn());
|
Emit(OpReturn());
|
||||||
Emit(OpFunctionEnd());
|
Emit(OpFunctionEnd());
|
||||||
}
|
}
|
||||||
|
@ -206,6 +226,9 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class ASTDecompiler;
|
||||||
|
friend class ExprDecompiler;
|
||||||
|
|
||||||
static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
|
static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
|
||||||
|
|
||||||
void AllocateBindings() {
|
void AllocateBindings() {
|
||||||
|
@ -294,6 +317,14 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DeclareFlowVariables() {
|
||||||
|
for (u32 i = 0; i < ir.GetASTNumVariables(); i++) {
|
||||||
|
const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
|
||||||
|
Name(id, fmt::format("flow_var_{}", static_cast<u32>(i)));
|
||||||
|
flow_variables.emplace(i, AddGlobalVariable(id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void DeclareLocalMemory() {
|
void DeclareLocalMemory() {
|
||||||
if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
|
if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
|
||||||
const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
|
const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
|
||||||
|
@ -1019,7 +1050,7 @@ private:
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
Id Exit(Operation operation) {
|
Id PreExit() {
|
||||||
switch (stage) {
|
switch (stage) {
|
||||||
case ShaderStage::Vertex: {
|
case ShaderStage::Vertex: {
|
||||||
// TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
|
// TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
|
||||||
|
@ -1067,6 +1098,11 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
Id Exit(Operation operation) {
|
||||||
|
PreExit();
|
||||||
BranchingOp([&]() { Emit(OpReturn()); });
|
BranchingOp([&]() { Emit(OpReturn()); });
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
@ -1545,6 +1581,7 @@ private:
|
||||||
Id per_vertex{};
|
Id per_vertex{};
|
||||||
std::map<u32, Id> registers;
|
std::map<u32, Id> registers;
|
||||||
std::map<Tegra::Shader::Pred, Id> predicates;
|
std::map<Tegra::Shader::Pred, Id> predicates;
|
||||||
|
std::map<u32, Id> flow_variables;
|
||||||
Id local_memory{};
|
Id local_memory{};
|
||||||
std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
|
std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
|
||||||
std::map<Attribute::Index, Id> input_attributes;
|
std::map<Attribute::Index, Id> input_attributes;
|
||||||
|
@ -1580,6 +1617,223 @@ private:
|
||||||
std::map<u32, Id> labels;
|
std::map<u32, Id> labels;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ExprDecompiler {
|
||||||
|
public:
|
||||||
|
ExprDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ExprAnd& expr) {
|
||||||
|
const Id type_def = decomp.GetTypeDefinition(Type::Bool);
|
||||||
|
const Id op1 = Visit(expr.operand1);
|
||||||
|
const Id op2 = Visit(expr.operand2);
|
||||||
|
current_id = decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2));
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ExprOr& expr) {
|
||||||
|
const Id type_def = decomp.GetTypeDefinition(Type::Bool);
|
||||||
|
const Id op1 = Visit(expr.operand1);
|
||||||
|
const Id op2 = Visit(expr.operand2);
|
||||||
|
current_id = decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2));
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ExprNot& expr) {
|
||||||
|
const Id type_def = decomp.GetTypeDefinition(Type::Bool);
|
||||||
|
const Id op1 = Visit(expr.operand1);
|
||||||
|
current_id = decomp.Emit(decomp.OpLogicalNot(type_def, op1));
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ExprPredicate& expr) {
|
||||||
|
auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
|
||||||
|
current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ExprCondCode& expr) {
|
||||||
|
Node cc = decomp.ir.GetConditionCode(expr.cc);
|
||||||
|
Id target;
|
||||||
|
|
||||||
|
if (const auto pred = std::get_if<PredicateNode>(&*cc)) {
|
||||||
|
const auto index = pred->GetIndex();
|
||||||
|
switch (index) {
|
||||||
|
case Tegra::Shader::Pred::NeverExecute:
|
||||||
|
target = decomp.v_false;
|
||||||
|
case Tegra::Shader::Pred::UnusedIndex:
|
||||||
|
target = decomp.v_true;
|
||||||
|
default:
|
||||||
|
target = decomp.predicates.at(index);
|
||||||
|
}
|
||||||
|
} else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
|
||||||
|
target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
|
||||||
|
}
|
||||||
|
current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, target));
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ExprVar& expr) {
|
||||||
|
current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ExprBoolean& expr) {
|
||||||
|
current_id = expr.value ? decomp.v_true : decomp.v_false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Id GetResult() {
|
||||||
|
return current_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
Id Visit(VideoCommon::Shader::Expr& node) {
|
||||||
|
std::visit(*this, *node);
|
||||||
|
return current_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Id current_id;
|
||||||
|
SPIRVDecompiler& decomp;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ASTDecompiler {
|
||||||
|
public:
|
||||||
|
ASTDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTProgram& ast) {
|
||||||
|
ASTNode current = ast.nodes.GetFirst();
|
||||||
|
while (current) {
|
||||||
|
Visit(current);
|
||||||
|
current = current->GetNext();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTIfThen& ast) {
|
||||||
|
ExprDecompiler expr_parser{decomp};
|
||||||
|
const Id condition = expr_parser.Visit(ast.condition);
|
||||||
|
const Id then_label = decomp.OpLabel();
|
||||||
|
const Id endif_label = decomp.OpLabel();
|
||||||
|
decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
|
||||||
|
decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
|
||||||
|
decomp.Emit(then_label);
|
||||||
|
ASTNode current = ast.nodes.GetFirst();
|
||||||
|
while (current) {
|
||||||
|
Visit(current);
|
||||||
|
current = current->GetNext();
|
||||||
|
}
|
||||||
|
decomp.Emit(endif_label);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTIfElse& ast) {
|
||||||
|
UNREACHABLE();
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTBlockEncoded& ast) {
|
||||||
|
UNREACHABLE();
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTBlockDecoded& ast) {
|
||||||
|
decomp.VisitBasicBlock(ast.nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTVarSet& ast) {
|
||||||
|
ExprDecompiler expr_parser{decomp};
|
||||||
|
const Id condition = expr_parser.Visit(ast.condition);
|
||||||
|
decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition));
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTLabel& ast) {
|
||||||
|
// Do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTGoto& ast) {
|
||||||
|
UNREACHABLE();
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTDoWhile& ast) {
|
||||||
|
const Id loop_label = decomp.OpLabel();
|
||||||
|
const Id endloop_label = decomp.OpLabel();
|
||||||
|
const Id loop_start_block = decomp.OpLabel();
|
||||||
|
const Id loop_end_block = decomp.OpLabel();
|
||||||
|
current_loop_exit = endloop_label;
|
||||||
|
decomp.Emit(loop_label);
|
||||||
|
decomp.Emit(decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone));
|
||||||
|
decomp.Emit(decomp.OpBranch(loop_start_block));
|
||||||
|
decomp.Emit(loop_start_block);
|
||||||
|
ASTNode current = ast.nodes.GetFirst();
|
||||||
|
while (current) {
|
||||||
|
Visit(current);
|
||||||
|
current = current->GetNext();
|
||||||
|
}
|
||||||
|
decomp.Emit(decomp.OpBranch(loop_end_block));
|
||||||
|
decomp.Emit(loop_end_block);
|
||||||
|
ExprDecompiler expr_parser{decomp};
|
||||||
|
const Id condition = expr_parser.Visit(ast.condition);
|
||||||
|
decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label));
|
||||||
|
decomp.Emit(endloop_label);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTReturn& ast) {
|
||||||
|
bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition);
|
||||||
|
if (!is_true) {
|
||||||
|
ExprDecompiler expr_parser{decomp};
|
||||||
|
const Id condition = expr_parser.Visit(ast.condition);
|
||||||
|
const Id then_label = decomp.OpLabel();
|
||||||
|
const Id endif_label = decomp.OpLabel();
|
||||||
|
decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
|
||||||
|
decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
|
||||||
|
decomp.Emit(then_label);
|
||||||
|
if (ast.kills) {
|
||||||
|
decomp.Emit(decomp.OpKill());
|
||||||
|
} else {
|
||||||
|
decomp.PreExit();
|
||||||
|
decomp.Emit(decomp.OpReturn());
|
||||||
|
}
|
||||||
|
decomp.Emit(endif_label);
|
||||||
|
} else {
|
||||||
|
decomp.Emit(decomp.OpLabel());
|
||||||
|
if (ast.kills) {
|
||||||
|
decomp.Emit(decomp.OpKill());
|
||||||
|
} else {
|
||||||
|
decomp.PreExit();
|
||||||
|
decomp.Emit(decomp.OpReturn());
|
||||||
|
}
|
||||||
|
decomp.Emit(decomp.OpLabel());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(VideoCommon::Shader::ASTBreak& ast) {
|
||||||
|
bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition);
|
||||||
|
if (!is_true) {
|
||||||
|
ExprDecompiler expr_parser{decomp};
|
||||||
|
const Id condition = expr_parser.Visit(ast.condition);
|
||||||
|
const Id then_label = decomp.OpLabel();
|
||||||
|
const Id endif_label = decomp.OpLabel();
|
||||||
|
decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
|
||||||
|
decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
|
||||||
|
decomp.Emit(then_label);
|
||||||
|
decomp.Emit(decomp.OpBranch(current_loop_exit));
|
||||||
|
decomp.Emit(endif_label);
|
||||||
|
} else {
|
||||||
|
decomp.Emit(decomp.OpLabel());
|
||||||
|
decomp.Emit(decomp.OpBranch(current_loop_exit));
|
||||||
|
decomp.Emit(decomp.OpLabel());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit(VideoCommon::Shader::ASTNode& node) {
|
||||||
|
std::visit(*this, *node->GetInnerData());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SPIRVDecompiler& decomp;
|
||||||
|
Id current_loop_exit;
|
||||||
|
};
|
||||||
|
|
||||||
|
void SPIRVDecompiler::DecompileAST() {
|
||||||
|
u32 num_flow_variables = ir.GetASTNumVariables();
|
||||||
|
for (u32 i = 0; i < num_flow_variables; i++) {
|
||||||
|
const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
|
||||||
|
Name(id, fmt::format("flow_var_{}", i));
|
||||||
|
flow_variables.emplace(i, AddGlobalVariable(id));
|
||||||
|
}
|
||||||
|
ASTDecompiler decompiler{*this};
|
||||||
|
VideoCommon::Shader::ASTNode program = ir.GetASTProgram();
|
||||||
|
decompiler.Visit(program);
|
||||||
|
}
|
||||||
|
|
||||||
DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
|
DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
|
||||||
Maxwell::ShaderStage stage) {
|
Maxwell::ShaderStage stage) {
|
||||||
auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage);
|
auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage);
|
||||||
|
|
|
@ -205,13 +205,29 @@ public:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MarkLabelUnused() const {
|
void MarkLabelUnused() {
|
||||||
auto inner = std::get_if<ASTLabel>(&data);
|
auto inner = std::get_if<ASTLabel>(&data);
|
||||||
if (inner) {
|
if (inner) {
|
||||||
inner->unused = true;
|
inner->unused = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsLabelUnused() const {
|
||||||
|
auto inner = std::get_if<ASTLabel>(&data);
|
||||||
|
if (inner) {
|
||||||
|
return inner->unused;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
u32 GetLabelIndex() const {
|
||||||
|
auto inner = std::get_if<ASTLabel>(&data);
|
||||||
|
if (inner) {
|
||||||
|
return inner->index;
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
Expr GetIfCondition() const {
|
Expr GetIfCondition() const {
|
||||||
auto inner = std::get_if<ASTIfThen>(&data);
|
auto inner = std::get_if<ASTIfThen>(&data);
|
||||||
if (inner) {
|
if (inner) {
|
||||||
|
@ -336,6 +352,10 @@ public:
|
||||||
return variables;
|
return variables;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::vector<ASTNode>& GetLabels() const {
|
||||||
|
return labels;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const;
|
bool IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const;
|
||||||
|
|
||||||
|
|
|
@ -151,6 +151,10 @@ public:
|
||||||
return decompiled;
|
return decompiled;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ASTManager& GetASTManager() const {
|
||||||
|
return program_manager;
|
||||||
|
}
|
||||||
|
|
||||||
ASTNode GetASTProgram() const {
|
ASTNode GetASTProgram() const {
|
||||||
return program_manager.GetProgram();
|
return program_manager.GetProgram();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue