Metal: Compute Shaders (#19)

* check for too bix texture bindings

* implement lod query

* print shader stage name

* always have fragment input

* resolve merge conflicts

* fix: lod query

* fix: casting texture coords

* support non-array memories

* use structure types for buffers

* implement compute pipeline cache

* compute dispatch

* improve error message

* rebind compute state

* bind compute textures

* pass local size as an argument to dispatch

* implement texture buffers

* hack: change vertex index to vertex id

* pass support buffer as an argument to every function

* return at the end of function

* fix: certain missing compute bindings

* implement texture base

* improve texture binding system

* remove useless exception

* move texture handle to texture base

* fix: segfault when using disposed textures

---------

Co-authored-by: Samuliak <samuliak77@gmail.com>
Co-authored-by: SamoZ256 <96914946+SamoZ256@users.noreply.github.com>
This commit is contained in:
Isaac Marovitz 2024-05-29 16:21:59 +01:00
parent 131ab75d55
commit b064d76a4f
26 changed files with 718 additions and 224 deletions

View file

@ -25,7 +25,7 @@ namespace Ryujinx.Graphics.GAL
void CopyBuffer(BufferHandle source, BufferHandle destination, int srcOffset, int dstOffset, int size); void CopyBuffer(BufferHandle source, BufferHandle destination, int srcOffset, int dstOffset, int size);
void DispatchCompute(int groupsX, int groupsY, int groupsZ); void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ);
void Draw(int vertexCount, int instanceCount, int firstVertex, int firstInstance); void Draw(int vertexCount, int instanceCount, int firstVertex, int firstInstance);
void DrawIndexed( void DrawIndexed(

View file

@ -6,17 +6,23 @@ namespace Ryujinx.Graphics.GAL.Multithreading.Commands
private int _groupsX; private int _groupsX;
private int _groupsY; private int _groupsY;
private int _groupsZ; private int _groupsZ;
private int _groupSizeX;
private int _groupSizeY;
private int _groupSizeZ;
public void Set(int groupsX, int groupsY, int groupsZ) public void Set(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{ {
_groupsX = groupsX; _groupsX = groupsX;
_groupsY = groupsY; _groupsY = groupsY;
_groupsZ = groupsZ; _groupsZ = groupsZ;
_groupSizeX = groupSizeX;
_groupSizeY = groupSizeY;
_groupSizeZ = groupSizeZ;
} }
public static void Run(ref DispatchComputeCommand command, ThreadedRenderer threaded, IRenderer renderer) public static void Run(ref DispatchComputeCommand command, ThreadedRenderer threaded, IRenderer renderer)
{ {
renderer.Pipeline.DispatchCompute(command._groupsX, command._groupsY, command._groupsZ); renderer.Pipeline.DispatchCompute(command._groupsX, command._groupsY, command._groupsZ, command._groupSizeX, command._groupSizeY, command._groupSizeZ);
} }
} }
} }

View file

@ -63,9 +63,9 @@ namespace Ryujinx.Graphics.GAL.Multithreading
_renderer.QueueCommand(); _renderer.QueueCommand();
} }
public void DispatchCompute(int groupsX, int groupsY, int groupsZ) public void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{ {
_renderer.New<DispatchComputeCommand>().Set(groupsX, groupsY, groupsZ); _renderer.New<DispatchComputeCommand>().Set(groupsX, groupsY, groupsZ, groupSizeX, groupSizeY, groupSizeZ);
_renderer.QueueCommand(); _renderer.QueueCommand();
} }

View file

@ -200,7 +200,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Compute
_channel.BufferManager.CommitComputeBindings(); _channel.BufferManager.CommitComputeBindings();
_context.Renderer.Pipeline.DispatchCompute(qmd.CtaRasterWidth, qmd.CtaRasterHeight, qmd.CtaRasterDepth); _context.Renderer.Pipeline.DispatchCompute(qmd.CtaRasterWidth, qmd.CtaRasterHeight, qmd.CtaRasterDepth, qmd.CtaThreadDimension0, qmd.CtaThreadDimension1, qmd.CtaThreadDimension2);
_3dEngine.ForceShaderUpdate(); _3dEngine.ForceShaderUpdate();
} }

View file

@ -211,7 +211,10 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed.ComputeDraw
_context.Renderer.Pipeline.DispatchCompute( _context.Renderer.Pipeline.DispatchCompute(
BitUtils.DivRoundUp(_count, ComputeLocalSize), BitUtils.DivRoundUp(_count, ComputeLocalSize),
BitUtils.DivRoundUp(_instanceCount, ComputeLocalSize), BitUtils.DivRoundUp(_instanceCount, ComputeLocalSize),
1); 1,
ComputeLocalSize,
ComputeLocalSize,
ComputeLocalSize);
} }
/// <summary> /// <summary>
@ -260,7 +263,10 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed.ComputeDraw
_context.Renderer.Pipeline.DispatchCompute( _context.Renderer.Pipeline.DispatchCompute(
BitUtils.DivRoundUp(primitivesCount, ComputeLocalSize), BitUtils.DivRoundUp(primitivesCount, ComputeLocalSize),
BitUtils.DivRoundUp(_instanceCount, ComputeLocalSize), BitUtils.DivRoundUp(_instanceCount, ComputeLocalSize),
_geometryAsCompute.Info.ThreadsPerInputPrimitive); _geometryAsCompute.Info.ThreadsPerInputPrimitive,
ComputeLocalSize,
ComputeLocalSize,
ComputeLocalSize);
} }
/// <summary> /// <summary>

View file

@ -0,0 +1,36 @@
using Ryujinx.Common.Logging;
using SharpMetal.Foundation;
using SharpMetal.Metal;
using System;
using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal
{
[SupportedOSPlatform("macos")]
public class ComputePipelineCache : StateCache<MTLComputePipelineState, MTLFunction, MTLFunction>
{
private readonly MTLDevice _device;
public ComputePipelineCache(MTLDevice device)
{
_device = device;
}
protected override MTLFunction GetHash(MTLFunction function)
{
return function;
}
protected override MTLComputePipelineState CreateValue(MTLFunction function)
{
var error = new NSError(IntPtr.Zero);
var pipelineState = _device.NewComputePipelineState(function, ref error);
if (error != IntPtr.Zero)
{
Logger.Error?.PrintMsg(LogClass.Gpu, $"Failed to create Compute Pipeline State: {StringHelper.String(error.LocalizedDescription)}");
}
return pipelineState;
}
}
}

View file

@ -15,6 +15,6 @@ namespace Ryujinx.Graphics.Metal
// TODO: Check this value // TODO: Check this value
public const int MaxVertexLayouts = 16; public const int MaxVertexLayouts = 16;
public const int MaxTextures = 31; public const int MaxTextures = 31;
public const int MaxSamplers = 31; public const int MaxSamplers = 16;
} }
} }

View file

