From 4e5cf38009188f213680dac753d87439203f000a Mon Sep 17 00:00:00 2001
From: Isaac Marovitz <isaacryu@icloud.com>
Date: Wed, 24 Jul 2024 12:13:40 +0100
Subject: [PATCH] Image shader gen support

---
 .../CodeGen/Msl/Declarations.cs               |  26 ++++
 .../CodeGen/Msl/Defaults.cs                   |   1 +
 .../CodeGen/Msl/Instructions/InstGen.cs       |   4 +-
 .../CodeGen/Msl/Instructions/InstGenMemory.cs | 133 +++++++++++++++++-
 .../CodeGen/Msl/MslGenerator.cs               |   1 +
 src/Ryujinx.Graphics.Shader/SamplerType.cs    |   4 +-
 6 files changed, 162 insertions(+), 7 deletions(-)

diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs
index d7475357c6..3c92b0606c 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs
@@ -77,6 +77,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
             DeclareBufferStructures(context, context.Properties.ConstantBuffers.Values, true);
             DeclareBufferStructures(context, context.Properties.StorageBuffers.Values, false);
             DeclareTextures(context, context.Properties.Textures.Values);
+            DeclareImages(context, context.Properties.Images.Values);
 
             if ((info.HelperFunctionsMask & HelperFunctionsMask.FindLSB) != 0)
             {
@@ -270,6 +271,31 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
             context.AppendLine();
         }
 
+        private static void DeclareImages(CodeGenContext context, IEnumerable<TextureDefinition> images)
+        {
+            context.AppendLine("struct Images");
+            context.EnterScope();
+
+            List<string> argBufferPointers = [];
+
+            // TODO: Avoid Linq if we can
+            var sortedImages = images.OrderBy(x => x.Binding).ToArray();
+
+            foreach (TextureDefinition image in sortedImages)
+            {
+                var imageTypeName = image.Type.ToMslTextureType(true);
+                argBufferPointers.Add($"{imageTypeName} {image.Name};");
+            }
+
+            foreach (var pointer in argBufferPointers)
+            {
+                context.AppendLine(pointer);
+            }
+
+            context.LeaveScope(";");
+            context.AppendLine();
+        }
+
         private static void DeclareInputAttributes(CodeGenContext context, IEnumerable<IoDefinition> inputs)
         {
             if (context.Definitions.IaIndexing)
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Defaults.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Defaults.cs
index c01242ffe1..f43f5f255d 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Defaults.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Defaults.cs
@@ -21,5 +21,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
         public const uint ConstantBuffersIndex = 20;
         public const uint StorageBuffersIndex = 21;
         public const uint TexturesIndex = 22;
+        public const uint ImagesIndex = 23;
     }
 }
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs
index 8d4ef0e372..05fc3b2c89 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs
@@ -133,11 +133,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
                     case Instruction.GroupMemoryBarrier:
                         return "|| FIND GROUP MEMORY BARRIER ||";
                     case Instruction.ImageLoad:
-                        return "|| IMAGE LOAD ||";
                     case Instruction.ImageStore:
-                        return "|| IMAGE STORE ||";
                     case Instruction.ImageAtomic:
-                        return "|| IMAGE ATOMIC ||";
+                        return ImageLoadOrStore(context, operation);
                     case Instruction.Load:
                         return Load(context, operation);
                     case Instruction.Lod:
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs
index fce76012ed..d13300e050 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs
@@ -3,6 +3,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation;
 using Ryujinx.Graphics.Shader.StructuredIr;
 using Ryujinx.Graphics.Shader.Translation;
 using System;
+using System.Text;
 using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
 using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
 
@@ -150,6 +151,129 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
             return varName;
         }
 
