Add missing check for thread termination on ArbitrateLock (#4722)

* Add missing check for thread termination on ArbitrateLock

* Use TerminationRequested in all places where it can be used
This commit is contained in:
gdkchan 2023-04-25 19:33:14 -03:00 committed by GitHub
parent db4242c5dc
commit 097562bc6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 19 additions and 23 deletions

View file

@ -188,8 +188,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
if (request.AsyncEvent == null) if (request.AsyncEvent == null)
{ {
if (request.ClientThread.ShallBeTerminated || if (request.ClientThread.TerminationRequested)
request.ClientThread.SchedFlags == ThreadSchedState.TerminationPending)
{ {
return KernelResult.ThreadTerminating; return KernelResult.ThreadTerminating;
} }
@ -1104,8 +1103,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
{ {
foreach (KSessionRequest request in IterateWithRemovalOfAllRequests()) foreach (KSessionRequest request in IterateWithRemovalOfAllRequests())
{ {
if (request.ClientThread.ShallBeTerminated || if (request.ClientThread.TerminationRequested)
request.ClientThread.SchedFlags == ThreadSchedState.TerminationPending)
{ {
continue; continue;
} }

View file

@ -31,6 +31,13 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
_context.CriticalSection.Enter(); _context.CriticalSection.Enter();
if (currentThread.TerminationRequested)
{
_context.CriticalSection.Leave();
return KernelResult.ThreadTerminating;
}
currentThread.SignaledObj = null; currentThread.SignaledObj = null;
currentThread.ObjSyncResult = Result.Success; currentThread.ObjSyncResult = Result.Success;
@ -114,8 +121,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
currentThread.SignaledObj = null; currentThread.SignaledObj = null;
currentThread.ObjSyncResult = KernelResult.TimedOut; currentThread.ObjSyncResult = KernelResult.TimedOut;
if (currentThread.ShallBeTerminated || if (currentThread.TerminationRequested)
currentThread.SchedFlags == ThreadSchedState.TerminationPending)
{ {
_context.CriticalSection.Leave(); _context.CriticalSection.Leave();
@ -280,8 +286,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
_context.CriticalSection.Enter(); _context.CriticalSection.Enter();
if (currentThread.ShallBeTerminated || if (currentThread.TerminationRequested)
currentThread.SchedFlags == ThreadSchedState.TerminationPending)
{ {
_context.CriticalSection.Leave(); _context.CriticalSection.Leave();
@ -351,8 +356,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
_context.CriticalSection.Enter(); _context.CriticalSection.Enter();
if (currentThread.ShallBeTerminated || if (currentThread.TerminationRequested)
currentThread.SchedFlags == ThreadSchedState.TerminationPending)
{ {
_context.CriticalSection.Leave(); _context.CriticalSection.Leave();

View file

@ -19,8 +19,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
currentThread.WithholderNode = threadList.AddLast(currentThread); currentThread.WithholderNode = threadList.AddLast(currentThread);
if (currentThread.ShallBeTerminated || if (currentThread.TerminationRequested)
currentThread.SchedFlags == ThreadSchedState.TerminationPending)
{ {
threadList.Remove(currentThread.WithholderNode); threadList.Remove(currentThread.WithholderNode);

View file

@ -47,8 +47,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
KThread currentThread = KernelStatic.GetCurrentThread(); KThread currentThread = KernelStatic.GetCurrentThread();
if (currentThread.ShallBeTerminated || if (currentThread.TerminationRequested)
currentThread.SchedFlags == ThreadSchedState.TerminationPending)
{ {
result = KernelResult.ThreadTerminating; result = KernelResult.ThreadTerminating;
} }
@ -61,7 +60,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
else else
{ {
LinkedListNode<KThread>[] syncNodesArray = ArrayPool<LinkedListNode<KThread>>.Shared.Rent(syncObjs.Length); LinkedListNode<KThread>[] syncNodesArray = ArrayPool<LinkedListNode<KThread>>.Shared.Rent(syncObjs.Length);
Span<LinkedListNode<KThread>> syncNodes = syncNodesArray.AsSpan(0, syncObjs.Length); Span<LinkedListNode<KThread>> syncNodes = syncNodesArray.AsSpan(0, syncObjs.Length);
for (int index = 0; index < syncObjs.Length; index++) for (int index = 0; index < syncObjs.Length; index++)

View file

@ -99,11 +99,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
private int _shallBeTerminated; private int _shallBeTerminated;
public bool ShallBeTerminated private bool ShallBeTerminated => _shallBeTerminated != 0;
{
get => _shallBeTerminated != 0;
set => _shallBeTerminated = value ? 1 : 0;
}
public bool TerminationRequested => ShallBeTerminated || SchedFlags == ThreadSchedState.TerminationPending; public bool TerminationRequested => ShallBeTerminated || SchedFlags == ThreadSchedState.TerminationPending;
@ -322,7 +318,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
ThreadSchedState result; ThreadSchedState result;
if (Interlocked.CompareExchange(ref _shallBeTerminated, 1, 0) == 0) if (Interlocked.Exchange(ref _shallBeTerminated, 1) == 0)
{ {
if ((SchedFlags & ThreadSchedState.LowMask) == ThreadSchedState.None) if ((SchedFlags & ThreadSchedState.LowMask) == ThreadSchedState.None)
{ {
@ -470,7 +466,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
{ {
KernelContext.CriticalSection.Enter(); KernelContext.CriticalSection.Enter();
if (ShallBeTerminated || SchedFlags == ThreadSchedState.TerminationPending) if (TerminationRequested)
{ {
KernelContext.CriticalSection.Leave(); KernelContext.CriticalSection.Leave();
@ -552,7 +548,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
return KernelResult.InvalidState; return KernelResult.InvalidState;
} }
if (!ShallBeTerminated && SchedFlags != ThreadSchedState.TerminationPending) if (!TerminationRequested)
{ {
if (pause) if (pause)
{ {