using Ryujinx.Common.Logging; using System; using System.Net; using System.Net.Sockets; using System.Runtime.InteropServices; namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd { class ManagedSocket : ISocket { public int Refcount { get; set; } public AddressFamily AddressFamily => Socket.AddressFamily; public SocketType SocketType => Socket.SocketType; public ProtocolType ProtocolType => Socket.ProtocolType; public bool Blocking { get => Socket.Blocking; set => Socket.Blocking = value; } public IntPtr Handle => Socket.Handle; public IPEndPoint RemoteEndPoint => Socket.RemoteEndPoint as IPEndPoint; public IPEndPoint LocalEndPoint => Socket.LocalEndPoint as IPEndPoint; public Socket Socket { get; } public ManagedSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) { Socket = new Socket(addressFamily, socketType, protocolType); Refcount = 1; } private ManagedSocket(Socket socket) { Socket = socket; Refcount = 1; } private static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags) { SocketFlags socketFlags = SocketFlags.None; if (bsdSocketFlags.HasFlag(BsdSocketFlags.Oob)) { socketFlags |= SocketFlags.OutOfBand; } if (bsdSocketFlags.HasFlag(BsdSocketFlags.Peek)) { socketFlags |= SocketFlags.Peek; } if (bsdSocketFlags.HasFlag(BsdSocketFlags.DontRoute)) { socketFlags |= SocketFlags.DontRoute; } if (bsdSocketFlags.HasFlag(BsdSocketFlags.Trunc)) { socketFlags |= SocketFlags.Truncated; } if (bsdSocketFlags.HasFlag(BsdSocketFlags.CTrunc)) { socketFlags |= SocketFlags.ControlDataTruncated; } bsdSocketFlags &= ~(BsdSocketFlags.Oob | BsdSocketFlags.Peek | BsdSocketFlags.DontRoute | BsdSocketFlags.DontWait | BsdSocketFlags.Trunc | BsdSocketFlags.CTrunc); if (bsdSocketFlags != BsdSocketFlags.None) { Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported socket flags: {bsdSocketFlags}"); } return socketFlags; } public LinuxError Accept(out ISocket newSocket) { try { newSocket = new ManagedSocket(Socket.Accept()); return LinuxError.SUCCESS; } catch (SocketException exception) { newSocket = null; return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } public LinuxError Bind(IPEndPoint localEndPoint) { try { Socket.Bind(localEndPoint); return LinuxError.SUCCESS; } catch (SocketException exception) { return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } public void Close() { Socket.Close(); } public LinuxError Connect(IPEndPoint remoteEndPoint) { try { Socket.Connect(remoteEndPoint); return LinuxError.SUCCESS; } catch (SocketException exception) { if (!Blocking && exception.ErrorCode == (int)WsaError.WSAEWOULDBLOCK) { return LinuxError.EINPROGRESS; } else { return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } } public void Disconnect() { Socket.Disconnect(true); } public void Dispose() { Socket.Close(); Socket.Dispose(); } public LinuxError Listen(int backlog) { try { Socket.Listen(backlog); return LinuxError.SUCCESS; } catch (SocketException exception) { return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } public bool Poll(int microSeconds, SelectMode mode) { return Socket.Poll(microSeconds, mode); } public LinuxError Shutdown(BsdSocketShutdownFlags how) { try { Socket.Shutdown((SocketShutdown)how); return LinuxError.SUCCESS; } catch (SocketException exception) { return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } public LinuxError Receive(out int receiveSize, Span buffer, BsdSocketFlags flags) { try { receiveSize = Socket.Receive(buffer, ConvertBsdSocketFlags(flags)); return LinuxError.SUCCESS; } catch (SocketException exception) { receiveSize = -1; return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } public LinuxError ReceiveFrom(out int receiveSize, Span buffer, int size, BsdSocketFlags flags, out IPEndPoint remoteEndPoint) { remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); LinuxError result; bool shouldBlockAfterOperation = false; try { EndPoint temp = new IPEndPoint(IPAddress.Any, 0); if (Blocking && flags.HasFlag(BsdSocketFlags.DontWait)) { Blocking = false; shouldBlockAfterOperation = true; } receiveSize = Socket.ReceiveFrom(buffer[..size], ConvertBsdSocketFlags(flags), ref temp); remoteEndPoint = (IPEndPoint)temp; result = LinuxError.SUCCESS; } catch (SocketException exception) { receiveSize = -1; result = WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } if (shouldBlockAfterOperation) { Blocking = true; } return result; } public LinuxError Send(out int sendSize, ReadOnlySpan buffer, BsdSocketFlags flags) { try { sendSize = Socket.Send(buffer, ConvertBsdSocketFlags(flags)); return LinuxError.SUCCESS; } catch (SocketException exception) { sendSize = -1; return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } public LinuxError SendTo(out int sendSize, ReadOnlySpan buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint) { try { sendSize = Socket.SendTo(buffer[..size], ConvertBsdSocketFlags(flags), remoteEndPoint); return LinuxError.SUCCESS; } catch (SocketException exception) { sendSize = -1; return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } public LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span optionValue) { try { if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName)) { Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported GetSockOpt Option: {option} Level: {level}"); return LinuxError.EOPNOTSUPP; } byte[] tempOptionValue = new byte[optionValue.Length]; Socket.GetSocketOption(level, optionName, tempOptionValue); tempOptionValue.AsSpan().CopyTo(optionValue); return LinuxError.SUCCESS; } catch (SocketException exception) { return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } public LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan optionValue) { try { if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName)) { Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported SetSockOpt Option: {option} Level: {level}"); return LinuxError.EOPNOTSUPP; } int value = optionValue.Length >= 4 ? MemoryMarshal.Read(optionValue) : MemoryMarshal.Read(optionValue); if (option == BsdSocketOption.SoLinger) { int value2 = MemoryMarshal.Read(optionValue[4..]); Socket.SetSocketOption(level, SocketOptionName.Linger, new LingerOption(value != 0, value2)); } else { Socket.SetSocketOption(level, optionName, value); } return LinuxError.SUCCESS; } catch (SocketException exception) { return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); } } public LinuxError Read(out int readSize, Span buffer) { return Receive(out readSize, buffer, BsdSocketFlags.None); } public LinuxError Write(out int writeSize, ReadOnlySpan buffer) { return Send(out writeSize, buffer, BsdSocketFlags.None); } } }