LightningJit: Reduce stack usage for Arm32 code (#6245)

* Write/read guest state to context for sync points, stop reserving stack for them

* Fix UsedGprsMask not being updated when allocating with preferencing

* POP should be also considered a return
This commit is contained in:
gdkchan 2024-02-08 16:17:47 -03:00 committed by GitHub
parent a0b3d82ee0
commit ea07328aea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 86 additions and 36 deletions

View file

@ -10,6 +10,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
public readonly List<InstInfo> Instructions;
public readonly bool EndsWithBranch;
public readonly bool HasHostCall;
public readonly bool HasHostCallSkipContext;
public readonly bool IsTruncated;
public readonly bool IsLoopEnd;
public readonly bool IsThumb;
@ -20,6 +21,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
List<InstInfo> instructions,
bool endsWithBranch,
bool hasHostCall,
bool hasHostCallSkipContext,
bool isTruncated,
bool isLoopEnd,
bool isThumb)
@ -31,6 +33,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
Instructions = instructions;
EndsWithBranch = endsWithBranch;
HasHostCall = hasHostCall;
HasHostCallSkipContext = hasHostCallSkipContext;
IsTruncated = isTruncated;
IsLoopEnd = isLoopEnd;
IsThumb = isThumb;
@ -57,6 +60,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
Instructions.GetRange(0, splitIndex),
false,
HasHostCall,
HasHostCallSkipContext,
false,
false,
IsThumb);
@ -67,6 +71,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
Instructions.GetRange(splitIndex, splitCount),
EndsWithBranch,
HasHostCall,
HasHostCallSkipContext,
IsTruncated,
IsLoopEnd,
IsThumb);

View file

@ -208,6 +208,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
InstMeta meta;
InstFlags extraFlags = InstFlags.None;
bool hasHostCall = false;
bool hasHostCallSkipContext = false;
bool isTruncated = false;
do
@ -246,9 +247,17 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
meta = InstTableA32<T>.GetMeta(encoding, cpuPreset.Version, cpuPreset.Features);
}
if (meta.Name.IsSystemOrCall() && !hasHostCall)
if (meta.Name.IsSystemOrCall())
{
hasHostCall = meta.Name.IsCall() || InstEmitSystem.NeedsCall(meta.Name);
if (!hasHostCall)
{
hasHostCall = InstEmitSystem.NeedsCall(meta.Name);
}
if (!hasHostCallSkipContext)
{
hasHostCallSkipContext = meta.Name.IsCall() || InstEmitSystem.NeedsCallSkipContext(meta.Name);
}
}
insts.Add(new(encoding, meta.Name, meta.EmitFunc, meta.Flags | extraFlags));
@ -259,8 +268,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
if (!isTruncated && IsBackwardsBranch(meta.Name, encoding))
{
hasHostCall = true;
isLoopEnd = true;
hasHostCallSkipContext = true;
}
return new(
@ -269,6 +278,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
insts,
!isTruncated,
hasHostCall,
hasHostCallSkipContext,
isTruncated,
isLoopEnd,
isThumb);

View file

@ -6,6 +6,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
{
public readonly List<Block> Blocks;
public readonly bool HasHostCall;
public readonly bool HasHostCallSkipContext;
public readonly bool IsTruncated;
public MultiBlock(List<Block> blocks)
@ -15,12 +16,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
Block block = blocks[0];
HasHostCall = block.HasHostCall;
HasHostCallSkipContext = block.HasHostCallSkipContext;
for (int index = 1; index < blocks.Count; index++)
{
block = blocks[index];
HasHostCall |= block.HasHostCall;
HasHostCallSkipContext |= block.HasHostCallSkipContext;
}
block = blocks[^1];

View file

@ -106,6 +106,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
if ((regMask & AbiConstants.ReservedRegsMask) == 0)
{
_gprMask |= regMask;
UsedGprsMask |= regMask;
return firstCalleeSaved;
}

View file

@ -305,12 +305,23 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
ForceConditionalEnd(cgContext, ref lastCondition, lastConditionIp);
}
int reservedStackSize = 0;
if (multiBlock.HasHostCall)
{
reservedStackSize = CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask);
}
else if (multiBlock.HasHostCallSkipContext)
{
reservedStackSize = 2 * sizeof(ulong); // Context and page table pointers.
}
RegisterSaveRestore rsr = new(
regAlloc.UsedGprsMask & AbiConstants.GprCalleeSavedRegsMask,
regAlloc.UsedFpSimdMask & AbiConstants.FpSimdCalleeSavedRegsMask,
OperandType.FP64,
multiBlock.HasHostCall,
multiBlock.HasHostCall ? CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask) : 0);
multiBlock.HasHostCall || multiBlock.HasHostCallSkipContext,
reservedStackSize);
TailMerger tailMerger = new();
@ -596,7 +607,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
name == InstName.Ldm ||
name == InstName.Ldmda ||
name == InstName.Ldmdb ||
name == InstName.Ldmib)
name == InstName.Ldmib ||
name == InstName.Pop)
{
// Arm32 does not have a return instruction, instead returns are implemented
// either using BX LR (for leaf functions), or POP { ... PC }.
@ -711,7 +723,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
switch (type)
{
case BranchType.SyncPoint:
InstEmitSystem.WriteSyncPoint(context.Writer, context.RegisterAllocator, context.TailMerger, context.GetReservedStackOffset());
InstEmitSystem.WriteSyncPoint(
context.Writer,
ref asm,
context.RegisterAllocator,
context.TailMerger,
context.GetReservedStackOffset(),
context.StoreToContext,
context.LoadFromContext);
break;
case BranchType.SoftwareInterrupt:
context.StoreToContext();

View file

@ -199,12 +199,12 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
}
}
private static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
public static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
{
WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: true);
}
private static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
public static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
{
WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: false);
}

