LightningJit: Reduce stack usage for Arm32 code (#6245)

* Write/read guest state to context for sync points, stop reserving stack for them

* Fix UsedGprsMask not being updated when allocating with preferencing

* POP should be also considered a return
This commit is contained in:
gdkchan 2024-02-08 16:17:47 -03:00 committed by GitHub
parent a0b3d82ee0
commit ea07328aea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 86 additions and 36 deletions

View file

@ -10,6 +10,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
public readonly List<InstInfo> Instructions; public readonly List<InstInfo> Instructions;
public readonly bool EndsWithBranch; public readonly bool EndsWithBranch;
public readonly bool HasHostCall; public readonly bool HasHostCall;
public readonly bool HasHostCallSkipContext;
public readonly bool IsTruncated; public readonly bool IsTruncated;
public readonly bool IsLoopEnd; public readonly bool IsLoopEnd;
public readonly bool IsThumb; public readonly bool IsThumb;
@ -20,6 +21,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
List<InstInfo> instructions, List<InstInfo> instructions,
bool endsWithBranch, bool endsWithBranch,
bool hasHostCall, bool hasHostCall,
bool hasHostCallSkipContext,
bool isTruncated, bool isTruncated,
bool isLoopEnd, bool isLoopEnd,
bool isThumb) bool isThumb)
@ -31,6 +33,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
Instructions = instructions; Instructions = instructions;
EndsWithBranch = endsWithBranch; EndsWithBranch = endsWithBranch;
HasHostCall = hasHostCall; HasHostCall = hasHostCall;
HasHostCallSkipContext = hasHostCallSkipContext;
IsTruncated = isTruncated; IsTruncated = isTruncated;
IsLoopEnd = isLoopEnd; IsLoopEnd = isLoopEnd;
IsThumb = isThumb; IsThumb = isThumb;
@ -57,6 +60,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
Instructions.GetRange(0, splitIndex), Instructions.GetRange(0, splitIndex),
false, false,
HasHostCall, HasHostCall,
HasHostCallSkipContext,
false, false,
false, false,
IsThumb); IsThumb);
@ -67,6 +71,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
Instructions.GetRange(splitIndex, splitCount), Instructions.GetRange(splitIndex, splitCount),
EndsWithBranch, EndsWithBranch,
HasHostCall, HasHostCall,
HasHostCallSkipContext,
IsTruncated, IsTruncated,
IsLoopEnd, IsLoopEnd,
IsThumb); IsThumb);

View file

@ -208,6 +208,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
InstMeta meta; InstMeta meta;
InstFlags extraFlags = InstFlags.None; InstFlags extraFlags = InstFlags.None;
bool hasHostCall = false; bool hasHostCall = false;
bool hasHostCallSkipContext = false;
bool isTruncated = false; bool isTruncated = false;
do do
@ -246,9 +247,17 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
meta = InstTableA32<T>.GetMeta(encoding, cpuPreset.Version, cpuPreset.Features); meta = InstTableA32<T>.GetMeta(encoding, cpuPreset.Version, cpuPreset.Features);
} }
if (meta.Name.IsSystemOrCall() && !hasHostCall) if (meta.Name.IsSystemOrCall())
{ {
hasHostCall = meta.Name.IsCall() || InstEmitSystem.NeedsCall(meta.Name); if (!hasHostCall)
{
hasHostCall = InstEmitSystem.NeedsCall(meta.Name);
}
if (!hasHostCallSkipContext)
{
hasHostCallSkipContext = meta.Name.IsCall() || InstEmitSystem.NeedsCallSkipContext(meta.Name);
}
} }
insts.Add(new(encoding, meta.Name, meta.EmitFunc, meta.Flags | extraFlags)); insts.Add(new(encoding, meta.Name, meta.EmitFunc, meta.Flags | extraFlags));
@ -259,8 +268,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
if (!isTruncated && IsBackwardsBranch(meta.Name, encoding)) if (!isTruncated && IsBackwardsBranch(meta.Name, encoding))
{ {
hasHostCall = true;
isLoopEnd = true; isLoopEnd = true;
hasHostCallSkipContext = true;
} }
return new( return new(
@ -269,6 +278,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
insts, insts,
!isTruncated, !isTruncated,
hasHostCall, hasHostCall,
hasHostCallSkipContext,
isTruncated, isTruncated,
isLoopEnd, isLoopEnd,
isThumb); isThumb);

View file

