Merge pull request #9305 from lioncash/request

hle_ipc: Add helper function for determining element counts
This commit is contained in:
bunnei 2022-11-25 00:38:17 -08:00 committed by GitHub
commit 64965cc658
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 78 additions and 72 deletions

View file

@ -199,7 +199,7 @@ public:
~HLERequestContext(); ~HLERequestContext();
/// Returns a pointer to the IPC command buffer for this request. /// Returns a pointer to the IPC command buffer for this request.
u32* CommandBuffer() { [[nodiscard]] u32* CommandBuffer() {
return cmd_buf.data(); return cmd_buf.data();
} }
@ -207,7 +207,7 @@ public:
* Returns the session through which this request was made. This can be used as a map key to * Returns the session through which this request was made. This can be used as a map key to
* access per-client data on services. * access per-client data on services.
*/ */
Kernel::KServerSession* Session() { [[nodiscard]] Kernel::KServerSession* Session() {
return server_session; return server_session;
} }
@ -217,61 +217,61 @@ public:
/// Writes data from this context back to the requesting process/thread. /// Writes data from this context back to the requesting process/thread.
Result WriteToOutgoingCommandBuffer(KThread& requesting_thread); Result WriteToOutgoingCommandBuffer(KThread& requesting_thread);
u32_le GetHipcCommand() const { [[nodiscard]] u32_le GetHipcCommand() const {
return command; return command;
} }
u32_le GetTipcCommand() const { [[nodiscard]] u32_le GetTipcCommand() const {
return static_cast<u32_le>(command_header->type.Value()) - return static_cast<u32_le>(command_header->type.Value()) -
static_cast<u32_le>(IPC::CommandType::TIPC_CommandRegion); static_cast<u32_le>(IPC::CommandType::TIPC_CommandRegion);
} }
u32_le GetCommand() const { [[nodiscard]] u32_le GetCommand() const {
return command_header->IsTipc() ? GetTipcCommand() : GetHipcCommand(); return command_header->IsTipc() ? GetTipcCommand() : GetHipcCommand();
} }
bool IsTipc() const { [[nodiscard]] bool IsTipc() const {
return command_header->IsTipc(); return command_header->IsTipc();
} }
IPC::CommandType GetCommandType() const { [[nodiscard]] IPC::CommandType GetCommandType() const {
return command_header->type; return command_header->type;
} }
u64 GetPID() const { [[nodiscard]] u64 GetPID() const {
return pid; return pid;
} }
u32 GetDataPayloadOffset() const { [[nodiscard]] u32 GetDataPayloadOffset() const {
return data_payload_offset; return data_payload_offset;
} }
const std::vector<IPC::BufferDescriptorX>& BufferDescriptorX() const { [[nodiscard]] const std::vector<IPC::BufferDescriptorX>& BufferDescriptorX() const {
return buffer_x_desciptors; return buffer_x_desciptors;
} }
const std::vector<IPC::BufferDescriptorABW>& BufferDescriptorA() const { [[nodiscard]] const std::vector<IPC::BufferDescriptorABW>& BufferDescriptorA() const {
return buffer_a_desciptors; return buffer_a_desciptors;
} }
const std::vector<IPC::BufferDescriptorABW>& BufferDescriptorB() const { [[nodiscard]] const std::vector<IPC::BufferDescriptorABW>& BufferDescriptorB() const {
return buffer_b_desciptors; return buffer_b_desciptors;
} }
const std::vector<IPC::BufferDescriptorC>& BufferDescriptorC() const { [[nodiscard]] const std::vector<IPC::BufferDescriptorC>& BufferDescriptorC() const {
return buffer_c_desciptors; return buffer_c_desciptors;
} }
const IPC::DomainMessageHeader& GetDomainMessageHeader() const { [[nodiscard]] const IPC::DomainMessageHeader& GetDomainMessageHeader() const {
return domain_message_header.value(); return domain_message_header.value();
} }
bool HasDomainMessageHeader() const { [[nodiscard]] bool HasDomainMessageHeader() const {
return domain_message_header.has_value(); return domain_message_header.has_value();
} }
/// Helper function to read a buffer using the appropriate buffer descriptor /// Helper function to read a buffer using the appropriate buffer descriptor
std::vector<u8> ReadBuffer(std::size_t buffer_index = 0) const; [[nodiscard]] std::vector<u8> ReadBuffer(std::size_t buffer_index = 0) const;
/// Helper function to write a buffer using the appropriate buffer descriptor /// Helper function to write a buffer using the appropriate buffer descriptor
std::size_t WriteBuffer(const void* buffer, std::size_t size, std::size_t WriteBuffer(const void* buffer, std::size_t size,
@ -308,22 +308,34 @@ public:
} }
/// Helper function to get the size of the input buffer /// Helper function to get the size of the input buffer
std::size_t GetReadBufferSize(std::size_t buffer_index = 0) const; [[nodiscard]] std::size_t GetReadBufferSize(std::size_t buffer_index = 0) const;
/// Helper function to get the size of the output buffer /// Helper function to get the size of the output buffer
std::size_t GetWriteBufferSize(std::size_t buffer_index = 0) const; [[nodiscard]] std::size_t GetWriteBufferSize(std::size_t buffer_index = 0) const;
/// Helper function to derive the number of elements able to be contained in the read buffer
template <typename T>
[[nodiscard]] std::size_t GetReadBufferNumElements(std::size_t buffer_index = 0) const {
return GetReadBufferSize(buffer_index) / sizeof(T);
}
/// Helper function to derive the number of elements able to be contained in the write buffer
template <typename T>
[[nodiscard]] std::size_t GetWriteBufferNumElements(std::size_t buffer_index = 0) const {
return GetWriteBufferSize(buffer_index) / sizeof(T);
}
/// Helper function to test whether the input buffer at buffer_index can be read /// Helper function to test whether the input buffer at buffer_index can be read
bool CanReadBuffer(std::size_t buffer_index = 0) const; [[nodiscard]] bool CanReadBuffer(std::size_t buffer_index = 0) const;
/// Helper function to test whether the output buffer at buffer_index can be written /// Helper function to test whether the output buffer at buffer_index can be written
bool CanWriteBuffer(std::size_t buffer_index = 0) const; [[nodiscard]] bool CanWriteBuffer(std::size_t buffer_index = 0) const;
Handle GetCopyHandle(std::size_t index) const { [[nodiscard]] Handle GetCopyHandle(std::size_t index) const {
return incoming_copy_handles.at(index); return incoming_copy_handles.at(index);
} }
Handle GetMoveHandle(std::size_t index) const { [[nodiscard]] Handle GetMoveHandle(std::size_t index) const {
return incoming_move_handles.at(index); return incoming_move_handles.at(index);
} }
@ -348,13 +360,13 @@ public:
manager = manager_; manager = manager_;
} }
std::string Description() const; [[nodiscard]] std::string Description() const;
KThread& GetThread() { [[nodiscard]] KThread& GetThread() {
return *thread; return *thread;
} }
std::shared_ptr<SessionRequestManager> GetManager() const { [[nodiscard]] std::shared_ptr<SessionRequestManager> GetManager() const {
return manager.lock(); return manager.lock();
} }

