From 4540bcfaf7f9e3a87aca5a2423911575e8483b37 Mon Sep 17 00:00:00 2001 From: Liam Date: Sat, 1 Jul 2023 15:03:48 -0400 Subject: [PATCH] k_server_session: translate special header for non-HLE requests --- src/core/CMakeLists.txt | 1 + src/core/hle/kernel/k_server_session.cpp | 165 +++++- src/core/hle/kernel/message_buffer.h | 612 +++++++++++++++++++++++ 3 files changed, 771 insertions(+), 7 deletions(-) create mode 100644 src/core/hle/kernel/message_buffer.h diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 3655b8478..28cb6f86f 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -285,6 +285,7 @@ add_library(core STATIC hle/kernel/kernel.cpp hle/kernel/kernel.h hle/kernel/memory_types.h + hle/kernel/message_buffer.h hle/kernel/physical_core.cpp hle/kernel/physical_core.h hle/kernel/physical_memory.h diff --git a/src/core/hle/kernel/k_server_session.cpp b/src/core/hle/kernel/k_server_session.cpp index c66aff501..c64ceb530 100644 --- a/src/core/hle/kernel/k_server_session.cpp +++ b/src/core/hle/kernel/k_server_session.cpp @@ -20,12 +20,132 @@ #include "core/hle/kernel/k_thread.h" #include "core/hle/kernel/k_thread_queue.h" #include "core/hle/kernel/kernel.h" +#include "core/hle/kernel/message_buffer.h" #include "core/hle/service/hle_ipc.h" #include "core/hle/service/ipc_helpers.h" #include "core/memory.h" namespace Kernel { +namespace { + +template +Result ProcessMessageSpecialData(KProcess& dst_process, KProcess& src_process, KThread& src_thread, + MessageBuffer& dst_msg, const MessageBuffer& src_msg, + MessageBuffer::SpecialHeader& src_special_header) { + // Copy the special header to the destination. + s32 offset = dst_msg.Set(src_special_header); + + // Copy the process ID. + if (src_special_header.GetHasProcessId()) { + offset = dst_msg.SetProcessId(offset, src_process.GetProcessId()); + } + + // Prepare to process handles. + auto& dst_handle_table = dst_process.GetHandleTable(); + auto& src_handle_table = src_process.GetHandleTable(); + Result result = ResultSuccess; + + // Process copy handles. + for (auto i = 0; i < src_special_header.GetCopyHandleCount(); ++i) { + // Get the handles. + const Handle src_handle = src_msg.GetHandle(offset); + Handle dst_handle = Svc::InvalidHandle; + + // If we're in a success state, try to move the handle to the new table. + if (R_SUCCEEDED(result) && src_handle != Svc::InvalidHandle) { + KScopedAutoObject obj = + src_handle_table.GetObjectForIpc(src_handle, std::addressof(src_thread)); + if (obj.IsNotNull()) { + Result add_result = + dst_handle_table.Add(std::addressof(dst_handle), obj.GetPointerUnsafe()); + if (R_FAILED(add_result)) { + result = add_result; + dst_handle = Svc::InvalidHandle; + } + } else { + result = ResultInvalidHandle; + } + } + + // Set the handle. + offset = dst_msg.SetHandle(offset, dst_handle); + } + + // Process move handles. + if constexpr (MoveHandleAllowed) { + for (auto i = 0; i < src_special_header.GetMoveHandleCount(); ++i) { + // Get the handles. + const Handle src_handle = src_msg.GetHandle(offset); + Handle dst_handle = Svc::InvalidHandle; + + // Whether or not we've succeeded, we need to remove the handles from the source table. + if (src_handle != Svc::InvalidHandle) { + if (R_SUCCEEDED(result)) { + KScopedAutoObject obj = + src_handle_table.GetObjectForIpcWithoutPseudoHandle(src_handle); + if (obj.IsNotNull()) { + Result add_result = dst_handle_table.Add(std::addressof(dst_handle), + obj.GetPointerUnsafe()); + + src_handle_table.Remove(src_handle); + + if (R_FAILED(add_result)) { + result = add_result; + dst_handle = Svc::InvalidHandle; + } + } else { + result = ResultInvalidHandle; + } + } else { + src_handle_table.Remove(src_handle); + } + } + + // Set the handle. + offset = dst_msg.SetHandle(offset, dst_handle); + } + } + + R_RETURN(result); +} + +void CleanupSpecialData(KProcess& dst_process, u32* dst_msg_ptr, size_t dst_buffer_size) { + // Parse the message. + const MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size); + const MessageBuffer::MessageHeader dst_header(dst_msg); + const MessageBuffer::SpecialHeader dst_special_header(dst_msg, dst_header); + + // Check that the size is big enough. + if (MessageBuffer::GetMessageBufferSize(dst_header, dst_special_header) > dst_buffer_size) { + return; + } + + // Set the special header. + int offset = dst_msg.Set(dst_special_header); + + // Clear the process id, if needed. + if (dst_special_header.GetHasProcessId()) { + offset = dst_msg.SetProcessId(offset, 0); + } + + // Clear handles, as relevant. + auto& dst_handle_table = dst_process.GetHandleTable(); + for (auto i = 0; + i < (dst_special_header.GetCopyHandleCount() + dst_special_header.GetMoveHandleCount()); + ++i) { + const Handle handle = dst_msg.GetHandle(offset); + + if (handle != Svc::InvalidHandle) { + dst_handle_table.Remove(handle); + } + + offset = dst_msg.SetHandle(offset, Svc::InvalidHandle); + } +} + +} // namespace + using ThreadQueueImplForKServerSessionRequest = KThreadQueue; KServerSession::KServerSession(KernelCore& kernel) @@ -223,12 +343,27 @@ Result KServerSession::SendReply(bool is_hle) { // the reply has already been written in this case. } else { Core::Memory::Memory& memory{client_thread->GetOwnerProcess()->GetMemory()}; - KThread* server_thread{GetCurrentThreadPointer(m_kernel)}; + KThread* server_thread = GetCurrentThreadPointer(m_kernel); + KProcess& src_process = *client_thread->GetOwnerProcess(); + KProcess& dst_process = *server_thread->GetOwnerProcess(); UNIMPLEMENTED_IF(server_thread->GetOwnerProcess() != client_thread->GetOwnerProcess()); - auto* src_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress()); - auto* dst_msg_buffer = memory.GetPointer(client_message); + auto* src_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress()); + auto* dst_msg_buffer = memory.GetPointer(client_message); std::memcpy(dst_msg_buffer, src_msg_buffer, client_buffer_size); + + // Translate special header ad-hoc. + MessageBuffer src_msg(src_msg_buffer, client_buffer_size); + MessageBuffer::MessageHeader src_header(src_msg); + MessageBuffer::SpecialHeader src_special_header(src_msg, src_header); + if (src_header.GetHasSpecialHeader()) { + MessageBuffer dst_msg(dst_msg_buffer, client_buffer_size); + result = ProcessMessageSpecialData(dst_process, src_process, *server_thread, + dst_msg, src_msg, src_special_header); + if (R_FAILED(result)) { + CleanupSpecialData(dst_process, dst_msg_buffer, client_buffer_size); + } + } } } else { result = ResultSessionClosed; @@ -330,12 +465,28 @@ Result KServerSession::ReceiveRequest(std::shared_ptrPopulateFromIncomingCommandBuffer(client_thread->GetOwnerProcess()->GetHandleTable(), cmd_buf); } else { - KThread* server_thread{GetCurrentThreadPointer(m_kernel)}; - UNIMPLEMENTED_IF(server_thread->GetOwnerProcess() != client_thread->GetOwnerProcess()); + KThread* server_thread = GetCurrentThreadPointer(m_kernel); + KProcess& src_process = *client_thread->GetOwnerProcess(); + KProcess& dst_process = *server_thread->GetOwnerProcess(); + UNIMPLEMENTED_IF(client_thread->GetOwnerProcess() != server_thread->GetOwnerProcess()); - auto* src_msg_buffer = memory.GetPointer(client_message); - auto* dst_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress()); + auto* src_msg_buffer = memory.GetPointer(client_message); + auto* dst_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress()); std::memcpy(dst_msg_buffer, src_msg_buffer, client_buffer_size); + + // Translate special header ad-hoc. + // TODO: fix this mess + MessageBuffer src_msg(src_msg_buffer, client_buffer_size); + MessageBuffer::MessageHeader src_header(src_msg); + MessageBuffer::SpecialHeader src_special_header(src_msg, src_header); + if (src_header.GetHasSpecialHeader()) { + MessageBuffer dst_msg(dst_msg_buffer, client_buffer_size); + Result res = ProcessMessageSpecialData(dst_process, src_process, *client_thread, + dst_msg, src_msg, src_special_header); + if (R_FAILED(res)) { + CleanupSpecialData(dst_process, dst_msg_buffer, client_buffer_size); + } + } } // We succeeded. diff --git a/src/core/hle/kernel/message_buffer.h b/src/core/hle/kernel/message_buffer.h new file mode 100644 index 000000000..75b275310 --- /dev/null +++ b/src/core/hle/kernel/message_buffer.h @@ -0,0 +1,612 @@ +// SPDX-FileCopyrightText: 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#pragma once + +#include "common/alignment.h" +#include "common/bit_field.h" +#include "core/hle/kernel/k_thread.h" + +namespace Kernel { + +constexpr inline size_t MessageBufferSize = 0x100; + +class MessageBuffer { +public: + class MessageHeader { + private: + static constexpr inline u64 NullTag = 0; + + public: + enum class ReceiveListCountType : u32 { + None = 0, + ToMessageBuffer = 1, + ToSingleBuffer = 2, + + CountOffset = 2, + CountMax = 13, + }; + + private: + union { + std::array raw; + + struct { + // Define fields for the first header word. + union { + BitField<0, 16, u16> tag; + BitField<16, 4, u32> pointer_count; + BitField<20, 4, u32> send_count; + BitField<24, 4, u32> receive_count; + BitField<28, 4, u32> exchange_count; + }; + + // Define fields for the second header word. + union { + BitField<0, 10, u32> raw_count; + BitField<10, 4, ReceiveListCountType> receive_list_count; + BitField<14, 6, u32> reserved0; + BitField<20, 11, u32> receive_list_offset; + BitField<31, 1, u32> has_special_header; + }; + }; + } m_header; + + public: + constexpr MessageHeader() : m_header{} {} + + constexpr MessageHeader(u16 tag, bool special, s32 ptr, s32 send, s32 recv, s32 exch, + s32 raw, ReceiveListCountType recv_list) + : m_header{} { + m_header.raw[0] = 0; + m_header.raw[1] = 0; + + m_header.tag.Assign(tag); + m_header.pointer_count.Assign(ptr); + m_header.send_count.Assign(send); + m_header.receive_count.Assign(recv); + m_header.exchange_count.Assign(exch); + + m_header.raw_count.Assign(raw); + m_header.receive_list_count.Assign(recv_list); + m_header.has_special_header.Assign(special); + } + + explicit MessageHeader(const MessageBuffer& buf) : m_header{} { + buf.Get(0, m_header.raw.data(), 2); + } + + explicit MessageHeader(const u32* msg) : m_header{{msg[0], msg[1]}} {} + + constexpr u16 GetTag() const { + return m_header.tag; + } + + constexpr s32 GetPointerCount() const { + return m_header.pointer_count; + } + + constexpr s32 GetSendCount() const { + return m_header.send_count; + } + + constexpr s32 GetReceiveCount() const { + return m_header.receive_count; + } + + constexpr s32 GetExchangeCount() const { + return m_header.exchange_count; + } + + constexpr s32 GetMapAliasCount() const { + return this->GetSendCount() + this->GetReceiveCount() + this->GetExchangeCount(); + } + + constexpr s32 GetRawCount() const { + return m_header.raw_count; + } + + constexpr ReceiveListCountType GetReceiveListCount() const { + return m_header.receive_list_count; + } + + constexpr s32 GetReceiveListOffset() const { + return m_header.receive_list_offset; + } + + constexpr bool GetHasSpecialHeader() const { + return m_header.has_special_header.Value() != 0; + } + + constexpr void SetReceiveListCount(ReceiveListCountType recv_list) { + m_header.receive_list_count.Assign(recv_list); + } + + constexpr const u32* GetData() const { + return m_header.raw.data(); + } + + static constexpr size_t GetDataSize() { + return sizeof(m_header); + } + }; + + class SpecialHeader { + private: + union { + std::array raw; + + // Define fields for the header word. + BitField<0, 1, u32> has_process_id; + BitField<1, 4, u32> copy_handle_count; + BitField<5, 4, u32> move_handle_count; + } m_header; + bool m_has_header; + + public: + constexpr explicit SpecialHeader(bool pid, s32 copy, s32 move) + : m_header{}, m_has_header(true) { + m_header.has_process_id.Assign(pid); + m_header.copy_handle_count.Assign(copy); + m_header.move_handle_count.Assign(move); + } + + constexpr explicit SpecialHeader(bool pid, s32 copy, s32 move, bool _has_header) + : m_header{}, m_has_header(_has_header) { + m_header.has_process_id.Assign(pid); + m_header.copy_handle_count.Assign(copy); + m_header.move_handle_count.Assign(move); + } + + explicit SpecialHeader(const MessageBuffer& buf, const MessageHeader& hdr) + : m_header{}, m_has_header(hdr.GetHasSpecialHeader()) { + if (m_has_header) { + buf.Get(static_cast(MessageHeader::GetDataSize() / sizeof(u32)), + m_header.raw.data(), sizeof(m_header) / sizeof(u32)); + } + } + + constexpr bool GetHasProcessId() const { + return m_header.has_process_id.Value() != 0; + } + + constexpr s32 GetCopyHandleCount() const { + return m_header.copy_handle_count; + } + + constexpr s32 GetMoveHandleCount() const { + return m_header.move_handle_count; + } + + constexpr const u32* GetHeader() const { + return m_header.raw.data(); + } + + constexpr size_t GetHeaderSize() const { + if (m_has_header) { + return sizeof(m_header); + } else { + return 0; + } + } + + constexpr size_t GetDataSize() const { + if (m_has_header) { + return (this->GetHasProcessId() ? sizeof(u64) : 0) + + (this->GetCopyHandleCount() * sizeof(Handle)) + + (this->GetMoveHandleCount() * sizeof(Handle)); + } else { + return 0; + } + } + }; + + class MapAliasDescriptor { + public: + enum class Attribute : u32 { + Ipc = 0, + NonSecureIpc = 1, + NonDeviceIpc = 3, + }; + + private: + static constexpr u32 SizeLowCount = 32; + static constexpr u32 SizeHighCount = 4; + static constexpr u32 AddressLowCount = 32; + static constexpr u32 AddressMidCount = 4; + + constexpr u32 GetAddressMid(u64 address) { + return static_cast(address >> AddressLowCount) & ((1U << AddressMidCount) - 1); + } + + constexpr u32 GetAddressHigh(u64 address) { + return static_cast(address >> (AddressLowCount + AddressMidCount)); + } + + private: + union { + std::array raw; + + struct { + // Define fields for the first two words. + u32 size_low; + u32 address_low; + + // Define fields for the packed descriptor word. + union { + BitField<0, 2, Attribute> attributes; + BitField<2, 3, u32> address_high; + BitField<5, 19, u32> reserved; + BitField<24, 4, u32> size_high; + BitField<28, 4, u32> address_mid; + }; + }; + } m_data; + + public: + constexpr MapAliasDescriptor() : m_data{} {} + + MapAliasDescriptor(const void* buffer, size_t _size, Attribute attr = Attribute::Ipc) + : m_data{} { + const u64 address = reinterpret_cast(buffer); + const u64 size = static_cast(_size); + m_data.size_low = static_cast(size); + m_data.address_low = static_cast(address); + m_data.attributes.Assign(attr); + m_data.address_mid.Assign(GetAddressMid(address)); + m_data.size_high.Assign(static_cast(size >> SizeLowCount)); + m_data.address_high.Assign(GetAddressHigh(address)); + } + + MapAliasDescriptor(const MessageBuffer& buf, s32 index) : m_data{} { + buf.Get(index, m_data.raw.data(), 3); + } + + constexpr uintptr_t GetAddress() const { + return (static_cast((m_data.address_high << AddressMidCount) | m_data.address_mid) + << AddressLowCount) | + m_data.address_low; + } + + constexpr uintptr_t GetSize() const { + return (static_cast(m_data.size_high) << SizeLowCount) | m_data.size_low; + } + + constexpr Attribute GetAttribute() const { + return m_data.attributes; + } + + constexpr const u32* GetData() const { + return m_data.raw.data(); + } + + static constexpr size_t GetDataSize() { + return sizeof(m_data); + } + }; + + class PointerDescriptor { + private: + static constexpr u32 AddressLowCount = 32; + static constexpr u32 AddressMidCount = 4; + + constexpr u32 GetAddressMid(u64 address) { + return static_cast(address >> AddressLowCount) & ((1u << AddressMidCount) - 1); + } + + constexpr u32 GetAddressHigh(u64 address) { + return static_cast(address >> (AddressLowCount + AddressMidCount)); + } + + private: + union { + std::array raw; + + struct { + // Define fields for the packed descriptor word. + union { + BitField<0, 4, u32> index; + BitField<4, 2, u32> reserved0; + BitField<6, 3, u32> address_high; + BitField<9, 3, u32> reserved1; + BitField<12, 4, u32> address_mid; + BitField<16, 16, u32> size; + }; + + // Define fields for the second word. + u32 address_low; + }; + } m_data; + + public: + constexpr PointerDescriptor() : m_data{} {} + + PointerDescriptor(const void* buffer, size_t size, s32 index) : m_data{} { + const u64 address = reinterpret_cast(buffer); + + m_data.index.Assign(index); + m_data.address_high.Assign(GetAddressHigh(address)); + m_data.address_mid.Assign(GetAddressMid(address)); + m_data.size.Assign(static_cast(size)); + + m_data.address_low = static_cast(address); + } + + PointerDescriptor(const MessageBuffer& buf, s32 index) : m_data{} { + buf.Get(index, m_data.raw.data(), 2); + } + + constexpr s32 GetIndex() const { + return m_data.index; + } + + constexpr uintptr_t GetAddress() const { + return (static_cast((m_data.address_high << AddressMidCount) | m_data.address_mid) + << AddressLowCount) | + m_data.address_low; + } + + constexpr size_t GetSize() const { + return m_data.size; + } + + constexpr const u32* GetData() const { + return m_data.raw.data(); + } + + static constexpr size_t GetDataSize() { + return sizeof(m_data); + } + }; + + class ReceiveListEntry { + private: + static constexpr u32 AddressLowCount = 32; + + constexpr u32 GetAddressHigh(u64 address) { + return static_cast(address >> (AddressLowCount)); + } + + private: + union { + std::array raw; + + struct { + // Define fields for the first word. + u32 address_low; + + // Define fields for the packed descriptor word. + union { + BitField<0, 7, u32> address_high; + BitField<7, 9, u32> reserved; + BitField<16, 16, u32> size; + }; + }; + } m_data; + + public: + constexpr ReceiveListEntry() : m_data{} {} + + ReceiveListEntry(const void* buffer, size_t size) : m_data{} { + const u64 address = reinterpret_cast(buffer); + + m_data.address_low = static_cast(address); + + m_data.address_high.Assign(GetAddressHigh(address)); + m_data.size.Assign(static_cast(size)); + } + + ReceiveListEntry(u32 a, u32 b) : m_data{{a, b}} {} + + constexpr uintptr_t GetAddress() const { + return (static_cast(m_data.address_high) << AddressLowCount) | m_data.address_low; + } + + constexpr size_t GetSize() const { + return m_data.size; + } + + constexpr const u32* GetData() const { + return m_data.raw.data(); + } + + static constexpr size_t GetDataSize() { + return sizeof(m_data); + } + }; + +private: + u32* m_buffer; + size_t m_size; + +public: + constexpr MessageBuffer(u32* b, size_t sz) : m_buffer(b), m_size(sz) {} + constexpr explicit MessageBuffer(u32* b) : m_buffer(b), m_size(MessageBufferSize) {} + + constexpr void* GetBufferForDebug() const { + return m_buffer; + } + + constexpr size_t GetBufferSize() const { + return m_size; + } + + void Get(s32 index, u32* dst, size_t count) const { + // Ensure that this doesn't get re-ordered. + std::atomic_thread_fence(std::memory_order_seq_cst); + + // Get the words. + static_assert(sizeof(*dst) == sizeof(*m_buffer)); + + memcpy(dst, m_buffer + index, count * sizeof(*dst)); + } + + s32 Set(s32 index, u32* src, size_t count) const { + // Ensure that this doesn't get re-ordered. + std::atomic_thread_fence(std::memory_order_seq_cst); + + // Set the words. + memcpy(m_buffer + index, src, count * sizeof(*src)); + + // Ensure that this doesn't get re-ordered. + std::atomic_thread_fence(std::memory_order_seq_cst); + + return static_cast(index + count); + } + + template + const T& GetRaw(s32 index) const { + return *reinterpret_cast(m_buffer + index); + } + + template + s32 SetRaw(s32 index, const T& val) const { + *reinterpret_cast(m_buffer + index) = val; + return index + (Common::AlignUp(sizeof(val), sizeof(*m_buffer)) / sizeof(*m_buffer)); + } + + void GetRawArray(s32 index, void* dst, size_t len) const { + memcpy(dst, m_buffer + index, len); + } + + void SetRawArray(s32 index, const void* src, size_t len) const { + memcpy(m_buffer + index, src, len); + } + + void SetNull() const { + this->Set(MessageHeader()); + } + + s32 Set(const MessageHeader& hdr) const { + memcpy(m_buffer, hdr.GetData(), hdr.GetDataSize()); + return static_cast(hdr.GetDataSize() / sizeof(*m_buffer)); + } + + s32 Set(const SpecialHeader& spc) const { + const s32 index = static_cast(MessageHeader::GetDataSize() / sizeof(*m_buffer)); + memcpy(m_buffer + index, spc.GetHeader(), spc.GetHeaderSize()); + return static_cast(index + (spc.GetHeaderSize() / sizeof(*m_buffer))); + } + + s32 SetHandle(s32 index, const Handle& hnd) const { + memcpy(m_buffer + index, std::addressof(hnd), sizeof(hnd)); + return static_cast(index + (sizeof(hnd) / sizeof(*m_buffer))); + } + + s32 SetProcessId(s32 index, const u64 pid) const { + memcpy(m_buffer + index, std::addressof(pid), sizeof(pid)); + return static_cast(index + (sizeof(pid) / sizeof(*m_buffer))); + } + + s32 Set(s32 index, const MapAliasDescriptor& desc) const { + memcpy(m_buffer + index, desc.GetData(), desc.GetDataSize()); + return static_cast(index + (desc.GetDataSize() / sizeof(*m_buffer))); + } + + s32 Set(s32 index, const PointerDescriptor& desc) const { + memcpy(m_buffer + index, desc.GetData(), desc.GetDataSize()); + return static_cast(index + (desc.GetDataSize() / sizeof(*m_buffer))); + } + + s32 Set(s32 index, const ReceiveListEntry& desc) const { + memcpy(m_buffer + index, desc.GetData(), desc.GetDataSize()); + return static_cast(index + (desc.GetDataSize() / sizeof(*m_buffer))); + } + + s32 Set(s32 index, const u32 val) const { + memcpy(m_buffer + index, std::addressof(val), sizeof(val)); + return static_cast(index + (sizeof(val) / sizeof(*m_buffer))); + } + + Result GetAsyncResult() const { + MessageHeader hdr(m_buffer); + MessageHeader null{}; + if (memcmp(hdr.GetData(), null.GetData(), MessageHeader::GetDataSize()) != 0) [[unlikely]] { + R_SUCCEED(); + } + return Result(m_buffer[MessageHeader::GetDataSize() / sizeof(*m_buffer)]); + } + + void SetAsyncResult(Result res) const { + const s32 index = this->Set(MessageHeader()); + const auto value = res.raw; + memcpy(m_buffer + index, std::addressof(value), sizeof(value)); + } + + u32 Get32(s32 index) const { + return m_buffer[index]; + } + + u64 Get64(s32 index) const { + u64 value; + memcpy(std::addressof(value), m_buffer + index, sizeof(value)); + return value; + } + + u64 GetProcessId(s32 index) const { + return this->Get64(index); + } + + Handle GetHandle(s32 index) const { + static_assert(sizeof(Handle) == sizeof(*m_buffer)); + return Handle(m_buffer[index]); + } + + static constexpr s32 GetSpecialDataIndex(const MessageHeader& hdr, const SpecialHeader& spc) { + return static_cast((MessageHeader::GetDataSize() / sizeof(u32)) + + (spc.GetHeaderSize() / sizeof(u32))); + } + + static constexpr s32 GetPointerDescriptorIndex(const MessageHeader& hdr, + const SpecialHeader& spc) { + return static_cast(GetSpecialDataIndex(hdr, spc) + (spc.GetDataSize() / sizeof(u32))); + } + + static constexpr s32 GetMapAliasDescriptorIndex(const MessageHeader& hdr, + const SpecialHeader& spc) { + return GetPointerDescriptorIndex(hdr, spc) + + static_cast(hdr.GetPointerCount() * PointerDescriptor::GetDataSize() / + sizeof(u32)); + } + + static constexpr s32 GetRawDataIndex(const MessageHeader& hdr, const SpecialHeader& spc) { + return GetMapAliasDescriptorIndex(hdr, spc) + + static_cast(hdr.GetMapAliasCount() * MapAliasDescriptor::GetDataSize() / + sizeof(u32)); + } + + static constexpr s32 GetReceiveListIndex(const MessageHeader& hdr, const SpecialHeader& spc) { + if (const s32 recv_list_index = hdr.GetReceiveListOffset()) { + return recv_list_index; + } else { + return GetRawDataIndex(hdr, spc) + hdr.GetRawCount(); + } + } + + static constexpr size_t GetMessageBufferSize(const MessageHeader& hdr, + const SpecialHeader& spc) { + // Get the size of the plain message. + size_t msg_size = GetReceiveListIndex(hdr, spc) * sizeof(u32); + + // Add the size of the receive list. + const auto count = hdr.GetReceiveListCount(); + switch (count) { + case MessageHeader::ReceiveListCountType::None: + break; + case MessageHeader::ReceiveListCountType::ToMessageBuffer: + break; + case MessageHeader::ReceiveListCountType::ToSingleBuffer: + msg_size += ReceiveListEntry::GetDataSize(); + break; + default: + msg_size += (static_cast(count) - + static_cast(MessageHeader::ReceiveListCountType::CountOffset)) * + ReceiveListEntry::GetDataSize(); + break; + } + + return msg_size; + } +}; + +} // namespace Kernel