Fix vertex “built-ins”
Only declare main func out in main Fix simd_ballot Fix thread_index_in_simdgroup outside of compute Fix atomic operations instance_index
This commit is contained in:
parent
4578ee53d3
commit
b094d34575
6 changed files with 66 additions and 25 deletions
|
@ -70,16 +70,26 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
{
|
||||
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
|
||||
DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true);
|
||||
}
|
||||
|
||||
switch (stage)
|
||||
{
|
||||
case ShaderStage.Vertex:
|
||||
context.AppendLine("VertexOut out;");
|
||||
break;
|
||||
case ShaderStage.Fragment:
|
||||
context.AppendLine("FragmentOut out;");
|
||||
break;
|
||||
switch (stage)
|
||||
{
|
||||
case ShaderStage.Vertex:
|
||||
context.AppendLine("VertexOut out;");
|
||||
// TODO: Only add if necessary
|
||||
context.AppendLine("uint instance_index = instance_id + base_instance;");
|
||||
break;
|
||||
case ShaderStage.Fragment:
|
||||
context.AppendLine("FragmentOut out;");
|
||||
break;
|
||||
}
|
||||
|
||||
// TODO: Only add if necessary
|
||||
if (stage != ShaderStage.Compute)
|
||||
{
|
||||
// MSL does not give us access to [[thread_index_in_simdgroup]]
|
||||
// outside compute. But we may still need to provide this value in frag/vert.
|
||||
context.AppendLine("uint thread_index_in_simdgroup = simd_prefix_exclusive_sum(1);");
|
||||
}
|
||||
}
|
||||
|
||||
foreach (AstOperand decl in function.Locals)
|
||||
|
@ -90,15 +100,18 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
}
|
||||
}
|
||||
|
||||
public static string GetVarTypeName(CodeGenContext context, AggregateType type)
|
||||
public static string GetVarTypeName(CodeGenContext context, AggregateType type, bool atomic = false)
|
||||
{
|
||||
var s32 = atomic ? "atomic_int" : "int";
|
||||
var u32 = atomic ? "atomic_uint" : "uint";
|
||||
|
||||
return type switch
|
||||
{
|
||||
AggregateType.Void => "void",
|
||||
AggregateType.Bool => "bool",
|
||||
AggregateType.FP32 => "float",
|
||||
AggregateType.S32 => "int",
|
||||
AggregateType.U32 => "uint",
|
||||
AggregateType.S32 => s32,
|
||||
AggregateType.U32 => u32,
|
||||
AggregateType.Vector2 | AggregateType.Bool => "bool2",
|
||||
AggregateType.Vector2 | AggregateType.FP32 => "float2",
|
||||
AggregateType.Vector2 | AggregateType.S32 => "int2",
|
||||
|
|
|
@ -3,6 +3,7 @@ using Ryujinx.Graphics.Shader.StructuredIr;
|
|||
using Ryujinx.Graphics.Shader.Translation;
|
||||
using System;
|
||||
using System.Text;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBallot;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
|
||||
|
@ -43,15 +44,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
|
||||
if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory))
|
||||
{
|
||||
builder.Append(GenerateLoadOrStore(context, operation, isStore: false));
|
||||
|
||||
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
|
||||
? AggregateType.S32
|
||||
: AggregateType.U32;
|
||||
|
||||
builder.Append($"(device {Declarations.GetVarTypeName(context, dstType, true)}*)&{GenerateLoadOrStore(context, operation, isStore: false)}");
|
||||
|
||||
|
||||
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
|
||||
{
|
||||
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}");
|
||||
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}, memory_order_relaxed");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -118,6 +120,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
{
|
||||
switch (inst & Instruction.Mask)
|
||||
{
|
||||
case Instruction.Ballot:
|
||||
return Ballot(context, operation);
|
||||
case Instruction.Barrier:
|
||||
return "threadgroup_barrier(mem_flags::mem_threadgroup)";
|
||||
case Instruction.Call:
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
using Ryujinx.Graphics.Shader.StructuredIr;
|
||||
using Ryujinx.Graphics.Shader.Translation;
|
||||
|
||||
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
|
||||
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
|
||||
|
||||
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
||||
{
|
||||
static class InstGenBallot
|
||||
{
|
||||
public static string Ballot(CodeGenContext context, AstOperation operation)
|
||||
{
|
||||
AggregateType dstType = GetSrcVarType(operation.Inst, 0);
|
||||
|
||||
string arg = GetSourceExpr(context, operation.GetSource(0), dstType);
|
||||
char component = "xyzw"[operation.Index];
|
||||
|
||||
return $"uint4(as_type<uint2>((simd_vote::vote_t)simd_ballot({arg})), 0, 0).{component}";
|
||||
}
|
||||
}
|
||||
}
|
|
@ -15,17 +15,17 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
_infoTable = new InstInfo[(int)Instruction.Count];
|
||||
|
||||
#pragma warning disable IDE0055 // Disable formatting
|
||||
Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_add_explicit");
|
||||
Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_and_explicit");
|
||||
Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_fetch_add_explicit");
|
||||
Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_fetch_and_explicit");
|
||||
Add(Instruction.AtomicCompareAndSwap, InstType.AtomicBinary, "atomic_compare_exchange_weak_explicit");
|
||||
Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_max_explicit");
|
||||
Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_min_explicit");
|
||||
Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_or_explicit");
|
||||
Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_fetch_max_explicit");
|
||||
Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_fetch_min_explicit");
|
||||
Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_fetch_or_explicit");
|
||||
Add(Instruction.AtomicSwap, InstType.AtomicBinary, "atomic_exchange_explicit");
|
||||
Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_xor_explicit");
|
||||
Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_fetch_xor_explicit");
|
||||
Add(Instruction.Absolute, InstType.CallUnary, "abs");
|
||||
Add(Instruction.Add, InstType.OpBinaryCom, "+", 2);
|
||||
Add(Instruction.Ballot, InstType.CallUnary, "simd_ballot");
|
||||
Add(Instruction.Ballot, InstType.Special);
|
||||
Add(Instruction.Barrier, InstType.Special);
|
||||
Add(Instruction.BitCount, InstType.CallUnary, "popcount");
|
||||
Add(Instruction.BitfieldExtractS32, InstType.CallTernary, "extract_bits");
|
||||
|
|
|
@ -17,15 +17,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
|
|||
{
|
||||
var returnValue = ioVariable switch
|
||||
{
|
||||
IoVariable.BaseInstance => ("base_instance", AggregateType.S32),
|
||||
IoVariable.BaseVertex => ("base_vertex", AggregateType.S32),
|
||||
IoVariable.BaseInstance => ("base_instance", AggregateType.U32),
|
||||
IoVariable.BaseVertex => ("base_vertex", AggregateType.U32),
|
||||
IoVariable.CtaId => ("threadgroup_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
|
||||
IoVariable.ClipDistance => ("clip_distance", AggregateType.Array | AggregateType.FP32),
|
||||
IoVariable.FragmentOutputColor => ($"out.color{location}", AggregateType.Vector4 | AggregateType.FP32),
|
||||
IoVariable.FragmentOutputDepth => ("out.depth", AggregateType.FP32),
|
||||
IoVariable.FrontFacing => ("in.front_facing", AggregateType.Bool),
|
||||
IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
|
||||
IoVariable.InstanceId => ("instance_id", AggregateType.S32),
|
||||
IoVariable.InstanceId => ("instance_id", AggregateType.U32),
|
||||
IoVariable.InstanceIndex => ("instance_index", AggregateType.U32),
|
||||
IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
|
||||
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32),
|
||||
IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
|
||||
|
|
|
@ -137,6 +137,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
|
|||
{
|
||||
args = args.Append("uint vertex_id [[vertex_id]]").ToArray();
|
||||
args = args.Append("uint instance_id [[instance_id]]").ToArray();
|
||||
args = args.Append("uint base_instance [[base_instance]]").ToArray();
|
||||
args = args.Append("uint base_vertex [[base_vertex]]").ToArray();
|
||||
}
|
||||
else if (stage == ShaderStage.Compute)
|
||||
{
|
||||
|
|
Loading…
Reference in a new issue