diff --git a/src/core/hle/kernel/errors.h b/src/core/hle/kernel/errors.h index 29d8dfdaa..5be20c878 100644 --- a/src/core/hle/kernel/errors.h +++ b/src/core/hle/kernel/errors.h @@ -20,6 +20,7 @@ enum { MaxConnectionsReached = 52, // Confirmed Switch OS error codes + MisalignedAddress = 102, InvalidHandle = 114, Timeout = 117, SynchronizationCanceled = 118, diff --git a/src/core/hle/kernel/mutex.cpp b/src/core/hle/kernel/mutex.cpp index 0b9dc700c..50a9a0805 100644 --- a/src/core/hle/kernel/mutex.cpp +++ b/src/core/hle/kernel/mutex.cpp @@ -7,6 +7,7 @@ #include #include "common/assert.h" #include "core/core.h" +#include "core/hle/kernel/errors.h" #include "core/hle/kernel/handle_table.h" #include "core/hle/kernel/kernel.h" #include "core/hle/kernel/mutex.h" @@ -15,6 +16,30 @@ namespace Kernel { +/// Returns the number of threads that are waiting for a mutex, and the highest priority one among +/// those. +static std::pair, u32> GetHighestPriorityMutexWaitingThread(VAddr mutex_addr) { + auto& thread_list = Core::System::GetInstance().Scheduler().GetThreadList(); + + SharedPtr highest_priority_thread; + u32 num_waiters = 0; + + for (auto& thread : thread_list) { + if (thread->mutex_wait_address != mutex_addr) + continue; + + ASSERT(thread->status == THREADSTATUS_WAIT_MUTEX); + + ++num_waiters; + if (highest_priority_thread == nullptr || + thread->GetPriority() < highest_priority_thread->GetPriority()) { + highest_priority_thread = thread; + } + } + + return {highest_priority_thread, num_waiters}; +} + void ReleaseThreadMutexes(Thread* thread) { for (auto& mtx : thread->held_mutexes) { mtx->SetHasWaiters(false); @@ -135,4 +160,73 @@ void Mutex::SetHasWaiters(bool has_waiters) { Memory::Write32(guest_addr, guest_state.raw); } +ResultCode Mutex::TryAcquire(VAddr address, Handle holding_thread_handle, + Handle requesting_thread_handle) { + // The mutex address must be 4-byte aligned + if ((address % sizeof(u32)) != 0) { + return ResultCode(ErrorModule::Kernel, ErrCodes::MisalignedAddress); + } + + SharedPtr holding_thread = g_handle_table.Get(holding_thread_handle); + SharedPtr requesting_thread = g_handle_table.Get(requesting_thread_handle); + + // TODO(Subv): It is currently unknown if it is possible to lock a mutex in behalf of another + // thread. + ASSERT(requesting_thread == GetCurrentThread()); + + u32 addr_value = Memory::Read32(address); + + // If the mutex isn't being held, just return success. + if (addr_value != (holding_thread_handle | Mutex::MutexHasWaitersFlag)) { + return RESULT_SUCCESS; + } + + if (holding_thread == nullptr) + return ERR_INVALID_HANDLE; + + // Wait until the mutex is released + requesting_thread->mutex_wait_address = address; + requesting_thread->wait_handle = requesting_thread_handle; + + requesting_thread->status = THREADSTATUS_WAIT_MUTEX; + requesting_thread->wakeup_callback = nullptr; + + Core::System::GetInstance().PrepareReschedule(); + + return RESULT_SUCCESS; +} + +ResultCode Mutex::Release(VAddr address) { + // The mutex address must be 4-byte aligned + if ((address % sizeof(u32)) != 0) { + return ResultCode(ErrorModule::Kernel, ErrCodes::MisalignedAddress); + } + + auto [thread, num_waiters] = GetHighestPriorityMutexWaitingThread(address); + + // There are no more threads waiting for the mutex, release it completely. + if (thread == nullptr) { + Memory::Write32(address, 0); + return RESULT_SUCCESS; + } + + u32 mutex_value = thread->wait_handle; + + if (num_waiters >= 2) { + // Notify the guest that there are still some threads waiting for the mutex + mutex_value |= Mutex::MutexHasWaitersFlag; + } + + // Grant the mutex to the next waiting thread and resume it. + Memory::Write32(address, mutex_value); + + ASSERT(thread->status == THREADSTATUS_WAIT_MUTEX); + thread->ResumeFromWait(); + + thread->condvar_wait_address = 0; + thread->mutex_wait_address = 0; + thread->wait_handle = 0; + + return RESULT_SUCCESS; +} } // namespace Kernel diff --git a/src/core/hle/kernel/mutex.h b/src/core/hle/kernel/mutex.h index 38db21005..310923087 100644 --- a/src/core/hle/kernel/mutex.h +++ b/src/core/hle/kernel/mutex.h @@ -77,6 +77,18 @@ public: /// Sets the has_waiters bit in the guest state. void SetHasWaiters(bool has_waiters); + /// Flag that indicates that a mutex still has threads waiting for it. + static constexpr u32 MutexHasWaitersFlag = 0x40000000; + /// Mask of the bits in a mutex address value that contain the mutex owner. + static constexpr u32 MutexOwnerMask = 0xBFFFFFFF; + + /// Attempts to acquire a mutex at the specified address. + static ResultCode TryAcquire(VAddr address, Handle holding_thread_handle, + Handle requesting_thread_handle); + + /// Releases the mutex at the specified address. + static ResultCode Release(VAddr address); + private: Mutex(); ~Mutex() override; diff --git a/src/core/hle/kernel/svc.cpp b/src/core/hle/kernel/svc.cpp index 6204bcaaa..92273b488 100644 --- a/src/core/hle/kernel/svc.cpp +++ b/src/core/hle/kernel/svc.cpp @@ -262,32 +262,14 @@ static ResultCode ArbitrateLock(Handle holding_thread_handle, VAddr mutex_addr, "requesting_current_thread_handle=0x%08X", holding_thread_handle, mutex_addr, requesting_thread_handle); - SharedPtr holding_thread = g_handle_table.Get(holding_thread_handle); - SharedPtr requesting_thread = g_handle_table.Get(requesting_thread_handle); - - ASSERT(requesting_thread); - ASSERT(requesting_thread == GetCurrentThread()); - - SharedPtr mutex = g_object_address_table.Get(mutex_addr); - if (!mutex) { - // Create a new mutex for the specified address if one does not already exist - mutex = Mutex::Create(holding_thread, mutex_addr); - mutex->name = Common::StringFromFormat("mutex-%llx", mutex_addr); - } - - ASSERT(holding_thread == mutex->GetHoldingThread()); - - return WaitSynchronization1(mutex, requesting_thread.get()); + return Mutex::TryAcquire(mutex_addr, holding_thread_handle, requesting_thread_handle); } /// Unlock a mutex static ResultCode ArbitrateUnlock(VAddr mutex_addr) { LOG_TRACE(Kernel_SVC, "called mutex_addr=0x%llx", mutex_addr); - SharedPtr mutex = g_object_address_table.Get(mutex_addr); - ASSERT(mutex); - - return mutex->Release(GetCurrentThread()); + return Mutex::Release(mutex_addr); } /// Break program execution diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp index f3a8aa4aa..0a0ad7cfb 100644 --- a/src/core/hle/kernel/thread.cpp +++ b/src/core/hle/kernel/thread.cpp @@ -126,6 +126,14 @@ static void ThreadWakeupCallback(u64 thread_handle, int cycles_late) { resume = thread->wakeup_callback(ThreadWakeupReason::Timeout, thread, nullptr, 0); } + if (thread->mutex_wait_address != 0 || thread->condvar_wait_address != 0 || + thread->wait_handle) { + ASSERT(thread->status == THREADSTATUS_WAIT_MUTEX); + thread->mutex_wait_address = 0; + thread->condvar_wait_address = 0; + thread->wait_handle = 0; + } + if (resume) thread->ResumeFromWait(); } @@ -151,6 +159,7 @@ void Thread::ResumeFromWait() { case THREADSTATUS_WAIT_HLE_EVENT: case THREADSTATUS_WAIT_SLEEP: case THREADSTATUS_WAIT_IPC: + case THREADSTATUS_WAIT_MUTEX: break; case THREADSTATUS_READY: @@ -256,7 +265,9 @@ ResultVal> Thread::Create(std::string name, VAddr entry_point, thread->last_running_ticks = CoreTiming::GetTicks(); thread->processor_id = processor_id; thread->wait_objects.clear(); - thread->wait_address = 0; + thread->mutex_wait_address = 0; + thread->condvar_wait_address = 0; + thread->wait_handle = 0; thread->name = std::move(name); thread->callback_handle = wakeup_callback_handle_table.Create(thread).Unwrap(); thread->owner_process = owner_process; diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h index dbf47e269..a3a6e6a64 100644 --- a/src/core/hle/kernel/thread.h +++ b/src/core/hle/kernel/thread.h @@ -43,6 +43,7 @@ enum ThreadStatus { THREADSTATUS_WAIT_IPC, ///< Waiting for the reply from an IPC request THREADSTATUS_WAIT_SYNCH_ANY, ///< Waiting due to WaitSynch1 or WaitSynchN with wait_all = false THREADSTATUS_WAIT_SYNCH_ALL, ///< Waiting due to WaitSynchronizationN with wait_all = true + THREADSTATUS_WAIT_MUTEX, ///< Waiting due to an ArbitrateLock/WaitProcessWideKey svc THREADSTATUS_DORMANT, ///< Created but not yet made ready THREADSTATUS_DEAD ///< Run to completion, or forcefully terminated }; @@ -217,7 +218,10 @@ public: // passed to WaitSynchronization1/N. std::vector> wait_objects; - VAddr wait_address; ///< If waiting on an AddressArbiter, this is the arbitration address + // If waiting on a ConditionVariable, this is the ConditionVariable address + VAddr condvar_wait_address; + VAddr mutex_wait_address; ///< If waiting on a Mutex, this is the mutex address + Handle wait_handle; ///< The handle used to wait for the mutex. std::string name;