diff --git a/src/core/hle/kernel/event.cpp b/src/core/hle/kernel/event.cpp index 3e116e3df..e1f42af05 100644 --- a/src/core/hle/kernel/event.cpp +++ b/src/core/hle/kernel/event.cpp @@ -30,12 +30,12 @@ SharedPtr Event::Create(ResetType reset_type, std::string name) { return evt; } -bool Event::ShouldWait() { +bool Event::ShouldWait(Thread* thread) const { return !signaled; } -void Event::Acquire() { - ASSERT_MSG(!ShouldWait(), "object unavailable!"); +void Event::Acquire(Thread* thread) { + ASSERT_MSG(!ShouldWait(thread), "object unavailable!"); // Release the event if it's not sticky... if (reset_type != ResetType::Sticky) diff --git a/src/core/hle/kernel/event.h b/src/core/hle/kernel/event.h index 8dcd23edb..39452bf33 100644 --- a/src/core/hle/kernel/event.h +++ b/src/core/hle/kernel/event.h @@ -35,8 +35,8 @@ public: bool signaled; ///< Whether the event has already been signaled std::string name; ///< Name of event (optional) - bool ShouldWait() override; - void Acquire() override; + bool ShouldWait(Thread* thread) const override; + void Acquire(Thread* thread) override; void Signal(); void Clear(); diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp index 1db8e102f..ef9dbafa5 100644 --- a/src/core/hle/kernel/kernel.cpp +++ b/src/core/hle/kernel/kernel.cpp @@ -39,11 +39,6 @@ SharedPtr WaitObject::GetHighestPriorityReadyThread() { thread->status == THREADSTATUS_DEAD; }); - // TODO(Subv): This call should be performed inside the loop below to check if an object can be - // acquired by a particular thread. This is useful for things like recursive locking of Mutexes. - if (ShouldWait()) - return nullptr; - Thread* candidate = nullptr; s32 candidate_priority = THREADPRIO_LOWEST + 1; @@ -51,9 +46,12 @@ SharedPtr WaitObject::GetHighestPriorityReadyThread() { if (thread->current_priority >= candidate_priority) continue; + if (ShouldWait(thread.get())) + continue; + bool ready_to_run = std::none_of(thread->wait_objects.begin(), thread->wait_objects.end(), - [](const SharedPtr& object) { return object->ShouldWait(); }); + [&thread](const SharedPtr& object) { return object->ShouldWait(thread.get()); }); if (ready_to_run) { candidate = thread.get(); candidate_priority = thread->current_priority; @@ -66,7 +64,7 @@ SharedPtr WaitObject::GetHighestPriorityReadyThread() { void WaitObject::WakeupAllWaitingThreads() { while (auto thread = GetHighestPriorityReadyThread()) { if (!thread->IsSleepingOnWaitAll()) { - Acquire(); + Acquire(thread.get()); // Set the output index of the WaitSynchronizationN call to the index of this object. if (thread->wait_set_output) { thread->SetWaitSynchronizationOutput(thread->GetWaitObjectIndex(this)); @@ -74,7 +72,7 @@ void WaitObject::WakeupAllWaitingThreads() { } } else { for (auto& object : thread->wait_objects) { - object->Acquire(); + object->Acquire(thread.get()); object->RemoveWaitingThread(thread.get()); } // Note: This case doesn't update the output index of WaitSynchronizationN. diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h index 9503e7d04..67eae93f2 100644 --- a/src/core/hle/kernel/kernel.h +++ b/src/core/hle/kernel/kernel.h @@ -132,13 +132,14 @@ using SharedPtr = boost::intrusive_ptr; class WaitObject : public Object { public: /** - * Check if the current thread should wait until the object is available + * Check if the specified thread should wait until the object is available + * @param thread The thread about which we're deciding. * @return True if the current thread should wait due to this object being unavailable */ - virtual bool ShouldWait() = 0; + virtual bool ShouldWait(Thread* thread) const = 0; - /// Acquire/lock the object if it is available - virtual void Acquire() = 0; + /// Acquire/lock the object for the specified thread if it is available + virtual void Acquire(Thread* thread) = 0; /** * Add a thread to wait on this object diff --git a/src/core/hle/kernel/mutex.cpp b/src/core/hle/kernel/mutex.cpp index 8d92a9b8e..072e4e7c1 100644 --- a/src/core/hle/kernel/mutex.cpp +++ b/src/core/hle/kernel/mutex.cpp @@ -40,31 +40,19 @@ SharedPtr Mutex::Create(bool initial_locked, std::string name) { mutex->name = std::move(name); mutex->holding_thread = nullptr; - // Acquire mutex with current thread if initialized as locked... + // Acquire mutex with current thread if initialized as locked if (initial_locked) - mutex->Acquire(); + mutex->Acquire(GetCurrentThread()); return mutex; } -bool Mutex::ShouldWait() { - auto thread = GetCurrentThread(); - bool wait = lock_count > 0 && holding_thread != thread; - - // If the holding thread of the mutex is lower priority than this thread, that thread should - // temporarily inherit this thread's priority - if (wait && thread->current_priority < holding_thread->current_priority) - holding_thread->BoostPriority(thread->current_priority); - - return wait; +bool Mutex::ShouldWait(Thread* thread) const { + return lock_count > 0 && thread != holding_thread; } -void Mutex::Acquire() { - Acquire(GetCurrentThread()); -} - -void Mutex::Acquire(SharedPtr thread) { - ASSERT_MSG(!ShouldWait(), "object unavailable!"); +void Mutex::Acquire(Thread* thread) { + ASSERT_MSG(!ShouldWait(thread), "object unavailable!"); // Actually "acquire" the mutex only if we don't already have it... if (lock_count == 0) { diff --git a/src/core/hle/kernel/mutex.h b/src/core/hle/kernel/mutex.h index 53c3dc1f1..98b3d40b5 100644 --- a/src/core/hle/kernel/mutex.h +++ b/src/core/hle/kernel/mutex.h @@ -38,8 +38,9 @@ public: std::string name; ///< Name of mutex (optional) SharedPtr holding_thread; ///< Thread that has acquired the mutex - bool ShouldWait() override; - void Acquire() override; + bool ShouldWait(Thread* thread) const override; + void Acquire(Thread* thread) override; + /** * Acquires the specified mutex for the specified thread diff --git a/src/core/hle/kernel/semaphore.cpp b/src/core/hle/kernel/semaphore.cpp index bf7600780..5e6139265 100644 --- a/src/core/hle/kernel/semaphore.cpp +++ b/src/core/hle/kernel/semaphore.cpp @@ -30,12 +30,12 @@ ResultVal> Semaphore::Create(s32 initial_count, s32 max_cou return MakeResult>(std::move(semaphore)); } -bool Semaphore::ShouldWait() { +bool Semaphore::ShouldWait(Thread* thread) const { return available_count <= 0; } -void Semaphore::Acquire() { - ASSERT_MSG(!ShouldWait(), "object unavailable!"); +void Semaphore::Acquire(Thread* thread) { + ASSERT_MSG(!ShouldWait(thread), "object unavailable!"); --available_count; } diff --git a/src/core/hle/kernel/semaphore.h b/src/core/hle/kernel/semaphore.h index e01908a25..cde94f7cc 100644 --- a/src/core/hle/kernel/semaphore.h +++ b/src/core/hle/kernel/semaphore.h @@ -39,8 +39,8 @@ public: s32 available_count; ///< Number of free slots left in the semaphore std::string name; ///< Name of semaphore (optional) - bool ShouldWait() override; - void Acquire() override; + bool ShouldWait(Thread* thread) const override; + void Acquire(Thread* thread) override; /** * Releases a certain number of slots from a semaphore. diff --git a/src/core/hle/kernel/server_port.cpp b/src/core/hle/kernel/server_port.cpp index 6c19aa7c0..fd3bbbcad 100644 --- a/src/core/hle/kernel/server_port.cpp +++ b/src/core/hle/kernel/server_port.cpp @@ -14,13 +14,13 @@ namespace Kernel { ServerPort::ServerPort() {} ServerPort::~ServerPort() {} -bool ServerPort::ShouldWait() { +bool ServerPort::ShouldWait(Thread* thread) const { // If there are no pending sessions, we wait until a new one is added. return pending_sessions.size() == 0; } -void ServerPort::Acquire() { - ASSERT_MSG(!ShouldWait(), "object unavailable!"); +void ServerPort::Acquire(Thread* thread) { + ASSERT_MSG(!ShouldWait(thread), "object unavailable!"); } std::tuple, SharedPtr> ServerPort::CreatePortPair( diff --git a/src/core/hle/kernel/server_port.h b/src/core/hle/kernel/server_port.h index b0f8df62c..6f8bdb6a9 100644 --- a/src/core/hle/kernel/server_port.h +++ b/src/core/hle/kernel/server_port.h @@ -53,8 +53,8 @@ public: /// ServerSessions created from this port inherit a reference to this handler. std::shared_ptr hle_handler; - bool ShouldWait() override; - void Acquire() override; + bool ShouldWait(Thread* thread) const override; + void Acquire(Thread* thread) override; private: ServerPort(); diff --git a/src/core/hle/kernel/server_session.cpp b/src/core/hle/kernel/server_session.cpp index 146458c1c..9447ff236 100644 --- a/src/core/hle/kernel/server_session.cpp +++ b/src/core/hle/kernel/server_session.cpp @@ -29,12 +29,12 @@ ResultVal> ServerSession::Create( return MakeResult>(std::move(server_session)); } -bool ServerSession::ShouldWait() { +bool ServerSession::ShouldWait(Thread* thread) const { return !signaled; } -void ServerSession::Acquire() { - ASSERT_MSG(!ShouldWait(), "object unavailable!"); +void ServerSession::Acquire(Thread* thread) { + ASSERT_MSG(!ShouldWait(thread), "object unavailable!"); signaled = false; } diff --git a/src/core/hle/kernel/server_session.h b/src/core/hle/kernel/server_session.h index 458284a5d..c088b9a19 100644 --- a/src/core/hle/kernel/server_session.h +++ b/src/core/hle/kernel/server_session.h @@ -57,9 +57,9 @@ public: */ ResultCode HandleSyncRequest(); - bool ShouldWait() override; + bool ShouldWait(Thread* thread) const override; - void Acquire() override; + void Acquire(Thread* thread) override; std::string name; ///< The name of this session (optional) bool signaled; ///< Whether there's new data available to this ServerSession diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp index 5fb95dada..7d03a2cf7 100644 --- a/src/core/hle/kernel/thread.cpp +++ b/src/core/hle/kernel/thread.cpp @@ -27,12 +27,12 @@ namespace Kernel { /// Event type for the thread wake up event static int ThreadWakeupEventType; -bool Thread::ShouldWait() { +bool Thread::ShouldWait(Thread* thread) const { return status != THREADSTATUS_DEAD; } -void Thread::Acquire() { - ASSERT_MSG(!ShouldWait(), "object unavailable!"); +void Thread::Acquire(Thread* thread) { + ASSERT_MSG(!ShouldWait(thread), "object unavailable!"); } // TODO(yuriks): This can be removed if Thread objects are explicitly pooled in the future, allowing diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h index c77ac644d..f2bc1ec9c 100644 --- a/src/core/hle/kernel/thread.h +++ b/src/core/hle/kernel/thread.h @@ -72,8 +72,8 @@ public: return HANDLE_TYPE; } - bool ShouldWait() override; - void Acquire() override; + bool ShouldWait(Thread* thread) const override; + void Acquire(Thread* thread) override; /** * Gets the thread's current priority diff --git a/src/core/hle/kernel/timer.cpp b/src/core/hle/kernel/timer.cpp index b50cf520d..8f2bc4c7f 100644 --- a/src/core/hle/kernel/timer.cpp +++ b/src/core/hle/kernel/timer.cpp @@ -39,12 +39,12 @@ SharedPtr Timer::Create(ResetType reset_type, std::string name) { return timer; } -bool Timer::ShouldWait() { +bool Timer::ShouldWait(Thread* thread) const { return !signaled; } -void Timer::Acquire() { - ASSERT_MSG(!ShouldWait(), "object unavailable!"); +void Timer::Acquire(Thread* thread) { + ASSERT_MSG(!ShouldWait(thread), "object unavailable!"); if (reset_type == ResetType::OneShot) signaled = false; diff --git a/src/core/hle/kernel/timer.h b/src/core/hle/kernel/timer.h index 18ea0236b..2e3b31b23 100644 --- a/src/core/hle/kernel/timer.h +++ b/src/core/hle/kernel/timer.h @@ -39,8 +39,8 @@ public: u64 initial_delay; ///< The delay until the timer fires for the first time u64 interval_delay; ///< The delay until the timer fires after the first time - bool ShouldWait() override; - void Acquire() override; + bool ShouldWait(Thread* thread) const override; + void Acquire(Thread* thread) override; /** * Starts the timer, with the specified initial delay and interval. diff --git a/src/core/hle/svc.cpp b/src/core/hle/svc.cpp index 5b538be22..159ac0bf6 100644 --- a/src/core/hle/svc.cpp +++ b/src/core/hle/svc.cpp @@ -272,7 +272,7 @@ static ResultCode WaitSynchronization1(Kernel::Handle handle, s64 nano_seconds) LOG_TRACE(Kernel_SVC, "called handle=0x%08X(%s:%s), nanoseconds=%lld", handle, object->GetTypeName().c_str(), object->GetName().c_str(), nano_seconds); - if (object->ShouldWait()) { + if (object->ShouldWait(thread)) { if (nano_seconds == 0) return ERR_SYNC_TIMEOUT; @@ -294,7 +294,7 @@ static ResultCode WaitSynchronization1(Kernel::Handle handle, s64 nano_seconds) return ERR_SYNC_TIMEOUT; } - object->Acquire(); + object->Acquire(thread); return RESULT_SUCCESS; } @@ -336,11 +336,11 @@ static ResultCode WaitSynchronizationN(s32* out, Kernel::Handle* handles, s32 ha if (wait_all) { bool all_available = std::all_of(objects.begin(), objects.end(), - [](const ObjectPtr& object) { return !object->ShouldWait(); }); + [thread](const ObjectPtr& object) { return !object->ShouldWait(thread); }); if (all_available) { // We can acquire all objects right now, do so. for (auto& object : objects) - object->Acquire(); + object->Acquire(thread); // Note: In this case, the `out` parameter is not set, // and retains whatever value it had before. return RESULT_SUCCESS; @@ -380,12 +380,12 @@ static ResultCode WaitSynchronizationN(s32* out, Kernel::Handle* handles, s32 ha } else { // Find the first object that is acquirable in the provided list of objects auto itr = std::find_if(objects.begin(), objects.end(), - [](const ObjectPtr& object) { return !object->ShouldWait(); }); + [thread](const ObjectPtr& object) { return !object->ShouldWait(thread); }); if (itr != objects.end()) { // We found a ready object, acquire it and set the result value Kernel::WaitObject* object = itr->get(); - object->Acquire(); + object->Acquire(thread); *out = std::distance(objects.begin(), itr); return RESULT_SUCCESS; }