+        public static string ImageLoadOrStore(CodeGenContext context, AstOperation operation)
+        {
+            AstTextureOperation texOp = (AstTextureOperation)operation;
+
+            bool isArray = (texOp.Type & SamplerType.Array) != 0;
+
+            var texCallBuilder = new StringBuilder();
+
+            string imageName = GetImageName(context.Properties, texOp);
+            texCallBuilder.Append($"images.{imageName}");
+            texCallBuilder.Append('.');
+
+            if (texOp.Inst == Instruction.ImageAtomic)
+            {
+                texCallBuilder.Append((texOp.Flags & TextureFlags.AtomicMask) switch
+                {
+                    TextureFlags.Add => "atomic_fetch_add",
+                    TextureFlags.Minimum => "atomic_min",
+                    TextureFlags.Maximum => "atomic_max",
+                    TextureFlags.Increment => "atomic_fetch_add",
+                    TextureFlags.Decrement => "atomic_fetch_sub",
+                    TextureFlags.BitwiseAnd => "atomic_fetch_and",
+                    TextureFlags.BitwiseOr => "atomic_fetch_or",
+                    TextureFlags.BitwiseXor => "atomic_fetch_xor",
+                    TextureFlags.Swap => "atomic_exchange",
+                    TextureFlags.CAS => "atomic_compare_exchange_weak",
+                    _ => "atomic_fetch_add",
+                });
+            }
+            else
+            {
+                texCallBuilder.Append(texOp.Inst == Instruction.ImageLoad ? "read" : "write");
+            }
+
+            int srcIndex = 0;
+
+            string Src(AggregateType type)
+            {
+                return GetSourceExpr(context, texOp.GetSource(srcIndex++), type);
+            }
+
+            texCallBuilder.Append('(');
+
+            var coordsBuilder = new StringBuilder();
+
+            int coordsCount = texOp.Type.GetDimensions();
+
+            if (coordsCount > 1)
+            {
+                string[] elems = new string[coordsCount];
+
+                for (int index = 0; index < coordsCount; index++)
+                {
+                    elems[index] = Src(AggregateType.S32);
+                }
+
+                coordsBuilder.Append($"uint{coordsCount}({string.Join(", ", elems)})");
+            }
+            else
+            {
+                coordsBuilder.Append(Src(AggregateType.S32));
+            }
+
+            if (isArray)
+            {
+                coordsBuilder.Append(", ");
+                coordsBuilder.Append(Src(AggregateType.S32));
+            }
+
+            if (texOp.Inst == Instruction.ImageStore)
+            {
+                AggregateType type = texOp.Format.GetComponentType();
+
+                string[] cElems = new string[4];
+
+                for (int index = 0; index < 4; index++)
+                {
+                    if (srcIndex < texOp.SourcesCount)
+                    {
+                        cElems[index] = Src(type);
+                    }
+                    else
+                    {
+                        cElems[index] = type switch
+                        {
+                            AggregateType.S32 => NumberFormatter.FormatInt(0),
+                            AggregateType.U32 => NumberFormatter.FormatUint(0),
+                            _ => NumberFormatter.FormatFloat(0),
+                        };
+                    }
+                }
+
+                string prefix = type switch
+                {
+                    AggregateType.S32 => "int",
+                    AggregateType.U32 => "uint",
+                    AggregateType.FP32 => "float",
+                    _ => string.Empty,
+                };
+
+                texCallBuilder.Append($"{prefix}4({string.Join(", ", cElems)})");
+            }
+
+            texCallBuilder.Append(", ");
+            texCallBuilder.Append(coordsBuilder);
+
+            if (texOp.Inst == Instruction.ImageAtomic)
+            {
+                // TODO: Finish atomic stuff
+            }
+            else
+            {
+                texCallBuilder.Append(')');
+
+                if (texOp.Inst == Instruction.ImageLoad)
+                {
+                    texCallBuilder.Append(GetMaskMultiDest(texOp.Index));
+                }
+            }
+
+            return texCallBuilder.ToString();
+        }
+
         public static string Load(CodeGenContext context, AstOperation operation)
         {
             return GenerateLoadOrStore(context, operation, isStore: false);
@@ -359,9 +483,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
             return texCall;
         }
 
-        private static string GetSamplerName(ShaderProperties resourceDefinitions, AstTextureOperation textOp)
+        private static string GetSamplerName(ShaderProperties resourceDefinitions, AstTextureOperation texOp)
         {
-            return resourceDefinitions.Textures[textOp.GetTextureSetAndBinding()].Name;
+            return resourceDefinitions.Textures[texOp.GetTextureSetAndBinding()].Name;
+        }
+
+        private static string GetImageName(ShaderProperties resourceDefinitions, AstTextureOperation texOp)
+        {
+            return resourceDefinitions.Images[texOp.GetTextureSetAndBinding()].Name;
         }
 
         private static string GetMaskMultiDest(int mask)
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs
index 248b7159c1..757abffdcb 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs
@@ -150,6 +150,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
                 args = args.Append($"constant ConstantBuffers &constant_buffers [[buffer({Defaults.ConstantBuffersIndex})]]").ToArray();
                 args = args.Append($"device StorageBuffers &storage_buffers [[buffer({Defaults.StorageBuffersIndex})]]").ToArray();
                 args = args.Append($"constant Textures &textures [[buffer({Defaults.TexturesIndex})]]").ToArray();
+                args = args.Append($"constant Images &images [[buffer({Defaults.ImagesIndex})]]").ToArray();
             }
 
             var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}(";
diff --git a/src/Ryujinx.Graphics.Shader/SamplerType.cs b/src/Ryujinx.Graphics.Shader/SamplerType.cs
index 67c5080127..44ff132948 100644
--- a/src/Ryujinx.Graphics.Shader/SamplerType.cs
+++ b/src/Ryujinx.Graphics.Shader/SamplerType.cs
@@ -156,7 +156,7 @@ namespace Ryujinx.Graphics.Shader
             return typeName;
         }
 
-        public static string ToMslTextureType(this SamplerType type)
+        public static string ToMslTextureType(this SamplerType type, bool image = false)
         {
             string typeName;
 
@@ -192,7 +192,7 @@ namespace Ryujinx.Graphics.Shader
                 typeName += "_array";
             }
 
-            return typeName + "<float>";
+            return $"{typeName} <float{(image ? ", access::read_write" : "")}>";
         }
     }
 }