View file

@ -122,10 +122,10 @@ private:
} }
void GetReleasedAudioInBuffer(Kernel::HLERequestContext& ctx) { void GetReleasedAudioInBuffer(Kernel::HLERequestContext& ctx) {
auto write_buffer_size = ctx.GetWriteBufferSize() / sizeof(u64); const auto write_buffer_size = ctx.GetWriteBufferNumElements<u64>();
std::vector<u64> released_buffers(write_buffer_size, 0); std::vector<u64> released_buffers(write_buffer_size);
auto count = impl->GetReleasedBuffers(released_buffers); const auto count = impl->GetReleasedBuffers(released_buffers);
[[maybe_unused]] std::string tags{}; [[maybe_unused]] std::string tags{};
for (u32 i = 0; i < count; i++) { for (u32 i = 0; i < count; i++) {
@ -228,7 +228,7 @@ void AudInU::ListAudioIns(Kernel::HLERequestContext& ctx) {
LOG_DEBUG(Service_Audio, "called"); LOG_DEBUG(Service_Audio, "called");
const auto write_count = const auto write_count =
static_cast<u32>(ctx.GetWriteBufferSize() / sizeof(AudioDevice::AudioDeviceName)); static_cast<u32>(ctx.GetWriteBufferNumElements<AudioDevice::AudioDeviceName>());
std::vector<AudioDevice::AudioDeviceName> device_names{}; std::vector<AudioDevice::AudioDeviceName> device_names{};
u32 out_count{0}; u32 out_count{0};
@ -248,7 +248,7 @@ void AudInU::ListAudioInsAutoFiltered(Kernel::HLERequestContext& ctx) {
LOG_DEBUG(Service_Audio, "called"); LOG_DEBUG(Service_Audio, "called");
const auto write_count = const auto write_count =
static_cast<u32>(ctx.GetWriteBufferSize() / sizeof(AudioDevice::AudioDeviceName)); static_cast<u32>(ctx.GetWriteBufferNumElements<AudioDevice::AudioDeviceName>());
std::vector<AudioDevice::AudioDeviceName> device_names{}; std::vector<AudioDevice::AudioDeviceName> device_names{};
u32 out_count{0}; u32 out_count{0};

View file

@ -129,16 +129,16 @@ private:
} }
void GetReleasedAudioOutBuffers(Kernel::HLERequestContext& ctx) { void GetReleasedAudioOutBuffers(Kernel::HLERequestContext& ctx) {
auto write_buffer_size = ctx.GetWriteBufferSize() / sizeof(u64); const auto write_buffer_size = ctx.GetWriteBufferNumElements<u64>();
std::vector<u64> released_buffers(write_buffer_size, 0); std::vector<u64> released_buffers(write_buffer_size);
auto count = impl->GetReleasedBuffers(released_buffers); const auto count = impl->GetReleasedBuffers(released_buffers);
[[maybe_unused]] std::string tags{}; [[maybe_unused]] std::string tags{};
for (u32 i = 0; i < count; i++) { for (u32 i = 0; i < count; i++) {
tags += fmt::format("{:08X}, ", released_buffers[i]); tags += fmt::format("{:08X}, ", released_buffers[i]);
} }
[[maybe_unused]] auto sessionid{impl->GetSystem().GetSessionId()}; [[maybe_unused]] const auto sessionid{impl->GetSystem().GetSessionId()};
LOG_TRACE(Service_Audio, "called. Session {} released {} buffers: {}", sessionid, count, LOG_TRACE(Service_Audio, "called. Session {} released {} buffers: {}", sessionid, count,
tags); tags);
@ -244,7 +244,7 @@ void AudOutU::ListAudioOuts(Kernel::HLERequestContext& ctx) {
std::scoped_lock l{impl->mutex}; std::scoped_lock l{impl->mutex};
const auto write_count = const auto write_count =
static_cast<u32>(ctx.GetWriteBufferSize() / sizeof(AudioDevice::AudioDeviceName)); static_cast<u32>(ctx.GetWriteBufferNumElements<AudioDevice::AudioDeviceName>());
std::vector<AudioDevice::AudioDeviceName> device_names{}; std::vector<AudioDevice::AudioDeviceName> device_names{};
if (write_count > 0) { if (write_count > 0) {
device_names.emplace_back("DeviceOut"); device_names.emplace_back("DeviceOut");

View file

@ -274,7 +274,7 @@ public:
private: private:
void ListAudioDeviceName(Kernel::HLERequestContext& ctx) { void ListAudioDeviceName(Kernel::HLERequestContext& ctx) {
const size_t in_count = ctx.GetWriteBufferSize() / sizeof(AudioDevice::AudioDeviceName); const size_t in_count = ctx.GetWriteBufferNumElements<AudioDevice::AudioDeviceName>();
std::vector<AudioDevice::AudioDeviceName> out_names{}; std::vector<AudioDevice::AudioDeviceName> out_names{};
@ -335,7 +335,7 @@ private:
} }
void GetActiveAudioDeviceName(Kernel::HLERequestContext& ctx) { void GetActiveAudioDeviceName(Kernel::HLERequestContext& ctx) {
const auto write_size = ctx.GetWriteBufferSize() / sizeof(char); const auto write_size = ctx.GetWriteBufferSize();
std::string out_name{"AudioTvOutput"}; std::string out_name{"AudioTvOutput"};
LOG_DEBUG(Service_Audio, "(STUBBED) called. Name={}", out_name); LOG_DEBUG(Service_Audio, "(STUBBED) called. Name={}", out_name);
@ -387,7 +387,7 @@ private:
} }
void ListAudioOutputDeviceName(Kernel::HLERequestContext& ctx) { void ListAudioOutputDeviceName(Kernel::HLERequestContext& ctx) {
const size_t in_count = ctx.GetWriteBufferSize() / sizeof(AudioDevice::AudioDeviceName); const size_t in_count = ctx.GetWriteBufferNumElements<AudioDevice::AudioDeviceName>();
std::vector<AudioDevice::AudioDeviceName> out_names{}; std::vector<AudioDevice::AudioDeviceName> out_names{};

View file

@ -68,7 +68,7 @@ private:
ExtraBehavior extra_behavior) { ExtraBehavior extra_behavior) {
u32 consumed = 0; u32 consumed = 0;
u32 sample_count = 0; u32 sample_count = 0;
std::vector<opus_int16> samples(ctx.GetWriteBufferSize() / sizeof(opus_int16)); std::vector<opus_int16> samples(ctx.GetWriteBufferNumElements<opus_int16>());
if (extra_behavior == ExtraBehavior::ResetContext) { if (extra_behavior == ExtraBehavior::ResetContext) {
ResetDecoderContext(); ResetDecoderContext();

View file

@ -443,7 +443,7 @@ private:
} }
void Read(Kernel::HLERequestContext& ctx) { void Read(Kernel::HLERequestContext& ctx) {
auto write_size = ctx.GetWriteBufferSize() / sizeof(DeliveryCacheDirectoryEntry); auto write_size = ctx.GetWriteBufferNumElements<DeliveryCacheDirectoryEntry>();
LOG_DEBUG(Service_BCAT, "called, write_size={:016X}", write_size); LOG_DEBUG(Service_BCAT, "called, write_size={:016X}", write_size);
@ -533,7 +533,7 @@ private:
} }
void EnumerateDeliveryCacheDirectory(Kernel::HLERequestContext& ctx) { void EnumerateDeliveryCacheDirectory(Kernel::HLERequestContext& ctx) {
auto size = ctx.GetWriteBufferSize() / sizeof(DirectoryName); auto size = ctx.GetWriteBufferNumElements<DirectoryName>();
LOG_DEBUG(Service_BCAT, "called, size={:016X}", size); LOG_DEBUG(Service_BCAT, "called, size={:016X}", size);

View file

@ -192,12 +192,10 @@ private:
} }
void ListCommonTicketRightsIds(Kernel::HLERequestContext& ctx) { void ListCommonTicketRightsIds(Kernel::HLERequestContext& ctx) {
u32 out_entries; size_t out_entries = 0;
if (keys.GetCommonTickets().empty()) if (!keys.GetCommonTickets().empty()) {
out_entries = 0; out_entries = ctx.GetWriteBufferNumElements<u128>();
else }
out_entries = static_cast<u32>(ctx.GetWriteBufferSize() / sizeof(u128));
LOG_DEBUG(Service_ETicket, "called, entries={:016X}", out_entries); LOG_DEBUG(Service_ETicket, "called, entries={:016X}", out_entries);
keys.PopulateTickets(); keys.PopulateTickets();
@ -206,20 +204,19 @@ private:
std::transform(tickets.begin(), tickets.end(), std::back_inserter(ids), std::transform(tickets.begin(), tickets.end(), std::back_inserter(ids),
[](const auto& pair) { return pair.first; }); [](const auto& pair) { return pair.first; });
out_entries = static_cast<u32>(std::min<std::size_t>(ids.size(), out_entries)); out_entries = std::min(ids.size(), out_entries);
ctx.WriteBuffer(ids.data(), out_entries * sizeof(u128)); ctx.WriteBuffer(ids.data(), out_entries * sizeof(u128));
IPC::ResponseBuilder rb{ctx, 3}; IPC::ResponseBuilder rb{ctx, 3};
rb.Push(ResultSuccess); rb.Push(ResultSuccess);
rb.Push<u32>(out_entries); rb.Push<u32>(static_cast<u32>(out_entries));
} }
void ListPersonalizedTicketRightsIds(Kernel::HLERequestContext& ctx) { void ListPersonalizedTicketRightsIds(Kernel::HLERequestContext& ctx) {
u32 out_entries; size_t out_entries = 0;
if (keys.GetPersonalizedTickets().empty()) if (!keys.GetPersonalizedTickets().empty()) {
out_entries = 0; out_entries = ctx.GetWriteBufferNumElements<u128>();
else }
out_entries = static_cast<u32>(ctx.GetWriteBufferSize() / sizeof(u128));
LOG_DEBUG(Service_ETicket, "called, entries={:016X}", out_entries); LOG_DEBUG(Service_ETicket, "called, entries={:016X}", out_entries);
@ -229,12 +226,12 @@ private:
std::transform(tickets.begin(), tickets.end(), std::back_inserter(ids), std::transform(tickets.begin(), tickets.end(), std::back_inserter(ids),
[](const auto& pair) { return pair.first; }); [](const auto& pair) { return pair.first; });
out_entries = static_cast<u32>(std::min<std::size_t>(ids.size(), out_entries)); out_entries = std::min(ids.size(), out_entries);
ctx.WriteBuffer(ids.data(), out_entries * sizeof(u128)); ctx.WriteBuffer(ids.data(), out_entries * sizeof(u128));
IPC::ResponseBuilder rb{ctx, 3}; IPC::ResponseBuilder rb{ctx, 3};
rb.Push(ResultSuccess); rb.Push(ResultSuccess);
rb.Push<u32>(out_entries); rb.Push<u32>(static_cast<u32>(out_entries));
} }
void GetCommonTicketSize(Kernel::HLERequestContext& ctx) { void GetCommonTicketSize(Kernel::HLERequestContext& ctx) {

View file

@ -277,7 +277,7 @@ private:
LOG_DEBUG(Service_FS, "called."); LOG_DEBUG(Service_FS, "called.");
// Calculate how many entries we can fit in the output buffer // Calculate how many entries we can fit in the output buffer
const u64 count_entries = ctx.GetWriteBufferSize() / sizeof(FileSys::Entry); const u64 count_entries = ctx.GetWriteBufferNumElements<FileSys::Entry>();
// Cap at total number of entries. // Cap at total number of entries.
const u64 actual_entries = std::min(count_entries, entries.size() - next_entry_index); const u64 actual_entries = std::min(count_entries, entries.size() - next_entry_index);
@ -543,7 +543,7 @@ public:
LOG_DEBUG(Service_FS, "called"); LOG_DEBUG(Service_FS, "called");
// Calculate how many entries we can fit in the output buffer // Calculate how many entries we can fit in the output buffer
const u64 count_entries = ctx.GetWriteBufferSize() / sizeof(SaveDataInfo); const u64 count_entries = ctx.GetWriteBufferNumElements<SaveDataInfo>();
// Cap at total number of entries. // Cap at total number of entries.
const u64 actual_entries = std::min(count_entries, info.size() - next_entry_index); const u64 actual_entries = std::min(count_entries, info.size() - next_entry_index);

View file

@ -292,7 +292,7 @@ public:
void GetNetworkInfoLatestUpdate(Kernel::HLERequestContext& ctx) { void GetNetworkInfoLatestUpdate(Kernel::HLERequestContext& ctx) {
const std::size_t network_buffer_size = ctx.GetWriteBufferSize(0); const std::size_t network_buffer_size = ctx.GetWriteBufferSize(0);
const std::size_t node_buffer_count = ctx.GetWriteBufferSize(1) / sizeof(NodeLatestUpdate); const std::size_t node_buffer_count = ctx.GetWriteBufferNumElements<NodeLatestUpdate>(1);
if (node_buffer_count == 0 || network_buffer_size != sizeof(NetworkInfo)) { if (node_buffer_count == 0 || network_buffer_size != sizeof(NetworkInfo)) {
LOG_ERROR(Service_LDN, "Invalid buffer, size = {}, count = {}", network_buffer_size, LOG_ERROR(Service_LDN, "Invalid buffer, size = {}, count = {}", network_buffer_size,
@ -333,7 +333,7 @@ public:
const auto channel{rp.PopEnum<WifiChannel>()}; const auto channel{rp.PopEnum<WifiChannel>()};
const auto scan_filter{rp.PopRaw<ScanFilter>()}; const auto scan_filter{rp.PopRaw<ScanFilter>()};
const std::size_t network_info_size = ctx.GetWriteBufferSize() / sizeof(NetworkInfo); const std::size_t network_info_size = ctx.GetWriteBufferNumElements<NetworkInfo>();
if (network_info_size == 0) { if (network_info_size == 0) {
LOG_ERROR(Service_LDN, "Invalid buffer size {}", network_info_size); LOG_ERROR(Service_LDN, "Invalid buffer size {}", network_info_size);

View file

@ -118,7 +118,7 @@ void IUser::ListDevices(Kernel::HLERequestContext& ctx) {
} }
std::vector<u64> nfp_devices; std::vector<u64> nfp_devices;
const std::size_t max_allowed_devices = ctx.GetWriteBufferSize() / sizeof(u64); const std::size_t max_allowed_devices = ctx.GetWriteBufferNumElements<u64>();
for (auto& device : devices) { for (auto& device : devices) {
if (nfp_devices.size() >= max_allowed_devices) { if (nfp_devices.size() >= max_allowed_devices) {

View file

@ -104,9 +104,9 @@ void IUser::ListDevices(Kernel::HLERequestContext& ctx) {
} }
std::vector<u64> nfp_devices; std::vector<u64> nfp_devices;
const std::size_t max_allowed_devices = ctx.GetWriteBufferSize() / sizeof(u64); const std::size_t max_allowed_devices = ctx.GetWriteBufferNumElements<u64>();
for (auto& device : devices) { for (const auto& device : devices) {
if (nfp_devices.size() >= max_allowed_devices) { if (nfp_devices.size() >= max_allowed_devices) {
continue; continue;
} }
@ -115,7 +115,7 @@ void IUser::ListDevices(Kernel::HLERequestContext& ctx) {
} }
} }
if (nfp_devices.size() == 0) { if (nfp_devices.empty()) {
IPC::ResponseBuilder rb{ctx, 2}; IPC::ResponseBuilder rb{ctx, 2};
rb.Push(DeviceNotFound); rb.Push(DeviceNotFound);
return; return;

View file

@ -279,13 +279,10 @@ void IPlatformServiceManager::GetSharedFontInOrderOfPriority(Kernel::HLERequestC
font_sizes.push_back(region.size); font_sizes.push_back(region.size);
} }
// Resize buffers if game requests smaller size output. // Resize buffers if game requests smaller size output
font_codes.resize( font_codes.resize(std::min(font_codes.size(), ctx.GetWriteBufferNumElements<u32>(0)));
std::min<std::size_t>(font_codes.size(), ctx.GetWriteBufferSize(0) / sizeof(u32))); font_offsets.resize(std::min(font_offsets.size(), ctx.GetWriteBufferNumElements<u32>(1)));
font_offsets.resize( font_sizes.resize(std::min(font_sizes.size(), ctx.GetWriteBufferNumElements<u32>(2)));
std::min<std::size_t>(font_offsets.size(), ctx.GetWriteBufferSize(1) / sizeof(u32)));
font_sizes.resize(
std::min<std::size_t>(font_sizes.size(), ctx.GetWriteBufferSize(2) / sizeof(u32)));
ctx.WriteBuffer(font_codes, 0); ctx.WriteBuffer(font_codes, 0);
ctx.WriteBuffer(font_offsets, 1); ctx.WriteBuffer(font_offsets, 1);

View file

@ -83,7 +83,7 @@ void PushResponseLanguageCode(Kernel::HLERequestContext& ctx, std::size_t num_la
} }
void GetAvailableLanguageCodesImpl(Kernel::HLERequestContext& ctx, std::size_t max_entries) { void GetAvailableLanguageCodesImpl(Kernel::HLERequestContext& ctx, std::size_t max_entries) {
const std::size_t requested_amount = ctx.GetWriteBufferSize() / sizeof(LanguageCode); const std::size_t requested_amount = ctx.GetWriteBufferNumElements<LanguageCode>();
const std::size_t max_amount = std::min(requested_amount, max_entries); const std::size_t max_amount = std::min(requested_amount, max_entries);
const std::size_t copy_amount = std::min(available_language_codes.size(), max_amount); const std::size_t copy_amount = std::min(available_language_codes.size(), max_amount);
const std::size_t copy_size = copy_amount * sizeof(LanguageCode); const std::size_t copy_size = copy_amount * sizeof(LanguageCode);