Make LM skip instead of crashing for invalid messages (#5290)

This commit is contained in:
gdkchan 2023-06-12 21:12:06 -03:00 committed by GitHub
parent 52aa4b6c22
commit cf4c78b9c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 14 deletions

View file

@ -24,6 +24,24 @@ namespace Ryujinx.Common.Memory
return value;
}
public bool TryRead<T>(out T value) where T : unmanaged
{
int valueSize = Unsafe.SizeOf<T>();
if (valueSize > _input.Length)
{
value = default;
return false;
}
value = MemoryMarshal.Cast<byte, T>(_input)[0];
_input = _input.Slice(valueSize);
return true;
}
public ReadOnlySpan<byte> GetSpan(int size)
{
ReadOnlySpan<byte> data = _input.Slice(0, size);

View file

@ -17,7 +17,7 @@ namespace Ryujinx.Horizon.LogManager.Ipc
private const int MessageLengthLimit = 5000;
private readonly LogService _log;
private readonly ulong _pid;
private readonly ulong _pid;
private LogPacket _logPacket;
@ -74,8 +74,12 @@ namespace Ryujinx.Horizon.LogManager.Ipc
private bool LogImpl(ReadOnlySpan<byte> message)
{
SpanReader reader = new(message);
LogPacketHeader header = reader.Read<LogPacketHeader>();
SpanReader reader = new(message);
if (!reader.TryRead(out LogPacketHeader header))
{
return true;
}
bool isHeadPacket = (header.Flags & LogPacketFlags.IsHead) != 0;
bool isTailPacket = (header.Flags & LogPacketFlags.IsTail) != 0;
@ -84,8 +88,10 @@ namespace Ryujinx.Horizon.LogManager.Ipc
while (reader.Length > 0)
{
int type = ReadUleb128(ref reader);
int size = ReadUleb128(ref reader);
if (!TryReadUleb128(ref reader, out int type) || !TryReadUleb128(ref reader, out int size))
{
return true;
}
LogDataChunkKey key = (LogDataChunkKey)type;
@ -101,15 +107,24 @@ namespace Ryujinx.Horizon.LogManager.Ipc
}
else if (key == LogDataChunkKey.Line)
{
_logPacket.Line = reader.Read<int>();
if (!reader.TryRead<int>(out _logPacket.Line))
{
return true;
}
}
else if (key == LogDataChunkKey.DropCount)
{
_logPacket.DropCount = reader.Read<long>();
if (!reader.TryRead<long>(out _logPacket.DropCount))
{
return true;
}
}
else if (key == LogDataChunkKey.Time)
{
_logPacket.Time = reader.Read<long>();
if (!reader.TryRead<long>(out _logPacket.Time))
{
return true;
}
}
else if (key == LogDataChunkKey.Message)
{
@ -154,23 +169,25 @@ namespace Ryujinx.Horizon.LogManager.Ipc
return isTailPacket;
}
private static int ReadUleb128(ref SpanReader reader)
private static bool TryReadUleb128(ref SpanReader reader, out int result)
{
int result = 0;
int count = 0;
result = 0;
int count = 0;
byte encoded;
do
{
encoded = reader.Read<byte>();
if (!reader.TryRead<byte>(out encoded))
{
return false;
}
result += (encoded & 0x7F) << (7 * count);
count++;
} while ((encoded & 0x80) != 0);
return result;
return true;
}
}
}