@ -8,20 +8,23 @@ namespace Ryujinx.Graphics.Metal
{ {
public struct DirtyFlags public struct DirtyFlags
{ {
public bool Pipeline = false; public bool RenderPipeline = false;
public bool ComputePipeline = false;
public bool DepthStencil = false; public bool DepthStencil = false;
public DirtyFlags() { } public DirtyFlags() { }
public void MarkAll() public void MarkAll()
{ {
Pipeline = true; RenderPipeline = true;
ComputePipeline = true;
DepthStencil = true; DepthStencil = true;
} }
public void Clear() public void Clear()
{ {
Pipeline = false; RenderPipeline = false;
ComputePipeline = false;
DepthStencil = false; DepthStencil = false;
} }
} }
@ -31,13 +34,17 @@ namespace Ryujinx.Graphics.Metal
{ {
public MTLFunction? VertexFunction = null; public MTLFunction? VertexFunction = null;
public MTLFunction? FragmentFunction = null; public MTLFunction? FragmentFunction = null;
public MTLFunction? ComputeFunction = null;
public MTLTexture[] FragmentTextures = new MTLTexture[Constants.MaxTextures]; public TextureBase[] FragmentTextures = new TextureBase[Constants.MaxTextures];
public MTLSamplerState[] FragmentSamplers = new MTLSamplerState[Constants.MaxSamplers]; public MTLSamplerState[] FragmentSamplers = new MTLSamplerState[Constants.MaxSamplers];
public MTLTexture[] VertexTextures = new MTLTexture[Constants.MaxTextures]; public TextureBase[] VertexTextures = new TextureBase[Constants.MaxTextures];
public MTLSamplerState[] VertexSamplers = new MTLSamplerState[Constants.MaxSamplers]; public MTLSamplerState[] VertexSamplers = new MTLSamplerState[Constants.MaxSamplers];
public TextureBase[] ComputeTextures = new TextureBase[Constants.MaxTextures];
public MTLSamplerState[] ComputeSamplers = new MTLSamplerState[Constants.MaxSamplers];
public List<BufferInfo> UniformBuffers = []; public List<BufferInfo> UniformBuffers = [];
public List<BufferInfo> StorageBuffers = []; public List<BufferInfo> StorageBuffers = [];
@ -87,10 +94,12 @@ namespace Ryujinx.Graphics.Metal
{ {
// Certain state (like viewport and scissor) doesn't need to be cloned, as it is always reacreated when assigned to // Certain state (like viewport and scissor) doesn't need to be cloned, as it is always reacreated when assigned to
EncoderState clone = this; EncoderState clone = this;
clone.FragmentTextures = (MTLTexture[])FragmentTextures.Clone(); clone.FragmentTextures = (TextureBase[])FragmentTextures.Clone();
clone.FragmentSamplers = (MTLSamplerState[])FragmentSamplers.Clone(); clone.FragmentSamplers = (MTLSamplerState[])FragmentSamplers.Clone();
clone.VertexTextures = (MTLTexture[])VertexTextures.Clone(); clone.VertexTextures = (TextureBase[])VertexTextures.Clone();
clone.VertexSamplers = (MTLSamplerState[])VertexSamplers.Clone(); clone.VertexSamplers = (MTLSamplerState[])VertexSamplers.Clone();
clone.ComputeTextures = (TextureBase[])ComputeTextures.Clone();
clone.ComputeSamplers = (MTLSamplerState[])ComputeSamplers.Clone();
clone.BlendDescriptors = (BlendDescriptor?[])BlendDescriptors.Clone(); clone.BlendDescriptors = (BlendDescriptor?[])BlendDescriptors.Clone();
clone.VertexBuffers = (VertexBufferDescriptor[])VertexBuffers.Clone(); clone.VertexBuffers = (VertexBufferDescriptor[])VertexBuffers.Clone();
clone.VertexAttribs = (VertexAttribDescriptor[])VertexAttribs.Clone(); clone.VertexAttribs = (VertexAttribDescriptor[])VertexAttribs.Clone();

View file

@ -15,6 +15,7 @@ namespace Ryujinx.Graphics.Metal
private readonly Pipeline _pipeline; private readonly Pipeline _pipeline;
private readonly RenderPipelineCache _renderPipelineCache; private readonly RenderPipelineCache _renderPipelineCache;
private readonly ComputePipelineCache _computePipelineCache;
private readonly DepthStencilCache _depthStencilCache; private readonly DepthStencilCache _depthStencilCache;
private EncoderState _currentState = new(); private EncoderState _currentState = new();
@ -33,6 +34,7 @@ namespace Ryujinx.Graphics.Metal
{ {
_pipeline = pipeline; _pipeline = pipeline;
_renderPipelineCache = new(device); _renderPipelineCache = new(device);
_computePipelineCache = new(device);
_depthStencilCache = new(device); _depthStencilCache = new(device);
// Zero buffer // Zero buffer
@ -50,6 +52,7 @@ namespace Ryujinx.Graphics.Metal
_currentState.BackFaceStencil.Dispose(); _currentState.BackFaceStencil.Dispose();
_renderPipelineCache.Dispose(); _renderPipelineCache.Dispose();
_computePipelineCache.Dispose();
_depthStencilCache.Dispose(); _depthStencilCache.Dispose();
} }
@ -77,8 +80,8 @@ namespace Ryujinx.Graphics.Metal
SetScissors(renderCommandEncoder); SetScissors(renderCommandEncoder);
SetViewports(renderCommandEncoder); SetViewports(renderCommandEncoder);
SetVertexBuffers(renderCommandEncoder, _currentState.VertexBuffers); SetVertexBuffers(renderCommandEncoder, _currentState.VertexBuffers);
SetBuffers(renderCommandEncoder, _currentState.UniformBuffers, true); SetRenderBuffers(renderCommandEncoder, _currentState.UniformBuffers, true);
SetBuffers(renderCommandEncoder, _currentState.StorageBuffers, true); SetRenderBuffers(renderCommandEncoder, _currentState.StorageBuffers, true);
SetCullMode(renderCommandEncoder); SetCullMode(renderCommandEncoder);
SetFrontFace(renderCommandEncoder); SetFrontFace(renderCommandEncoder);
SetStencilRefValue(renderCommandEncoder); SetStencilRefValue(renderCommandEncoder);
@ -107,7 +110,7 @@ namespace Ryujinx.Graphics.Metal
if (_currentState.RenderTargets[i] != null) if (_currentState.RenderTargets[i] != null)
{ {
var passAttachment = renderPassDescriptor.ColorAttachments.Object((ulong)i); var passAttachment = renderPassDescriptor.ColorAttachments.Object((ulong)i);
passAttachment.Texture = _currentState.RenderTargets[i].MTLTexture; passAttachment.Texture = _currentState.RenderTargets[i].GetHandle();
passAttachment.LoadAction = _currentState.ClearLoadAction ? MTLLoadAction.Clear : MTLLoadAction.Load; passAttachment.LoadAction = _currentState.ClearLoadAction ? MTLLoadAction.Clear : MTLLoadAction.Load;
passAttachment.StoreAction = MTLStoreAction.Store; passAttachment.StoreAction = MTLStoreAction.Store;
} }
@ -118,19 +121,19 @@ namespace Ryujinx.Graphics.Metal
if (_currentState.DepthStencil != null) if (_currentState.DepthStencil != null)
{ {
switch (_currentState.DepthStencil.MTLTexture.PixelFormat) switch (_currentState.DepthStencil.GetHandle().PixelFormat)
{ {
// Depth Only Attachment // Depth Only Attachment
case MTLPixelFormat.Depth16Unorm: case MTLPixelFormat.Depth16Unorm:
case MTLPixelFormat.Depth32Float: case MTLPixelFormat.Depth32Float:
depthAttachment.Texture = _currentState.DepthStencil.MTLTexture; depthAttachment.Texture = _currentState.DepthStencil.GetHandle();
depthAttachment.LoadAction = MTLLoadAction.Load; depthAttachment.LoadAction = MTLLoadAction.Load;
depthAttachment.StoreAction = MTLStoreAction.Store; depthAttachment.StoreAction = MTLStoreAction.Store;
break; break;
// Stencil Only Attachment // Stencil Only Attachment
case MTLPixelFormat.Stencil8: case MTLPixelFormat.Stencil8:
stencilAttachment.Texture = _currentState.DepthStencil.MTLTexture; stencilAttachment.Texture = _currentState.DepthStencil.GetHandle();
stencilAttachment.LoadAction = MTLLoadAction.Load; stencilAttachment.LoadAction = MTLLoadAction.Load;
stencilAttachment.StoreAction = MTLStoreAction.Store; stencilAttachment.StoreAction = MTLStoreAction.Store;
break; break;
@ -138,16 +141,16 @@ namespace Ryujinx.Graphics.Metal
// Combined Attachment // Combined Attachment
case MTLPixelFormat.Depth24UnormStencil8: case MTLPixelFormat.Depth24UnormStencil8:
case MTLPixelFormat.Depth32FloatStencil8: case MTLPixelFormat.Depth32FloatStencil8:
depthAttachment.Texture = _currentState.DepthStencil.MTLTexture; depthAttachment.Texture = _currentState.DepthStencil.GetHandle();
depthAttachment.LoadAction = MTLLoadAction.Load; depthAttachment.LoadAction = MTLLoadAction.Load;
depthAttachment.StoreAction = MTLStoreAction.Store; depthAttachment.StoreAction = MTLStoreAction.Store;
stencilAttachment.Texture = _currentState.DepthStencil.MTLTexture; stencilAttachment.Texture = _currentState.DepthStencil.GetHandle();
stencilAttachment.LoadAction = MTLLoadAction.Load; stencilAttachment.LoadAction = MTLLoadAction.Load;
stencilAttachment.StoreAction = MTLStoreAction.Store; stencilAttachment.StoreAction = MTLStoreAction.Store;
break; break;
default: default:
Logger.Error?.PrintMsg(LogClass.Gpu, $"Unsupported Depth/Stencil Format: {_currentState.DepthStencil.MTLTexture.PixelFormat}!"); Logger.Error?.PrintMsg(LogClass.Gpu, $"Unsupported Depth/Stencil Format: {_currentState.DepthStencil.GetHandle().PixelFormat}!");
break; break;
} }
} }
@ -166,10 +169,18 @@ namespace Ryujinx.Graphics.Metal
SetViewports(renderCommandEncoder); SetViewports(renderCommandEncoder);
SetScissors(renderCommandEncoder); SetScissors(renderCommandEncoder);
SetVertexBuffers(renderCommandEncoder, _currentState.VertexBuffers); SetVertexBuffers(renderCommandEncoder, _currentState.VertexBuffers);
SetBuffers(renderCommandEncoder, _currentState.UniformBuffers, true); SetRenderBuffers(renderCommandEncoder, _currentState.UniformBuffers, true);
SetBuffers(renderCommandEncoder, _currentState.StorageBuffers, true); SetRenderBuffers(renderCommandEncoder, _currentState.StorageBuffers, true);
SetTextureAndSampler(renderCommandEncoder, ShaderStage.Vertex, _currentState.VertexTextures, _currentState.VertexSamplers); for (ulong i = 0; i < Constants.MaxTextures; i++)
SetTextureAndSampler(renderCommandEncoder, ShaderStage.Fragment, _currentState.FragmentTextures, _currentState.FragmentSamplers); {
SetRenderTexture(renderCommandEncoder, ShaderStage.Vertex, i, _currentState.VertexTextures[i]);
SetRenderTexture(renderCommandEncoder, ShaderStage.Fragment, i, _currentState.FragmentTextures[i]);
}
for (ulong i = 0; i < Constants.MaxSamplers; i++)
{
SetRenderSampler(renderCommandEncoder, ShaderStage.Vertex, i, _currentState.VertexSamplers[i]);
SetRenderSampler(renderCommandEncoder, ShaderStage.Fragment, i, _currentState.FragmentSamplers[i]);
}
// Cleanup // Cleanup
renderPassDescriptor.Dispose(); renderPassDescriptor.Dispose();
@ -177,11 +188,34 @@ namespace Ryujinx.Graphics.Metal
return renderCommandEncoder; return renderCommandEncoder;
} }
public void RebindState(MTLRenderCommandEncoder renderCommandEncoder) public MTLComputeCommandEncoder CreateComputeCommandEncoder()
{ {
if (_currentState.Dirty.Pipeline) var descriptor = new MTLComputePassDescriptor();
var computeCommandEncoder = _pipeline.CommandBuffer.ComputeCommandEncoder(descriptor);
// Rebind all the state
SetComputeBuffers(computeCommandEncoder, _currentState.UniformBuffers);
SetComputeBuffers(computeCommandEncoder, _currentState.StorageBuffers);
for (ulong i = 0; i < Constants.MaxTextures; i++)
{ {
SetPipelineState(renderCommandEncoder); SetComputeTexture(computeCommandEncoder, i, _currentState.ComputeTextures[i]);
}
for (ulong i = 0; i < Constants.MaxSamplers; i++)
{
SetComputeSampler(computeCommandEncoder, i, _currentState.ComputeSamplers[i]);
}
// Cleanup
descriptor.Dispose();
return computeCommandEncoder;
}
public void RebindRenderState(MTLRenderCommandEncoder renderCommandEncoder)
{
if (_currentState.Dirty.RenderPipeline)
{
SetRenderPipelineState(renderCommandEncoder);
} }
if (_currentState.Dirty.DepthStencil) if (_currentState.Dirty.DepthStencil)
@ -190,10 +224,22 @@ namespace Ryujinx.Graphics.Metal
} }
// Clear the dirty flags // Clear the dirty flags
_currentState.Dirty.Clear(); _currentState.Dirty.RenderPipeline = false;
_currentState.Dirty.DepthStencil = false;
} }
private readonly void SetPipelineState(MTLRenderCommandEncoder renderCommandEncoder) public void RebindComputeState(MTLComputeCommandEncoder computeCommandEncoder)
{
if (_currentState.Dirty.ComputePipeline)
{
SetComputePipelineState(computeCommandEncoder);
}
// Clear the dirty flags
_currentState.Dirty.ComputePipeline = false;
}
private readonly void SetRenderPipelineState(MTLRenderCommandEncoder renderCommandEncoder)
{ {
var renderPipelineDescriptor = new MTLRenderPipelineDescriptor(); var renderPipelineDescriptor = new MTLRenderPipelineDescriptor();
@ -202,7 +248,7 @@ namespace Ryujinx.Graphics.Metal
if (_currentState.RenderTargets[i] != null) if (_currentState.RenderTargets[i] != null)
{ {
var pipelineAttachment = renderPipelineDescriptor.ColorAttachments.Object((ulong)i); var pipelineAttachment = renderPipelineDescriptor.ColorAttachments.Object((ulong)i);
pipelineAttachment.PixelFormat = _currentState.RenderTargets[i].MTLTexture.PixelFormat; pipelineAttachment.PixelFormat = _currentState.RenderTargets[i].GetHandle().PixelFormat;
pipelineAttachment.SourceAlphaBlendFactor = MTLBlendFactor.SourceAlpha; pipelineAttachment.SourceAlphaBlendFactor = MTLBlendFactor.SourceAlpha;
pipelineAttachment.DestinationAlphaBlendFactor = MTLBlendFactor.OneMinusSourceAlpha; pipelineAttachment.DestinationAlphaBlendFactor = MTLBlendFactor.OneMinusSourceAlpha;
pipelineAttachment.SourceRGBBlendFactor = MTLBlendFactor.SourceAlpha; pipelineAttachment.SourceRGBBlendFactor = MTLBlendFactor.SourceAlpha;
@ -225,27 +271,27 @@ namespace Ryujinx.Graphics.Metal
if (_currentState.DepthStencil != null) if (_currentState.DepthStencil != null)
{ {
switch (_currentState.DepthStencil.MTLTexture.PixelFormat) switch (_currentState.DepthStencil.GetHandle().PixelFormat)
{ {
// Depth Only Attachment // Depth Only Attachment
case MTLPixelFormat.Depth16Unorm: case MTLPixelFormat.Depth16Unorm:
case MTLPixelFormat.Depth32Float: case MTLPixelFormat.Depth32Float:
renderPipelineDescriptor.DepthAttachmentPixelFormat = _currentState.DepthStencil.MTLTexture.PixelFormat; renderPipelineDescriptor.DepthAttachmentPixelFormat = _currentState.DepthStencil.GetHandle().PixelFormat;
break; break;
// Stencil Only Attachment // Stencil Only Attachment
case MTLPixelFormat.Stencil8: case MTLPixelFormat.Stencil8:
renderPipelineDescriptor.StencilAttachmentPixelFormat = _currentState.DepthStencil.MTLTexture.PixelFormat; renderPipelineDescriptor.StencilAttachmentPixelFormat = _currentState.DepthStencil.GetHandle().PixelFormat;
break; break;
// Combined Attachment // Combined Attachment
case MTLPixelFormat.Depth24UnormStencil8: case MTLPixelFormat.Depth24UnormStencil8:
case MTLPixelFormat.Depth32FloatStencil8: case MTLPixelFormat.Depth32FloatStencil8:
renderPipelineDescriptor.DepthAttachmentPixelFormat = _currentState.DepthStencil.MTLTexture.PixelFormat; renderPipelineDescriptor.DepthAttachmentPixelFormat = _currentState.DepthStencil.GetHandle().PixelFormat;
renderPipelineDescriptor.StencilAttachmentPixelFormat = _currentState.DepthStencil.MTLTexture.PixelFormat; renderPipelineDescriptor.StencilAttachmentPixelFormat = _currentState.DepthStencil.GetHandle().PixelFormat;
break; break;
default: default:
Logger.Error?.PrintMsg(LogClass.Gpu, $"Unsupported Depth/Stencil Format: {_currentState.DepthStencil.MTLTexture.PixelFormat}!"); Logger.Error?.PrintMsg(LogClass.Gpu, $"Unsupported Depth/Stencil Format: {_currentState.DepthStencil.GetHandle().PixelFormat}!");
break; break;
} }
} }
@ -287,6 +333,18 @@ namespace Ryujinx.Graphics.Metal
} }
} }
private readonly void SetComputePipelineState(MTLComputeCommandEncoder computeCommandEncoder)
{
if (_currentState.ComputeFunction == null)
{
return;
}
var pipelineState = _computePipelineCache.GetOrCreate(_currentState.ComputeFunction.Value);
computeCommandEncoder.SetComputePipelineState(pipelineState);
}
public void UpdateIndexBuffer(BufferRange buffer, IndexType type) public void UpdateIndexBuffer(BufferRange buffer, IndexType type)
{ {
if (buffer.Handle != BufferHandle.Null) if (buffer.Handle != BufferHandle.Null)
@ -307,17 +365,34 @@ namespace Ryujinx.Graphics.Metal
{ {
Program prg = (Program)program; Program prg = (Program)program;
if (prg.VertexFunction == IntPtr.Zero) if (prg.VertexFunction == IntPtr.Zero && prg.ComputeFunction == IntPtr.Zero)
{ {
Logger.Error?.PrintMsg(LogClass.Gpu, "Invalid Vertex Function!"); if (prg.FragmentFunction == IntPtr.Zero)
{
Logger.Error?.PrintMsg(LogClass.Gpu, "No compute function");
}
else
{
Logger.Error?.PrintMsg(LogClass.Gpu, "No vertex function");
}
return; return;
} }
_currentState.VertexFunction = prg.VertexFunction; if (prg.VertexFunction != IntPtr.Zero)
_currentState.FragmentFunction = prg.FragmentFunction; {
_currentState.VertexFunction = prg.VertexFunction;
_currentState.FragmentFunction = prg.FragmentFunction;
// Mark dirty // Mark dirty
_currentState.Dirty.Pipeline = true; _currentState.Dirty.RenderPipeline = true;
}
if (prg.ComputeFunction != IntPtr.Zero)
{
_currentState.ComputeFunction = prg.ComputeFunction;
// Mark dirty
_currentState.Dirty.ComputePipeline = true;
}
} }
public void UpdateRenderTargets(ITexture[] colors, ITexture depthStencil) public void UpdateRenderTargets(ITexture[] colors, ITexture depthStencil)
@ -383,7 +458,7 @@ namespace Ryujinx.Graphics.Metal
_currentState.VertexAttribs = vertexAttribs.ToArray(); _currentState.VertexAttribs = vertexAttribs.ToArray();
// Mark dirty // Mark dirty
_currentState.Dirty.Pipeline = true; _currentState.Dirty.RenderPipeline = true;
} }
public void UpdateBlendDescriptors(int index, BlendDescriptor blend) public void UpdateBlendDescriptors(int index, BlendDescriptor blend)
@ -557,7 +632,7 @@ namespace Ryujinx.Graphics.Metal
} }
// Mark dirty // Mark dirty
_currentState.Dirty.Pipeline = true; _currentState.Dirty.RenderPipeline = true;
} }
// Inlineable // Inlineable
@ -579,10 +654,18 @@ namespace Ryujinx.Graphics.Metal
} }
// Inline update // Inline update
if (_pipeline.CurrentEncoderType == EncoderType.Render && _pipeline.CurrentEncoder != null) if (_pipeline.CurrentEncoder != null)
{ {
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value); if (_pipeline.CurrentEncoderType == EncoderType.Render)
SetBuffers(renderCommandEncoder, _currentState.UniformBuffers, true); {
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetRenderBuffers(renderCommandEncoder, _currentState.UniformBuffers, true);
}
else if (_pipeline.CurrentEncoderType == EncoderType.Compute)
{
var computeCommandEncoder = new MTLComputeCommandEncoder(_pipeline.CurrentEncoder.Value);
SetComputeBuffers(computeCommandEncoder, _currentState.UniformBuffers);
}
} }
} }
@ -606,10 +689,18 @@ namespace Ryujinx.Graphics.Metal
} }
// Inline update // Inline update
if (_pipeline.CurrentEncoderType == EncoderType.Render && _pipeline.CurrentEncoder != null) if (_pipeline.CurrentEncoder != null)
{ {
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value); if (_pipeline.CurrentEncoderType == EncoderType.Render)
SetBuffers(renderCommandEncoder, _currentState.StorageBuffers, true); {
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetRenderBuffers(renderCommandEncoder, _currentState.StorageBuffers, true);
}
else if (_pipeline.CurrentEncoderType == EncoderType.Compute)
{
var computeCommandEncoder = new MTLComputeCommandEncoder(_pipeline.CurrentEncoder.Value);
SetComputeBuffers(computeCommandEncoder, _currentState.StorageBuffers);
}
} }
} }
@ -653,29 +744,86 @@ namespace Ryujinx.Graphics.Metal
} }
// Inlineable // Inlineable
public readonly void UpdateTextureAndSampler(ShaderStage stage, ulong binding, MTLTexture texture, MTLSamplerState sampler) public readonly void UpdateTexture(ShaderStage stage, ulong binding, TextureBase texture)
{ {
if (binding > 30)
{
Logger.Warning?.Print(LogClass.Gpu, $"Texture binding ({binding}) must be <= 30");
return;
}
switch (stage) switch (stage)
{ {
case ShaderStage.Fragment: case ShaderStage.Fragment:
_currentState.FragmentTextures[binding] = texture; _currentState.FragmentTextures[binding] = texture;
_currentState.FragmentSamplers[binding] = sampler;
break; break;
case ShaderStage.Vertex: case ShaderStage.Vertex:
_currentState.VertexTextures[binding] = texture; _currentState.VertexTextures[binding] = texture;
_currentState.VertexSamplers[binding] = sampler; break;
case ShaderStage.Compute:
_currentState.ComputeTextures[binding] = texture;
break; break;
} }
if (_pipeline.CurrentEncoderType == EncoderType.Render && _pipeline.CurrentEncoder != null) if (_pipeline.CurrentEncoder != null)
{ {
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value); if (_pipeline.CurrentEncoderType == EncoderType.Render)
// TODO: Only update the new ones {
SetTextureAndSampler(renderCommandEncoder, ShaderStage.Vertex, _currentState.VertexTextures, _currentState.VertexSamplers); var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetTextureAndSampler(renderCommandEncoder, ShaderStage.Fragment, _currentState.FragmentTextures, _currentState.FragmentSamplers); SetRenderTexture(renderCommandEncoder, ShaderStage.Vertex, binding, texture);
SetRenderTexture(renderCommandEncoder, ShaderStage.Fragment, binding, texture);
}
else if (_pipeline.CurrentEncoderType == EncoderType.Compute)
{
var computeCommandEncoder = new MTLComputeCommandEncoder(_pipeline.CurrentEncoder.Value);
SetComputeTexture(computeCommandEncoder, binding, texture);
}
} }
} }
// Inlineable
public readonly void UpdateSampler(ShaderStage stage, ulong binding, MTLSamplerState sampler)
{
if (binding > 15)
{
Logger.Warning?.Print(LogClass.Gpu, $"Sampler binding ({binding}) must be <= 15");
return;
}
switch (stage)
{
case ShaderStage.Fragment:
_currentState.FragmentSamplers[binding] = sampler;
break;
case ShaderStage.Vertex:
_currentState.VertexSamplers[binding] = sampler;
break;
case ShaderStage.Compute:
_currentState.ComputeSamplers[binding] = sampler;
break;
}
if (_pipeline.CurrentEncoder != null)
{
if (_pipeline.CurrentEncoderType == EncoderType.Render)
{
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetRenderSampler(renderCommandEncoder, ShaderStage.Vertex, binding, sampler);
SetRenderSampler(renderCommandEncoder, ShaderStage.Fragment, binding, sampler);
}
else if (_pipeline.CurrentEncoderType == EncoderType.Compute)
{
var computeCommandEncoder = new MTLComputeCommandEncoder(_pipeline.CurrentEncoder.Value);
SetComputeSampler(computeCommandEncoder, binding, sampler);
}
}
}
// Inlineable
public readonly void UpdateTextureAndSampler(ShaderStage stage, ulong binding, TextureBase texture, MTLSamplerState sampler)
{
UpdateTexture(stage, binding, texture);
UpdateSampler(stage, binding, sampler);
}
private readonly void SetDepthStencilState(MTLRenderCommandEncoder renderCommandEncoder) private readonly void SetDepthStencilState(MTLRenderCommandEncoder renderCommandEncoder)
{ {
if (_currentState.DepthStencilState != null) if (_currentState.DepthStencilState != null)
@ -807,10 +955,10 @@ namespace Ryujinx.Graphics.Metal
Index = bufferDescriptors.Length Index = bufferDescriptors.Length
}); });
SetBuffers(renderCommandEncoder, buffers); SetRenderBuffers(renderCommandEncoder, buffers);
} }
private readonly void SetBuffers(MTLRenderCommandEncoder renderCommandEncoder, List<BufferInfo> buffers, bool fragment = false) private readonly void SetRenderBuffers(MTLRenderCommandEncoder renderCommandEncoder, List<BufferInfo> buffers, bool fragment = false)
{ {
foreach (var buffer in buffers) foreach (var buffer in buffers)
{ {
@ -823,6 +971,14 @@ namespace Ryujinx.Graphics.Metal
} }
} }
private readonly void SetComputeBuffers(MTLComputeCommandEncoder computeCommandEncoder, List<BufferInfo> buffers)
{
foreach (var buffer in buffers)
{
computeCommandEncoder.SetBuffer(new MTLBuffer(buffer.Handle), (ulong)buffer.Offset, (ulong)buffer.Index);
}
}
private readonly void SetCullMode(MTLRenderCommandEncoder renderCommandEncoder) private readonly void SetCullMode(MTLRenderCommandEncoder renderCommandEncoder)
{ {
renderCommandEncoder.SetCullMode(_currentState.CullMode); renderCommandEncoder.SetCullMode(_currentState.CullMode);
@ -838,41 +994,64 @@ namespace Ryujinx.Graphics.Metal
renderCommandEncoder.SetStencilReferenceValues((uint)_currentState.FrontRefValue, (uint)_currentState.BackRefValue); renderCommandEncoder.SetStencilReferenceValues((uint)_currentState.FrontRefValue, (uint)_currentState.BackRefValue);
} }
private static void SetTextureAndSampler(MTLRenderCommandEncoder renderCommandEncoder, ShaderStage stage, MTLTexture[] textures, MTLSamplerState[] samplers) private static void SetRenderTexture(MTLRenderCommandEncoder renderCommandEncoder, ShaderStage stage, ulong binding, TextureBase texture)
{ {
for (int i = 0; i < textures.Length; i++) if (texture == null)
{ {
var texture = textures[i]; return;
if (texture != IntPtr.Zero)
{
switch (stage)
{
case ShaderStage.Vertex:
renderCommandEncoder.SetVertexTexture(texture, (ulong)i);
break;
case ShaderStage.Fragment:
renderCommandEncoder.SetFragmentTexture(texture, (ulong)i);
break;
}
}
} }
for (int i = 0; i < samplers.Length; i++) var textureHandle = texture.GetHandle();
if (textureHandle != IntPtr.Zero)
{ {
var sampler = samplers[i]; switch (stage)
if (sampler != IntPtr.Zero)
{ {
switch (stage) case ShaderStage.Vertex:
{ renderCommandEncoder.SetVertexTexture(textureHandle, binding);
case ShaderStage.Vertex: break;
renderCommandEncoder.SetVertexSamplerState(sampler, (ulong)i); case ShaderStage.Fragment:
break; renderCommandEncoder.SetFragmentTexture(textureHandle, binding);
case ShaderStage.Fragment: break;
renderCommandEncoder.SetFragmentSamplerState(sampler, (ulong)i);
break;
}
} }
} }
} }
private static void SetRenderSampler(MTLRenderCommandEncoder renderCommandEncoder, ShaderStage stage, ulong binding, MTLSamplerState sampler)
{
if (sampler != IntPtr.Zero)
{
switch (stage)
{
case ShaderStage.Vertex:
renderCommandEncoder.SetVertexSamplerState(sampler, binding);
break;
case ShaderStage.Fragment:
renderCommandEncoder.SetFragmentSamplerState(sampler, binding);
break;
}
}
}
private static void SetComputeTexture(MTLComputeCommandEncoder computeCommandEncoder, ulong binding, TextureBase texture)
{
if (texture == null)
{
return;
}
var textureHandle = texture.GetHandle();
if (textureHandle != IntPtr.Zero)
{
computeCommandEncoder.SetTexture(textureHandle, binding);
}
}
private static void SetComputeSampler(MTLComputeCommandEncoder computeCommandEncoder, ulong binding, MTLSamplerState sampler)
{
if (sampler != IntPtr.Zero)
{
computeCommandEncoder.SetSamplerState(sampler, binding);
}
}
} }
} }

