Use RunAsync in multiple socket operations (#7053)

* Use RunAsync in multiple socket operations

* EOF newline

* Fix linux compilation

* Fix compilation on macos
This commit is contained in:
PabloMK7 2023-10-09 23:59:08 +02:00 committed by GitHub
parent 6cfd00e42d
commit 6264b6d43c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1093,59 +1093,88 @@ void SOC_U::RecvFromOther(Kernel::HLERequestContext& ctx) {
#endif // _WIN32 #endif // _WIN32
u32 addr_len = rp.Pop<u32>(); u32 addr_len = rp.Pop<u32>();
rp.PopPID(); rp.PopPID();
auto& buffer = rp.PopMappedBuffer();
CTRSockAddr ctr_src_addr; bool needs_async = GetSocketBlocking(fd_info->second) && !dont_wait;
std::vector<u8> output_buff(len); struct AsyncData {
std::vector<u8> addr_buff(addr_len); // Input
sockaddr src_addr; u32 len{};
socklen_t src_addr_len = sizeof(src_addr); u32 flags{};
u32 addr_len{};
s32 ret = -1; SocketHolder* fd_info;
if (GetSocketBlocking(fd_info->second) && !dont_wait) {
PreTimerAdjust();
}
if (addr_len > 0) {
ret = static_cast<s32>(::recvfrom(fd_info->second.socket_fd,
reinterpret_cast<char*>(output_buff.data()), len, flags,
&src_addr, &src_addr_len));
if (ret >= 0 && src_addr_len > 0) {
ctr_src_addr = CTRSockAddr::FromPlatform(src_addr);
std::memcpy(addr_buff.data(), &ctr_src_addr, addr_len);
}
} else {
ret = static_cast<s32>(::recvfrom(fd_info->second.socket_fd,
reinterpret_cast<char*>(output_buff.data()), len, flags,
NULL, 0));
addr_buff.resize(0);
}
int recv_error = (ret == SOCKET_ERROR_VALUE) ? GET_ERRNO : 0;
if (GetSocketBlocking(fd_info->second) && !dont_wait) {
PostTimerAdjust(ctx, "RecvFromOther");
}
#ifdef _WIN32 #ifdef _WIN32
if (dont_wait && was_blocking) { bool dont_wait;
SetSocketBlocking(fd_info->second, true); bool was_blocking;
}
#endif #endif
if (ret == SOCKET_ERROR_VALUE) {
ret = TranslateError(recv_error);
} else {
buffer.Write(output_buff.data(), 0, ret);
}
IPC::RequestBuilder rb = rp.MakeBuilder(2, 4); // Output
rb.Push(RESULT_SUCCESS); s32 ret{};
rb.Push(ret); int recv_error;
rb.PushStaticBuffer(std::move(addr_buff), 0); Kernel::MappedBuffer* buffer;
rb.PushMappedBuffer(buffer); std::vector<u8> output_buff;
std::vector<u8> addr_buff;
};
auto async_data = std::make_shared<AsyncData>();
async_data->buffer = &rp.PopMappedBuffer();
async_data->ret = -1;
async_data->len = len;
async_data->flags = flags;
async_data->addr_len = addr_len;
async_data->output_buff.resize(len);
async_data->addr_buff.resize(addr_len);
async_data->fd_info = &fd_info->second;
#ifdef _WIN32
async_data->dont_wait = dont_wait;
async_data->was_blocking = was_blocking;
#endif
ctx.RunAsync(
[async_data](Kernel::HLERequestContext& ctx) {
sockaddr src_addr;
socklen_t src_addr_len = sizeof(src_addr);
CTRSockAddr ctr_src_addr;
if (async_data->addr_len > 0) {
async_data->ret = static_cast<s32>(
::recvfrom(async_data->fd_info->socket_fd,
reinterpret_cast<char*>(async_data->output_buff.data()),
async_data->len, async_data->flags, &src_addr, &src_addr_len));
if (async_data->ret >= 0 && src_addr_len > 0) {
ctr_src_addr = CTRSockAddr::FromPlatform(src_addr);
std::memcpy(async_data->addr_buff.data(), &ctr_src_addr, async_data->addr_len);
}
} else {
async_data->ret = static_cast<s32>(
::recvfrom(async_data->fd_info->socket_fd,
reinterpret_cast<char*>(async_data->output_buff.data()),
async_data->len, async_data->flags, NULL, 0));
async_data->addr_buff.resize(0);
}
async_data->recv_error = (async_data->ret == SOCKET_ERROR_VALUE) ? GET_ERRNO : 0;
return 0;
},
[this, async_data](Kernel::HLERequestContext& ctx) {
if (async_data->ret == SOCKET_ERROR_VALUE) {
async_data->ret = TranslateError(async_data->recv_error);
} else {
async_data->buffer->Write(async_data->output_buff.data(), 0, async_data->ret);
}
#ifdef _WIN32
if (async_data->dont_wait && async_data->was_blocking) {
SetSocketBlocking(*async_data->fd_info, true);
}
#else
(void)this;
#endif
IPC::RequestBuilder rb(ctx, 0x07, 2, 4);
rb.Push(RESULT_SUCCESS);
rb.Push(async_data->ret);
rb.PushStaticBuffer(std::move(async_data->addr_buff), 0);
rb.PushMappedBuffer(*async_data->buffer);
},
needs_async);
} }
void SOC_U::RecvFrom(Kernel::HLERequestContext& ctx) { void SOC_U::RecvFrom(Kernel::HLERequestContext& ctx) {
// TODO(Subv): Calling this function on a blocking socket will block the emu thread,
// preventing graceful shutdown when closing the emulator, this can be fixed by always
// performing nonblocking operations and spinlock until the data is available
IPC::RequestParser rp(ctx); IPC::RequestParser rp(ctx);
u32 socket_handle = rp.Pop<u32>(); u32 socket_handle = rp.Pop<u32>();
auto fd_info = open_sockets.find(socket_handle); auto fd_info = open_sockets.find(socket_handle);
@ -1172,55 +1201,89 @@ void SOC_U::RecvFrom(Kernel::HLERequestContext& ctx) {
u32 addr_len = rp.Pop<u32>(); u32 addr_len = rp.Pop<u32>();
rp.PopPID(); rp.PopPID();
CTRSockAddr ctr_src_addr; bool needs_async = GetSocketBlocking(fd_info->second) && !dont_wait;
std::vector<u8> output_buff(len); struct AsyncData {
std::vector<u8> addr_buff(addr_len); // Input
sockaddr src_addr; u32 len{};
socklen_t src_addr_len = sizeof(src_addr); u32 flags{};
u32 addr_len{};
s32 ret = -1; SocketHolder* fd_info;
if (GetSocketBlocking(fd_info->second) && !dont_wait) {
PreTimerAdjust();
}
if (addr_len > 0) {
// Only get src adr if input adr available
ret = static_cast<s32>(::recvfrom(fd_info->second.socket_fd,
reinterpret_cast<char*>(output_buff.data()), len, flags,
&src_addr, &src_addr_len));
if (ret >= 0 && src_addr_len > 0) {
ctr_src_addr = CTRSockAddr::FromPlatform(src_addr);
std::memcpy(addr_buff.data(), &ctr_src_addr, addr_len);
}
} else {
ret = static_cast<s32>(::recvfrom(fd_info->second.socket_fd,
reinterpret_cast<char*>(output_buff.data()), len, flags,
NULL, 0));
addr_buff.resize(0);
}
int recv_error = (ret == SOCKET_ERROR_VALUE) ? GET_ERRNO : 0;
if (GetSocketBlocking(fd_info->second) && !dont_wait) {
PostTimerAdjust(ctx, "RecvFrom");
}
#ifdef _WIN32 #ifdef _WIN32
if (dont_wait && was_blocking) { bool dont_wait;
SetSocketBlocking(fd_info->second, true); bool was_blocking;
}
#endif #endif
s32 total_received = ret;
if (ret == SOCKET_ERROR_VALUE) {
ret = TranslateError(recv_error);
total_received = 0;
}
// Write only the data we received to avoid overwriting parts of the buffer with zeros // Output
output_buff.resize(total_received); s32 ret{};
int recv_error;
std::vector<u8> output_buff;
std::vector<u8> addr_buff;
};
IPC::RequestBuilder rb = rp.MakeBuilder(3, 4); auto async_data = std::make_shared<AsyncData>();
rb.Push(RESULT_SUCCESS); async_data->ret = -1;
rb.Push(ret); async_data->len = len;
rb.Push(total_received); async_data->flags = flags;
rb.PushStaticBuffer(std::move(output_buff), 0); async_data->addr_len = addr_len;
rb.PushStaticBuffer(std::move(addr_buff), 1); async_data->output_buff.resize(len);
async_data->addr_buff.resize(addr_len);
async_data->fd_info = &fd_info->second;
#ifdef _WIN32
async_data->dont_wait = dont_wait;
async_data->was_blocking = was_blocking;
#endif
ctx.RunAsync(
[async_data](Kernel::HLERequestContext& ctx) {
sockaddr src_addr;
socklen_t src_addr_len = sizeof(src_addr);
CTRSockAddr ctr_src_addr;
if (async_data->addr_len > 0) {
// Only get src adr if input adr available
async_data->ret = static_cast<s32>(
::recvfrom(async_data->fd_info->socket_fd,
reinterpret_cast<char*>(async_data->output_buff.data()),
async_data->len, async_data->flags, &src_addr, &src_addr_len));
if (async_data->ret >= 0 && src_addr_len > 0) {
ctr_src_addr = CTRSockAddr::FromPlatform(src_addr);
std::memcpy(async_data->addr_buff.data(), &ctr_src_addr, async_data->addr_len);
}
} else {
async_data->ret = static_cast<s32>(
::recvfrom(async_data->fd_info->socket_fd,
reinterpret_cast<char*>(async_data->output_buff.data()),
async_data->len, async_data->flags, NULL, 0));
async_data->addr_buff.resize(0);
}
async_data->recv_error = (async_data->ret == SOCKET_ERROR_VALUE) ? GET_ERRNO : 0;
return 0;
},
[this, async_data](Kernel::HLERequestContext& ctx) {
#ifdef _WIN32
if (async_data->dont_wait && async_data->was_blocking) {
SetSocketBlocking(*async_data->fd_info, true);
}
#else
(void)this;
#endif
s32 total_received = async_data->ret;
if (async_data->ret == SOCKET_ERROR_VALUE) {
async_data->ret = TranslateError(async_data->recv_error);
total_received = 0;
}
// Write only the data we received to avoid overwriting parts of the buffer with zeros
async_data->output_buff.resize(total_received);
IPC::RequestBuilder rb(ctx, 0x08, 3, 4);
rb.Push(RESULT_SUCCESS);
rb.Push(async_data->ret);
rb.Push(total_received);
rb.PushStaticBuffer(std::move(async_data->output_buff), 0);
rb.PushStaticBuffer(std::move(async_data->addr_buff), 1);
},
needs_async);
} }
void SOC_U::Poll(Kernel::HLERequestContext& ctx) { void SOC_U::Poll(Kernel::HLERequestContext& ctx) {
@ -1230,45 +1293,71 @@ void SOC_U::Poll(Kernel::HLERequestContext& ctx) {
rp.PopPID(); rp.PopPID();
auto input_fds = rp.PopStaticBuffer(); auto input_fds = rp.PopStaticBuffer();
std::vector<CTRPollFD> ctr_fds(nfds); struct AsyncData {
std::memcpy(ctr_fds.data(), input_fds.data(), nfds * sizeof(CTRPollFD)); // Input
s32 timeout;
u32 nfds;
// Input/Output
std::vector<pollfd> platform_pollfd;
std::vector<u8> has_libctru_bug;
std::vector<CTRPollFD> ctr_fds;
// Output
s32 ret;
int poll_error;
};
auto async_data = std::make_shared<AsyncData>();
async_data->timeout = timeout;
async_data->nfds = nfds;
async_data->ctr_fds.resize(nfds);
std::memcpy(async_data->ctr_fds.data(), input_fds.data(), nfds * sizeof(CTRPollFD));
// The 3ds_pollfd and the pollfd structures may be different (Windows/Linux have different // The 3ds_pollfd and the pollfd structures may be different (Windows/Linux have different
// sizes) // sizes)
// so we have to copy the data in order // so we have to copy the data in order
std::vector<pollfd> platform_pollfd(nfds); async_data->platform_pollfd.resize(nfds);
std::vector<u8> has_libctru_bug(nfds, false); async_data->has_libctru_bug.resize(nfds, false);
for (u32 i = 0; i < nfds; i++) { for (u32 i = 0; i < nfds; i++) {
platform_pollfd[i] = CTRPollFD::ToPlatform(*this, ctr_fds[i], has_libctru_bug[i]); async_data->platform_pollfd[i] =
CTRPollFD::ToPlatform(*this, async_data->ctr_fds[i], async_data->has_libctru_bug[i]);
} }
if (timeout) { ctx.RunAsync(
PreTimerAdjust(); [async_data](Kernel::HLERequestContext& ctx) {
} async_data->ret =
s32 ret = ::poll(platform_pollfd.data(), nfds, timeout); ::poll(async_data->platform_pollfd.data(), async_data->nfds, async_data->timeout);
if (timeout) { if (async_data->ret == SOCKET_ERROR_VALUE) {
PostTimerAdjust(ctx, "Poll"); async_data->poll_error = GET_ERRNO;
} }
return 0;
},
[this, async_data](Kernel::HLERequestContext& ctx) {
// Now update the output 3ds_pollfd structure
for (u32 i = 0; i < async_data->nfds; i++) {
async_data->ctr_fds[i] = CTRPollFD::FromPlatform(
*this, async_data->platform_pollfd[i], async_data->has_libctru_bug[i]);
}
// Now update the output 3ds_pollfd structure std::vector<u8> output_fds(async_data->nfds * sizeof(CTRPollFD));
for (u32 i = 0; i < nfds; i++) { std::memcpy(output_fds.data(), async_data->ctr_fds.data(),
ctr_fds[i] = CTRPollFD::FromPlatform(*this, platform_pollfd[i], has_libctru_bug[i]); async_data->nfds * sizeof(CTRPollFD));
}
std::vector<u8> output_fds(nfds * sizeof(CTRPollFD)); if (async_data->ret == SOCKET_ERROR_VALUE) {
std::memcpy(output_fds.data(), ctr_fds.data(), nfds * sizeof(CTRPollFD)); int err = async_data->poll_error;
LOG_DEBUG(Service_SOC, "Socket error: {}", err);
if (ret == SOCKET_ERROR_VALUE) { async_data->ret = TranslateError(GET_ERRNO);
int err = GET_ERRNO; }
LOG_ERROR(Service_SOC, "Socket error: {}", err);
ret = TranslateError(GET_ERRNO); IPC::RequestBuilder rb(ctx, static_cast<u16>(ctx.CommandHeader().command_id.Value()), 2,
} 2);
rb.Push(RESULT_SUCCESS);
IPC::RequestBuilder rb = rp.MakeBuilder(2, 2); rb.Push(async_data->ret);
rb.Push(RESULT_SUCCESS); rb.PushStaticBuffer(std::move(output_fds), 0);
rb.Push(ret); },
rb.PushStaticBuffer(std::move(output_fds), 0); timeout != 0);
} }
void SOC_U::GetSockName(Kernel::HLERequestContext& ctx) { void SOC_U::GetSockName(Kernel::HLERequestContext& ctx) {