From 097562bc6c227c42f803ce1078fcb4adf06cd20c Mon Sep 17 00:00:00 2001 From: gdkchan Date: Tue, 25 Apr 2023 19:33:14 -0300 Subject: [PATCH] Add missing check for thread termination on ArbitrateLock (#4722) * Add missing check for thread termination on ArbitrateLock * Use TerminationRequested in all places where it can be used --- Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs | 6 ++---- .../HOS/Kernel/Threading/KAddressArbiter.cs | 16 ++++++++++------ .../HOS/Kernel/Threading/KConditionVariable.cs | 3 +-- .../HOS/Kernel/Threading/KSynchronization.cs | 5 ++--- Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs | 12 ++++-------- 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs b/Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs index 9c2184d923..86469c03ae 100644 --- a/Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs +++ b/Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs @@ -188,8 +188,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc if (request.AsyncEvent == null) { - if (request.ClientThread.ShallBeTerminated || - request.ClientThread.SchedFlags == ThreadSchedState.TerminationPending) + if (request.ClientThread.TerminationRequested) { return KernelResult.ThreadTerminating; } @@ -1104,8 +1103,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc { foreach (KSessionRequest request in IterateWithRemovalOfAllRequests()) { - if (request.ClientThread.ShallBeTerminated || - request.ClientThread.SchedFlags == ThreadSchedState.TerminationPending) + if (request.ClientThread.TerminationRequested) { continue; } diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KAddressArbiter.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KAddressArbiter.cs index a5f9df5ef2..74867b44eb 100644 --- a/Ryujinx.HLE/HOS/Kernel/Threading/KAddressArbiter.cs +++ b/Ryujinx.HLE/HOS/Kernel/Threading/KAddressArbiter.cs @@ -31,6 +31,13 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading _context.CriticalSection.Enter(); + if (currentThread.TerminationRequested) + { + _context.CriticalSection.Leave(); + + return KernelResult.ThreadTerminating; + } + currentThread.SignaledObj = null; currentThread.ObjSyncResult = Result.Success; @@ -114,8 +121,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading currentThread.SignaledObj = null; currentThread.ObjSyncResult = KernelResult.TimedOut; - if (currentThread.ShallBeTerminated || - currentThread.SchedFlags == ThreadSchedState.TerminationPending) + if (currentThread.TerminationRequested) { _context.CriticalSection.Leave(); @@ -280,8 +286,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading _context.CriticalSection.Enter(); - if (currentThread.ShallBeTerminated || - currentThread.SchedFlags == ThreadSchedState.TerminationPending) + if (currentThread.TerminationRequested) { _context.CriticalSection.Leave(); @@ -351,8 +356,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading _context.CriticalSection.Enter(); - if (currentThread.ShallBeTerminated || - currentThread.SchedFlags == ThreadSchedState.TerminationPending) + if (currentThread.TerminationRequested) { _context.CriticalSection.Leave(); diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KConditionVariable.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KConditionVariable.cs index d146bff01b..891e632f9f 100644 --- a/Ryujinx.HLE/HOS/Kernel/Threading/KConditionVariable.cs +++ b/Ryujinx.HLE/HOS/Kernel/Threading/KConditionVariable.cs @@ -19,8 +19,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading currentThread.WithholderNode = threadList.AddLast(currentThread); - if (currentThread.ShallBeTerminated || - currentThread.SchedFlags == ThreadSchedState.TerminationPending) + if (currentThread.TerminationRequested) { threadList.Remove(currentThread.WithholderNode); diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs index d42f900320..9c196810c3 100644 --- a/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs +++ b/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs @@ -47,8 +47,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading KThread currentThread = KernelStatic.GetCurrentThread(); - if (currentThread.ShallBeTerminated || - currentThread.SchedFlags == ThreadSchedState.TerminationPending) + if (currentThread.TerminationRequested) { result = KernelResult.ThreadTerminating; } @@ -61,7 +60,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading else { LinkedListNode[] syncNodesArray = ArrayPool>.Shared.Rent(syncObjs.Length); - + Span> syncNodes = syncNodesArray.AsSpan(0, syncObjs.Length); for (int index = 0; index < syncObjs.Length; index++) diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs index 6fd496058c..6339646861 100644 --- a/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs +++ b/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs @@ -99,11 +99,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading private int _shallBeTerminated; - public bool ShallBeTerminated - { - get => _shallBeTerminated != 0; - set => _shallBeTerminated = value ? 1 : 0; - } + private bool ShallBeTerminated => _shallBeTerminated != 0; public bool TerminationRequested => ShallBeTerminated || SchedFlags == ThreadSchedState.TerminationPending; @@ -322,7 +318,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading ThreadSchedState result; - if (Interlocked.CompareExchange(ref _shallBeTerminated, 1, 0) == 0) + if (Interlocked.Exchange(ref _shallBeTerminated, 1) == 0) { if ((SchedFlags & ThreadSchedState.LowMask) == ThreadSchedState.None) { @@ -470,7 +466,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading { KernelContext.CriticalSection.Enter(); - if (ShallBeTerminated || SchedFlags == ThreadSchedState.TerminationPending) + if (TerminationRequested) { KernelContext.CriticalSection.Leave(); @@ -552,7 +548,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading return KernelResult.InvalidState; } - if (!ShallBeTerminated && SchedFlags != ThreadSchedState.TerminationPending) + if (!TerminationRequested) { if (pause) {