gl_shader_decompiler: Make use of fmt with the decompiler

Allows us to avoid even more string churn by allowing the AddLine
function to make use of fmt formatting so the string is formatted all at
once instead of concatenating multiple strings.

This is similar to how yuzu's decompiler works, which I've made function
the same way in the past.
This commit is contained in:
Lioncash 2020-05-04 22:40:31 -04:00
parent db5b8b9c88
commit 016d43df98

View file

@ -8,6 +8,7 @@
#include <string>
#include <tuple>
#include <utility>
#include <fmt/format.h>
#include <nihstro/shader_bytecode.h>
#include "common/assert.h"
#include "common/common_types.h"
@ -196,12 +197,19 @@ private:
class ShaderWriter {
public:
void AddLine(std::string_view text) {
// Forwards all arguments directly to libfmt.
// Note that all formatting requirements for fmt must be
// obeyed when using this function. (e.g. {{ must be used
// printing the character '{' is desirable. Ditto for }} and '}',
// etc).
template <typename... Args>
void AddLine(std::string_view text, Args&&... args) {
AddExpression(fmt::format(text, std::forward<Args>(args)...));
AddNewLine();
}
void AddNewLine() {
DEBUG_ASSERT(scope >= 0);
if (!text.empty()) {
shader_source.append(static_cast<std::size_t>(scope) * 4, ' ');
}
shader_source += text;
shader_source += '\n';
}
@ -212,6 +220,13 @@ public:
int scope = 0;
private:
void AddExpression(std::string_view text) {
if (!text.empty()) {
shader_source.append(static_cast<std::size_t>(scope) * 4, ' ');
}
shader_source += text;
}
std::string shader_source;
};
@ -222,16 +237,16 @@ std::string GetSelectorSrc(const SwizzlePattern& pattern) {
for (std::size_t i = 0; i < 4; ++i) {
switch ((pattern.*getter)(i)) {
case SwizzlePattern::Selector::x:
out += "x";
out += 'x';
break;
case SwizzlePattern::Selector::y:
out += "y";
out += 'y';
break;
case SwizzlePattern::Selector::z:
out += "z";
out += 'z';
break;
case SwizzlePattern::Selector::w:
out += "w";
out += 'w';
break;
default:
UNREACHABLE();
@ -275,28 +290,28 @@ private:
static std::string EvaluateCondition(Instruction::FlowControlType flow_control) {
using Op = Instruction::FlowControlType::Op;
std::string result_x =
const std::string_view result_x =
flow_control.refx.Value() ? "conditional_code.x" : "!conditional_code.x";
std::string result_y =
const std::string_view result_y =
flow_control.refy.Value() ? "conditional_code.y" : "!conditional_code.y";
switch (flow_control.op) {
case Op::JustX:
return result_x;
return std::string(result_x);
case Op::JustY:
return result_y;
return std::string(result_y);
case Op::Or:
case Op::And: {
std::string and_or = flow_control.op == Op::Or ? "any" : "all";
const std::string_view and_or = flow_control.op == Op::Or ? "any" : "all";
std::string bvec;
if (flow_control.refx.Value() && flow_control.refy.Value()) {
bvec = "conditional_code";
} else if (!flow_control.refx.Value() && !flow_control.refy.Value()) {
bvec = "not(conditional_code)";
} else {
bvec = "bvec2(" + result_x + ", " + result_y + ")";
bvec = fmt::format("bvec2({}, {})", result_x, result_y);
}
return and_or + "(" + bvec + ")";
return fmt::format("{}({})", and_or, bvec);
}
default:
UNREACHABLE();
@ -307,20 +322,19 @@ private:
/// Generates code representing a source register.
std::string GetSourceRegister(const SourceRegister& source_reg,
u32 address_register_index) const {
u32 index = static_cast<u32>(source_reg.GetIndex());
std::string index_str = std::to_string(index);
const u32 index = static_cast<u32>(source_reg.GetIndex());
switch (source_reg.GetRegisterType()) {
case RegisterType::Input:
return inputreg_getter(index);
case RegisterType::Temporary:
return "reg_tmp" + index_str;
return fmt::format("reg_tmp{}", index);
case RegisterType::FloatUniform:
if (address_register_index != 0) {
index_str +=
std::string(" + address_registers.") + "xyz"[address_register_index - 1];
return fmt::format("uniforms.f[{} + address_registers.{}]", index,
"xyz"[address_register_index - 1]);
}
return "uniforms.f[" + index_str + "]";
return fmt::format("uniforms.f[{}]", index);
default:
UNREACHABLE();
return "";
@ -329,13 +343,13 @@ private:
/// Generates code representing a destination register.
std::string GetDestRegister(const DestRegister& dest_reg) const {
u32 index = static_cast<u32>(dest_reg.GetIndex());
const u32 index = static_cast<u32>(dest_reg.GetIndex());
switch (dest_reg.GetRegisterType()) {
case RegisterType::Output:
return outputreg_getter(index);
case RegisterType::Temporary:
return "reg_tmp" + std::to_string(index);
return fmt::format("reg_tmp{}", index);
default:
UNREACHABLE();
return "";
@ -344,7 +358,7 @@ private:
/// Generates code representing a bool uniform
std::string GetUniformBool(u32 index) const {
return "uniforms.b[" + std::to_string(index) + "]";
return fmt::format("uniforms.b[{}]", index);
}
/**
@ -353,12 +367,12 @@ private:
*/
void CallSubroutine(const Subroutine& subroutine) {
if (subroutine.exit_method == ExitMethod::AlwaysEnd) {
shader.AddLine(subroutine.GetName() + "();");
shader.AddLine("{}();", subroutine.GetName());
shader.AddLine("return true;");
} else if (subroutine.exit_method == ExitMethod::Conditional) {
shader.AddLine("if (" + subroutine.GetName() + "()) { return true; }");
shader.AddLine("if ({}()) {{ return true; }}", subroutine.GetName());
} else {
shader.AddLine(subroutine.GetName() + "();");
shader.AddLine("{}();", subroutine.GetName());
}
}
@ -370,7 +384,7 @@ private:
* @param dest_num_components number of components of the destination register.
* @param value_num_components number of components of the value to assign.
*/
void SetDest(const SwizzlePattern& swizzle, const std::string& reg, const std::string& value,
void SetDest(const SwizzlePattern& swizzle, std::string_view reg, std::string_view value,
u32 dest_num_components, u32 value_num_components) {
u32 dest_mask_num_components = 0;
std::string dest_mask_swizzle = ".";
@ -387,18 +401,19 @@ private:
}
DEBUG_ASSERT(value_num_components >= dest_num_components || value_num_components == 1);
std::string dest = reg + (dest_num_components != 1 ? dest_mask_swizzle : "");
const std::string dest =
fmt::format("{}{}", reg, dest_num_components != 1 ? dest_mask_swizzle : "");
std::string src = value;
std::string src{value};
if (value_num_components == 1) {
if (dest_mask_num_components != 1) {
src = "vec" + std::to_string(dest_mask_num_components) + "(" + value + ")";
src = fmt::format("vec{}({})", dest_mask_num_components, value);
}
} else if (value_num_components != dest_mask_num_components) {
src = "(" + value + ")" + dest_mask_swizzle;
src = fmt::format("({}){}", value, dest_mask_swizzle);
}
shader.AddLine(dest + " = " + src + ";");
shader.AddLine("{} = {};", dest, src);
}
/**
@ -417,7 +432,7 @@ private:
: instr.common.operand_desc_id;
const SwizzlePattern swizzle = {swizzle_data[swizzle_offset]};
shader.AddLine("// " + std::to_string(offset) + ": " + instr.opcode.Value().GetInfo().name);
shader.AddLine("// {}: {}", offset, instr.opcode.Value().GetInfo().name);
switch (instr.opcode.Value().GetInfo().type) {
case OpCode::Type::Arithmetic: {
@ -438,31 +453,32 @@ private:
switch (instr.opcode.Value().EffectiveOpCode()) {
case OpCode::Id::ADD: {
SetDest(swizzle, dest_reg, src1 + " + " + src2, 4, 4);
SetDest(swizzle, dest_reg, fmt::format("{} + {}", src1, src2), 4, 4);
break;
}
case OpCode::Id::MUL: {
if (sanitize_mul) {
SetDest(swizzle, dest_reg, "sanitize_mul(" + src1 + ", " + src2 + ")", 4, 4);
SetDest(swizzle, dest_reg, fmt::format("sanitize_mul({}, {})", src1, src2), 4,
4);
} else {
SetDest(swizzle, dest_reg, src1 + " * " + src2, 4, 4);
SetDest(swizzle, dest_reg, fmt::format("{} * {}", src1, src2), 4, 4);
}
break;
}
case OpCode::Id::FLR: {
SetDest(swizzle, dest_reg, "floor(" + src1 + ")", 4, 4);
SetDest(swizzle, dest_reg, fmt::format("floor({})", src1), 4, 4);
break;
}
case OpCode::Id::MAX: {
SetDest(swizzle, dest_reg, "max(" + src1 + ", " + src2 + ")", 4, 4);
SetDest(swizzle, dest_reg, fmt::format("max({}, {})", src1, src2), 4, 4);
break;
}
case OpCode::Id::MIN: {
SetDest(swizzle, dest_reg, "min(" + src1 + ", " + src2 + ")", 4, 4);
SetDest(swizzle, dest_reg, fmt::format("min({}, {})", src1, src2), 4, 4);
break;
}
@ -474,18 +490,20 @@ private:
std::string dot;
if (opcode == OpCode::Id::DP3) {
if (sanitize_mul) {
dot = "dot(vec3(sanitize_mul(" + src1 + ", " + src2 + ")), vec3(1.0))";
dot = fmt::format("dot(vec3(sanitize_mul({}, {})), vec3(1.0))", src1, src2);
} else {
dot = "dot(vec3(" + src1 + "), vec3(" + src2 + "))";
dot = fmt::format("dot(vec3({}), vec3({}))", src1, src2);
}
} else {
std::string src1_ = (opcode == OpCode::Id::DPH || opcode == OpCode::Id::DPHI)
? "vec4(" + src1 + ".xyz, 1.0)"
: src1;
if (sanitize_mul) {
dot = "dot(sanitize_mul(" + src1_ + ", " + src2 + "), vec4(1.0))";
const std::string src1_ =
(opcode == OpCode::Id::DPH || opcode == OpCode::Id::DPHI)
? fmt::format("vec4({}.xyz, 1.0)", src1)
: std::move(src1);
dot = fmt::format("dot(sanitize_mul({}, {}), vec4(1.0))", src1_, src2);
} else {
dot = "dot(" + src1 + ", " + src2 + ")";
dot = fmt::format("dot({}, {})", src1, src2);
}
}
@ -494,17 +512,17 @@ private:
}
case OpCode::Id::RCP: {
SetDest(swizzle, dest_reg, "(1.0 / " + src1 + ".x)", 4, 1);
SetDest(swizzle, dest_reg, fmt::format("(1.0 / {}.x)", src1), 4, 1);
break;
}
case OpCode::Id::RSQ: {
SetDest(swizzle, dest_reg, "inversesqrt(" + src1 + ".x)", 4, 1);
SetDest(swizzle, dest_reg, fmt::format("inversesqrt({}.x)", src1), 4, 1);
break;
}
case OpCode::Id::MOVA: {
SetDest(swizzle, "address_registers", "ivec2(" + src1 + ")", 2, 2);
SetDest(swizzle, "address_registers", fmt::format("ivec2({})", src1), 2, 2);
break;
}
@ -515,26 +533,27 @@ private:
case OpCode::Id::SGE:
case OpCode::Id::SGEI: {
SetDest(swizzle, dest_reg, "vec4(greaterThanEqual(" + src1 + "," + src2 + "))", 4,
4);
SetDest(swizzle, dest_reg,
fmt::format("vec4(greaterThanEqual({}, {}))", src1, src2), 4, 4);
break;
}
case OpCode::Id::SLT:
case OpCode::Id::SLTI: {
SetDest(swizzle, dest_reg, "vec4(lessThan(" + src1 + "," + src2 + "))", 4, 4);
SetDest(swizzle, dest_reg, fmt::format("vec4(lessThan({}, {}))", src1, src2), 4, 4);
break;
}
case OpCode::Id::CMP: {
using CompareOp = Instruction::Common::CompareOpType::Op;
const std::map<CompareOp, std::pair<std::string, std::string>> cmp_ops{
const std::map<CompareOp, std::pair<std::string_view, std::string_view>> cmp_ops{
{CompareOp::Equal, {"==", "equal"}},
{CompareOp::NotEqual, {"!=", "notEqual"}},
{CompareOp::LessThan, {"<", "lessThan"}},
{CompareOp::LessEqual, {"<=", "lessThanEqual"}},
{CompareOp::GreaterThan, {">", "greaterThan"}},
{CompareOp::GreaterEqual, {">=", "greaterThanEqual"}}};
{CompareOp::GreaterEqual, {">=", "greaterThanEqual"}},
};
const CompareOp op_x = instr.common.compare_op.x.Value();
const CompareOp op_y = instr.common.compare_op.y.Value();
@ -544,24 +563,24 @@ private:
} else if (cmp_ops.find(op_y) == cmp_ops.end()) {
LOG_ERROR(HW_GPU, "Unknown compare mode {:x}", static_cast<int>(op_y));
} else if (op_x != op_y) {
shader.AddLine("conditional_code.x = " + src1 + ".x " +
cmp_ops.find(op_x)->second.first + " " + src2 + ".x;");
shader.AddLine("conditional_code.y = " + src1 + ".y " +
cmp_ops.find(op_y)->second.first + " " + src2 + ".y;");
shader.AddLine("conditional_code.x = {}.x {} {}.x;", src1,
cmp_ops.find(op_x)->second.first, src2);
shader.AddLine("conditional_code.y = {}.y {} {}.y;", src1,
cmp_ops.find(op_y)->second.first, src2);
} else {
shader.AddLine("conditional_code = " + cmp_ops.find(op_x)->second.second +
"(vec2(" + src1 + "), vec2(" + src2 + "));");
shader.AddLine("conditional_code = {}(vec2({}), vec2({}));",
cmp_ops.find(op_x)->second.second, src1, src2);
}
break;
}
case OpCode::Id::EX2: {
SetDest(swizzle, dest_reg, "exp2(" + src1 + ".x)", 4, 1);
SetDest(swizzle, dest_reg, fmt::format("exp2({}.x)", src1), 4, 1);
break;
}
case OpCode::Id::LG2: {
SetDest(swizzle, dest_reg, "log2(" + src1 + ".x)", 4, 1);
SetDest(swizzle, dest_reg, fmt::format("log2({}.x)", src1), 4, 1);
break;
}
@ -604,10 +623,10 @@ private:
: "";
if (sanitize_mul) {
SetDest(swizzle, dest_reg, "sanitize_mul(" + src1 + ", " + src2 + ") + " + src3,
4, 4);
SetDest(swizzle, dest_reg,
fmt::format("sanitize_mul({}, {}) + {}", src1, src2, src3), 4, 4);
} else {
SetDest(swizzle, dest_reg, src1 + " * " + src2 + " + " + src3, 4, 4);
SetDest(swizzle, dest_reg, fmt::format("{} * {} + {}", src1, src2, src3), 4, 4);
}
} else {
LOG_ERROR(HW_GPU, "Unhandled multiply-add instruction: 0x{:02x} ({}): 0x{:08x}",
@ -637,13 +656,13 @@ private:
GetUniformBool(instr.flow_control.bool_uniform_id);
}
shader.AddLine("if (" + condition + ") {");
shader.AddLine("if ({}) {{", condition);
++shader.scope;
shader.AddLine("{ jmp_to = " + std::to_string(instr.flow_control.dest_offset) +
"u; break; }");
shader.AddLine("{{ jmp_to = {}u; break; }}",
instr.flow_control.dest_offset.Value());
--shader.scope;
shader.AddLine("}");
shader.AddLine("}}");
break;
}
@ -657,7 +676,11 @@ private:
condition = GetUniformBool(instr.flow_control.bool_uniform_id);
}
shader.AddLine(condition.empty() ? "{" : "if (" + condition + ") {");
if (condition.empty()) {
shader.AddLine("{{");
} else {
shader.AddLine("if ({}) {{", condition);
}
++shader.scope;
auto& call_sub = GetSubroutine(instr.flow_control.dest_offset,
@ -671,7 +694,7 @@ private:
}
--shader.scope;
shader.AddLine("}");
shader.AddLine("}}");
break;
}
@ -693,7 +716,7 @@ private:
const u32 endif_offset =
instr.flow_control.dest_offset + instr.flow_control.num_instructions;
shader.AddLine("if (" + condition + ") {");
shader.AddLine("if ({}) {{", condition);
++shader.scope;
auto& if_sub = GetSubroutine(if_offset, else_offset);
@ -702,7 +725,7 @@ private:
if (instr.flow_control.num_instructions != 0) {
--shader.scope;
shader.AddLine("} else {");
shader.AddLine("}} else {{");
++shader.scope;
auto& else_sub = GetSubroutine(else_offset, endif_offset);
@ -716,20 +739,20 @@ private:
}
--shader.scope;
shader.AddLine("}");
shader.AddLine("}}");
break;
}
case OpCode::Id::LOOP: {
std::string int_uniform =
"uniforms.i[" + std::to_string(instr.flow_control.int_uniform_id) + "]";
const std::string int_uniform =
fmt::format("uniforms.i[{}]", instr.flow_control.int_uniform_id.Value());
shader.AddLine("address_registers.z = int(" + int_uniform + ".y);");
shader.AddLine("address_registers.z = int({}.y);", int_uniform);
std::string loop_var = "loop" + std::to_string(offset);
shader.AddLine("for (uint " + loop_var + " = 0u; " + loop_var +
" <= " + int_uniform + ".x; address_registers.z += int(" +
int_uniform + ".z), ++" + loop_var + ") {");
const std::string loop_var = fmt::format("loop{}", offset);
shader.AddLine(
"for (uint {} = 0u; {} <= {}.x; address_registers.z += int({}.z), ++{}) {{",
loop_var, loop_var, int_uniform, int_uniform, loop_var);
++shader.scope;
auto& loop_sub = GetSubroutine(offset + 1, instr.flow_control.dest_offset + 1);
@ -737,7 +760,7 @@ private:
offset = instr.flow_control.dest_offset;
--shader.scope;
shader.AddLine("}");
shader.AddLine("}}");
if (loop_sub.exit_method == ExitMethod::AlwaysEnd) {
offset = PROGRAM_END - 1;
@ -782,41 +805,41 @@ private:
void Generate() {
if (sanitize_mul) {
shader.AddLine("vec4 sanitize_mul(vec4 lhs, vec4 rhs) {");
shader.AddLine("vec4 sanitize_mul(vec4 lhs, vec4 rhs) {{");
++shader.scope;
shader.AddLine("vec4 product = lhs * rhs;");
shader.AddLine("return mix(product, mix(mix(vec4(0.0), product, isnan(rhs)), product, "
"isnan(lhs)), isnan(product));");
--shader.scope;
shader.AddLine("}\n");
shader.AddLine("}}\n");
}
// Add declarations for registers
shader.AddLine("bvec2 conditional_code = bvec2(false);");
shader.AddLine("ivec3 address_registers = ivec3(0);");
for (int i = 0; i < 16; ++i) {
shader.AddLine("vec4 reg_tmp" + std::to_string(i) + " = vec4(0.0, 0.0, 0.0, 1.0);");
shader.AddLine("vec4 reg_tmp{} = vec4(0.0, 0.0, 0.0, 1.0);", i);
}
shader.AddLine("");
shader.AddNewLine();
// Add declarations for all subroutines
for (const auto& subroutine : subroutines) {
shader.AddLine("bool " + subroutine.GetName() + "();");
shader.AddLine("bool {}();", subroutine.GetName());
}
shader.AddLine("");
shader.AddNewLine();
// Add the main entry point
shader.AddLine("bool exec_shader() {");
shader.AddLine("bool exec_shader() {{");
++shader.scope;
CallSubroutine(GetSubroutine(main_offset, PROGRAM_END));
--shader.scope;
shader.AddLine("}\n");
shader.AddLine("}}\n");
// Add definitions for all subroutines
for (const auto& subroutine : subroutines) {
std::set<u32> labels = subroutine.labels;
shader.AddLine("bool " + subroutine.GetName() + "() {");
shader.AddLine("bool {}() {{", subroutine.GetName());
++shader.scope;
if (labels.empty()) {
@ -825,14 +848,14 @@ private:
}
} else {
labels.insert(subroutine.begin);
shader.AddLine("uint jmp_to = " + std::to_string(subroutine.begin) + "u;");
shader.AddLine("while (true) {");
shader.AddLine("uint jmp_to = {}u;", subroutine.begin);
shader.AddLine("while (true) {{");
++shader.scope;
shader.AddLine("switch (jmp_to) {");
shader.AddLine("switch (jmp_to) {{");
for (auto label : labels) {
shader.AddLine("case " + std::to_string(label) + "u: {");
shader.AddLine("case {}u: {{", label);
++shader.scope;
auto next_it = labels.lower_bound(label + 1);
@ -841,25 +864,25 @@ private:
u32 compile_end = CompileRange(label, next_label);
if (compile_end > next_label && compile_end != PROGRAM_END) {
// This happens only when there is a label inside a IF/LOOP block
shader.AddLine("{ jmp_to = " + std::to_string(compile_end) + "u; break; }");
shader.AddLine("{{ jmp_to = {}u; break; }}", compile_end);
labels.emplace(compile_end);
}
--shader.scope;
shader.AddLine("}");
shader.AddLine("}}");
}
shader.AddLine("default: return false;");
shader.AddLine("}");
shader.AddLine("}}");
--shader.scope;
shader.AddLine("}");
shader.AddLine("}}");
shader.AddLine("return false;");
}
--shader.scope;
shader.AddLine("}\n");
shader.AddLine("}}\n");
DEBUG_ASSERT(shader.scope == 0);
}