ServerBase thread safety (#4577)

* Add guard against ServerBase.Dispose() being called multiple times. Add reset event to avoid Dispose() being called while the ServerLoop is still running.

* remove unused usings

* rework ServerBase to use one collection each for sessions and ports, and make all accesses thread-safe.

* fix Logger call

* use GetSessionObj(int) instead of using _sessions directly

* move _threadStopped check inside "dispose once" test

* - Replace _threadStopped event with attempt to Join() the ending thread (if that isn't the current thread) instead.

- Use the instance-local _selfProcess and (new) _selfThread variables to avoid suggesting that the current KProcess and KThread could change. Per gdkchan, they can't currently, and this old IPC system will be removed before that changes.

- Re-order Dispose() so that the Interlocked _isDisposed check is the last check before disposing, to increase the likelihood that multiple callers will result in one of them succeeding.

* code style suggestions per AcK77

* add infinite wait for thread termination
This commit is contained in:
jhorv 2023-05-21 15:28:51 -04:00 committed by GitHub
parent 5626f2ca1c
commit 21e88f17f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,4 +1,5 @@
using Ryujinx.Common; using Ryujinx.Common;
using Ryujinx.Common.Logging;
using Ryujinx.Common.Memory; using Ryujinx.Common.Memory;
using Ryujinx.HLE.HOS.Ipc; using Ryujinx.HLE.HOS.Ipc;
using Ryujinx.HLE.HOS.Kernel; using Ryujinx.HLE.HOS.Kernel;
@ -32,13 +33,14 @@ namespace Ryujinx.HLE.HOS.Services
0x01007FFF 0x01007FFF
}; };
private readonly object _handleLock = new(); // The amount of time Dispose() will wait to Join() the thread executing the ServerLoop()
private static readonly TimeSpan ThreadJoinTimeout = TimeSpan.FromSeconds(3);
private readonly KernelContext _context; private readonly KernelContext _context;
private KProcess _selfProcess; private KProcess _selfProcess;
private KThread _selfThread;
private readonly List<int> _sessionHandles = new List<int>(); private readonly ReaderWriterLockSlim _handleLock = new ReaderWriterLockSlim();
private readonly List<int> _portHandles = new List<int>();
private readonly Dictionary<int, IpcService> _sessions = new Dictionary<int, IpcService>(); private readonly Dictionary<int, IpcService> _sessions = new Dictionary<int, IpcService>();
private readonly Dictionary<int, Func<IpcService>> _ports = new Dictionary<int, Func<IpcService>>(); private readonly Dictionary<int, Func<IpcService>> _ports = new Dictionary<int, Func<IpcService>>();
@ -48,6 +50,8 @@ namespace Ryujinx.HLE.HOS.Services
private readonly MemoryStream _responseDataStream; private readonly MemoryStream _responseDataStream;
private readonly BinaryWriter _responseDataWriter; private readonly BinaryWriter _responseDataWriter;
private int _isDisposed = 0;
public ManualResetEvent InitDone { get; } public ManualResetEvent InitDone { get; }
public string Name { get; } public string Name { get; }
public Func<IpcService> SmObjectFactory { get; } public Func<IpcService> SmObjectFactory { get; }
@ -79,12 +83,21 @@ namespace Ryujinx.HLE.HOS.Services
private void AddPort(int serverPortHandle, Func<IpcService> objectFactory) private void AddPort(int serverPortHandle, Func<IpcService> objectFactory)
{ {
lock (_handleLock) bool lockTaken = false;
try
{ {
_portHandles.Add(serverPortHandle); lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
}
_ports.Add(serverPortHandle, objectFactory); _ports.Add(serverPortHandle, objectFactory);
} }
finally
{
if (lockTaken)
{
_handleLock.ExitWriteLock();
}
}
}
public void AddSessionObj(KServerSession serverSession, IpcService obj) public void AddSessionObj(KServerSession serverSession, IpcService obj)
{ {
@ -92,17 +105,63 @@ namespace Ryujinx.HLE.HOS.Services
InitDone.WaitOne(); InitDone.WaitOne();
_selfProcess.HandleTable.GenerateHandle(serverSession, out int serverSessionHandle); _selfProcess.HandleTable.GenerateHandle(serverSession, out int serverSessionHandle);
AddSessionObj(serverSessionHandle, obj); AddSessionObj(serverSessionHandle, obj);
} }
public void AddSessionObj(int serverSessionHandle, IpcService obj) public void AddSessionObj(int serverSessionHandle, IpcService obj)
{ {
lock (_handleLock) bool lockTaken = false;
try
{ {
_sessionHandles.Add(serverSessionHandle); lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
}
_sessions.Add(serverSessionHandle, obj); _sessions.Add(serverSessionHandle, obj);
} }
finally
{
if (lockTaken)
{
_handleLock.ExitWriteLock();
}
}
}
private IpcService GetSessionObj(int serverSessionHandle)
{
bool lockTaken = false;
try
{
lockTaken = _handleLock.TryEnterReadLock(Timeout.Infinite);
return _sessions[serverSessionHandle];
}
finally
{
if (lockTaken)
{
_handleLock.ExitReadLock();
}
}
}
private bool RemoveSessionObj(int serverSessionHandle, out IpcService obj)
{
bool lockTaken = false;
try
{
lockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
return _sessions.Remove(serverSessionHandle, out obj);
}
finally
{
if (lockTaken)
{
_handleLock.ExitWriteLock();
}
}
}
private void Main() private void Main()
{ {
@ -112,6 +171,7 @@ namespace Ryujinx.HLE.HOS.Services
private void ServerLoop() private void ServerLoop()
{ {
_selfProcess = KernelStatic.GetCurrentProcess(); _selfProcess = KernelStatic.GetCurrentProcess();
_selfThread = KernelStatic.GetCurrentThread();
if (SmObjectFactory != null) if (SmObjectFactory != null)
{ {
@ -122,8 +182,7 @@ namespace Ryujinx.HLE.HOS.Services
InitDone.Set(); InitDone.Set();
KThread thread = KernelStatic.GetCurrentThread(); ulong messagePtr = _selfThread.TlsAddress;
ulong messagePtr = thread.TlsAddress;
_context.Syscall.SetHeapSize(out ulong heapAddr, 0x200000); _context.Syscall.SetHeapSize(out ulong heapAddr, 0x200000);
_selfProcess.CpuMemory.Write(messagePtr + 0x0, 0); _selfProcess.CpuMemory.Write(messagePtr + 0x0, 0);
@ -134,27 +193,39 @@ namespace Ryujinx.HLE.HOS.Services
while (true) while (true)
{ {
int handleCount;
int portHandleCount; int portHandleCount;
int handleCount;
int[] handles; int[] handles;
lock (_handleLock) bool handleLockTaken = false;
try
{ {
portHandleCount = _portHandles.Count; handleLockTaken = _handleLock.TryEnterReadLock(Timeout.Infinite);
handleCount = portHandleCount + _sessionHandles.Count;
portHandleCount = _ports.Count;
handleCount = portHandleCount + _sessions.Count;
handles = ArrayPool<int>.Shared.Rent(handleCount); handles = ArrayPool<int>.Shared.Rent(handleCount);
_portHandles.CopyTo(handles, 0); _ports.Keys.CopyTo(handles, 0);
_sessionHandles.CopyTo(handles, portHandleCount);
_sessions.Keys.CopyTo(handles, portHandleCount);
}
finally
{
if (handleLockTaken)
{
_handleLock.ExitReadLock();
}
} }
// We still need a timeout here to allow the service to pick up and listen new sessions... // We still need a timeout here to allow the service to pick up and listen new sessions...
var rc = _context.Syscall.ReplyAndReceive(out int signaledIndex, handles.AsSpan(0, handleCount), replyTargetHandle, 1000000L); var rc = _context.Syscall.ReplyAndReceive(out int signaledIndex, handles.AsSpan(0, handleCount), replyTargetHandle, 1000000L);
thread.HandlePostSyscall(); _selfThread.HandlePostSyscall();
if (!thread.Context.Running) if (!_selfThread.Context.Running)
{ {
break; break;
} }
@ -178,9 +249,20 @@ namespace Ryujinx.HLE.HOS.Services
// We got a new connection, accept the session to allow servicing future requests. // We got a new connection, accept the session to allow servicing future requests.
if (_context.Syscall.AcceptSession(out int serverSessionHandle, handles[signaledIndex]) == Result.Success) if (_context.Syscall.AcceptSession(out int serverSessionHandle, handles[signaledIndex]) == Result.Success)
{ {
bool handleWriteLockTaken = false;
try
{
handleWriteLockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
IpcService obj = _ports[handles[signaledIndex]].Invoke(); IpcService obj = _ports[handles[signaledIndex]].Invoke();
_sessions.Add(serverSessionHandle, obj);
AddSessionObj(serverSessionHandle, obj); }
finally
{
if (handleWriteLockTaken)
{
_handleLock.ExitWriteLock();
}
}
} }
} }
@ -197,11 +279,7 @@ namespace Ryujinx.HLE.HOS.Services
private bool Process(int serverSessionHandle, ulong recvListAddr) private bool Process(int serverSessionHandle, ulong recvListAddr)
{ {
KProcess process = KernelStatic.GetCurrentProcess(); IpcMessage request = ReadRequest();
KThread thread = KernelStatic.GetCurrentThread();
ulong messagePtr = thread.TlsAddress;
IpcMessage request = ReadRequest(process, messagePtr);
IpcMessage response = new IpcMessage(); IpcMessage response = new IpcMessage();
@ -247,15 +325,15 @@ namespace Ryujinx.HLE.HOS.Services
ServiceCtx context = new ServiceCtx( ServiceCtx context = new ServiceCtx(
_context.Device, _context.Device,
process, _selfProcess,
process.CpuMemory, _selfProcess.CpuMemory,
thread, _selfThread,
request, request,
response, response,
_requestDataReader, _requestDataReader,
_responseDataWriter); _responseDataWriter);
_sessions[serverSessionHandle].CallCmifMethod(context); GetSessionObj(serverSessionHandle).CallCmifMethod(context);
response.RawData = _responseDataStream.ToArray(); response.RawData = _responseDataStream.ToArray();
} }
@ -268,7 +346,7 @@ namespace Ryujinx.HLE.HOS.Services
switch (cmdId) switch (cmdId)
{ {
case 0: case 0:
FillHipcResponse(response, 0, _sessions[serverSessionHandle].ConvertToDomain()); FillHipcResponse(response, 0, GetSessionObj(serverSessionHandle).ConvertToDomain());
break; break;
case 3: case 3:
@ -278,17 +356,31 @@ namespace Ryujinx.HLE.HOS.Services
// TODO: Whats the difference between IpcDuplicateSession/Ex? // TODO: Whats the difference between IpcDuplicateSession/Ex?
case 2: case 2:
case 4: case 4:
int unknown = _requestDataReader.ReadInt32(); {
_ = _requestDataReader.ReadInt32();
_context.Syscall.CreateSession(out int dupServerSessionHandle, out int dupClientSessionHandle, false, 0); _context.Syscall.CreateSession(out int dupServerSessionHandle, out int dupClientSessionHandle, false, 0);
AddSessionObj(dupServerSessionHandle, _sessions[serverSessionHandle]); bool writeLockTaken = false;
try
{
writeLockTaken = _handleLock.TryEnterWriteLock(Timeout.Infinite);
_sessions[dupServerSessionHandle] = _sessions[serverSessionHandle];
}
finally
{
if (writeLockTaken)
{
_handleLock.ExitWriteLock();
}
}
response.HandleDesc = IpcHandleDesc.MakeMove(dupClientSessionHandle); response.HandleDesc = IpcHandleDesc.MakeMove(dupClientSessionHandle);
FillHipcResponse(response, 0); FillHipcResponse(response, 0);
break; break;
}
default: throw new NotImplementedException(cmdId.ToString()); default: throw new NotImplementedException(cmdId.ToString());
} }
@ -296,13 +388,10 @@ namespace Ryujinx.HLE.HOS.Services
else if (request.Type == IpcMessageType.CmifCloseSession || request.Type == IpcMessageType.TipcCloseSession) else if (request.Type == IpcMessageType.CmifCloseSession || request.Type == IpcMessageType.TipcCloseSession)
{ {
_context.Syscall.CloseHandle(serverSessionHandle); _context.Syscall.CloseHandle(serverSessionHandle);
lock (_handleLock) if (RemoveSessionObj(serverSessionHandle, out var session))
{ {
_sessionHandles.Remove(serverSessionHandle); (session as IDisposable)?.Dispose();
} }
IpcService service = _sessions[serverSessionHandle];
(service as IDisposable)?.Dispose();
_sessions.Remove(serverSessionHandle);
shouldReply = false; shouldReply = false;
} }
// If the type is past 0xF, we are using TIPC // If the type is past 0xF, we are using TIPC
@ -317,20 +406,20 @@ namespace Ryujinx.HLE.HOS.Services
ServiceCtx context = new ServiceCtx( ServiceCtx context = new ServiceCtx(
_context.Device, _context.Device,
process, _selfProcess,
process.CpuMemory, _selfProcess.CpuMemory,
thread, _selfThread,
request, request,
response, response,
_requestDataReader, _requestDataReader,
_responseDataWriter); _responseDataWriter);
_sessions[serverSessionHandle].CallTipcMethod(context); GetSessionObj(serverSessionHandle).CallTipcMethod(context);
response.RawData = _responseDataStream.ToArray(); response.RawData = _responseDataStream.ToArray();
using var responseStream = response.GetStreamTipc(); using var responseStream = response.GetStreamTipc();
process.CpuMemory.Write(messagePtr, responseStream.GetReadOnlySequence()); _selfProcess.CpuMemory.Write(_selfThread.TlsAddress, responseStream.GetReadOnlySequence());
} }
else else
{ {
@ -339,27 +428,24 @@ namespace Ryujinx.HLE.HOS.Services
if (!isTipcCommunication) if (!isTipcCommunication)
{ {
using var responseStream = response.GetStream((long)messagePtr, recvListAddr | ((ulong)PointerBufferSize << 48)); using var responseStream = response.GetStream((long)_selfThread.TlsAddress, recvListAddr | ((ulong)PointerBufferSize << 48));
process.CpuMemory.Write(messagePtr, responseStream.GetReadOnlySequence()); _selfProcess.CpuMemory.Write(_selfThread.TlsAddress, responseStream.GetReadOnlySequence());
} }
return shouldReply; return shouldReply;
} }
private static IpcMessage ReadRequest(KProcess process, ulong messagePtr) private IpcMessage ReadRequest()
{ {
const int messageSize = 0x100; const int messageSize = 0x100;
byte[] reqData = ArrayPool<byte>.Shared.Rent(messageSize); using IMemoryOwner<byte> reqDataOwner = ByteMemoryPool.Shared.Rent(messageSize);
Span<byte> reqDataSpan = reqData.AsSpan(0, messageSize); Span<byte> reqDataSpan = reqDataOwner.Memory.Span;
reqDataSpan.Clear();
process.CpuMemory.Read(messagePtr, reqDataSpan); _selfProcess.CpuMemory.Read(_selfThread.TlsAddress, reqDataSpan);
IpcMessage request = new IpcMessage(reqDataSpan, (long)messagePtr); IpcMessage request = new IpcMessage(reqDataSpan, (long)_selfThread.TlsAddress);
ArrayPool<byte>.Shared.Return(reqData);
return request; return request;
} }
@ -392,19 +478,27 @@ namespace Ryujinx.HLE.HOS.Services
protected virtual void Dispose(bool disposing) protected virtual void Dispose(bool disposing)
{ {
if (disposing) if (disposing && _selfThread != null)
{
if (_selfThread.HostThread.ManagedThreadId != Environment.CurrentManagedThreadId && _selfThread.HostThread.Join(ThreadJoinTimeout) == false)
{
Logger.Warning?.Print(LogClass.Service, $"The ServerBase thread didn't terminate within {ThreadJoinTimeout:g}, waiting longer.");
_selfThread.HostThread.Join(Timeout.Infinite);
}
if (Interlocked.Exchange(ref _isDisposed, 1) == 0)
{ {
foreach (IpcService service in _sessions.Values) foreach (IpcService service in _sessions.Values)
{ {
if (service is IDisposable disposableObj) (service as IDisposable)?.Dispose();
{
disposableObj.Dispose();
}
service.DestroyAtExit(); service.DestroyAtExit();
} }
_sessions.Clear(); _sessions.Clear();
_ports.Clear();
_handleLock.Dispose();
_requestDataReader.Dispose(); _requestDataReader.Dispose();
_requestDataStream.Dispose(); _requestDataStream.Dispose();
@ -414,6 +508,7 @@ namespace Ryujinx.HLE.HOS.Services
InitDone.Dispose(); InitDone.Dispose();
} }
} }
}
public void Dispose() public void Dispose()
{ {