Fix vertex “built-ins”

Only declare main func out in main

Fix simd_ballot

Fix thread_index_in_simdgroup outside of compute

Fix atomic operations

instance_index
This commit is contained in:
Isaac Marovitz 2024-06-21 16:58:58 +01:00 committed by Isaac Marovitz
parent 4578ee53d3
commit b094d34575
6 changed files with 66 additions and 25 deletions

View file

@ -70,18 +70,28 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false); DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true); DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true);
}
switch (stage) switch (stage)
{ {
case ShaderStage.Vertex: case ShaderStage.Vertex:
context.AppendLine("VertexOut out;"); context.AppendLine("VertexOut out;");
// TODO: Only add if necessary
context.AppendLine("uint instance_index = instance_id + base_instance;");
break; break;
case ShaderStage.Fragment: case ShaderStage.Fragment:
context.AppendLine("FragmentOut out;"); context.AppendLine("FragmentOut out;");
break; break;
} }
// TODO: Only add if necessary
if (stage != ShaderStage.Compute)
{
// MSL does not give us access to [[thread_index_in_simdgroup]]
// outside compute. But we may still need to provide this value in frag/vert.
context.AppendLine("uint thread_index_in_simdgroup = simd_prefix_exclusive_sum(1);");
}
}
foreach (AstOperand decl in function.Locals) foreach (AstOperand decl in function.Locals)
{ {
string name = context.OperandManager.DeclareLocal(decl); string name = context.OperandManager.DeclareLocal(decl);
@ -90,15 +100,18 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
} }
} }
public static string GetVarTypeName(CodeGenContext context, AggregateType type) public static string GetVarTypeName(CodeGenContext context, AggregateType type, bool atomic = false)
{ {
var s32 = atomic ? "atomic_int" : "int";
var u32 = atomic ? "atomic_uint" : "uint";
return type switch return type switch
{ {
AggregateType.Void => "void", AggregateType.Void => "void",
AggregateType.Bool => "bool", AggregateType.Bool => "bool",
AggregateType.FP32 => "float", AggregateType.FP32 => "float",
AggregateType.S32 => "int", AggregateType.S32 => s32,
AggregateType.U32 => "uint", AggregateType.U32 => u32,
AggregateType.Vector2 | AggregateType.Bool => "bool2", AggregateType.Vector2 | AggregateType.Bool => "bool2",
AggregateType.Vector2 | AggregateType.FP32 => "float2", AggregateType.Vector2 | AggregateType.FP32 => "float2",
AggregateType.Vector2 | AggregateType.S32 => "int2", AggregateType.Vector2 | AggregateType.S32 => "int2",

View file

@ -3,6 +3,7 @@ using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation; using Ryujinx.Graphics.Shader.Translation;
using System; using System;
using System.Text; using System.Text;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBallot;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
@ -43,15 +44,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory)) if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory))
{ {
builder.Append(GenerateLoadOrStore(context, operation, isStore: false));
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32 AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
? AggregateType.S32 ? AggregateType.S32
: AggregateType.U32; : AggregateType.U32;
builder.Append($"(device {Declarations.GetVarTypeName(context, dstType, true)}*)&{GenerateLoadOrStore(context, operation, isStore: false)}");
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++) for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
{ {
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}"); builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}, memory_order_relaxed");
} }
} }
else else
@ -118,6 +120,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{ {
switch (inst & Instruction.Mask) switch (inst & Instruction.Mask)
{ {
case Instruction.Ballot:
return Ballot(context, operation);
case Instruction.Barrier: case Instruction.Barrier:
return "threadgroup_barrier(mem_flags::mem_threadgroup)"; return "threadgroup_barrier(mem_flags::mem_threadgroup)";
case Instruction.Call: case Instruction.Call:

View file

@ -0,0 +1,21 @@
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGenBallot
{
public static string Ballot(CodeGenContext context, AstOperation operation)
{
AggregateType dstType = GetSrcVarType(operation.Inst, 0);
string arg = GetSourceExpr(context, operation.GetSource(0), dstType);
char component = "xyzw"[operation.Index];
return $"uint4(as_type<uint2>((simd_vote::vote_t)simd_ballot({arg})), 0, 0).{component}";
}
}
}

View file

@ -15,17 +15,17 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
_infoTable = new InstInfo[(int)Instruction.Count]; _infoTable = new InstInfo[(int)Instruction.Count];
#pragma warning disable IDE0055 // Disable formatting #pragma warning disable IDE0055 // Disable formatting
Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_add_explicit"); Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_fetch_add_explicit");
Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_and_explicit"); Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_fetch_and_explicit");
Add(Instruction.AtomicCompareAndSwap, InstType.AtomicBinary, "atomic_compare_exchange_weak_explicit"); Add(Instruction.AtomicCompareAndSwap, InstType.AtomicBinary, "atomic_compare_exchange_weak_explicit");
Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_max_explicit"); Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_fetch_max_explicit");
Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_min_explicit"); Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_fetch_min_explicit");
Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_or_explicit"); Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_fetch_or_explicit");
Add(Instruction.AtomicSwap, InstType.AtomicBinary, "atomic_exchange_explicit"); Add(Instruction.AtomicSwap, InstType.AtomicBinary, "atomic_exchange_explicit");
Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_xor_explicit"); Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_fetch_xor_explicit");
Add(Instruction.Absolute, InstType.CallUnary, "abs"); Add(Instruction.Absolute, InstType.CallUnary, "abs");
Add(Instruction.Add, InstType.OpBinaryCom, "+", 2); Add(Instruction.Add, InstType.OpBinaryCom, "+", 2);
Add(Instruction.Ballot, InstType.CallUnary, "simd_ballot"); Add(Instruction.Ballot, InstType.Special);
Add(Instruction.Barrier, InstType.Special); Add(Instruction.Barrier, InstType.Special);
Add(Instruction.BitCount, InstType.CallUnary, "popcount"); Add(Instruction.BitCount, InstType.CallUnary, "popcount");
Add(Instruction.BitfieldExtractS32, InstType.CallTernary, "extract_bits"); Add(Instruction.BitfieldExtractS32, InstType.CallTernary, "extract_bits");