View file

@ -354,11 +354,18 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
// All instructions that might do a host call should be included here.
// That is required to reserve space on the stack for caller saved registers.
return name == InstName.Mrrc;
}
public static bool NeedsCallSkipContext(InstName name)
{
// All instructions that might do a host call should be included here.
// That is required to reserve space on the stack for caller saved registers.
switch (name)
{
case InstName.Mcr:
case InstName.Mrc:
case InstName.Mrrc:
case InstName.Svc:
case InstName.Udf:
return true;
@ -372,7 +379,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
Assembler asm = new(writer);
WriteCall(ref asm, regAlloc, GetBkptHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
}
public static void WriteSvc(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint svcId)
@ -380,7 +387,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
Assembler asm = new(writer);
WriteCall(ref asm, regAlloc, GetSvcHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, svcId);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
}
public static void WriteUdf(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint imm)
@ -388,7 +395,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
Assembler asm = new(writer);
WriteCall(ref asm, regAlloc, GetUdfHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
}
public static void WriteReadCntpct(CodeWriter writer, RegisterAllocator regAlloc, int spillBaseOffset, int rt, int rt2)
@ -422,14 +429,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
WriteFill(ref asm, regAlloc, resultMask, skipContext: false, spillBaseOffset, tempRegister);
}
public static void WriteSyncPoint(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset)
{
Assembler asm = new(writer);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: false, spillBaseOffset);
}
private static void WriteSyncPoint(CodeWriter writer, ref Assembler asm, RegisterAllocator regAlloc, TailMerger tailMerger, bool skipContext, int spillBaseOffset)
public static void WriteSyncPoint(
CodeWriter writer,
ref Assembler asm,
RegisterAllocator regAlloc,
TailMerger tailMerger,
int spillBaseOffset,
Action storeToContext = null,
Action loadFromContext = null)
{
int tempRegister = regAlloc.AllocateTempGprRegister();
@ -440,7 +447,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
int branchIndex = writer.InstructionPointer;
asm.Cbnz(rt, 0);
WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister);
storeToContext?.Invoke();
WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister);
Operand rn = Register(tempRegister == 0 ? 1 : 0);
@ -449,7 +457,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
tailMerger.AddConditionalZeroReturn(writer, asm, Register(0, OperandType.I32));
WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister);
WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister);
loadFromContext?.Invoke();
asm.LdrRiUn(rt, Register(regAlloc.FixedContextRegister), NativeContextOffsets.CounterOffset);
@ -514,18 +523,31 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
private static void WriteSpill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
{
WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: true);
if (skipContext)
{
InstEmitFlow.WriteSpillSkipContext(ref asm, regAlloc, spillOffset);
}
else
{
WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: true);
}
}
private static void WriteFill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
{
WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: false);
if (skipContext)
{
InstEmitFlow.WriteFillSkipContext(ref asm, regAlloc, spillOffset);
}
else
{
WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: false);
}
}
private static void WriteSpillOrFill(
ref Assembler asm,
RegisterAllocator regAlloc,
bool skipContext,
uint exceptMask,
int spillOffset,
int tempRegister,
@ -533,11 +555,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
{
uint gprMask = regAlloc.UsedGprsMask & ~(AbiConstants.GprCalleeSavedRegsMask | exceptMask);
if (skipContext)
{
gprMask &= ~Compiler.UsableGprsMask;
}
if (!spill)
{
// We must reload the status register before reloading the GPRs,
@ -600,11 +617,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
uint fpSimdMask = regAlloc.UsedFpSimdMask;
if (skipContext)
{
fpSimdMask &= ~Compiler.UsableFpSimdMask;
}
while (fpSimdMask != 0)
{
int reg = BitOperations.TrailingZeroCount(fpSimdMask);