Bsd: Implement Select (#4017)

* bsd: Add gdkchan's Select implementation

Co-authored-by: TSRBerry <20988865+tsrberry@users.noreply.github.com>

* bsd: Fix Select() causing a crash with an ArgumentException

.NET Sockets have to be used for the Select() call

* bsd: Make Select more generic

* bsd: Adjust namespaces and remove unused imports

* bsd: Fix NullReferenceException in Select

Co-authored-by: gdkchan <gab.dark.100@gmail.com>
This commit is contained in:
TSRBerry 2022-12-12 14:59:31 +01:00 committed by GitHub
parent 403e67d983
commit ba5c0cf5d8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 263 additions and 43 deletions

View file

@ -1,5 +1,8 @@
using System.Collections.Concurrent;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Numerics;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
@ -41,6 +44,27 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return null;
}
public List<IFileDescriptor> RetrieveFileDescriptorsFromMask(ReadOnlySpan<byte> mask)
{
List<IFileDescriptor> fds = new();
for (int i = 0; i < mask.Length; i++)
{
byte current = mask[i];
while (current != 0)
{
int bit = BitOperations.TrailingZeroCount(current);
current &= (byte)~(1 << bit);
int fd = i * 8 + bit;
fds.Add(RetrieveFileDescriptor(fd));
}
}
return fds;
}
public int RegisterFileDescriptor(IFileDescriptor file)
{
lock (_lock)
@ -61,6 +85,16 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
}
}
public void BuildMask(List<IFileDescriptor> fds, Span<byte> mask)
{
foreach (IFileDescriptor descriptor in fds)
{
int fd = _fds.IndexOf(descriptor);
mask[fd >> 3] |= (byte)(1 << (fd & 7));
}
}
public int DuplicateFileDescriptor(int fd)
{
IFileDescriptor oldFile = RetrieveFileDescriptor(fd);

View file

@ -1,10 +1,13 @@
using Ryujinx.Common;
using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using Ryujinx.Memory;
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Text;
@ -202,12 +205,122 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
}
[CommandHipc(5)]
// Select(u32 nfds, nn::socket::timeout timeout, buffer<nn::socket::fd_set, 0x21, 0> readfds_in, buffer<nn::socket::fd_set, 0x21, 0> writefds_in, buffer<nn::socket::fd_set, 0x21, 0> errorfds_in) -> (i32 ret, u32 bsd_errno, buffer<nn::socket::fd_set, 0x22, 0> readfds_out, buffer<nn::socket::fd_set, 0x22, 0> writefds_out, buffer<nn::socket::fd_set, 0x22, 0> errorfds_out)
// Select(u32 nfds, nn::socket::timeval timeout, buffer<nn::socket::fd_set, 0x21, 0> readfds_in, buffer<nn::socket::fd_set, 0x21, 0> writefds_in, buffer<nn::socket::fd_set, 0x21, 0> errorfds_in)
// -> (i32 ret, u32 bsd_errno, buffer<nn::socket::fd_set, 0x22, 0> readfds_out, buffer<nn::socket::fd_set, 0x22, 0> writefds_out, buffer<nn::socket::fd_set, 0x22, 0> errorfds_out)
public ResultCode Select(ServiceCtx context)
{
WriteBsdResult(context, -1, LinuxError.EOPNOTSUPP);
int fdsCount = context.RequestData.ReadInt32();
int timeout = context.RequestData.ReadInt32();
Logger.Stub?.PrintStub(LogClass.ServiceBsd);
(ulong readFdsInBufferPosition, ulong readFdsInBufferSize) = context.Request.GetBufferType0x21(0);
(ulong writeFdsInBufferPosition, ulong writeFdsInBufferSize) = context.Request.GetBufferType0x21(1);
(ulong errorFdsInBufferPosition, ulong errorFdsInBufferSize) = context.Request.GetBufferType0x21(2);
(ulong readFdsOutBufferPosition, ulong readFdsOutBufferSize) = context.Request.GetBufferType0x22(0);
(ulong writeFdsOutBufferPosition, ulong writeFdsOutBufferSize) = context.Request.GetBufferType0x22(1);
(ulong errorFdsOutBufferPosition, ulong errorFdsOutBufferSize) = context.Request.GetBufferType0x22(2);
List<IFileDescriptor> readFds = _context.RetrieveFileDescriptorsFromMask(context.Memory.GetSpan(readFdsInBufferPosition, (int)readFdsInBufferSize));
List<IFileDescriptor> writeFds = _context.RetrieveFileDescriptorsFromMask(context.Memory.GetSpan(writeFdsInBufferPosition, (int)writeFdsInBufferSize));
List<IFileDescriptor> errorFds = _context.RetrieveFileDescriptorsFromMask(context.Memory.GetSpan(errorFdsInBufferPosition, (int)errorFdsInBufferSize));
int actualFdsCount = readFds.Count + writeFds.Count + errorFds.Count;
if (fdsCount == 0 || actualFdsCount == 0)
{
WriteBsdResult(context, 0);
return ResultCode.Success;
}
PollEvent[] events = new PollEvent[actualFdsCount];
int index = 0;
foreach (IFileDescriptor fd in readFds)
{
events[index] = new PollEvent(new PollEventData { InputEvents = PollEventTypeMask.Input }, fd);
index++;
}
foreach (IFileDescriptor fd in writeFds)
{
events[index] = new PollEvent(new PollEventData { InputEvents = PollEventTypeMask.Output }, fd);
index++;
}
foreach (IFileDescriptor fd in errorFds)
{
events[index] = new PollEvent(new PollEventData { InputEvents = PollEventTypeMask.Error }, fd);
index++;
}
List<PollEvent>[] eventsByPollManager = new List<PollEvent>[_pollManagers.Count];
for (int i = 0; i < eventsByPollManager.Length; i++)
{
eventsByPollManager[i] = new List<PollEvent>();
foreach (PollEvent evnt in events)
{
if (_pollManagers[i].IsCompatible(evnt))
{
eventsByPollManager[i].Add(evnt);
}
}
}
int updatedCount = 0;
for (int i = 0; i < _pollManagers.Count; i++)
{
if (eventsByPollManager[i].Count > 0)
{
_pollManagers[i].Select(eventsByPollManager[i], timeout, out int updatedPollCount);
updatedCount += updatedPollCount;
}
}
readFds.Clear();
writeFds.Clear();
errorFds.Clear();
foreach (PollEvent pollEvent in events)
{
for (int i = 0; i < _pollManagers.Count; i++)
{
if (eventsByPollManager[i].Contains(pollEvent))
{
if (pollEvent.Data.OutputEvents.HasFlag(PollEventTypeMask.Input))
{
readFds.Add(pollEvent.FileDescriptor);
}
if (pollEvent.Data.OutputEvents.HasFlag(PollEventTypeMask.Output))
{
writeFds.Add(pollEvent.FileDescriptor);
}
if (pollEvent.Data.OutputEvents.HasFlag(PollEventTypeMask.Error))
{
errorFds.Add(pollEvent.FileDescriptor);
}
}
}
}
using var readFdsOut = context.Memory.GetWritableRegion(readFdsOutBufferPosition, (int)readFdsOutBufferSize);
using var writeFdsOut = context.Memory.GetWritableRegion(writeFdsOutBufferPosition, (int)writeFdsOutBufferSize);
using var errorFdsOut = context.Memory.GetWritableRegion(errorFdsOutBufferPosition, (int)errorFdsOutBufferSize);
_context.BuildMask(readFds, readFdsOut.Memory.Span);
_context.BuildMask(writeFds, writeFdsOut.Memory.Span);
_context.BuildMask(errorFds, errorFdsOut.Memory.Span);
WriteBsdResult(context, updatedCount);
return ResultCode.Success;
}
@ -320,14 +433,14 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
break;
}
// If we are here, that mean nothing was availaible, sleep for 50ms
// If we are here, that mean nothing was available, sleep for 50ms
context.Device.System.KernelContext.Syscall.SleepThread(50 * 1000000);
}
while (PerformanceCounter.ElapsedMilliseconds < budgetLeftMilliseconds);
}
else if (timeout == -1)
{
// FIXME: If we get a timeout of -1 and there is no fds to wait on, this should kill the KProces. (need to check that with re)
// FIXME: If we get a timeout of -1 and there is no fds to wait on, this should kill the KProcess. (need to check that with re)
throw new InvalidOperationException();
}
else