View file

@ -5,6 +5,8 @@ using Ryujinx.Graphics.Shader.Translation;
using SharpMetal.Metal; using SharpMetal.Metal;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning; using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal namespace Ryujinx.Graphics.Metal

View file

@ -97,9 +97,12 @@ namespace Ryujinx.Graphics.Metal
public ITexture CreateTexture(TextureCreateInfo info) public ITexture CreateTexture(TextureCreateInfo info)
{ {
var texture = new Texture(_device, _pipeline, info); if (info.Target == Target.TextureBuffer)
{
return new TextureBuffer(_device, _pipeline, info);
}
return texture; return new Texture(_device, _pipeline, info);
} }
public ITextureArray CreateTextureArray(int size, bool isBuffer) public ITextureArray CreateTextureArray(int size, bool isBuffer)

View file

@ -69,7 +69,6 @@ namespace Ryujinx.Graphics.Metal
public MTLRenderCommandEncoder GetOrCreateRenderEncoder() public MTLRenderCommandEncoder GetOrCreateRenderEncoder()
{ {
MTLRenderCommandEncoder renderCommandEncoder; MTLRenderCommandEncoder renderCommandEncoder;
if (_currentEncoder == null || _currentEncoderType != EncoderType.Render) if (_currentEncoder == null || _currentEncoderType != EncoderType.Render)
{ {
renderCommandEncoder = BeginRenderPass(); renderCommandEncoder = BeginRenderPass();
@ -79,7 +78,7 @@ namespace Ryujinx.Graphics.Metal
renderCommandEncoder = new MTLRenderCommandEncoder(_currentEncoder.Value); renderCommandEncoder = new MTLRenderCommandEncoder(_currentEncoder.Value);
} }
_encoderStateManager.RebindState(renderCommandEncoder); _encoderStateManager.RebindRenderState(renderCommandEncoder);
return renderCommandEncoder; return renderCommandEncoder;
} }
@ -99,15 +98,19 @@ namespace Ryujinx.Graphics.Metal
public MTLComputeCommandEncoder GetOrCreateComputeEncoder() public MTLComputeCommandEncoder GetOrCreateComputeEncoder()
{ {
if (_currentEncoder != null) MTLComputeCommandEncoder computeCommandEncoder;
if (_currentEncoder == null || _currentEncoderType != EncoderType.Compute)
{ {
if (_currentEncoderType == EncoderType.Compute) computeCommandEncoder = BeginComputePass();
{ }
return new MTLComputeCommandEncoder(_currentEncoder.Value); else
} {
computeCommandEncoder = new MTLComputeCommandEncoder(_currentEncoder.Value);
} }
return BeginComputePass(); _encoderStateManager.RebindComputeState(computeCommandEncoder);
return computeCommandEncoder;
} }
public void EndCurrentPass() public void EndCurrentPass()
@ -164,8 +167,7 @@ namespace Ryujinx.Graphics.Metal
{ {
EndCurrentPass(); EndCurrentPass();
var descriptor = new MTLComputePassDescriptor(); var computeCommandEncoder = _encoderStateManager.CreateComputeCommandEncoder();
var computeCommandEncoder = _commandBuffer.ComputeCommandEncoder(descriptor);
_currentEncoder = computeCommandEncoder; _currentEncoder = computeCommandEncoder;
_currentEncoderType = EncoderType.Compute; _currentEncoderType = EncoderType.Compute;
@ -274,9 +276,13 @@ namespace Ryujinx.Graphics.Metal
(ulong)size); (ulong)size);
} }
public void DispatchCompute(int groupsX, int groupsY, int groupsZ) public void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{ {
Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!"); var computeCommandEncoder = GetOrCreateComputeEncoder();
computeCommandEncoder.DispatchThreadgroups(
new MTLSize{width = (ulong)groupsX, height = (ulong)groupsY, depth = (ulong)groupsZ},
new MTLSize{width = (ulong)groupSizeX, height = (ulong)groupSizeY, depth = (ulong)groupSizeZ});
} }
public void Draw(int vertexCount, int instanceCount, int firstVertex, int firstInstance) public void Draw(int vertexCount, int instanceCount, int firstVertex, int firstInstance)
@ -397,7 +403,10 @@ namespace Ryujinx.Graphics.Metal
public void SetImage(ShaderStage stage, int binding, ITexture texture, Format imageFormat) public void SetImage(ShaderStage stage, int binding, ITexture texture, Format imageFormat)
{ {
Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!"); if (texture is TextureBase tex)
{
_encoderStateManager.UpdateTexture(stage, (ulong)binding, tex);
}
} }
public void SetImageArray(ShaderStage stage, int binding, IImageArray array) public void SetImageArray(ShaderStage stage, int binding, IImageArray array)
@ -491,28 +500,14 @@ namespace Ryujinx.Graphics.Metal
public void SetTextureAndSampler(ShaderStage stage, int binding, ITexture texture, ISampler sampler) public void SetTextureAndSampler(ShaderStage stage, int binding, ITexture texture, ISampler sampler)
{ {
if (texture is Texture tex) if (texture is TextureBase tex)
{ {
if (sampler is Sampler samp) if (sampler is Sampler samp)
{ {
var mtlTexture = tex.MTLTexture;
var mtlSampler = samp.GetSampler(); var mtlSampler = samp.GetSampler();
var index = (ulong)binding; var index = (ulong)binding;
switch (stage) _encoderStateManager.UpdateTextureAndSampler(stage, index, tex, mtlSampler);
{
case ShaderStage.Vertex:
case ShaderStage.Fragment:
_encoderStateManager.UpdateTextureAndSampler(stage, index, mtlTexture, mtlSampler);
break;
case ShaderStage.Compute:
var computeCommandEncoder = GetOrCreateComputeEncoder();
computeCommandEncoder.SetTexture(mtlTexture, index);
computeCommandEncoder.SetSamplerState(mtlSampler, index);
break;
default:
throw new ArgumentOutOfRangeException(nameof(stage), stage, "Unsupported shader stage!");
}
} }
} }
} }

