Ryujinx/ARMeilleure/Signal/NativeSignalHandler.cs
riperiperi 14ce9e1567
Move partial unmap handler to the native signal handler (#3437)
* Initial commit with a lot of testing stuff.

* Partial Unmap Cleanup Part 1

* Fix some minor issues, hopefully windows tests.

* Disable partial unmap tests on macos for now

Weird issue.

* Goodbye magic number

* Add COMPlus_EnableAlternateStackCheck for tests

`COMPlus_EnableAlternateStackCheck` is needed for NullReferenceException handling to work on linux after registering the signal handler, due to how dotnet registers its own signal handler.

* Address some feedback

* Force retry when memory is mapped in memory tracking

This case existed before, but returning `false` no longer retries, so it would crash immediately after unprotecting the memory... Now, we return `true` to deliberately retry.

This case existed before (was just broken by this change) and I don't really want to look into fixing the issue right now. Technically, this means that on guest code partial unmaps will retry _due to this_ rather than hitting the handler. I don't expect this to cause any issues.

This should fix random crashes in Xenoblade Chronicles 2.

* Use IsRangeMapped

* Suppress MockMemoryManager.UnmapEvent warning

This event is not signalled by the mock memory manager.

* Remove 4kb mapping
2022-07-29 19:16:29 -03:00

339 lines
13 KiB
C#

using ARMeilleure.IntermediateRepresentation;
using ARMeilleure.Translation;
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
namespace ARMeilleure.Signal
{
[StructLayout(LayoutKind.Sequential, Pack = 1)]
struct SignalHandlerRange
{
public int IsActive;
public nuint RangeAddress;
public nuint RangeEndAddress;
public IntPtr ActionPointer;
}
[StructLayout(LayoutKind.Sequential, Pack = 1)]
struct SignalHandlerConfig
{
/// <summary>
/// The byte offset of the faulting address in the SigInfo or ExceptionRecord struct.
/// </summary>
public int StructAddressOffset;
/// <summary>
/// The byte offset of the write flag in the SigInfo or ExceptionRecord struct.
/// </summary>
public int StructWriteOffset;
/// <summary>
/// The sigaction handler that was registered before this one. (unix only)
/// </summary>
public nuint UnixOldSigaction;
/// <summary>
/// The type of the previous sigaction. True for the 3 argument variant. (unix only)
/// </summary>
public int UnixOldSigaction3Arg;
public SignalHandlerRange Range0;
public SignalHandlerRange Range1;
public SignalHandlerRange Range2;
public SignalHandlerRange Range3;
public SignalHandlerRange Range4;
public SignalHandlerRange Range5;
public SignalHandlerRange Range6;
public SignalHandlerRange Range7;
}
public static class NativeSignalHandler
{
private delegate void UnixExceptionHandler(int sig, IntPtr info, IntPtr ucontext);
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
private delegate int VectoredExceptionHandler(IntPtr exceptionInfo);
private const int MaxTrackedRanges = 8;
private const int StructAddressOffset = 0;
private const int StructWriteOffset = 4;
private const int UnixOldSigaction = 8;
private const int UnixOldSigaction3Arg = 16;
private const int RangeOffset = 20;
private const int EXCEPTION_CONTINUE_SEARCH = 0;
private const int EXCEPTION_CONTINUE_EXECUTION = -1;
private const uint EXCEPTION_ACCESS_VIOLATION = 0xc0000005;
private const ulong PageSize = 0x1000;
private const ulong PageMask = PageSize - 1;
private static IntPtr _handlerConfig;
private static IntPtr _signalHandlerPtr;
private static IntPtr _signalHandlerHandle;
private static readonly object _lock = new object();
private static bool _initialized;
static NativeSignalHandler()
{
_handlerConfig = Marshal.AllocHGlobal(Unsafe.SizeOf<SignalHandlerConfig>());
ref SignalHandlerConfig config = ref GetConfigRef();
config = new SignalHandlerConfig();
}
public static void InitializeSignalHandler()
{
if (_initialized) return;
lock (_lock)
{
if (_initialized) return;
bool unix = OperatingSystem.IsLinux() || OperatingSystem.IsMacOS();
ref SignalHandlerConfig config = ref GetConfigRef();
if (unix)
{
// Unix siginfo struct locations.
// NOTE: These are incredibly likely to be different between kernel version and architectures.
config.StructAddressOffset = OperatingSystem.IsMacOS() ? 24 : 16; // si_addr
config.StructWriteOffset = 8; // si_code
_signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateUnixSignalHandler(_handlerConfig));
SigAction old = UnixSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
config.UnixOldSigaction = (nuint)(ulong)old.sa_handler;
config.UnixOldSigaction3Arg = old.sa_flags & 4;
}
else
{
config.StructAddressOffset = 40; // ExceptionInformation1
config.StructWriteOffset = 32; // ExceptionInformation0
_signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateWindowsSignalHandler(_handlerConfig));
_signalHandlerHandle = WindowsSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
}
_initialized = true;
}
}
private static unsafe ref SignalHandlerConfig GetConfigRef()
{
return ref Unsafe.AsRef<SignalHandlerConfig>((void*)_handlerConfig);
}
public static unsafe bool AddTrackedRegion(nuint address, nuint endAddress, IntPtr action)
{
var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
for (int i = 0; i < MaxTrackedRanges; i++)
{
if (ranges[i].IsActive == 0)
{
ranges[i].RangeAddress = address;
ranges[i].RangeEndAddress = endAddress;
ranges[i].ActionPointer = action;
ranges[i].IsActive = 1;
return true;
}
}
return false;
}
public static unsafe bool RemoveTrackedRegion(nuint address)
{
var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
for (int i = 0; i < MaxTrackedRanges; i++)
{
if (ranges[i].IsActive == 1 && ranges[i].RangeAddress == address)
{
ranges[i].IsActive = 0;
return true;
}
}
return false;
}
private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr signalStructPtr, Operand faultAddress, Operand isWrite)
{
Operand inRegionLocal = context.AllocateLocal(OperandType.I32);
context.Copy(inRegionLocal, Const(0));
Operand endLabel = Label();
for (int i = 0; i < MaxTrackedRanges; i++)
{
ulong rangeBaseOffset = (ulong)(RangeOffset + i * Unsafe.SizeOf<SignalHandlerRange>());
Operand nextLabel = Label();
Operand isActive = context.Load(OperandType.I32, Const((ulong)signalStructPtr + rangeBaseOffset));
context.BranchIfFalse(nextLabel, isActive);
Operand rangeAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 4));
Operand rangeEndAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 12));
// Is the fault address within this tracked region?
Operand inRange = context.BitwiseAnd(
context.ICompare(faultAddress, rangeAddress, Comparison.GreaterOrEqualUI),
context.ICompare(faultAddress, rangeEndAddress, Comparison.LessUI)
);
// Only call tracking if in range.
context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);
Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~PageMask));
// Call the tracking action, with the pointer's relative offset to the base address.
Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
context.Copy(inRegionLocal, Const(0));
Operand skipActionLabel = Label();
// Tracking action should be non-null to call it, otherwise assume false return.
context.BranchIfFalse(skipActionLabel, trackingActionPtr);
Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0));
context.Copy(inRegionLocal, result);
context.MarkLabel(skipActionLabel);
// If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows.
if (OperatingSystem.IsWindows())
{
context.BranchIfTrue(endLabel, inRegionLocal);
context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context));
}
context.Branch(endLabel);
context.MarkLabel(nextLabel);
}
context.MarkLabel(endLabel);
return context.Copy(inRegionLocal);
}
private static UnixExceptionHandler GenerateUnixSignalHandler(IntPtr signalStructPtr)
{
EmitterContext context = new EmitterContext();
// (int sig, SigInfo* sigInfo, void* ucontext)
Operand sigInfoPtr = context.LoadArgument(OperandType.I64, 1);
Operand structAddressOffset = context.Load(OperandType.I64, Const((ulong)signalStructPtr + StructAddressOffset));
Operand structWriteOffset = context.Load(OperandType.I64, Const((ulong)signalStructPtr + StructWriteOffset));
Operand faultAddress = context.Load(OperandType.I64, context.Add(sigInfoPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
Operand writeFlag = context.Load(OperandType.I64, context.Add(sigInfoPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
Operand endLabel = Label();
context.BranchIfTrue(endLabel, isInRegion);
Operand unixOldSigaction = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction));
Operand unixOldSigaction3Arg = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction3Arg));
Operand threeArgLabel = Label();
context.BranchIfTrue(threeArgLabel, unixOldSigaction3Arg);
context.Call(unixOldSigaction, OperandType.None, context.LoadArgument(OperandType.I32, 0));
context.Branch(endLabel);
context.MarkLabel(threeArgLabel);
context.Call(unixOldSigaction,
OperandType.None,
context.LoadArgument(OperandType.I32, 0),
sigInfoPtr,
context.LoadArgument(OperandType.I64, 2)
);
context.MarkLabel(endLabel);
context.Return();
ControlFlowGraph cfg = context.GetControlFlowGraph();
OperandType[] argTypes = new OperandType[] { OperandType.I32, OperandType.I64, OperandType.I64 };
return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq).Map<UnixExceptionHandler>();
}
private static VectoredExceptionHandler GenerateWindowsSignalHandler(IntPtr signalStructPtr)
{
EmitterContext context = new EmitterContext();
// (ExceptionPointers* exceptionInfo)
Operand exceptionInfoPtr = context.LoadArgument(OperandType.I64, 0);
Operand exceptionRecordPtr = context.Load(OperandType.I64, exceptionInfoPtr);
// First thing's first - this catches a number of exceptions, but we only want access violations.
Operand validExceptionLabel = Label();
Operand exceptionCode = context.Load(OperandType.I32, exceptionRecordPtr);
context.BranchIf(validExceptionLabel, exceptionCode, Const(EXCEPTION_ACCESS_VIOLATION), Comparison.Equal);
context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); // Don't handle this one.
context.MarkLabel(validExceptionLabel);
// Next, read the address of the invalid access, and whether it is a write or not.
Operand structAddressOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructAddressOffset));
Operand structWriteOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructWriteOffset));
Operand faultAddress = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
Operand writeFlag = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
Operand endLabel = Label();
// If the region check result is false, then run the next vectored exception handler.
context.BranchIfTrue(endLabel, isInRegion);
context.Return(Const(EXCEPTION_CONTINUE_SEARCH));
context.MarkLabel(endLabel);
// Otherwise, return to execution.
context.Return(Const(EXCEPTION_CONTINUE_EXECUTION));
// Compile and return the function.
ControlFlowGraph cfg = context.GetControlFlowGraph();
OperandType[] argTypes = new OperandType[] { OperandType.I64 };
return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<VectoredExceptionHandler>();
}
}
}