diff --git a/src/core/hle/service/sm/srv.cpp b/src/core/hle/service/sm/srv.cpp index f459d3784..babcb1f9e 100644 --- a/src/core/hle/service/sm/srv.cpp +++ b/src/core/hle/service/sm/srv.cpp @@ -3,7 +3,6 @@ // Refer to the license.txt file included. #include - #include "common/common_types.h" #include "common/logging/log.h" #include "core/hle/ipc.h" @@ -11,10 +10,12 @@ #include "core/hle/kernel/client_port.h" #include "core/hle/kernel/client_session.h" #include "core/hle/kernel/errors.h" +#include "core/hle/kernel/event.h" #include "core/hle/kernel/hle_ipc.h" #include "core/hle/kernel/semaphore.h" #include "core/hle/kernel/server_port.h" #include "core/hle/kernel/server_session.h" +#include "core/hle/lock.h" #include "core/hle/service/sm/sm.h" #include "core/hle/service/sm/srv.h" @@ -99,12 +100,44 @@ void SRV::GetServiceHandle(Kernel::HLERequestContext& ctx) { // TODO(yuriks): Permission checks go here + auto get_handle = [name, this](Kernel::SharedPtr thread, + Kernel::HLERequestContext& ctx, ThreadWakeupReason reason) { + LOG_ERROR(Service_SRV, "called service={} wakeup", name); + auto client_port = service_manager->GetServicePort(name); + + auto session = client_port.Unwrap()->Connect(); + if (session.Succeeded()) { + LOG_DEBUG(Service_SRV, "called service={} -> session={}", name, + (*session)->GetObjectId()); + IPC::RequestBuilder rb(ctx, 0x5, 1, 2); + rb.Push(session.Code()); + rb.PushMoveObjects(std::move(session).Unwrap()); + } else if (session.Code() == Kernel::ERR_MAX_CONNECTIONS_REACHED) { + LOG_ERROR(Service_SRV, "called service={} -> ERR_MAX_CONNECTIONS_REACHED", name); + UNREACHABLE(); + } else { + LOG_ERROR(Service_SRV, "called service={} -> error 0x{:08X}", name, session.Code().raw); + IPC::RequestBuilder rb(ctx, 0x5, 1, 0); + rb.Push(session.Code()); + } + }; + auto client_port = service_manager->GetServicePort(name); if (client_port.Failed()) { - IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); - rb.Push(client_port.Code()); - LOG_ERROR(Service_SRV, "called service={} -> error 0x{:08X}", name, client_port.Code().raw); - return; + if (wait_until_available && client_port.Code() == ERR_SERVICE_NOT_REGISTERED) { + LOG_INFO(Service_SRV, "called service={} delayed", name); + Kernel::SharedPtr get_service_handle_event = + ctx.SleepClientThread(Kernel::GetCurrentThread(), "GetServiceHandle", + std::chrono::nanoseconds(-1), get_handle); + get_service_handle_delayed_map[name] = std::move(get_service_handle_event); + return; + } else { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(client_port.Code()); + LOG_ERROR(Service_SRV, "called service={} -> error 0x{:08X}", name, + client_port.Code().raw); + return; + } } auto session = client_port.Unwrap()->Connect(); @@ -199,6 +232,12 @@ void SRV::RegisterService(Kernel::HLERequestContext& ctx) { return; } + auto it = get_service_handle_delayed_map.find(name); + if (it != get_service_handle_delayed_map.end()) { + it->second->Signal(); + get_service_handle_delayed_map.erase(it); + } + IPC::RequestBuilder rb = rp.MakeBuilder(1, 2); rb.Push(RESULT_SUCCESS); rb.PushMoveObjects(port.Unwrap()); diff --git a/src/core/hle/service/sm/srv.h b/src/core/hle/service/sm/srv.h index ab5855620..d3525ca65 100644 --- a/src/core/hle/service/sm/srv.h +++ b/src/core/hle/service/sm/srv.h @@ -4,6 +4,7 @@ #pragma once +#include #include "core/hle/kernel/kernel.h" #include "core/hle/service/service.h" @@ -32,6 +33,8 @@ private: std::shared_ptr service_manager; Kernel::SharedPtr notification_semaphore; + std::unordered_map> + get_service_handle_delayed_map; }; } // namespace SM