diff --git a/src/core/hle/kernel/address_arbiter.cpp b/src/core/hle/kernel/address_arbiter.cpp index fc71adc25..5970d71c8 100644 --- a/src/core/hle/kernel/address_arbiter.cpp +++ b/src/core/hle/kernel/address_arbiter.cpp @@ -2,6 +2,7 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#include #include "common/common_types.h" #include "common/logging/log.h" #include "core/hle/kernel/address_arbiter.h" @@ -14,6 +15,55 @@ namespace Kernel { +void AddressArbiter::WaitThread(SharedPtr thread, VAddr wait_address) { + thread->wait_address = wait_address; + thread->status = THREADSTATUS_WAIT_ARB; + waiting_threads.emplace_back(std::move(thread)); +} + +void AddressArbiter::ResumeAllThreads(VAddr address) { + // Determine which threads are waiting on this address, those should be woken up. + auto itr = std::stable_partition(waiting_threads.begin(), waiting_threads.end(), + [address](const auto& thread) { + ASSERT_MSG(thread->status == THREADSTATUS_WAIT_ARB, + "Inconsistent AddressArbiter state"); + return thread->wait_address != address; + }); + + // Wake up all the found threads + std::for_each(itr, waiting_threads.end(), [](auto& thread) { thread->ResumeFromWait(); }); + + // Remove the woken up threads from the wait list. + waiting_threads.erase(itr, waiting_threads.end()); +} + +SharedPtr AddressArbiter::ResumeHighestPriorityThread(VAddr address) { + // Determine which threads are waiting on this address, those should be considered for wakeup. + auto matches_start = std::stable_partition( + waiting_threads.begin(), waiting_threads.end(), [address](const auto& thread) { + ASSERT_MSG(thread->status == THREADSTATUS_WAIT_ARB, + "Inconsistent AddressArbiter state"); + return thread->wait_address != address; + }); + + // Iterate through threads, find highest priority thread that is waiting to be arbitrated. + // Note: The real kernel will pick the first thread in the list if more than one have the + // same highest priority value. Lower priority values mean higher priority. + auto itr = std::min_element(matches_start, waiting_threads.end(), + [](const auto& lhs, const auto& rhs) { + return lhs->current_priority < rhs->current_priority; + }); + + if (itr == waiting_threads.end()) + return nullptr; + + auto thread = *itr; + thread->ResumeFromWait(); + + waiting_threads.erase(itr); + return thread; +} + AddressArbiter::AddressArbiter() {} AddressArbiter::~AddressArbiter() {} @@ -25,32 +75,32 @@ SharedPtr AddressArbiter::Create(std::string name) { return address_arbiter; } -ResultCode AddressArbiter::ArbitrateAddress(ArbitrationType type, VAddr address, s32 value, - u64 nanoseconds) { +ResultCode AddressArbiter::ArbitrateAddress(SharedPtr thread, ArbitrationType type, + VAddr address, s32 value, u64 nanoseconds) { switch (type) { // Signal thread(s) waiting for arbitrate address... case ArbitrationType::Signal: // Negative value means resume all threads if (value < 0) { - ArbitrateAllThreads(address); + ResumeAllThreads(address); } else { // Resume first N threads for (int i = 0; i < value; i++) - ArbitrateHighestPriorityThread(address); + ResumeHighestPriorityThread(address); } break; // Wait current thread (acquire the arbiter)... case ArbitrationType::WaitIfLessThan: if ((s32)Memory::Read32(address) < value) { - Kernel::WaitCurrentThread_ArbitrateAddress(address); + WaitThread(std::move(thread), address); } break; case ArbitrationType::WaitIfLessThanWithTimeout: if ((s32)Memory::Read32(address) < value) { - Kernel::WaitCurrentThread_ArbitrateAddress(address); - GetCurrentThread()->WakeAfterDelay(nanoseconds); + thread->WakeAfterDelay(nanoseconds); + WaitThread(std::move(thread), address); } break; case ArbitrationType::DecrementAndWaitIfLessThan: { @@ -58,7 +108,7 @@ ResultCode AddressArbiter::ArbitrateAddress(ArbitrationType type, VAddr address, if (memory_value < value) { // Only change the memory value if the thread should wait Memory::Write32(address, (s32)memory_value - 1); - Kernel::WaitCurrentThread_ArbitrateAddress(address); + WaitThread(std::move(thread), address); } break; } @@ -67,8 +117,8 @@ ResultCode AddressArbiter::ArbitrateAddress(ArbitrationType type, VAddr address, if (memory_value < value) { // Only change the memory value if the thread should wait Memory::Write32(address, (s32)memory_value - 1); - Kernel::WaitCurrentThread_ArbitrateAddress(address); - GetCurrentThread()->WakeAfterDelay(nanoseconds); + thread->WakeAfterDelay(nanoseconds); + WaitThread(std::move(thread), address); } break; } diff --git a/src/core/hle/kernel/address_arbiter.h b/src/core/hle/kernel/address_arbiter.h index 1d24401b1..9b9bdd311 100644 --- a/src/core/hle/kernel/address_arbiter.h +++ b/src/core/hle/kernel/address_arbiter.h @@ -4,6 +4,7 @@ #pragma once +#include #include "common/common_types.h" #include "core/hle/kernel/kernel.h" #include "core/hle/result.h" @@ -11,13 +12,15 @@ // Address arbiters are an underlying kernel synchronization object that can be created/used via // supervisor calls (SVCs). They function as sort of a global lock. Typically, games/other CTR // applications use them as an underlying mechanism to implement thread-safe barriers, events, and -// semphores. +// semaphores. //////////////////////////////////////////////////////////////////////////////////////////////////// // Kernel namespace namespace Kernel { +class Thread; + enum class ArbitrationType : u32 { Signal, WaitIfLessThan, @@ -50,11 +53,25 @@ public: std::string name; ///< Name of address arbiter object (optional) - ResultCode ArbitrateAddress(ArbitrationType type, VAddr address, s32 value, u64 nanoseconds); + ResultCode ArbitrateAddress(SharedPtr thread, ArbitrationType type, VAddr address, + s32 value, u64 nanoseconds); private: AddressArbiter(); ~AddressArbiter() override; + + /// Puts the thread to wait on the specified arbitration address under this address arbiter. + void WaitThread(SharedPtr thread, VAddr wait_address); + + /// Resume all threads found to be waiting on the address under this address arbiter + void ResumeAllThreads(VAddr address); + + /// Resume one thread found to be waiting on the address under this address arbiter and return + /// the resumed thread. + SharedPtr ResumeHighestPriorityThread(VAddr address); + + /// Threads waiting for the address arbiter to be signaled. + std::vector> waiting_threads; }; -} // namespace FileSys +} // namespace Kernel diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp index 1d77fb582..b7ee35481 100644 --- a/src/core/hle/kernel/thread.cpp +++ b/src/core/hle/kernel/thread.cpp @@ -67,16 +67,6 @@ Thread* GetCurrentThread() { return current_thread.get(); } -/** - * Check if the specified thread is waiting on the specified address to be arbitrated - * @param thread The thread to test - * @param wait_address The address to test against - * @return True if the thread is waiting, false otherwise - */ -static bool CheckWait_AddressArbiter(const Thread* thread, VAddr wait_address) { - return thread->status == THREADSTATUS_WAIT_ARB && wait_address == thread->wait_address; -} - void Thread::Stop() { // Cancel any outstanding wakeup events for this thread CoreTiming::UnscheduleEvent(ThreadWakeupEventType, callback_handle); @@ -109,40 +99,6 @@ void Thread::Stop() { Kernel::g_current_process->tls_slots[tls_page].reset(tls_slot); } -Thread* ArbitrateHighestPriorityThread(u32 address) { - Thread* highest_priority_thread = nullptr; - u32 priority = THREADPRIO_LOWEST; - - // Iterate through threads, find highest priority thread that is waiting to be arbitrated... - for (auto& thread : thread_list) { - if (!CheckWait_AddressArbiter(thread.get(), address)) - continue; - - if (thread == nullptr) - continue; - - if (thread->current_priority <= priority) { - highest_priority_thread = thread.get(); - priority = thread->current_priority; - } - } - - // If a thread was arbitrated, resume it - if (nullptr != highest_priority_thread) { - highest_priority_thread->ResumeFromWait(); - } - - return highest_priority_thread; -} - -void ArbitrateAllThreads(u32 address) { - // Resume all threads found to be waiting on the address - for (auto& thread : thread_list) { - if (CheckWait_AddressArbiter(thread.get(), address)) - thread->ResumeFromWait(); - } -} - /** * Switches the CPU's active thread context to that of the specified thread * @param new_thread The thread to switch to @@ -220,12 +176,6 @@ void WaitCurrentThread_Sleep() { thread->status = THREADSTATUS_WAIT_SLEEP; } -void WaitCurrentThread_ArbitrateAddress(VAddr wait_address) { - Thread* thread = GetCurrentThread(); - thread->wait_address = wait_address; - thread->status = THREADSTATUS_WAIT_ARB; -} - void ExitCurrentThread() { Thread* thread = GetCurrentThread(); thread->Stop(); diff --git a/src/core/hle/svc.cpp b/src/core/hle/svc.cpp index d8cb7f654..0a6b561c6 100644 --- a/src/core/hle/svc.cpp +++ b/src/core/hle/svc.cpp @@ -629,7 +629,8 @@ static ResultCode ArbitrateAddress(Kernel::Handle handle, u32 address, u32 type, if (arbiter == nullptr) return ERR_INVALID_HANDLE; - auto res = arbiter->ArbitrateAddress(static_cast(type), address, value, + auto res = arbiter->ArbitrateAddress(Kernel::GetCurrentThread(), + static_cast(type), address, value, nanoseconds); // TODO(Subv): Identify in which specific cases this call should cause a reschedule.