Vulkan: Simplify MultiFenceHolder and managing them (#4845)

* Vulkan: Simplify waitable add/remove

Removal of unnecessary hashset and dictionary

* Thread safety for GetBufferData in PersistentFlushBuffer

* Fix WaitForFencesImpl thread safety

* Proper methods for risky reference increments

* Wrong type of CB.

* Address feedback
This commit is contained in:
riperiperi 2023-05-08 11:45:12 +01:00 committed by GitHub
parent 895d9b53bc
commit 1b28ecd63e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 138 additions and 48 deletions

View file

@ -105,6 +105,23 @@ namespace Ryujinx.Graphics.Vulkan
} }
} }
public bool TryIncrementReferenceCount()
{
int lastValue;
do
{
lastValue = _referenceCount;
if (lastValue == 0)
{
return false;
}
}
while (Interlocked.CompareExchange(ref _referenceCount, lastValue + 1, lastValue) != lastValue);
return true;
}
public void IncrementReferenceCount() public void IncrementReferenceCount()
{ {
if (Interlocked.Increment(ref _referenceCount) == 1) if (Interlocked.Increment(ref _referenceCount) == 1)

View file

@ -599,9 +599,10 @@ namespace Ryujinx.Graphics.Vulkan
Auto<DisposableBuffer> dst, Auto<DisposableBuffer> dst,
int srcOffset, int srcOffset,
int dstOffset, int dstOffset,
int size) int size,
bool registerSrcUsage = true)
{ {
var srcBuffer = src.Get(cbs, srcOffset, size).Value; var srcBuffer = registerSrcUsage ? src.Get(cbs, srcOffset, size).Value : src.GetUnsafe().Value;
var dstBuffer = dst.Get(cbs, dstOffset, size).Value; var dstBuffer = dst.Get(cbs, dstOffset, size).Value;
InsertBufferBarrier( InsertBufferBarrier(

View file

@ -31,7 +31,7 @@ namespace Ryujinx.Graphics.Vulkan
public SemaphoreHolder Semaphore; public SemaphoreHolder Semaphore;
public List<IAuto> Dependants; public List<IAuto> Dependants;
public HashSet<MultiFenceHolder> Waitables; public List<MultiFenceHolder> Waitables;
public HashSet<SemaphoreHolder> Dependencies; public HashSet<SemaphoreHolder> Dependencies;
public void Initialize(Vk api, Device device, CommandPool pool) public void Initialize(Vk api, Device device, CommandPool pool)
@ -47,7 +47,7 @@ namespace Ryujinx.Graphics.Vulkan
api.AllocateCommandBuffers(device, allocateInfo, out CommandBuffer); api.AllocateCommandBuffers(device, allocateInfo, out CommandBuffer);
Dependants = new List<IAuto>(); Dependants = new List<IAuto>();
Waitables = new HashSet<MultiFenceHolder>(); Waitables = new List<MultiFenceHolder>();
Dependencies = new HashSet<SemaphoreHolder>(); Dependencies = new HashSet<SemaphoreHolder>();
} }
} }
@ -143,9 +143,11 @@ namespace Ryujinx.Graphics.Vulkan
public void AddWaitable(int cbIndex, MultiFenceHolder waitable) public void AddWaitable(int cbIndex, MultiFenceHolder waitable)
{ {
ref var entry = ref _commandBuffers[cbIndex]; ref var entry = ref _commandBuffers[cbIndex];
waitable.AddFence(cbIndex, entry.Fence); if (waitable.AddFence(cbIndex, entry.Fence))
{
entry.Waitables.Add(waitable); entry.Waitables.Add(waitable);
} }
}
public bool HasWaitableOnRentedCommandBuffer(MultiFenceHolder waitable, int offset, int size) public bool HasWaitableOnRentedCommandBuffer(MultiFenceHolder waitable, int offset, int size)
{ {
@ -156,7 +158,7 @@ namespace Ryujinx.Graphics.Vulkan
ref var entry = ref _commandBuffers[i]; ref var entry = ref _commandBuffers[i];
if (entry.InUse && if (entry.InUse &&
entry.Waitables.Contains(waitable) && waitable.HasFence(i) &&
waitable.IsBufferRangeInUse(i, offset, size)) waitable.IsBufferRangeInUse(i, offset, size))
{ {
return true; return true;
@ -331,7 +333,7 @@ namespace Ryujinx.Graphics.Vulkan
foreach (var waitable in entry.Waitables) foreach (var waitable in entry.Waitables)
{ {
waitable.RemoveFence(cbIndex, entry.Fence); waitable.RemoveFence(cbIndex);
waitable.RemoveBufferUses(cbIndex); waitable.RemoveBufferUses(cbIndex);
} }

View file

@ -32,6 +32,25 @@ namespace Ryujinx.Graphics.Vulkan
return _fence; return _fence;
} }
public bool TryGet(out Fence fence)
{
int lastValue;
do
{
lastValue = _referenceCount;
if (lastValue == 0)
{
fence = default;
return false;
}
}
while (Interlocked.CompareExchange(ref _referenceCount, lastValue + 1, lastValue) != lastValue);
fence = _fence;
return true;
}
public Fence Get() public Fence Get()
{ {
Interlocked.Increment(ref _referenceCount); Interlocked.Increment(ref _referenceCount);

View file

@ -1,6 +1,5 @@
using Silk.NET.Vulkan; using Silk.NET.Vulkan;
using System.Collections.Generic; using System;
using System.Linq;
namespace Ryujinx.Graphics.Vulkan namespace Ryujinx.Graphics.Vulkan
{ {
@ -11,7 +10,7 @@ namespace Ryujinx.Graphics.Vulkan
{ {
private static int BufferUsageTrackingGranularity = 4096; private static int BufferUsageTrackingGranularity = 4096;
private readonly Dictionary<FenceHolder, int> _fences; private readonly FenceHolder[] _fences;
private BufferUsageBitmap _bufferUsageBitmap; private BufferUsageBitmap _bufferUsageBitmap;
/// <summary> /// <summary>
@ -19,7 +18,7 @@ namespace Ryujinx.Graphics.Vulkan
/// </summary> /// </summary>
public MultiFenceHolder() public MultiFenceHolder()
{ {
_fences = new Dictionary<FenceHolder, int>(); _fences = new FenceHolder[CommandBufferPool.MaxCommandBuffers];
} }
/// <summary> /// <summary>
@ -28,7 +27,7 @@ namespace Ryujinx.Graphics.Vulkan
/// <param name="size">Size of the buffer</param> /// <param name="size">Size of the buffer</param>
public MultiFenceHolder(int size) public MultiFenceHolder(int size)
{ {
_fences = new Dictionary<FenceHolder, int>(); _fences = new FenceHolder[CommandBufferPool.MaxCommandBuffers];
_bufferUsageBitmap = new BufferUsageBitmap(size, BufferUsageTrackingGranularity); _bufferUsageBitmap = new BufferUsageBitmap(size, BufferUsageTrackingGranularity);
} }
@ -80,25 +79,37 @@ namespace Ryujinx.Graphics.Vulkan
/// </summary> /// </summary>
/// <param name="cbIndex">Command buffer index of the command buffer that owns the fence</param> /// <param name="cbIndex">Command buffer index of the command buffer that owns the fence</param>
/// <param name="fence">Fence to be added</param> /// <param name="fence">Fence to be added</param>
public void AddFence(int cbIndex, FenceHolder fence) /// <returns>True if the command buffer's previous fence value was null</returns>
public bool AddFence(int cbIndex, FenceHolder fence)
{ {
lock (_fences) ref FenceHolder fenceRef = ref _fences[cbIndex];
if (fenceRef == null)
{ {
_fences.TryAdd(fence, cbIndex); fenceRef = fence;
return true;
} }
return false;
} }
/// <summary> /// <summary>
/// Removes a fence from the holder. /// Removes a fence from the holder.
/// </summary> /// </summary>
/// <param name="cbIndex">Command buffer index of the command buffer that owns the fence</param> /// <param name="cbIndex">Command buffer index of the command buffer that owns the fence</param>
/// <param name="fence">Fence to be removed</param> public void RemoveFence(int cbIndex)
public void RemoveFence(int cbIndex, FenceHolder fence)
{ {
lock (_fences) _fences[cbIndex] = null;
{
_fences.Remove(fence);
} }
/// <summary>
/// Determines if a fence referenced on the given command buffer.
/// </summary>
/// <param name="cbIndex">Index of the command buffer to check if it's used</param>
/// <returns>True if referenced, false otherwise</returns>
public bool HasFence(int cbIndex)
{
return _fences[cbIndex] != null;
} }
/// <summary> /// <summary>
@ -147,21 +158,29 @@ namespace Ryujinx.Graphics.Vulkan
/// <returns>True if all fences were signaled before the timeout expired, false otherwise</returns> /// <returns>True if all fences were signaled before the timeout expired, false otherwise</returns>
private bool WaitForFencesImpl(Vk api, Device device, int offset, int size, bool hasTimeout, ulong timeout) private bool WaitForFencesImpl(Vk api, Device device, int offset, int size, bool hasTimeout, ulong timeout)
{ {
FenceHolder[] fenceHolders; Span<FenceHolder> fenceHolders = new FenceHolder[CommandBufferPool.MaxCommandBuffers];
Fence[] fences;
lock (_fences) int count = size != 0 ? GetOverlappingFences(fenceHolders, offset, size) : GetFences(fenceHolders);
{ Span<Fence> fences = stackalloc Fence[count];
fenceHolders = size != 0 ? GetOverlappingFences(offset, size) : _fences.Keys.ToArray();
fences = new Fence[fenceHolders.Length];
for (int i = 0; i < fenceHolders.Length; i++) int fenceCount = 0;
for (int i = 0; i < count; i++)
{ {
fences[i] = fenceHolders[i].Get(); if (fenceHolders[i].TryGet(out Fence fence))
{
fences[fenceCount] = fence;
if (fenceCount < i)
{
fenceHolders[fenceCount] = fenceHolders[i];
}
fenceCount++;
} }
} }
if (fences.Length == 0) if (fenceCount == 0)
{ {
return true; return true;
} }
@ -170,14 +189,14 @@ namespace Ryujinx.Graphics.Vulkan
if (hasTimeout) if (hasTimeout)
{ {
signaled = FenceHelper.AllSignaled(api, device, fences, timeout); signaled = FenceHelper.AllSignaled(api, device, fences.Slice(0, fenceCount), timeout);
} }
else else
{ {
FenceHelper.WaitAllIndefinitely(api, device, fences); FenceHelper.WaitAllIndefinitely(api, device, fences.Slice(0, fenceCount));
} }
for (int i = 0; i < fenceHolders.Length; i++) for (int i = 0; i < fenceCount; i++)
{ {
fenceHolders[i].Put(); fenceHolders[i].Put();
} }
@ -185,28 +204,50 @@ namespace Ryujinx.Graphics.Vulkan
return signaled; return signaled;
} }
/// <summary>
/// Gets fences to wait for.
/// </summary>
/// <param name="storage">Span to store fences in</param>
/// <returns>Number of fences placed in storage</returns>
private int GetFences(Span<FenceHolder> storage)
{
int count = 0;
for (int i = 0; i < _fences.Length; i++)
{
var fence = _fences[i];
if (fence != null)
{
storage[count++] = fence;
}
}
return count;
}
/// <summary> /// <summary>
/// Gets fences to wait for use of a given buffer region. /// Gets fences to wait for use of a given buffer region.
/// </summary> /// </summary>
/// <param name="storage">Span to store overlapping fences in</param>
/// <param name="offset">Offset of the range</param> /// <param name="offset">Offset of the range</param>
/// <param name="size">Size of the range in bytes</param> /// <param name="size">Size of the range in bytes</param>
/// <returns>Fences for the specified region</returns> /// <returns>Number of fences for the specified region placed in storage</returns>
private FenceHolder[] GetOverlappingFences(int offset, int size) private int GetOverlappingFences(Span<FenceHolder> storage, int offset, int size)
{ {
List<FenceHolder> overlapping = new List<FenceHolder>(); int count = 0;
foreach (var kv in _fences) for (int i = 0; i < _fences.Length; i++)
{ {
var fence = kv.Key; var fence = _fences[i];
var ownerCbIndex = kv.Value;
if (_bufferUsageBitmap.OverlapsWith(ownerCbIndex, offset, size)) if (fence != null && _bufferUsageBitmap.OverlapsWith(i, offset, size))
{ {
overlapping.Add(fence); storage[count++] = fence;
} }
} }
return overlapping.ToArray(); return count;
} }
} }
} }

View file

@ -34,16 +34,26 @@ namespace Ryujinx.Graphics.Vulkan
public Span<byte> GetBufferData(CommandBufferPool cbp, BufferHolder buffer, int offset, int size) public Span<byte> GetBufferData(CommandBufferPool cbp, BufferHolder buffer, int offset, int size)
{ {
var flushStorage = ResizeIfNeeded(size); var flushStorage = ResizeIfNeeded(size);
Auto<DisposableBuffer> srcBuffer;
using (var cbs = cbp.Rent()) using (var cbs = cbp.Rent())
{ {
var srcBuffer = buffer.GetBuffer(cbs.CommandBuffer); srcBuffer = buffer.GetBuffer(cbs.CommandBuffer);
var dstBuffer = flushStorage.GetBuffer(cbs.CommandBuffer); var dstBuffer = flushStorage.GetBuffer(cbs.CommandBuffer);
BufferHolder.Copy(_gd, cbs, srcBuffer, dstBuffer, offset, 0, size); if (srcBuffer.TryIncrementReferenceCount())
{
BufferHolder.Copy(_gd, cbs, srcBuffer, dstBuffer, offset, 0, size, registerSrcUsage: false);
}
else
{
// Source buffer is no longer alive, don't copy anything to flush storage.
srcBuffer = null;
}
} }
flushStorage.WaitForFences(); flushStorage.WaitForFences();
srcBuffer?.DecrementReferenceCount();
return flushStorage.GetDataStorage(0, size); return flushStorage.GetDataStorage(0, size);
} }