From 6e3aaa6360c81c38129b41ae09c77d973ff57685 Mon Sep 17 00:00:00 2001 From: Isaac Marovitz <42140194+IsaacMarovitz@users.noreply.github.com> Date: Mon, 5 Aug 2024 23:31:24 +0100 Subject: [PATCH] Metal: Argument Buffer Pre-Pass (#38) * Init * Fix missing flags * Cleanup --- .../CommandBufferEncoder.cs | 6 +- .../EncoderResources.cs | 63 +++++++ src/Ryujinx.Graphics.Metal/EncoderState.cs | 3 + .../EncoderStateManager.cs | 154 +++++++++++------- src/Ryujinx.Graphics.Metal/Pipeline.cs | 22 +++ 5 files changed, 187 insertions(+), 61 deletions(-) create mode 100644 src/Ryujinx.Graphics.Metal/EncoderResources.cs diff --git a/src/Ryujinx.Graphics.Metal/CommandBufferEncoder.cs b/src/Ryujinx.Graphics.Metal/CommandBufferEncoder.cs index 26cb4f5c79..ec41500304 100644 --- a/src/Ryujinx.Graphics.Metal/CommandBufferEncoder.cs +++ b/src/Ryujinx.Graphics.Metal/CommandBufferEncoder.cs @@ -18,11 +18,11 @@ class CommandBufferEncoder { public EncoderType CurrentEncoderType { get; private set; } = EncoderType.None; - public MTLBlitCommandEncoder BlitEncoder => new MTLBlitCommandEncoder(CurrentEncoder.Value); + public MTLBlitCommandEncoder BlitEncoder => new(CurrentEncoder.Value); - public MTLComputeCommandEncoder ComputeEncoder => new MTLComputeCommandEncoder(CurrentEncoder.Value); + public MTLComputeCommandEncoder ComputeEncoder => new(CurrentEncoder.Value); - public MTLRenderCommandEncoder RenderEncoder => new MTLRenderCommandEncoder(CurrentEncoder.Value); + public MTLRenderCommandEncoder RenderEncoder => new(CurrentEncoder.Value); internal MTLCommandEncoder? CurrentEncoder { get; private set; } diff --git a/src/Ryujinx.Graphics.Metal/EncoderResources.cs b/src/Ryujinx.Graphics.Metal/EncoderResources.cs new file mode 100644 index 0000000000..4fbb9b2821 --- /dev/null +++ b/src/Ryujinx.Graphics.Metal/EncoderResources.cs @@ -0,0 +1,63 @@ +using SharpMetal.Metal; +using System.Collections.Generic; + +namespace Ryujinx.Graphics.Metal +{ + public struct RenderEncoderResources + { + public List Resources = new(); + public List VertexBuffers = new(); + public List FragmentBuffers = new(); + + public RenderEncoderResources() { } + + public void Clear() + { + Resources.Clear(); + VertexBuffers.Clear(); + FragmentBuffers.Clear(); + } + } + + public struct ComputeEncoderResources + { + public List Resources = new(); + public List Buffers = new(); + + public ComputeEncoderResources() { } + + public void Clear() + { + Resources.Clear(); + Buffers.Clear(); + } + } + + public struct BufferResource + { + public MTLBuffer Buffer; + public ulong Offset; + public ulong Binding; + + public BufferResource(MTLBuffer buffer, ulong offset, ulong binding) + { + Buffer = buffer; + Offset = offset; + Binding = binding; + } + } + + public struct Resource + { + public MTLResource MtlResource; + public MTLResourceUsage ResourceUsage; + public MTLRenderStages Stages; + + public Resource(MTLResource resource, MTLResourceUsage resourceUsage, MTLRenderStages stages) + { + MtlResource = resource; + ResourceUsage = resourceUsage; + Stages = stages; + } + } +} diff --git a/src/Ryujinx.Graphics.Metal/EncoderState.cs b/src/Ryujinx.Graphics.Metal/EncoderState.cs index a99a6f35a6..d5dd5123b7 100644 --- a/src/Ryujinx.Graphics.Metal/EncoderState.cs +++ b/src/Ryujinx.Graphics.Metal/EncoderState.cs @@ -152,6 +152,9 @@ namespace Ryujinx.Graphics.Metal // Only to be used for present public bool ClearLoadAction = false; + public RenderEncoderResources RenderEncoderResources = new(); + public ComputeEncoderResources ComputeEncoderResources = new(); + public EncoderState() { Pipeline.Initialize(); diff --git a/src/Ryujinx.Graphics.Metal/EncoderStateManager.cs b/src/Ryujinx.Graphics.Metal/EncoderStateManager.cs index 88731a5042..e50438a10a 100644 --- a/src/Ryujinx.Graphics.Metal/EncoderStateManager.cs +++ b/src/Ryujinx.Graphics.Metal/EncoderStateManager.cs @@ -62,6 +62,16 @@ namespace Ryujinx.Graphics.Metal _currentState.Dirty |= flags; } + public void SignalRenderDirty() + { + SignalDirty(DirtyFlags.RenderAll); + } + + public void SignalComputeDirty() + { + SignalDirty(DirtyFlags.ComputeAll); + } + public EncoderState SwapState(EncoderState state, DirtyFlags flags = DirtyFlags.All) { _currentState = state ?? _mainState; @@ -110,7 +120,7 @@ namespace Ryujinx.Graphics.Metal public readonly MTLRenderCommandEncoder CreateRenderCommandEncoder() { // Initialise Pass & State - var renderPassDescriptor = new MTLRenderPassDescriptor(); + using var renderPassDescriptor = new MTLRenderPassDescriptor(); for (int i = 0; i < Constants.MaxColorAttachments; i++) { @@ -165,12 +175,6 @@ namespace Ryujinx.Graphics.Metal // Initialise Encoder var renderCommandEncoder = _pipeline.CommandBuffer.RenderCommandEncoder(renderPassDescriptor); - // Mark all state as dirty to ensure it is set on the encoder - SignalDirty(DirtyFlags.RenderAll); - - // Cleanup - renderPassDescriptor.Dispose(); - return renderCommandEncoder; } @@ -179,18 +183,69 @@ namespace Ryujinx.Graphics.Metal using var descriptor = new MTLComputePassDescriptor(); var computeCommandEncoder = _pipeline.CommandBuffer.ComputeCommandEncoder(descriptor); - // Mark all state as dirty to ensure it is set on the encoder - SignalDirty(DirtyFlags.ComputeAll); - return computeCommandEncoder; } + public void RenderResourcesPrepass() + { + _currentState.RenderEncoderResources.Clear(); + + if ((_currentState.Dirty & DirtyFlags.RenderPipeline) != 0) + { + SetVertexBuffers(_currentState.VertexBuffers, ref _currentState.RenderEncoderResources); + } + + if ((_currentState.Dirty & DirtyFlags.Uniforms) != 0) + { + UpdateAndBind(_currentState.RenderProgram, Constants.ConstantBuffersSetIndex, ref _currentState.RenderEncoderResources); + } + + if ((_currentState.Dirty & DirtyFlags.Storages) != 0) + { + UpdateAndBind(_currentState.RenderProgram, Constants.StorageBuffersSetIndex, ref _currentState.RenderEncoderResources); + } + + if ((_currentState.Dirty & DirtyFlags.Textures) != 0) + { + UpdateAndBind(_currentState.RenderProgram, Constants.TexturesSetIndex, ref _currentState.RenderEncoderResources); + } + + if ((_currentState.Dirty & DirtyFlags.Images) != 0) + { + UpdateAndBind(_currentState.RenderProgram, Constants.ImagesSetIndex, ref _currentState.RenderEncoderResources); + } + } + + public void ComputeResourcesPrepass() + { + _currentState.ComputeEncoderResources.Clear(); + + if ((_currentState.Dirty & DirtyFlags.Uniforms) != 0) + { + UpdateAndBind(_currentState.ComputeProgram, Constants.ConstantBuffersSetIndex, ref _currentState.ComputeEncoderResources); + } + + if ((_currentState.Dirty & DirtyFlags.Storages) != 0) + { + UpdateAndBind(_currentState.ComputeProgram, Constants.StorageBuffersSetIndex, ref _currentState.ComputeEncoderResources); + } + + if ((_currentState.Dirty & DirtyFlags.Textures) != 0) + { + UpdateAndBind(_currentState.ComputeProgram, Constants.TexturesSetIndex, ref _currentState.ComputeEncoderResources); + } + + if ((_currentState.Dirty & DirtyFlags.Images) != 0) + { + UpdateAndBind(_currentState.ComputeProgram, Constants.ImagesSetIndex, ref _currentState.ComputeEncoderResources); + } + } + public void RebindRenderState(MTLRenderCommandEncoder renderCommandEncoder) { if ((_currentState.Dirty & DirtyFlags.RenderPipeline) != 0) { SetRenderPipelineState(renderCommandEncoder); - SetVertexBuffers(renderCommandEncoder, _currentState.VertexBuffers); } if ((_currentState.Dirty & DirtyFlags.DepthStencil) != 0) @@ -233,24 +288,19 @@ namespace Ryujinx.Graphics.Metal SetScissors(renderCommandEncoder); } - if ((_currentState.Dirty & DirtyFlags.Uniforms) != 0) + foreach (var resource in _currentState.RenderEncoderResources.Resources) { - UpdateAndBind(renderCommandEncoder, _currentState.RenderProgram, Constants.ConstantBuffersSetIndex); + renderCommandEncoder.UseResource(resource.MtlResource, resource.ResourceUsage, resource.Stages); } - if ((_currentState.Dirty & DirtyFlags.Storages) != 0) + foreach (var buffer in _currentState.RenderEncoderResources.VertexBuffers) { - UpdateAndBind(renderCommandEncoder, _currentState.RenderProgram, Constants.StorageBuffersSetIndex); + renderCommandEncoder.SetVertexBuffer(buffer.Buffer, buffer.Offset, buffer.Binding); } - if ((_currentState.Dirty & DirtyFlags.Textures) != 0) + foreach (var buffer in _currentState.RenderEncoderResources.FragmentBuffers) { - UpdateAndBind(renderCommandEncoder, _currentState.RenderProgram, Constants.TexturesSetIndex); - } - - if ((_currentState.Dirty & DirtyFlags.Images) != 0) - { - UpdateAndBind(renderCommandEncoder, _currentState.RenderProgram, Constants.ImagesSetIndex); + renderCommandEncoder.SetFragmentBuffer(buffer.Buffer, buffer.Offset, buffer.Binding); } _currentState.Dirty &= ~DirtyFlags.RenderAll; @@ -263,24 +313,14 @@ namespace Ryujinx.Graphics.Metal SetComputePipelineState(computeCommandEncoder); } - if ((_currentState.Dirty & DirtyFlags.Uniforms) != 0) + foreach (var resource in _currentState.ComputeEncoderResources.Resources) { - UpdateAndBind(computeCommandEncoder, _currentState.ComputeProgram, Constants.ConstantBuffersSetIndex); + computeCommandEncoder.UseResource(resource.MtlResource, resource.ResourceUsage); } - if ((_currentState.Dirty & DirtyFlags.Storages) != 0) + foreach (var buffer in _currentState.ComputeEncoderResources.Buffers) { - UpdateAndBind(computeCommandEncoder, _currentState.ComputeProgram, Constants.StorageBuffersSetIndex); - } - - if ((_currentState.Dirty & DirtyFlags.Textures) != 0) - { - UpdateAndBind(computeCommandEncoder, _currentState.ComputeProgram, Constants.TexturesSetIndex); - } - - if ((_currentState.Dirty & DirtyFlags.Images) != 0) - { - UpdateAndBind(computeCommandEncoder, _currentState.ComputeProgram, Constants.ImagesSetIndex); + computeCommandEncoder.SetBuffer(buffer.Buffer, buffer.Offset, buffer.Binding); } _currentState.Dirty &= ~DirtyFlags.ComputeAll; @@ -1013,7 +1053,7 @@ namespace Ryujinx.Graphics.Metal pipeline.VertexBindingDescriptionsCount = Constants.ZeroBufferIndex + 1; // TODO: move this out? } - private readonly void SetVertexBuffers(MTLRenderCommandEncoder renderCommandEncoder, VertexBufferState[] bufferStates) + private readonly void SetVertexBuffers(VertexBufferState[] bufferStates, ref readonly RenderEncoderResources resources) { for (int i = 0; i < bufferStates.Length; i++) { @@ -1021,7 +1061,7 @@ namespace Ryujinx.Graphics.Metal if (mtlBuffer.NativePtr != IntPtr.Zero) { - renderCommandEncoder.SetVertexBuffer(mtlBuffer, (ulong)offset, (ulong)i); + resources.VertexBuffers.Add(new BufferResource(mtlBuffer, (ulong)offset, (ulong)i)); } } @@ -1035,10 +1075,10 @@ namespace Ryujinx.Graphics.Metal } var zeroMtlBuffer = autoZeroBuffer.Get(_pipeline.Cbs).Value; - renderCommandEncoder.SetVertexBuffer(zeroMtlBuffer, 0, Constants.ZeroBufferIndex); + resources.VertexBuffers.Add(new BufferResource(zeroMtlBuffer, 0, Constants.ZeroBufferIndex)); } - private readonly void UpdateAndBind(MTLRenderCommandEncoder renderCommandEncoder, Program program, uint setIndex) + private readonly void UpdateAndBind(Program program, uint setIndex, ref readonly RenderEncoderResources resources) { var bindingSegments = program.BindingSegments[setIndex]; @@ -1120,7 +1160,7 @@ namespace Ryujinx.Graphics.Metal renderStages |= MTLRenderStages.RenderStageFragment; } - renderCommandEncoder.UseResource(new MTLResource(mtlBuffer.NativePtr), MTLResourceUsage.Read, renderStages); + resources.Resources.Add(new Resource(new MTLResource(mtlBuffer.NativePtr), MTLResourceUsage.Read, renderStages)); } break; case Constants.StorageBuffersSetIndex: @@ -1170,7 +1210,7 @@ namespace Ryujinx.Graphics.Metal renderStages |= MTLRenderStages.RenderStageFragment; } - renderCommandEncoder.UseResource(new MTLResource(mtlBuffer.NativePtr), MTLResourceUsage.Read, renderStages); + resources.Resources.Add(new Resource(new MTLResource(mtlBuffer.NativePtr), MTLResourceUsage.Read, renderStages)); } break; case Constants.TexturesSetIndex: @@ -1226,7 +1266,7 @@ namespace Ryujinx.Graphics.Metal renderStages |= MTLRenderStages.RenderStageFragment; } - renderCommandEncoder.UseResource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read, renderStages); + resources.Resources.Add(new Resource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read, renderStages)); } } else @@ -1268,8 +1308,7 @@ namespace Ryujinx.Graphics.Metal renderStages |= MTLRenderStages.RenderStageFragment; } - renderCommandEncoder.UseResource(new MTLResource(mtlTexture.NativePtr), - MTLResourceUsage.Read, renderStages); + resources.Resources.Add(new Resource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read, renderStages)); } foreach (var sampler in samplers) @@ -1325,7 +1364,7 @@ namespace Ryujinx.Graphics.Metal renderStages |= MTLRenderStages.RenderStageFragment; } - renderCommandEncoder.UseResource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read, renderStages); + resources.Resources.Add(new Resource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read, renderStages)); } } } @@ -1364,7 +1403,7 @@ namespace Ryujinx.Graphics.Metal renderStages |= MTLRenderStages.RenderStageFragment; } - renderCommandEncoder.UseResource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read | MTLResourceUsage.Write, renderStages); + resources.Resources.Add(new Resource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read | MTLResourceUsage.Write, renderStages)); } } break; @@ -1375,18 +1414,18 @@ namespace Ryujinx.Graphics.Metal { vertArgBuffer.Holder.SetDataUnchecked(vertArgBuffer.Offset, MemoryMarshal.AsBytes(vertResourceIds)); var mtlVertArgBuffer = _bufferManager.GetBuffer(vertArgBuffer.Handle, false).Get(_pipeline.Cbs).Value; - renderCommandEncoder.SetVertexBuffer(mtlVertArgBuffer, (uint)vertArgBuffer.Range.Offset, SetIndexToBindingIndex(setIndex)); + resources.VertexBuffers.Add(new BufferResource(mtlVertArgBuffer, (uint)vertArgBuffer.Range.Offset, SetIndexToBindingIndex(setIndex))); } if (program.FragArgumentBufferSizes[setIndex] > 0) { fragArgBuffer.Holder.SetDataUnchecked(fragArgBuffer.Offset, MemoryMarshal.AsBytes(fragResourceIds)); var mtlFragArgBuffer = _bufferManager.GetBuffer(fragArgBuffer.Handle, false).Get(_pipeline.Cbs).Value; - renderCommandEncoder.SetFragmentBuffer(mtlFragArgBuffer, (uint)fragArgBuffer.Range.Offset, SetIndexToBindingIndex(setIndex)); + resources.FragmentBuffers.Add(new BufferResource(mtlFragArgBuffer, (uint)fragArgBuffer.Range.Offset, SetIndexToBindingIndex(setIndex))); } } - private readonly void UpdateAndBind(MTLComputeCommandEncoder computeCommandEncoder, Program program, uint setIndex) + private readonly void UpdateAndBind(Program program, uint setIndex, ref readonly ComputeEncoderResources resources) { var bindingSegments = program.BindingSegments[setIndex]; @@ -1443,7 +1482,7 @@ namespace Ryujinx.Graphics.Metal if ((segment.Stages & ResourceStages.Compute) != 0) { - computeCommandEncoder.UseResource(new MTLResource(mtlBuffer.NativePtr), MTLResourceUsage.Read); + resources.Resources.Add(new Resource(new MTLResource(mtlBuffer.NativePtr), MTLResourceUsage.Read, 0)); resourceIds[resourceIdIndex] = mtlBuffer.GpuAddress + (ulong)offset; resourceIdIndex++; } @@ -1480,7 +1519,7 @@ namespace Ryujinx.Graphics.Metal if ((segment.Stages & ResourceStages.Compute) != 0) { - computeCommandEncoder.UseResource(new MTLResource(mtlBuffer.NativePtr), MTLResourceUsage.Read | MTLResourceUsage.Write); + resources.Resources.Add(new Resource(new MTLResource(mtlBuffer.NativePtr), MTLResourceUsage.Read | MTLResourceUsage.Write, 0)); resourceIds[resourceIdIndex] = mtlBuffer.GpuAddress + (ulong)offset; resourceIdIndex++; } @@ -1511,7 +1550,7 @@ namespace Ryujinx.Graphics.Metal if ((segment.Stages & ResourceStages.Compute) != 0) { - computeCommandEncoder.UseResource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read); + resources.Resources.Add(new Resource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read, 0)); resourceIds[resourceIdIndex] = mtlTexture.GpuResourceID._impl; resourceIdIndex++; @@ -1545,8 +1584,7 @@ namespace Ryujinx.Graphics.Metal if ((segment.Stages & ResourceStages.Compute) != 0) { - computeCommandEncoder.UseResource(new MTLResource(mtlTexture.NativePtr), - MTLResourceUsage.Read); + resources.Resources.Add(new Resource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read, 0)); resourceIds[resourceIdIndex] = mtlTexture.GpuResourceID._impl; resourceIdIndex++; @@ -1580,7 +1618,7 @@ namespace Ryujinx.Graphics.Metal if ((segment.Stages & ResourceStages.Compute) != 0) { - computeCommandEncoder.UseResource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read); + resources.Resources.Add(new Resource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read, 0)); resourceIds[resourceIdIndex] = mtlTexture.GpuResourceID._impl; resourceIdIndex++; } @@ -1610,7 +1648,7 @@ namespace Ryujinx.Graphics.Metal if ((segment.Stages & ResourceStages.Compute) != 0) { - computeCommandEncoder.UseResource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read | MTLResourceUsage.Write); + resources.Resources.Add(new Resource(new MTLResource(mtlTexture.NativePtr), MTLResourceUsage.Read | MTLResourceUsage.Write, 0)); resourceIds[resourceIdIndex] = mtlTexture.GpuResourceID._impl; resourceIdIndex++; } @@ -1625,7 +1663,7 @@ namespace Ryujinx.Graphics.Metal { argBuffer.Holder.SetDataUnchecked(argBuffer.Offset, MemoryMarshal.AsBytes(resourceIds)); var mtlArgBuffer = _bufferManager.GetBuffer(argBuffer.Handle, false).Get(_pipeline.Cbs).Value; - computeCommandEncoder.SetBuffer(mtlArgBuffer, (uint)argBuffer.Range.Offset, SetIndexToBindingIndex(setIndex)); + resources.Buffers.Add(new BufferResource(mtlArgBuffer, (uint)argBuffer.Range.Offset, SetIndexToBindingIndex(setIndex))); } } diff --git a/src/Ryujinx.Graphics.Metal/Pipeline.cs b/src/Ryujinx.Graphics.Metal/Pipeline.cs index 8fb407905a..87306abcf3 100644 --- a/src/Ryujinx.Graphics.Metal/Pipeline.cs +++ b/src/Ryujinx.Graphics.Metal/Pipeline.cs @@ -82,6 +82,17 @@ namespace Ryujinx.Graphics.Metal public MTLRenderCommandEncoder GetOrCreateRenderEncoder(bool forDraw = false) { + // Mark all state as dirty to ensure it is set on the new encoder + if (Cbs.Encoders.CurrentEncoderType != EncoderType.Render) + { + _encoderStateManager.SignalRenderDirty(); + } + + if (forDraw) + { + _encoderStateManager.RenderResourcesPrepass(); + } + MTLRenderCommandEncoder renderCommandEncoder = Cbs.Encoders.EnsureRenderEncoder(); if (forDraw) @@ -99,6 +110,17 @@ namespace Ryujinx.Graphics.Metal public MTLComputeCommandEncoder GetOrCreateComputeEncoder(bool forDispatch = false) { + // Mark all state as dirty to ensure it is set on the new encoder + if (Cbs.Encoders.CurrentEncoderType != EncoderType.Compute) + { + _encoderStateManager.SignalComputeDirty(); + } + + if (forDispatch) + { + _encoderStateManager.ComputeResourcesPrepass(); + } + MTLComputeCommandEncoder computeCommandEncoder = Cbs.Encoders.EnsureComputeEncoder(); if (forDispatch)