View file

@ -17,15 +17,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{ {
var returnValue = ioVariable switch var returnValue = ioVariable switch
{ {
IoVariable.BaseInstance => ("base_instance", AggregateType.S32), IoVariable.BaseInstance => ("base_instance", AggregateType.U32),
IoVariable.BaseVertex => ("base_vertex", AggregateType.S32), IoVariable.BaseVertex => ("base_vertex", AggregateType.U32),
IoVariable.CtaId => ("threadgroup_position_in_grid", AggregateType.Vector3 | AggregateType.U32), IoVariable.CtaId => ("threadgroup_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
IoVariable.ClipDistance => ("clip_distance", AggregateType.Array | AggregateType.FP32), IoVariable.ClipDistance => ("clip_distance", AggregateType.Array | AggregateType.FP32),
IoVariable.FragmentOutputColor => ($"out.color{location}", AggregateType.Vector4 | AggregateType.FP32), IoVariable.FragmentOutputColor => ($"out.color{location}", AggregateType.Vector4 | AggregateType.FP32),
IoVariable.FragmentOutputDepth => ("out.depth", AggregateType.FP32), IoVariable.FragmentOutputDepth => ("out.depth", AggregateType.FP32),
IoVariable.FrontFacing => ("in.front_facing", AggregateType.Bool), IoVariable.FrontFacing => ("in.front_facing", AggregateType.Bool),
IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32), IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
IoVariable.InstanceId => ("instance_id", AggregateType.S32), IoVariable.InstanceId => ("instance_id", AggregateType.U32),
IoVariable.InstanceIndex => ("instance_index", AggregateType.U32),
IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32), IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32), IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32),
IoVariable.PointSize => ("out.point_size", AggregateType.FP32), IoVariable.PointSize => ("out.point_size", AggregateType.FP32),

View file

@ -137,6 +137,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
args = args.Append("uint vertex_id [[vertex_id]]").ToArray(); args = args.Append("uint vertex_id [[vertex_id]]").ToArray();
args = args.Append("uint instance_id [[instance_id]]").ToArray(); args = args.Append("uint instance_id [[instance_id]]").ToArray();
args = args.Append("uint base_instance [[base_instance]]").ToArray();
args = args.Append("uint base_vertex [[base_vertex]]").ToArray();
} }
else if (stage == ShaderStage.Compute) else if (stage == ShaderStage.Compute)
{ {