Implement RomFS cache and async reads. (#7089)

* Implement RomFS cache and async reads.

* Suggestions and fix compilation.

* Apply suggestions
This commit is contained in:
PabloMK7 2023-11-03 01:19:00 +01:00 committed by GitHub
parent 79ea06b226
commit 4284893044
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 404 additions and 22 deletions

View file

@ -124,6 +124,7 @@ add_library(citra_common STATIC
serialization/boost_flat_set.h serialization/boost_flat_set.h
serialization/boost_small_vector.hpp serialization/boost_small_vector.hpp
serialization/boost_vector.hpp serialization/boost_vector.hpp
static_lru_cache.h
string_literal.h string_literal.h
string_util.cpp string_util.cpp
string_util.h string_util.h

View file

@ -1155,6 +1155,43 @@ std::size_t IOFile::ReadImpl(void* data, std::size_t length, std::size_t data_si
return std::fread(data, data_size, length, m_file); return std::fread(data, data_size, length, m_file);
} }
#ifdef _WIN32
static std::size_t pread(int fd, void* buf, size_t count, uint64_t offset) {
long unsigned int read_bytes = 0;
OVERLAPPED overlapped = {0};
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
overlapped.OffsetHigh = static_cast<uint32_t>(offset >> 32);
overlapped.Offset = static_cast<uint32_t>(offset & 0xFFFF'FFFFLL);
SetLastError(0);
bool ret = ReadFile(file, buf, static_cast<uint32_t>(count), &read_bytes, &overlapped);
if (!ret && GetLastError() != ERROR_HANDLE_EOF) {
errno = GetLastError();
return std::numeric_limits<std::size_t>::max();
}
return read_bytes;
}
#else
#define pread ::pread
#endif
std::size_t IOFile::ReadAtImpl(void* data, std::size_t length, std::size_t data_size,
std::size_t offset) {
if (!IsOpen()) {
m_good = false;
return std::numeric_limits<std::size_t>::max();
}
if (length == 0) {
return 0;
}
DEBUG_ASSERT(data != nullptr);
return pread(fileno(m_file), data, data_size * length, offset);
}
std::size_t IOFile::WriteImpl(const void* data, std::size_t length, std::size_t data_size) { std::size_t IOFile::WriteImpl(const void* data, std::size_t length, std::size_t data_size) {
if (!IsOpen()) { if (!IsOpen()) {
m_good = false; m_good = false;

View file

@ -294,6 +294,18 @@ public:
return items_read; return items_read;
} }
template <typename T>
std::size_t ReadAtArray(T* data, std::size_t length, std::size_t offset) {
static_assert(std::is_trivially_copyable_v<T>,
"Given array does not consist of trivially copyable objects");
std::size_t items_read = ReadAtImpl(data, length, sizeof(T), offset);
if (items_read != length)
m_good = false;
return items_read;
}
template <typename T> template <typename T>
std::size_t WriteArray(const T* data, std::size_t length) { std::size_t WriteArray(const T* data, std::size_t length) {
static_assert(std::is_trivially_copyable_v<T>, static_assert(std::is_trivially_copyable_v<T>,
@ -312,6 +324,12 @@ public:
return ReadArray(reinterpret_cast<char*>(data), length); return ReadArray(reinterpret_cast<char*>(data), length);
} }
template <typename T>
std::size_t ReadAtBytes(T* data, std::size_t length, std::size_t offset) {
static_assert(std::is_trivially_copyable_v<T>, "T must be trivially copyable");
return ReadAtArray(reinterpret_cast<char*>(data), length, offset);
}
template <typename T> template <typename T>
std::size_t WriteBytes(const T* data, std::size_t length) { std::size_t WriteBytes(const T* data, std::size_t length) {
static_assert(std::is_trivially_copyable_v<T>, "T must be trivially copyable"); static_assert(std::is_trivially_copyable_v<T>, "T must be trivially copyable");
@ -363,6 +381,8 @@ public:
private: private:
std::size_t ReadImpl(void* data, std::size_t length, std::size_t data_size); std::size_t ReadImpl(void* data, std::size_t length, std::size_t data_size);
std::size_t ReadAtImpl(void* data, std::size_t length, std::size_t data_size,
std::size_t offset);
std::size_t WriteImpl(const void* data, std::size_t length, std::size_t data_size); std::size_t WriteImpl(const void* data, std::size_t length, std::size_t data_size);
bool Open(); bool Open();

View file

@ -0,0 +1,113 @@
// Modified version of: https://www.boost.org/doc/libs/1_79_0/boost/compute/detail/lru_cache.hpp
// Most important change is the use of an array instead of a map, so that elements are
// statically allocated. The insert and get methods have been merged into the request method.
// Original license:
//
//---------------------------------------------------------------------------//
// Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
//
// Distributed under the Boost Software License, Version 1.0
// See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt
//
// See http://boostorg.github.com/compute for more information.
//---------------------------------------------------------------------------//
#pragma once
#include <array>
#include <list>
#include <tuple>
#include <utility>
namespace Common {
// a cache which evicts the least recently used item when it is full
// the cache elements are statically allocated.
template <class Key, class Value, size_t Size>
class StaticLRUCache {
public:
using key_type = Key;
using value_type = Value;
using list_type = std::list<std::pair<Key, size_t>>;
using array_type = std::array<Value, Size>;
StaticLRUCache() = default;
~StaticLRUCache() = default;
size_t size() const {
return m_list.size();
}
constexpr size_t capacity() const {
return m_array.size();
}
bool empty() const {
return m_list.empty();
}
bool contains(const key_type& key) const {
return find(key) != m_list.end();
}
// Requests an element from the cache. If it is not found,
// the element is inserted using its key.
// Returns whether the element was present in the cache
// and a reference to the element itself.
std::pair<bool, value_type&> request(const key_type& key) {
// lookup value in the cache
auto i = find(key);
if (i == m_list.cend()) {
size_t next_index = size();
// insert item into the cache, but first check if it is full
if (next_index >= capacity()) {
// cache is full, evict the least recently used item
next_index = evict();
}
// insert the new item
m_list.push_front(std::make_pair(key, next_index));
return std::pair<bool, value_type&>(false, m_array[next_index]);
}
// return the value, but first update its place in the most
// recently used list
if (i != m_list.cbegin()) {
// move item to the front of the most recently used list
auto backup = *i;
m_list.erase(i);
m_list.push_front(backup);
// return the value
return std::pair<bool, value_type&>(true, m_array[backup.second]);
} else {
// the item is already at the front of the most recently
// used list so just return it
return std::pair<bool, value_type&>(true, m_array[i->second]);
}
}
void clear() {
m_list.clear();
}
private:
typename list_type::const_iterator find(const key_type& key) const {
return std::find_if(m_list.cbegin(), m_list.cend(),
[&key](const auto& el) { return el.first == key; });
}
size_t evict() {
// evict item from the end of most recently used list
typename list_type::iterator i = --m_list.end();
size_t evicted_index = i->second;
m_list.erase(i);
return evicted_index;
}
private:
array_type m_array;
list_type m_list;
};
} // namespace Common

View file

@ -86,6 +86,20 @@ public:
*/ */
virtual void Flush() const = 0; virtual void Flush() const = 0;
/**
* Whether the backend supports cached reads.
*/
virtual bool AllowsCachedReads() const {
return false;
}
/**
* Whether the cache is ready for a specified offset and length.
*/
virtual bool CacheReady(std::size_t file_offset, std::size_t length) {
return false;
}
protected: protected:
std::unique_ptr<DelayGenerator> delay_generator; std::unique_ptr<DelayGenerator> delay_generator;

View file

@ -131,6 +131,14 @@ public:
} }
void Flush() const override {} void Flush() const override {}
bool AllowsCachedReads() const override {
return romfs_file->AllowsCachedReads();
}
bool CacheReady(std::size_t file_offset, std::size_t length) override {
return romfs_file->CacheReady(file_offset, length);
}
private: private:
std::shared_ptr<RomFSReader> romfs_file; std::shared_ptr<RomFSReader> romfs_file;

View file

@ -53,6 +53,14 @@ public:
bool DumpRomFS(const std::string& target_path); bool DumpRomFS(const std::string& target_path);
bool AllowsCachedReads() const override {
return false;
}
bool CacheReady(std::size_t file_offset, std::size_t length) override {
return false;
}
private: private:
struct File; struct File;
struct Directory { struct Directory {

View file

@ -1,4 +1,5 @@
#include <algorithm> #include <algorithm>
#include <vector>
#include <cryptopp/aes.h> #include <cryptopp/aes.h>
#include <cryptopp/modes.h> #include <cryptopp/modes.h>
#include "common/archives.h" #include "common/archives.h"
@ -9,17 +10,102 @@ SERIALIZE_EXPORT_IMPL(FileSys::DirectRomFSReader)
namespace FileSys { namespace FileSys {
std::size_t DirectRomFSReader::ReadFile(std::size_t offset, std::size_t length, u8* buffer) { std::size_t DirectRomFSReader::ReadFile(std::size_t offset, std::size_t length, u8* buffer) {
length = std::min(length, static_cast<std::size_t>(data_size) - offset);
if (length == 0) if (length == 0)
return 0; // Crypto++ does not like zero size buffer return 0; // Crypto++ does not like zero size buffer
file.Seek(file_offset + offset, SEEK_SET);
std::size_t read_length = std::min(length, static_cast<std::size_t>(data_size) - offset); const auto segments = BreakupRead(offset, length);
read_length = file.ReadBytes(buffer, read_length); size_t read_progress = 0;
if (is_encrypted) {
CryptoPP::CTR_Mode<CryptoPP::AES>::Decryption d(key.data(), key.size(), ctr.data()); // Skip cache if the read is too big
d.Seek(crypto_offset + offset); if (segments.size() == 1 && segments[0].second > cache_line_size) {
d.ProcessData(buffer, buffer, read_length); length = file.ReadAtBytes(buffer, length, file_offset + offset);
if (is_encrypted) {
CryptoPP::CTR_Mode<CryptoPP::AES>::Decryption d(key.data(), key.size(), ctr.data());
d.Seek(crypto_offset + offset);
d.ProcessData(buffer, buffer, length);
}
// LOG_INFO(Service_FS, "Cache SKIP: offset={}, length={}", offset, length);
return length;
} }
return read_length;
// TODO(PabloMK7): Make cache thread safe, read the comment in CacheReady function.
// std::unique_lock<std::shared_mutex> read_guard(cache_mutex);
for (const auto& seg : segments) {
size_t read_size = cache_line_size;
size_t page = OffsetToPage(seg.first);
// Check if segment is in cache
auto cache_entry = cache.request(page);
if (!cache_entry.first) {
// If not found, read from disk and cache the data
read_size = file.ReadAtBytes(cache_entry.second.data(), read_size, file_offset + page);
if (is_encrypted && read_size) {
CryptoPP::CTR_Mode<CryptoPP::AES>::Decryption d(key.data(), key.size(), ctr.data());
d.Seek(crypto_offset + page);
d.ProcessData(cache_entry.second.data(), cache_entry.second.data(), read_size);
}
// LOG_INFO(Service_FS, "Cache MISS: page={}, length={}, into={}", page, seg.second,
// (seg.first - page));
} else {
// LOG_INFO(Service_FS, "Cache HIT: page={}, length={}, into={}", page, seg.second,
// (seg.first - page));
}
size_t copy_amount =
(read_size > (seg.first - page))
? std::min((seg.first - page) + seg.second, read_size) - (seg.first - page)
: 0;
std::memcpy(buffer + read_progress, cache_entry.second.data() + (seg.first - page),
copy_amount);
read_progress += copy_amount;
}
return read_progress;
}
bool DirectRomFSReader::AllowsCachedReads() const {
return true;
}
bool DirectRomFSReader::CacheReady(std::size_t file_offset, std::size_t length) {
auto segments = BreakupRead(file_offset, length);
if (segments.size() == 1 && segments[0].second > cache_line_size) {
return false;
} else {
// TODO(PabloMK7): Since the LRU cache is not thread safe, a lock must be used.
// However, this completely breaks the point of using a cache, because
// smaller reads may be blocked by bigger reads. For now, always return
// data being in cache to prevent the need of a lock, and only read data
// asynchronously if it is too big to use the cache.
/*
std::shared_lock<std::shared_mutex> read_guard(cache_mutex);
for (auto it = segments.begin(); it != segments.end(); it++) {
if (!cache.contains(OffsetToPage(it->first)))
return false;
}
*/
return true;
}
}
std::vector<std::pair<std::size_t, std::size_t>> DirectRomFSReader::BreakupRead(
std::size_t offset, std::size_t length) {
std::vector<std::pair<std::size_t, std::size_t>> ret;
// Reads bigger than the cache line size will probably never hit again
if (length > cache_line_size) {
ret.push_back(std::make_pair(offset, length));
return ret;
}
size_t curr_offset = offset;
while (length) {
size_t next_page = OffsetToPage(curr_offset + cache_line_size);
size_t curr_page_len = std::min(length, next_page - curr_offset);
ret.push_back(std::make_pair(curr_offset, curr_page_len));
curr_offset = next_page;
length -= curr_page_len;
}
return ret;
} }
} // namespace FileSys } // namespace FileSys

View file

@ -1,11 +1,14 @@
#pragma once #pragma once
#include <array> #include <array>
#include <shared_mutex>
#include <boost/serialization/array.hpp> #include <boost/serialization/array.hpp>
#include <boost/serialization/base_object.hpp> #include <boost/serialization/base_object.hpp>
#include <boost/serialization/export.hpp> #include <boost/serialization/export.hpp>
#include "common/alignment.h"
#include "common/common_types.h" #include "common/common_types.h"
#include "common/file_util.h" #include "common/file_util.h"
#include "common/static_lru_cache.h"
namespace FileSys { namespace FileSys {
@ -18,6 +21,8 @@ public:
virtual std::size_t GetSize() const = 0; virtual std::size_t GetSize() const = 0;
virtual std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) = 0; virtual std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) = 0;
virtual bool AllowsCachedReads() const = 0;
virtual bool CacheReady(std::size_t file_offset, std::size_t length) = 0;
private: private:
template <class Archive> template <class Archive>
@ -48,6 +53,10 @@ public:
std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) override; std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) override;
bool AllowsCachedReads() const override;
bool CacheReady(std::size_t file_offset, std::size_t length) override;
private: private:
bool is_encrypted; bool is_encrypted;
FileUtil::IOFile file; FileUtil::IOFile file;
@ -57,8 +66,23 @@ private:
u64 crypto_offset; u64 crypto_offset;
u64 data_size; u64 data_size;
// Total cache size: 128KB
static constexpr size_t cache_line_size = (1 << 13); // About 8KB
static constexpr size_t cache_line_count = 16;
Common::StaticLRUCache<std::size_t, std::array<u8, cache_line_size>, cache_line_count> cache;
// TODO(PabloMK7): Make cache thread safe, read the comment in CacheReady function.
// std::shared_mutex cache_mutex;
DirectRomFSReader() = default; DirectRomFSReader() = default;
std::size_t OffsetToPage(std::size_t offset) {
return Common::AlignDown<std::size_t>(offset, cache_line_size);
}
std::vector<std::pair<std::size_t, std::size_t>> BreakupRead(std::size_t offset,
std::size_t length);
template <class Archive> template <class Archive>
void serialize(Archive& ar, const unsigned int) { void serialize(Archive& ar, const unsigned int) {
ar& boost::serialization::base_object<RomFSReader>(*this); ar& boost::serialization::base_object<RomFSReader>(*this);

View file

@ -57,7 +57,6 @@ void File::Read(Kernel::HLERequestContext& ctx) {
IPC::RequestParser rp(ctx); IPC::RequestParser rp(ctx);
u64 offset = rp.Pop<u64>(); u64 offset = rp.Pop<u64>();
u32 length = rp.Pop<u32>(); u32 length = rp.Pop<u32>();
auto& buffer = rp.PopMappedBuffer();
LOG_TRACE(Service_FS, "Read {}: offset=0x{:x} length=0x{:08X}", GetName(), offset, length); LOG_TRACE(Service_FS, "Read {}: offset=0x{:x} length=0x{:08X}", GetName(), offset, length);
const FileSessionSlot* file = GetSessionData(ctx.Session()); const FileSessionSlot* file = GetSessionData(ctx.Session());
@ -76,22 +75,94 @@ void File::Read(Kernel::HLERequestContext& ctx) {
offset, length, backend->GetSize()); offset, length, backend->GetSize());
} }
IPC::RequestBuilder rb = rp.MakeBuilder(2, 2); // Conventional reading if the backend does not support cache.
if (!backend->AllowsCachedReads()) {
auto& buffer = rp.PopMappedBuffer();
IPC::RequestBuilder rb = rp.MakeBuilder(2, 2);
std::unique_ptr<u8*> data = std::make_unique<u8*>(static_cast<u8*>(operator new(length)));
const auto read = backend->Read(offset, length, *data);
if (read.Failed()) {
rb.Push(read.Code());
rb.Push<u32>(0);
} else {
buffer.Write(*data, 0, *read);
rb.Push(RESULT_SUCCESS);
rb.Push<u32>(static_cast<u32>(*read));
}
rb.PushMappedBuffer(buffer);
std::vector<u8> data(length); std::chrono::nanoseconds read_timeout_ns{backend->GetReadDelayNs(length)};
ResultVal<std::size_t> read = backend->Read(offset, data.size(), data.data()); ctx.SleepClientThread("file::read", read_timeout_ns, nullptr);
if (read.Failed()) { return;
rb.Push(read.Code());
rb.Push<u32>(0);
} else {
buffer.Write(data.data(), 0, *read);
rb.Push(RESULT_SUCCESS);
rb.Push<u32>(static_cast<u32>(*read));
} }
rb.PushMappedBuffer(buffer);
std::chrono::nanoseconds read_timeout_ns{backend->GetReadDelayNs(length)}; struct AsyncData {
ctx.SleepClientThread("file::read", read_timeout_ns, nullptr); // Input
u32 length;
u64 offset;
std::chrono::steady_clock::time_point pre_timer;
bool cache_ready;
// Output
ResultCode ret{0};
Kernel::MappedBuffer* buffer;
std::unique_ptr<u8*> data;
size_t read_size;
};
auto async_data = std::make_shared<AsyncData>();
async_data->buffer = &rp.PopMappedBuffer();
async_data->length = length;
async_data->offset = offset;
async_data->cache_ready = backend->CacheReady(offset, length);
if (!async_data->cache_ready) {
async_data->pre_timer = std::chrono::steady_clock::now();
}
// LOG_DEBUG(Service_FS, "cache={}, offset={}, length={}", cache_ready, offset, length);
ctx.RunAsync(
[this, async_data](Kernel::HLERequestContext& ctx) {
async_data->data =
std::make_unique<u8*>(static_cast<u8*>(operator new(async_data->length)));
const auto read =
backend->Read(async_data->offset, async_data->length, *async_data->data);
if (read.Failed()) {
async_data->ret = read.Code();
async_data->read_size = 0;
} else {
async_data->ret = RESULT_SUCCESS;
async_data->read_size = *read;
}
const auto read_delay = static_cast<s64>(backend->GetReadDelayNs(async_data->length));
if (!async_data->cache_ready) {
const auto time_took = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now() - async_data->pre_timer)
.count();
/*
if (time_took > read_delay) {
LOG_DEBUG(Service_FS, "Took longer! length={}, time_took={}, read_delay={}",
async_data->length, time_took, read_delay);
}
*/
return static_cast<s64>((read_delay > time_took) ? (read_delay - time_took) : 0);
} else {
return static_cast<s64>(read_delay);
}
},
[async_data](Kernel::HLERequestContext& ctx) {
IPC::RequestBuilder rb(ctx, 0x0802, 2, 2);
if (async_data->ret.IsError()) {
rb.Push(async_data->ret);
rb.Push<u32>(0);
} else {
async_data->buffer->Write(*async_data->data, 0, async_data->read_size);
rb.Push(RESULT_SUCCESS);
rb.Push<u32>(static_cast<u32>(async_data->read_size));
}
rb.PushMappedBuffer(*async_data->buffer);
},
!async_data->cache_ready);
} }
void File::Write(Kernel::HLERequestContext& ctx) { void File::Write(Kernel::HLERequestContext& ctx) {