Fix SSL GetCertificates with certificate ID set to All (#3727)

* Fix SSL GetCertificates with certificate ID set to All

* Fix last entry status value
This commit is contained in:
gdkchan 2022-09-29 12:45:25 -03:00 committed by GitHub
parent f502cfaf62
commit dbe43c1719
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 22 deletions

View file

@ -181,7 +181,11 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
} }
} }
public bool TryGetCertificates(ReadOnlySpan<CaCertificateId> ids, out CertStoreEntry[] entries) public bool TryGetCertificates(
ReadOnlySpan<CaCertificateId> ids,
out CertStoreEntry[] entries,
out bool hasAllCertificates,
out int requiredSize)
{ {
lock (_lock) lock (_lock)
{ {
@ -190,7 +194,8 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
throw new InvalidSystemResourceException(CertStoreTitleMissingErrorMessage); throw new InvalidSystemResourceException(CertStoreTitleMissingErrorMessage);
} }
bool hasAllCertificates = false; requiredSize = 0;
hasAllCertificates = false;
foreach (CaCertificateId id in ids) foreach (CaCertificateId id in ids)
{ {
@ -205,12 +210,14 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
if (hasAllCertificates) if (hasAllCertificates)
{ {
entries = new CertStoreEntry[_certificates.Count]; entries = new CertStoreEntry[_certificates.Count];
requiredSize = (_certificates.Count + 1) * Unsafe.SizeOf<BuiltInCertificateInfo>();
int i = 0; int i = 0;
foreach (CertStoreEntry entry in _certificates.Values) foreach (CertStoreEntry entry in _certificates.Values)
{ {
entries[i++] = entry; entries[i++] = entry;
requiredSize += (entry.Data.Length + 3) & ~3;
} }
return true; return true;
@ -218,6 +225,7 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
else else
{ {
entries = new CertStoreEntry[ids.Length]; entries = new CertStoreEntry[ids.Length];
requiredSize = ids.Length * Unsafe.SizeOf<BuiltInCertificateInfo>();
for (int i = 0; i < ids.Length; i++) for (int i = 0; i < ids.Length; i++)
{ {
@ -227,6 +235,7 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
} }
entries[i] = entry; entries[i] = entry;
requiredSize += (entry.Data.Length + 3) & ~3;
} }
return true; return true;

View file

@ -29,42 +29,40 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
return ResultCode.Success; return ResultCode.Success;
} }
private uint ComputeCertificateBufferSizeRequired(ReadOnlySpan<BuiltInCertificateManager.CertStoreEntry> entries)
{
uint totalSize = 0;
for (int i = 0; i < entries.Length; i++)
{
totalSize += (uint)Unsafe.SizeOf<BuiltInCertificateInfo>();
totalSize += (uint)entries[i].Data.Length;
}
return totalSize;
}
[CommandHipc(2)] [CommandHipc(2)]
// GetCertificates(buffer<CaCertificateId, 5> ids) -> (u32 certificates_count, buffer<bytes, 6> certificates) // GetCertificates(buffer<CaCertificateId, 5> ids) -> (u32 certificates_count, buffer<bytes, 6> certificates)
public ResultCode GetCertificates(ServiceCtx context) public ResultCode GetCertificates(ServiceCtx context)
{ {
ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size)); ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size));
if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out BuiltInCertificateManager.CertStoreEntry[] entries)) if (!BuiltInCertificateManager.Instance.TryGetCertificates(
ids,
out BuiltInCertificateManager.CertStoreEntry[] entries,
out bool hasAllCertificates,
out int requiredSize))
{ {
throw new InvalidOperationException(); throw new InvalidOperationException();
} }
if (ComputeCertificateBufferSizeRequired(entries) > context.Request.ReceiveBuff[0].Size) if ((uint)requiredSize > (uint)context.Request.ReceiveBuff[0].Size)
{ {
return ResultCode.InvalidCertBufSize; return ResultCode.InvalidCertBufSize;
} }
int infosCount = entries.Length;
if (hasAllCertificates)
{
infosCount++;
}
using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size)) using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
{ {
Span<byte> rawData = region.Memory.Span; Span<byte> rawData = region.Memory.Span;
Span<BuiltInCertificateInfo> infos = MemoryMarshal.Cast<byte, BuiltInCertificateInfo>(rawData)[..entries.Length]; Span<BuiltInCertificateInfo> infos = MemoryMarshal.Cast<byte, BuiltInCertificateInfo>(rawData)[..infosCount];
Span<byte> certificatesData = rawData[(Unsafe.SizeOf<BuiltInCertificateInfo>() * entries.Length)..]; Span<byte> certificatesData = rawData[(Unsafe.SizeOf<BuiltInCertificateInfo>() * infosCount)..];
for (int i = 0; i < infos.Length; i++) for (int i = 0; i < entries.Length; i++)
{ {
entries[i].Data.CopyTo(certificatesData); entries[i].Data.CopyTo(certificatesData);
@ -78,6 +76,17 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
certificatesData = certificatesData[entries[i].Data.Length..]; certificatesData = certificatesData[entries[i].Data.Length..];
} }
if (hasAllCertificates)
{
infos[entries.Length] = new BuiltInCertificateInfo
{
Id = CaCertificateId.All,
Status = TrustedCertStatus.Invalid,
CertificateDataSize = 0,
CertificateDataOffset = 0
};
}
} }
context.ResponseData.Write(entries.Length); context.ResponseData.Write(entries.Length);
@ -91,12 +100,12 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
{ {
ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size)); ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size));
if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out BuiltInCertificateManager.CertStoreEntry[] entries)) if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out _, out _, out int requiredSize))
{ {
throw new InvalidOperationException(); throw new InvalidOperationException();
} }
context.ResponseData.Write(ComputeCertificateBufferSizeRequired(entries)); context.ResponseData.Write(requiredSize);
return ResultCode.Success; return ResultCode.Success;
} }