View file

@ -26,7 +26,7 @@ namespace Ryujinx.Graphics.Metal
var shaderLibrary = device.NewLibrary(StringHelper.NSString(shader.Code), new MTLCompileOptions(IntPtr.Zero), ref libraryError); var shaderLibrary = device.NewLibrary(StringHelper.NSString(shader.Code), new MTLCompileOptions(IntPtr.Zero), ref libraryError);
if (libraryError != IntPtr.Zero) if (libraryError != IntPtr.Zero)
{ {
Logger.Warning?.Print(LogClass.Gpu, $"Shader linking failed: \n{StringHelper.String(libraryError.LocalizedDescription)}"); Logger.Warning?.Print(LogClass.Gpu, $"{shader.Stage} shader linking failed: \n{StringHelper.String(libraryError.LocalizedDescription)}");
_status = ProgramLinkStatus.Failure; _status = ProgramLinkStatus.Failure;
return; return;
} }
@ -34,7 +34,7 @@ namespace Ryujinx.Graphics.Metal
switch (shaders[index].Stage) switch (shaders[index].Stage)
{ {
case ShaderStage.Compute: case ShaderStage.Compute:
ComputeFunction = shaderLibrary.NewFunction(StringHelper.NSString("computeMain")); ComputeFunction = shaderLibrary.NewFunction(StringHelper.NSString("kernelMain"));
break; break;
case ShaderStage.Vertex: case ShaderStage.Vertex:
VertexFunction = shaderLibrary.NewFunction(StringHelper.NSString("vertexMain")); VertexFunction = shaderLibrary.NewFunction(StringHelper.NSString("vertexMain"));

View file

@ -10,24 +10,10 @@ using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal namespace Ryujinx.Graphics.Metal
{ {
[SupportedOSPlatform("macos")] [SupportedOSPlatform("macos")]
class Texture : ITexture, IDisposable class Texture : TextureBase, ITexture
{ {
private readonly TextureCreateInfo _info; public Texture(MTLDevice device, Pipeline pipeline, TextureCreateInfo info) : base(device, pipeline, info)
private readonly Pipeline _pipeline;
private readonly MTLDevice _device;
public MTLTexture MTLTexture;
public TextureCreateInfo Info => _info;
public int Width => Info.Width;
public int Height => Info.Height;
public int Depth => Info.Depth;
public Texture(MTLDevice device, Pipeline pipeline, TextureCreateInfo info)
{ {
_device = device;
_pipeline = pipeline;
_info = info;
var descriptor = new MTLTextureDescriptor var descriptor = new MTLTextureDescriptor
{ {
PixelFormat = FormatTable.GetFormat(Info.Format), PixelFormat = FormatTable.GetFormat(Info.Format),
@ -50,15 +36,11 @@ namespace Ryujinx.Graphics.Metal
descriptor.Swizzle = GetSwizzle(info, descriptor.PixelFormat); descriptor.Swizzle = GetSwizzle(info, descriptor.PixelFormat);
MTLTexture = _device.NewTexture(descriptor); _mtlTexture = _device.NewTexture(descriptor);
} }
public Texture(MTLDevice device, Pipeline pipeline, TextureCreateInfo info, MTLTexture sourceTexture, int firstLayer, int firstLevel) public Texture(MTLDevice device, Pipeline pipeline, TextureCreateInfo info, MTLTexture sourceTexture, int firstLayer, int firstLevel) : base(device, pipeline, info)
{ {
_device = device;
_pipeline = pipeline;
_info = info;
var pixelFormat = FormatTable.GetFormat(Info.Format); var pixelFormat = FormatTable.GetFormat(Info.Format);
var textureType = Info.Target.Convert(); var textureType = Info.Target.Convert();
NSRange levels; NSRange levels;
@ -75,7 +57,7 @@ namespace Ryujinx.Graphics.Metal
var swizzle = GetSwizzle(info, pixelFormat); var swizzle = GetSwizzle(info, pixelFormat);
MTLTexture = sourceTexture.NewTextureView(pixelFormat, textureType, levels, slices, swizzle); _mtlTexture = sourceTexture.NewTextureView(pixelFormat, textureType, levels, slices, swizzle);
} }
private MTLTextureSwizzleChannels GetSwizzle(TextureCreateInfo info, MTLPixelFormat pixelFormat) private MTLTextureSwizzleChannels GetSwizzle(TextureCreateInfo info, MTLPixelFormat pixelFormat)
@ -118,14 +100,14 @@ namespace Ryujinx.Graphics.Metal
if (destination is Texture destinationTexture) if (destination is Texture destinationTexture)
{ {
blitCommandEncoder.CopyFromTexture( blitCommandEncoder.CopyFromTexture(
MTLTexture, _mtlTexture,
(ulong)firstLayer, (ulong)firstLayer,
(ulong)firstLevel, (ulong)firstLevel,
destinationTexture.MTLTexture, destinationTexture._mtlTexture,
(ulong)firstLayer, (ulong)firstLayer,
(ulong)firstLevel, (ulong)firstLevel,
MTLTexture.ArrayLength, _mtlTexture.ArrayLength,
MTLTexture.MipmapLevelCount); _mtlTexture.MipmapLevelCount);
} }
} }
@ -136,14 +118,14 @@ namespace Ryujinx.Graphics.Metal
if (destination is Texture destinationTexture) if (destination is Texture destinationTexture)
{ {
blitCommandEncoder.CopyFromTexture( blitCommandEncoder.CopyFromTexture(
MTLTexture, _mtlTexture,
(ulong)srcLayer, (ulong)srcLayer,
(ulong)srcLevel, (ulong)srcLevel,
destinationTexture.MTLTexture, destinationTexture._mtlTexture,
(ulong)dstLayer, (ulong)dstLayer,
(ulong)dstLevel, (ulong)dstLevel,
MTLTexture.ArrayLength, _mtlTexture.ArrayLength,
MTLTexture.MipmapLevelCount); _mtlTexture.MipmapLevelCount);
} }
} }
@ -158,7 +140,7 @@ namespace Ryujinx.Graphics.Metal
ulong bytesPerRow = (ulong)Info.GetMipStride(level); ulong bytesPerRow = (ulong)Info.GetMipStride(level);
ulong bytesPerImage = 0; ulong bytesPerImage = 0;
if (MTLTexture.TextureType == MTLTextureType.Type3D) if (_mtlTexture.TextureType == MTLTextureType.Type3D)
{ {
bytesPerImage = bytesPerRow * (ulong)Info.Height; bytesPerImage = bytesPerRow * (ulong)Info.Height;
} }
@ -167,11 +149,11 @@ namespace Ryujinx.Graphics.Metal
MTLBuffer mtlBuffer = new(Unsafe.As<BufferHandle, IntPtr>(ref handle)); MTLBuffer mtlBuffer = new(Unsafe.As<BufferHandle, IntPtr>(ref handle));
blitCommandEncoder.CopyFromTexture( blitCommandEncoder.CopyFromTexture(
MTLTexture, _mtlTexture,
(ulong)layer, (ulong)layer,
(ulong)level, (ulong)level,
new MTLOrigin(), new MTLOrigin(),
new MTLSize { width = MTLTexture.Width, height = MTLTexture.Height, depth = MTLTexture.Depth }, new MTLSize { width = _mtlTexture.Width, height = _mtlTexture.Height, depth = _mtlTexture.Depth },
mtlBuffer, mtlBuffer,
(ulong)range.Offset, (ulong)range.Offset,
bytesPerRow, bytesPerRow,
@ -180,7 +162,7 @@ namespace Ryujinx.Graphics.Metal
public ITexture CreateView(TextureCreateInfo info, int firstLayer, int firstLevel) public ITexture CreateView(TextureCreateInfo info, int firstLayer, int firstLevel)
{ {
return new Texture(_device, _pipeline, info, MTLTexture, firstLayer, firstLevel); return new Texture(_device, _pipeline, info, _mtlTexture, firstLayer, firstLevel);
} }
public PinnedSpan<byte> GetData() public PinnedSpan<byte> GetData()
@ -195,7 +177,7 @@ namespace Ryujinx.Graphics.Metal
ulong bytesPerRow = (ulong)Info.GetMipStride(level); ulong bytesPerRow = (ulong)Info.GetMipStride(level);
ulong length = bytesPerRow * (ulong)Info.Height; ulong length = bytesPerRow * (ulong)Info.Height;
ulong bytesPerImage = 0; ulong bytesPerImage = 0;
if (MTLTexture.TextureType == MTLTextureType.Type3D) if (_mtlTexture.TextureType == MTLTextureType.Type3D)
{ {
bytesPerImage = length; bytesPerImage = length;
} }
@ -205,11 +187,11 @@ namespace Ryujinx.Graphics.Metal
var mtlBuffer = _device.NewBuffer(length, MTLResourceOptions.ResourceStorageModeShared); var mtlBuffer = _device.NewBuffer(length, MTLResourceOptions.ResourceStorageModeShared);
blitCommandEncoder.CopyFromTexture( blitCommandEncoder.CopyFromTexture(
MTLTexture, _mtlTexture,
(ulong)layer, (ulong)layer,
(ulong)level, (ulong)level,
new MTLOrigin(), new MTLOrigin(),
new MTLSize { width = MTLTexture.Width, height = MTLTexture.Height, depth = MTLTexture.Depth }, new MTLSize { width = _mtlTexture.Width, height = _mtlTexture.Height, depth = _mtlTexture.Depth },
mtlBuffer, mtlBuffer,
0, 0,
bytesPerRow, bytesPerRow,
@ -255,7 +237,7 @@ namespace Ryujinx.Graphics.Metal
(ulong)Info.GetMipStride(level), (ulong)Info.GetMipStride(level),
(ulong)mipSize, (ulong)mipSize,
new MTLSize { width = (ulong)width, height = (ulong)height, depth = is3D ? (ulong)depth : 1 }, new MTLSize { width = (ulong)width, height = (ulong)height, depth = is3D ? (ulong)depth : 1 },
MTLTexture, _mtlTexture,
0, 0,
(ulong)level, (ulong)level,
new MTLOrigin() new MTLOrigin()
@ -282,7 +264,7 @@ namespace Ryujinx.Graphics.Metal
ulong bytesPerRow = (ulong)Info.GetMipStride(level); ulong bytesPerRow = (ulong)Info.GetMipStride(level);
ulong bytesPerImage = 0; ulong bytesPerImage = 0;
if (MTLTexture.TextureType == MTLTextureType.Type3D) if (_mtlTexture.TextureType == MTLTextureType.Type3D)
{ {
bytesPerImage = bytesPerRow * (ulong)Info.Height; bytesPerImage = bytesPerRow * (ulong)Info.Height;
} }
@ -299,8 +281,8 @@ namespace Ryujinx.Graphics.Metal
0, 0,
bytesPerRow, bytesPerRow,
bytesPerImage, bytesPerImage,
new MTLSize { width = MTLTexture.Width, height = MTLTexture.Height, depth = MTLTexture.Depth }, new MTLSize { width = _mtlTexture.Width, height = _mtlTexture.Height, depth = _mtlTexture.Depth },
MTLTexture, _mtlTexture,
(ulong)layer, (ulong)layer,
(ulong)level, (ulong)level,
new MTLOrigin() new MTLOrigin()
@ -317,7 +299,7 @@ namespace Ryujinx.Graphics.Metal
ulong bytesPerRow = (ulong)Info.GetMipStride(level); ulong bytesPerRow = (ulong)Info.GetMipStride(level);
ulong bytesPerImage = 0; ulong bytesPerImage = 0;
if (MTLTexture.TextureType == MTLTextureType.Type3D) if (_mtlTexture.TextureType == MTLTextureType.Type3D)
{ {
bytesPerImage = bytesPerRow * (ulong)Info.Height; bytesPerImage = bytesPerRow * (ulong)Info.Height;
} }
@ -335,7 +317,7 @@ namespace Ryujinx.Graphics.Metal
bytesPerRow, bytesPerRow,
bytesPerImage, bytesPerImage,
new MTLSize { width = (ulong)region.Width, height = (ulong)region.Height, depth = 1 }, new MTLSize { width = (ulong)region.Width, height = (ulong)region.Height, depth = 1 },
MTLTexture, _mtlTexture,
(ulong)layer, (ulong)layer,
(ulong)level, (ulong)level,
new MTLOrigin { x = (ulong)region.X, y = (ulong)region.Y } new MTLOrigin { x = (ulong)region.X, y = (ulong)region.Y }
@ -348,18 +330,7 @@ namespace Ryujinx.Graphics.Metal
public void SetStorage(BufferRange buffer) public void SetStorage(BufferRange buffer)
{ {
Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!"); throw new NotImplementedException();
}
public void Release()
{
Dispose();
}
public void Dispose()
{
MTLTexture.SetPurgeableState(MTLPurgeableState.Volatile);
MTLTexture.Dispose();
} }
} }
} }

View file

@ -0,0 +1,59 @@
using Ryujinx.Common.Logging;
using Ryujinx.Graphics.GAL;
using SharpMetal.Foundation;
using SharpMetal.Metal;
using System;
using System.Buffers;
using System.Runtime.CompilerServices;
using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal
{
[SupportedOSPlatform("macos")]
abstract class TextureBase : IDisposable
{
private bool _disposed;
protected readonly TextureCreateInfo _info;
protected readonly Pipeline _pipeline;
protected readonly MTLDevice _device;
protected MTLTexture _mtlTexture;
public TextureCreateInfo Info => _info;
public int Width => Info.Width;
public int Height => Info.Height;
public int Depth => Info.Depth;
public TextureBase(MTLDevice device, Pipeline pipeline, TextureCreateInfo info)
{
_device = device;
_pipeline = pipeline;
_info = info;
}
public MTLTexture GetHandle()
{
if (_disposed)
{
return new MTLTexture(IntPtr.Zero);
}
return _mtlTexture;
}
public void Release()
{
Dispose();
}
public void Dispose()
{
if (_mtlTexture != IntPtr.Zero)
{
_mtlTexture.Dispose();
}
_disposed = true;
}
}
}

View file

@ -0,0 +1,112 @@
using Ryujinx.Common.Logging;
using Ryujinx.Graphics.GAL;
using SharpMetal.Foundation;
using SharpMetal.Metal;
using System;
using System.Buffers;
using System.Runtime.CompilerServices;
using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal
{
[SupportedOSPlatform("macos")]
class TextureBuffer : Texture, ITexture
{
private MTLBuffer? _bufferHandle;
private int _offset;
private int _size;
public TextureBuffer(MTLDevice device, Pipeline pipeline, TextureCreateInfo info) : base(device, pipeline, info) { }
public void CreateView()
{
var descriptor = new MTLTextureDescriptor
{
PixelFormat = FormatTable.GetFormat(Info.Format),
Usage = MTLTextureUsage.ShaderRead | MTLTextureUsage.ShaderWrite,
StorageMode = MTLStorageMode.Shared,
TextureType = Info.Target.Convert(),
Width = (ulong)Info.Width,
Height = (ulong)Info.Height
};
_mtlTexture = _bufferHandle.Value.NewTexture(descriptor, (ulong)_offset, (ulong)_size);
}
public void CopyTo(ITexture destination, int firstLayer, int firstLevel)
{
throw new NotSupportedException();
}
public void CopyTo(ITexture destination, int srcLayer, int dstLayer, int srcLevel, int dstLevel)
{
throw new NotSupportedException();
}
public void CopyTo(ITexture destination, Extents2D srcRegion, Extents2D dstRegion, bool linearFilter)
{
throw new NotSupportedException();
}
public ITexture CreateView(TextureCreateInfo info, int firstLayer, int firstLevel)
{
throw new NotSupportedException();
}
// TODO: Implement this method
public PinnedSpan<byte> GetData()
{
throw new NotImplementedException();
}
public PinnedSpan<byte> GetData(int layer, int level)
{
return GetData();
}
public void CopyTo(BufferRange range, int layer, int level, int stride)
{
throw new NotImplementedException();
}
public void SetData(IMemoryOwner<byte> data)
{
// TODO
//_gd.SetBufferData(_bufferHandle, _offset, data.Memory.Span);
data.Dispose();
}
public void SetData(IMemoryOwner<byte> data, int layer, int level)
{
throw new NotSupportedException();
}
public void SetData(IMemoryOwner<byte> data, int layer, int level, Rectangle<int> region)
{
throw new NotSupportedException();
}
public void SetStorage(BufferRange buffer)
{
if (buffer.Handle != BufferHandle.Null)
{
var handle = buffer.Handle;
MTLBuffer bufferHandle = new(Unsafe.As<BufferHandle, IntPtr>(ref handle));
if (_bufferHandle == bufferHandle &&
_offset == buffer.Offset &&
_size == buffer.Size)
{
return;
}
_bufferHandle = bufferHandle;
_offset = buffer.Offset;
_size = buffer.Size;
Release();
CreateView();
}
}
}
}

View file

@ -205,7 +205,7 @@ namespace Ryujinx.Graphics.OpenGL
Buffer.Copy(source, destination, srcOffset, dstOffset, size); Buffer.Copy(source, destination, srcOffset, dstOffset, size);
} }
public void DispatchCompute(int groupsX, int groupsY, int groupsZ) public void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{ {
if (!_program.IsLinked) if (!_program.IsLinked)
{ {

View file

@ -8,6 +8,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
public const string Tab = " "; public const string Tab = " ";
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
public const int additionalArgCount = 1;
public StructuredFunction CurrentFunction { get; set; } public StructuredFunction CurrentFunction { get; set; }
public StructuredProgramInfo Info { get; } public StructuredProgramInfo Info { get; }

View file

@ -54,6 +54,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
DeclareInputAttributes(context, info.IoDefinitions.Where(x => IsUserDefined(x, StorageKind.Input))); DeclareInputAttributes(context, info.IoDefinitions.Where(x => IsUserDefined(x, StorageKind.Input)));
context.AppendLine(); context.AppendLine();
DeclareOutputAttributes(context, info.IoDefinitions.Where(x => x.StorageKind == StorageKind.Output)); DeclareOutputAttributes(context, info.IoDefinitions.Where(x => x.StorageKind == StorageKind.Output));
context.AppendLine();
DeclareBufferStructures(context, context.Properties.ConstantBuffers.Values);
DeclareBufferStructures(context, context.Properties.StorageBuffers.Values);
} }
static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind) static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind)
@ -111,8 +114,41 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
foreach (var memory in memories) foreach (var memory in memories)
{ {
string arraySize = "";
if ((memory.Type & AggregateType.Array) != 0)
{
arraySize = $"[{memory.ArrayLength}]";
}
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array); var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {memory.Name}[{memory.ArrayLength}];"); context.AppendLine($"{typeName} {memory.Name}{arraySize};");
}
}
private static void DeclareBufferStructures(CodeGenContext context, IEnumerable<BufferDefinition> buffers)
{
foreach (BufferDefinition buffer in buffers)
{
context.AppendLine($"struct Struct_{buffer.Name}");
context.EnterScope();
foreach (StructureField field in buffer.Type.Fields)
{
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0)
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];");
}
else
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {field.Name};");
}
}
context.LeaveScope(";");
context.AppendLine();
} }
} }
@ -124,7 +160,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
} }
else else
{ {
if (inputs.Any()) if (inputs.Any() || context.Definitions.Stage == ShaderStage.Fragment)
{ {
string prefix = ""; string prefix = "";
@ -136,9 +172,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
case ShaderStage.Fragment: case ShaderStage.Fragment:
context.AppendLine($"struct FragmentIn"); context.AppendLine($"struct FragmentIn");
break; break;
case ShaderStage.Compute:
context.AppendLine($"struct KernelIn");
break;
} }
context.EnterScope(); context.EnterScope();

View file

@ -134,7 +134,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
case Instruction.Load: case Instruction.Load:
return Load(context, operation); return Load(context, operation);
case Instruction.Lod: case Instruction.Lod:
return "|| LOD ||"; return Lod(context, operation);
case Instruction.MemoryBarrier: case Instruction.MemoryBarrier:
return "|| MEMORY BARRIER ||"; return "|| MEMORY BARRIER ||";
case Instruction.Store: case Instruction.Store:

View file

@ -12,11 +12,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
var functon = context.GetFunction(funcId.Value); var functon = context.GetFunction(funcId.Value);
string[] args = new string[operation.SourcesCount - 1]; int argCount = operation.SourcesCount - 1;
string[] args = new string[argCount + CodeGenContext.additionalArgCount];
for (int i = 0; i < args.Length; i++) // Additional arguments
args[0] = "support_buffer";
int argIndex = CodeGenContext.additionalArgCount;
for (int i = 0; i < argCount; i++)
{ {
args[i] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i)); args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
} }
return $"{functon.Name}({string.Join(", ", args)})"; return $"{functon.Name}({string.Join(", ", args)})";