View file

@ -1,4 +1,5 @@
using System;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{

View file

@ -1,4 +1,5 @@
using System;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
using System.Net;
using System.Net.Sockets;

View file

@ -1,8 +1,9 @@
using System;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
using System.Runtime.InteropServices;
using System.Threading;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
class EventFileDescriptor : IFileDescriptor
{

View file

@ -1,8 +1,9 @@
using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System.Collections.Generic;
using System.Threading;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
class EventFileDescriptorPollManager : IPollManager
{
@ -109,5 +110,13 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return LinuxError.SUCCESS;
}
public LinuxError Select(List<PollEvent> events, int timeout, out int updatedCount)
{
// TODO: Implement Select for event file descriptors
updatedCount = 0;
return LinuxError.EOPNOTSUPP;
}
}
}

View file

@ -1,4 +1,5 @@
using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
using System.Collections.Generic;
using System.Diagnostics;
@ -6,7 +7,7 @@ using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
class ManagedSocket : ISocket
{

View file

@ -1,8 +1,9 @@
using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System.Collections.Generic;
using System.Net.Sockets;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
class ManagedSocketPollManager : IPollManager
{
@ -117,5 +118,60 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
return LinuxError.SUCCESS;
}
public LinuxError Select(List<PollEvent> events, int timeout, out int updatedCount)
{
List<Socket> readEvents = new();
List<Socket> writeEvents = new();
List<Socket> errorEvents = new();
updatedCount = 0;
foreach (PollEvent pollEvent in events)
{
ManagedSocket socket = (ManagedSocket)pollEvent.FileDescriptor;
if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Input))
{
readEvents.Add(socket.Socket);
}
if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Output))
{
writeEvents.Add(socket.Socket);
}
if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Error))
{
errorEvents.Add(socket.Socket);
}
}
Socket.Select(readEvents, writeEvents, errorEvents, timeout);
updatedCount = readEvents.Count + writeEvents.Count + errorEvents.Count;
foreach (PollEvent pollEvent in events)
{
ManagedSocket socket = (ManagedSocket)pollEvent.FileDescriptor;
if (readEvents.Contains(socket.Socket))
{
pollEvent.Data.OutputEvents |= PollEventTypeMask.Input;
}
if (writeEvents.Contains(socket.Socket))
{
pollEvent.Data.OutputEvents |= PollEventTypeMask.Output;
}
if (errorEvents.Contains(socket.Socket))
{
pollEvent.Data.OutputEvents |= PollEventTypeMask.Error;
}
}
return LinuxError.SUCCESS;
}
}
}