@ -6,6 +6,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
{ {
public readonly List<Block> Blocks; public readonly List<Block> Blocks;
public readonly bool HasHostCall; public readonly bool HasHostCall;
public readonly bool HasHostCallSkipContext;
public readonly bool IsTruncated; public readonly bool IsTruncated;
public MultiBlock(List<Block> blocks) public MultiBlock(List<Block> blocks)
@ -15,12 +16,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
Block block = blocks[0]; Block block = blocks[0];
HasHostCall = block.HasHostCall; HasHostCall = block.HasHostCall;
HasHostCallSkipContext = block.HasHostCallSkipContext;
for (int index = 1; index < blocks.Count; index++) for (int index = 1; index < blocks.Count; index++)
{ {
block = blocks[index]; block = blocks[index];
HasHostCall |= block.HasHostCall; HasHostCall |= block.HasHostCall;
HasHostCallSkipContext |= block.HasHostCallSkipContext;
} }
block = blocks[^1]; block = blocks[^1];

View file

@ -106,6 +106,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
if ((regMask & AbiConstants.ReservedRegsMask) == 0) if ((regMask & AbiConstants.ReservedRegsMask) == 0)
{ {
_gprMask |= regMask; _gprMask |= regMask;
UsedGprsMask |= regMask;
return firstCalleeSaved; return firstCalleeSaved;
} }

View file

@ -305,12 +305,23 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
ForceConditionalEnd(cgContext, ref lastCondition, lastConditionIp); ForceConditionalEnd(cgContext, ref lastCondition, lastConditionIp);
} }
int reservedStackSize = 0;
if (multiBlock.HasHostCall)
{
reservedStackSize = CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask);
}
else if (multiBlock.HasHostCallSkipContext)
{
reservedStackSize = 2 * sizeof(ulong); // Context and page table pointers.
}
RegisterSaveRestore rsr = new( RegisterSaveRestore rsr = new(
regAlloc.UsedGprsMask & AbiConstants.GprCalleeSavedRegsMask, regAlloc.UsedGprsMask & AbiConstants.GprCalleeSavedRegsMask,
regAlloc.UsedFpSimdMask & AbiConstants.FpSimdCalleeSavedRegsMask, regAlloc.UsedFpSimdMask & AbiConstants.FpSimdCalleeSavedRegsMask,
OperandType.FP64, OperandType.FP64,
multiBlock.HasHostCall, multiBlock.HasHostCall || multiBlock.HasHostCallSkipContext,
multiBlock.HasHostCall ? CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask) : 0); reservedStackSize);
TailMerger tailMerger = new(); TailMerger tailMerger = new();
@ -596,7 +607,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
name == InstName.Ldm || name == InstName.Ldm ||
name == InstName.Ldmda || name == InstName.Ldmda ||
name == InstName.Ldmdb || name == InstName.Ldmdb ||
name == InstName.Ldmib) name == InstName.Ldmib ||
name == InstName.Pop)
{ {
// Arm32 does not have a return instruction, instead returns are implemented // Arm32 does not have a return instruction, instead returns are implemented
// either using BX LR (for leaf functions), or POP { ... PC }. // either using BX LR (for leaf functions), or POP { ... PC }.
@ -711,7 +723,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
switch (type) switch (type)
{ {
case BranchType.SyncPoint: case BranchType.SyncPoint:
InstEmitSystem.WriteSyncPoint(context.Writer, context.RegisterAllocator, context.TailMerger, context.GetReservedStackOffset()); InstEmitSystem.WriteSyncPoint(
context.Writer,
ref asm,
context.RegisterAllocator,
context.TailMerger,
context.GetReservedStackOffset(),
context.StoreToContext,
context.LoadFromContext);
break; break;
case BranchType.SoftwareInterrupt: case BranchType.SoftwareInterrupt:
context.StoreToContext(); context.StoreToContext();

View file

@ -199,12 +199,12 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
} }
} }
private static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset) public static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
{ {
WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: true); WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: true);
} }
private static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset) public static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
{ {
WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: false); WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: false);
} }

View file