View file

@ -24,6 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
inputsCount--; inputsCount--;
} }
string fieldName = "";
switch (storageKind) switch (storageKind)
{ {
case StorageKind.ConstantBuffer: case StorageKind.ConstantBuffer:
@ -45,6 +46,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
StructureField field = buffer.Type.Fields[fieldIndex.Value]; StructureField field = buffer.Type.Fields[fieldIndex.Value];
varName = buffer.Name; varName = buffer.Name;
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
{
// Unsized array, the buffer is indexed instead of the field
fieldName = "." + field.Name;
}
else
{
varName += "->" + field.Name;
}
varType = field.Type; varType = field.Type;
break; break;
@ -126,6 +136,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
varName += $"[{GetSourceExpr(context, src, AggregateType.S32)}]"; varName += $"[{GetSourceExpr(context, src, AggregateType.S32)}]";
} }
} }
varName += fieldName;
if (isStore) if (isStore)
{ {
@ -141,6 +152,37 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
return GenerateLoadOrStore(context, operation, isStore: false); return GenerateLoadOrStore(context, operation, isStore: false);
} }
// TODO: check this
public static string Lod(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
int coordsCount = texOp.Type.GetDimensions();
int coordsIndex = 0;
string samplerName = GetSamplerName(context.Properties, texOp);
string coordsExpr;
if (coordsCount > 1)
{
string[] elems = new string[coordsCount];
for (int index = 0; index < coordsCount; index++)
{
elems[index] = GetSourceExpr(context, texOp.GetSource(coordsIndex + index), AggregateType.FP32);
}
coordsExpr = "float" + coordsCount + "(" + string.Join(", ", elems) + ")";
}
else
{
coordsExpr = GetSourceExpr(context, texOp.GetSource(coordsIndex), AggregateType.FP32);
}
return $"tex_{samplerName}.calculate_unclamped_lod(samp_{samplerName}, {coordsExpr}){GetMaskMultiDest(texOp.Index)}";
}
public static string Store(CodeGenContext context, AstOperation operation) public static string Store(CodeGenContext context, AstOperation operation)
{ {
return GenerateLoadOrStore(context, operation, isStore: true); return GenerateLoadOrStore(context, operation, isStore: true);
@ -176,11 +218,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
} }
else else
{ {
texCall += "sample";
if (isGather) if (isGather)
{ {
texCall += "_gather"; texCall += "gather";
}
else
{
texCall += "sample";
} }
if (isShadow) if (isShadow)
@ -188,22 +232,31 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
texCall += "_compare"; texCall += "_compare";
} }
texCall += $"(samp_{samplerName}"; texCall += $"(samp_{samplerName}, ";
} }
int coordsCount = texOp.Type.GetDimensions(); int coordsCount = texOp.Type.GetDimensions();
int pCount = coordsCount; int pCount = coordsCount;
bool appended = false;
void Append(string str) void Append(string str)
{ {
texCall += ", " + str; if (appended)
{
texCall += ", ";
}
else {
appended = true;
}
texCall += str;
} }
AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32; AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32;
string AssemblePVector(int count) string AssemblePVector(int count)
{ {
string coords;
if (count > 1) if (count > 1)
{ {
string[] elems = new string[count]; string[] elems = new string[count];
@ -213,14 +266,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
elems[index] = Src(coordType); elems[index] = Src(coordType);
} }
string prefix = intCoords ? "int" : "float"; coords = string.Join(", ", elems);
return prefix + count + "(" + string.Join(", ", elems) + ")";
} }
else else
{ {
return Src(coordType); coords = Src(coordType);
} }
string prefix = intCoords ? "uint" : "float";
return prefix + (count > 1 ? count : "") + "(" + coords + ")";
} }
Append(AssemblePVector(pCount)); Append(AssemblePVector(pCount));
@ -254,6 +309,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
private static string GetMaskMultiDest(int mask) private static string GetMaskMultiDest(int mask)
{ {
if (mask == 0x0)
{
return "";
}
string swizzle = "."; string swizzle = ".";
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++)

View file

@ -35,7 +35,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32), IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
IoVariable.VertexId => ("vertex_id", AggregateType.S32), IoVariable.VertexId => ("vertex_id", AggregateType.S32),
// gl_VertexIndex does not have a direct equivalent in MSL // gl_VertexIndex does not have a direct equivalent in MSL
IoVariable.VertexIndex => ("vertex_index", AggregateType.U32), IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),
IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32), IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32),
IoVariable.FragmentCoord => ("in.position", AggregateType.Vector4 | AggregateType.FP32), IoVariable.FragmentCoord => ("in.position", AggregateType.Vector4 | AggregateType.FP32),
_ => (null, AggregateType.Invalid), _ => (null, AggregateType.Invalid),