View file

@ -1,6 +1,6 @@
using System.Diagnostics.CodeAnalysis;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
[SuppressMessage("ReSharper", "InconsistentNaming")]
enum WsaError

View file

@ -1,7 +1,8 @@
using System.Collections.Generic;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System.Collections.Generic;
using System.Net.Sockets;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
static class WinSockHelper
{

View file

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
enum BsdAddressFamily : uint
{

View file

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
enum BsdIoctl
{

View file

@ -1,6 +1,6 @@
using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
class BsdMMsgHdr
{

View file

@ -1,7 +1,7 @@
using System;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
class BsdMsgHdr
{

View file

@ -3,7 +3,7 @@ using System;
using System.Net;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
[StructLayout(LayoutKind.Sequential, Pack = 1, Size = 0x10)]
struct BsdSockAddr

View file

@ -1,6 +1,6 @@
using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
[Flags]
enum BsdSocketCreationFlags

View file

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
enum BsdSocketFlags
{

View file

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
enum BsdSocketOption
{

View file

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
enum BsdSocketShutdownFlags
{

View file

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
enum BsdSocketType
{

View file

@ -1,6 +1,6 @@
using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
[Flags]
enum EventFdFlags : uint

View file

@ -1,11 +1,13 @@
using System.Collections.Generic;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
interface IPollManager
{
bool IsCompatible(PollEvent evnt);
LinuxError Poll(List<PollEvent> events, int timeoutMilliseconds, out int updatedCount);
LinuxError Select(List<PollEvent> events, int timeout, out int updatedCount);
}
}

View file

@ -1,6 +1,6 @@
using System.Diagnostics.CodeAnalysis;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
[SuppressMessage("ReSharper", "InconsistentNaming")]
enum LinuxError

View file

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
class PollEvent
{

View file

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
struct PollEventData
{

View file

@ -1,6 +1,6 @@
using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
[Flags]
enum PollEventTypeMask : ushort

View file

@ -1,4 +1,4 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types
{
public struct TimeVal
{

View file

@ -1,4 +1,5 @@
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System;
using System.IO;