@ -354,11 +354,18 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
// All instructions that might do a host call should be included here. // All instructions that might do a host call should be included here.
// That is required to reserve space on the stack for caller saved registers. // That is required to reserve space on the stack for caller saved registers.
return name == InstName.Mrrc;
}
public static bool NeedsCallSkipContext(InstName name)
{
// All instructions that might do a host call should be included here.
// That is required to reserve space on the stack for caller saved registers.
switch (name) switch (name)
{ {
case InstName.Mcr: case InstName.Mcr:
case InstName.Mrc: case InstName.Mrc:
case InstName.Mrrc:
case InstName.Svc: case InstName.Svc:
case InstName.Udf: case InstName.Udf:
return true; return true;
@ -372,7 +379,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
Assembler asm = new(writer); Assembler asm = new(writer);
WriteCall(ref asm, regAlloc, GetBkptHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm); WriteCall(ref asm, regAlloc, GetBkptHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset); WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
} }
public static void WriteSvc(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint svcId) public static void WriteSvc(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint svcId)
@ -380,7 +387,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
Assembler asm = new(writer); Assembler asm = new(writer);
WriteCall(ref asm, regAlloc, GetSvcHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, svcId); WriteCall(ref asm, regAlloc, GetSvcHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, svcId);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset); WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
} }
public static void WriteUdf(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint imm) public static void WriteUdf(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint imm)
@ -388,7 +395,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
Assembler asm = new(writer); Assembler asm = new(writer);
WriteCall(ref asm, regAlloc, GetUdfHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm); WriteCall(ref asm, regAlloc, GetUdfHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset); WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
} }
public static void WriteReadCntpct(CodeWriter writer, RegisterAllocator regAlloc, int spillBaseOffset, int rt, int rt2) public static void WriteReadCntpct(CodeWriter writer, RegisterAllocator regAlloc, int spillBaseOffset, int rt, int rt2)
@ -422,14 +429,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
WriteFill(ref asm, regAlloc, resultMask, skipContext: false, spillBaseOffset, tempRegister); WriteFill(ref asm, regAlloc, resultMask, skipContext: false, spillBaseOffset, tempRegister);
} }
public static void WriteSyncPoint(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset) public static void WriteSyncPoint(
{ CodeWriter writer,
Assembler asm = new(writer); ref Assembler asm,
RegisterAllocator regAlloc,
WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: false, spillBaseOffset); TailMerger tailMerger,
} int spillBaseOffset,
Action storeToContext = null,
private static void WriteSyncPoint(CodeWriter writer, ref Assembler asm, RegisterAllocator regAlloc, TailMerger tailMerger, bool skipContext, int spillBaseOffset) Action loadFromContext = null)
{ {
int tempRegister = regAlloc.AllocateTempGprRegister(); int tempRegister = regAlloc.AllocateTempGprRegister();
@ -440,7 +447,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
int branchIndex = writer.InstructionPointer; int branchIndex = writer.InstructionPointer;
asm.Cbnz(rt, 0); asm.Cbnz(rt, 0);
WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister); storeToContext?.Invoke();
WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister);
Operand rn = Register(tempRegister == 0 ? 1 : 0); Operand rn = Register(tempRegister == 0 ? 1 : 0);
@ -449,7 +457,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
tailMerger.AddConditionalZeroReturn(writer, asm, Register(0, OperandType.I32)); tailMerger.AddConditionalZeroReturn(writer, asm, Register(0, OperandType.I32));
WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister); WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister);
loadFromContext?.Invoke();
asm.LdrRiUn(rt, Register(regAlloc.FixedContextRegister), NativeContextOffsets.CounterOffset); asm.LdrRiUn(rt, Register(regAlloc.FixedContextRegister), NativeContextOffsets.CounterOffset);
@ -514,18 +523,31 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
private static void WriteSpill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister) private static void WriteSpill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
{ {
WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: true); if (skipContext)
{
InstEmitFlow.WriteSpillSkipContext(ref asm, regAlloc, spillOffset);
}
else
{
WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: true);
}
} }
private static void WriteFill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister) private static void WriteFill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
{ {
WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: false); if (skipContext)
{
InstEmitFlow.WriteFillSkipContext(ref asm, regAlloc, spillOffset);
}
else
{
WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: false);
}
} }
private static void WriteSpillOrFill( private static void WriteSpillOrFill(
ref Assembler asm, ref Assembler asm,
RegisterAllocator regAlloc, RegisterAllocator regAlloc,
bool skipContext,
uint exceptMask, uint exceptMask,
int spillOffset, int spillOffset,
int tempRegister, int tempRegister,
@ -533,11 +555,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
{ {
uint gprMask = regAlloc.UsedGprsMask & ~(AbiConstants.GprCalleeSavedRegsMask | exceptMask); uint gprMask = regAlloc.UsedGprsMask & ~(AbiConstants.GprCalleeSavedRegsMask | exceptMask);
if (skipContext)
{
gprMask &= ~Compiler.UsableGprsMask;
}
if (!spill) if (!spill)
{ {
// We must reload the status register before reloading the GPRs, // We must reload the status register before reloading the GPRs,
@ -600,11 +617,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
uint fpSimdMask = regAlloc.UsedFpSimdMask; uint fpSimdMask = regAlloc.UsedFpSimdMask;
if (skipContext)
{
fpSimdMask &= ~Compiler.UsableFpSimdMask;
}
while (fpSimdMask != 0) while (fpSimdMask != 0)
{ {
int reg = BitOperations.TrailingZeroCount(fpSimdMask); int reg = BitOperations.TrailingZeroCount(fpSimdMask);