View file

@ -48,6 +48,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
PrintBlock(context, function.MainBlock, isMainFunc); PrintBlock(context, function.MainBlock, isMainFunc);
// In case the shader hasn't returned, return
if (isMainFunc && stage != ShaderStage.Compute)
{
context.AppendLine("return out;");
}
context.LeaveScope(); context.LeaveScope();
} }
@ -57,11 +63,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
ShaderStage stage, ShaderStage stage,
bool isMainFunc = false) bool isMainFunc = false)
{ {
string[] args = new string[function.InArguments.Length + function.OutArguments.Length]; int additionalArgCount = isMainFunc ? 0 : CodeGenContext.additionalArgCount;
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
// All non-main functions need to be able to access the support_buffer as well
if (!isMainFunc)
{
args[0] = "constant Struct_support_buffer* support_buffer";
}
int argIndex = additionalArgCount;
for (int i = 0; i < function.InArguments.Length; i++) for (int i = 0; i < function.InArguments.Length; i++)
{ {
args[i] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}"; args[argIndex++] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
} }
for (int i = 0; i < function.OutArguments.Length; i++) for (int i = 0; i < function.OutArguments.Length; i++)
@ -69,7 +84,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
int j = i + function.InArguments.Length; int j = i + function.InArguments.Length;
// Likely need to be made into pointers // Likely need to be made into pointers
args[j] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}"; args[argIndex++] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
} }
string funcKeyword = "inline"; string funcKeyword = "inline";
@ -97,20 +112,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
returnType = "void"; returnType = "void";
} }
if (context.AttributeUsage.UsedInputAttributes != 0) if (stage == ShaderStage.Vertex)
{ {
if (stage == ShaderStage.Vertex) if (context.AttributeUsage.UsedInputAttributes != 0)
{ {
args = args.Prepend("VertexIn in [[stage_in]]").ToArray(); args = args.Prepend("VertexIn in [[stage_in]]").ToArray();
} }
else if (stage == ShaderStage.Fragment) }
{ else if (stage == ShaderStage.Fragment)
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray(); {
} args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
else if (stage == ShaderStage.Compute)
{
args = args.Prepend("KernelIn in [[stage_in]]").ToArray();
}
} }
// TODO: add these only if they are used // TODO: add these only if they are used
@ -119,18 +130,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
args = args.Append("uint vertex_id [[vertex_id]]").ToArray(); args = args.Append("uint vertex_id [[vertex_id]]").ToArray();
args = args.Append("uint instance_id [[instance_id]]").ToArray(); args = args.Append("uint instance_id [[instance_id]]").ToArray();
} }
else if (stage == ShaderStage.Compute)
{
args = args.Append("uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_grid [[thread_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]").ToArray();
}
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values) foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
{ {
var varType = constantBuffer.Type.Fields[0].Type & ~AggregateType.Array; args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
args = args.Append($"constant {Declarations.GetVarTypeName(context, varType)} *{constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
} }
foreach (var storageBuffers in context.Properties.StorageBuffers.Values) foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
{ {
var varType = storageBuffers.Type.Fields[0].Type & ~AggregateType.Array;
// Offset the binding by 15 to avoid clashing with the constant buffers // Offset the binding by 15 to avoid clashing with the constant buffers
args = args.Append($"device {Declarations.GetVarTypeName(context, varType)} *{storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray(); args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
} }
foreach (var texture in context.Properties.Textures.Values) foreach (var texture in context.Properties.Textures.Values)

View file

@ -861,7 +861,7 @@ namespace Ryujinx.Graphics.Vulkan
_pipeline.SetStorageBuffers(1, sbRanges); _pipeline.SetStorageBuffers(1, sbRanges);
_pipeline.SetProgram(_programStrideChange); _pipeline.SetProgram(_programStrideChange);
_pipeline.DispatchCompute(1 + elems / ConvertElementsPerWorkgroup, 1, 1); _pipeline.DispatchCompute(1 + elems / ConvertElementsPerWorkgroup, 1, 1, 0, 0, 0);
_pipeline.Finish(gd, cbs); _pipeline.Finish(gd, cbs);
} }
@ -1044,7 +1044,7 @@ namespace Ryujinx.Graphics.Vulkan
int dispatchX = (Math.Min(srcView.Info.Width, dstView.Info.Width) + 31) / 32; int dispatchX = (Math.Min(srcView.Info.Width, dstView.Info.Width) + 31) / 32;
int dispatchY = (Math.Min(srcView.Info.Height, dstView.Info.Height) + 31) / 32; int dispatchY = (Math.Min(srcView.Info.Height, dstView.Info.Height) + 31) / 32;
_pipeline.DispatchCompute(dispatchX, dispatchY, 1); _pipeline.DispatchCompute(dispatchX, dispatchY, 1, 0, 0, 0);
if (srcView != src) if (srcView != src)
{ {
@ -1170,7 +1170,7 @@ namespace Ryujinx.Graphics.Vulkan
_pipeline.SetTextureAndSamplerIdentitySwizzle(ShaderStage.Compute, 0, srcView, null); _pipeline.SetTextureAndSamplerIdentitySwizzle(ShaderStage.Compute, 0, srcView, null);
_pipeline.SetImage(ShaderStage.Compute, 0, dstView.GetView(format)); _pipeline.SetImage(ShaderStage.Compute, 0, dstView.GetView(format));
_pipeline.DispatchCompute(dispatchX, dispatchY, 1); _pipeline.DispatchCompute(dispatchX, dispatchY, 1, 0, 0, 0);
if (srcView != src) if (srcView != src)
{ {
@ -1582,7 +1582,7 @@ namespace Ryujinx.Graphics.Vulkan
_pipeline.SetStorageBuffers(stackalloc[] { new BufferAssignment(3, patternScoped.Range) }); _pipeline.SetStorageBuffers(stackalloc[] { new BufferAssignment(3, patternScoped.Range) });
_pipeline.SetProgram(_programConvertIndirectData); _pipeline.SetProgram(_programConvertIndirectData);
_pipeline.DispatchCompute(1, 1, 1); _pipeline.DispatchCompute(1, 1, 1, 0, 0, 0);
BufferHolder.InsertBufferBarrier( BufferHolder.InsertBufferBarrier(
gd, gd,
@ -1684,7 +1684,7 @@ namespace Ryujinx.Graphics.Vulkan
_pipeline.SetStorageBuffers(1, sbRanges); _pipeline.SetStorageBuffers(1, sbRanges);
_pipeline.SetProgram(_programConvertD32S8ToD24S8); _pipeline.SetProgram(_programConvertD32S8ToD24S8);
_pipeline.DispatchCompute(1 + inSize / ConvertElementsPerWorkgroup, 1, 1); _pipeline.DispatchCompute(1 + inSize / ConvertElementsPerWorkgroup, 1, 1, 0, 0, 0);
_pipeline.Finish(gd, cbs); _pipeline.Finish(gd, cbs);

View file

@ -295,7 +295,7 @@ namespace Ryujinx.Graphics.Vulkan
} }
} }
public void DispatchCompute(int groupsX, int groupsY, int groupsZ) public void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{ {
if (!_program.IsLinked) if (!_program.IsLinked)
{ {