From 83e0cc45f4ccf1d170a6aa6054a99f61db8b1dbd Mon Sep 17 00:00:00 2001 From: zhupengfei Date: Fri, 7 Feb 2020 12:45:54 +0800 Subject: [PATCH 01/41] core/file_sys: Make RomFSReader an abstract interface The original RomFSReader is renamed to DirectRomFSReader that directly reads the RomFS. --- src/core/file_sys/romfs_reader.cpp | 2 +- src/core/file_sys/romfs_reader.h | 24 ++++++++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/core/file_sys/romfs_reader.cpp b/src/core/file_sys/romfs_reader.cpp index 8e624cbfc..64374684a 100644 --- a/src/core/file_sys/romfs_reader.cpp +++ b/src/core/file_sys/romfs_reader.cpp @@ -5,7 +5,7 @@ namespace FileSys { -std::size_t RomFSReader::ReadFile(std::size_t offset, std::size_t length, u8* buffer) { +std::size_t DirectRomFSReader::ReadFile(std::size_t offset, std::size_t length, u8* buffer) { if (length == 0) return 0; // Crypto++ does not like zero size buffer file.Seek(file_offset + offset, SEEK_SET); diff --git a/src/core/file_sys/romfs_reader.h b/src/core/file_sys/romfs_reader.h index 72a02cde3..5ee39015b 100644 --- a/src/core/file_sys/romfs_reader.h +++ b/src/core/file_sys/romfs_reader.h @@ -6,23 +6,35 @@ namespace FileSys { +/** + * Interface for reading RomFS data. + */ class RomFSReader { public: - RomFSReader(FileUtil::IOFile&& file, std::size_t file_offset, std::size_t data_size) + virtual std::size_t GetSize() const = 0; + virtual std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) = 0; +}; + +/** + * A RomFS reader that directly reads the RomFS file. + */ +class DirectRomFSReader : public RomFSReader { +public: + DirectRomFSReader(FileUtil::IOFile&& file, std::size_t file_offset, std::size_t data_size) : is_encrypted(false), file(std::move(file)), file_offset(file_offset), data_size(data_size) {} - RomFSReader(FileUtil::IOFile&& file, std::size_t file_offset, std::size_t data_size, - const std::array& key, const std::array& ctr, - std::size_t crypto_offset) + DirectRomFSReader(FileUtil::IOFile&& file, std::size_t file_offset, std::size_t data_size, + const std::array& key, const std::array& ctr, + std::size_t crypto_offset) : is_encrypted(true), file(std::move(file)), key(key), ctr(ctr), file_offset(file_offset), crypto_offset(crypto_offset), data_size(data_size) {} - std::size_t GetSize() const { + std::size_t GetSize() const override { return data_size; } - std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer); + std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) override; private: bool is_encrypted; From 890405bb7cdb404a0f84189d09d04ee7930559c6 Mon Sep 17 00:00:00 2001 From: zhupengfei Date: Fri, 7 Feb 2020 12:54:07 +0800 Subject: [PATCH 02/41] core/file_sys: LayeredFS implementation This implementation is different from Luma3DS's which directly hooks the SDK functions. Instead, we read the RomFS's metadata and figure out the directory and file structure. Then, relocations (i.e. replacements/deletions/patches) are applied. Afterwards, we rebuild the metadata, and assign 'fake' data offsets to the files. When we want to read file data from this rebuilt RomFS, we use binary search to find the last data offset smaller or equal to the given offset and read from that file (either from the original RomFS, or from replacement files, or from buffered data with patches applied) and any later files when length is not enough. The code that rebuilds the metadata is pretty complex and uses quite a few variables to keep track of necessary information like metadata offsets. According to my tests, it is able to build RomFS-es identical to the original (but without trailing garbage data) when no relocations are applied. --- src/core/CMakeLists.txt | 2 + src/core/file_sys/layered_fs.cpp | 551 +++++++++++++++++++++++++++++++ src/core/file_sys/layered_fs.h | 117 +++++++ 3 files changed, 670 insertions(+) create mode 100644 src/core/file_sys/layered_fs.cpp create mode 100644 src/core/file_sys/layered_fs.h diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 064e44f94..d5b2d2f47 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -72,6 +72,8 @@ add_library(core STATIC file_sys/delay_generator.h file_sys/ivfc_archive.cpp file_sys/ivfc_archive.h + file_sys/layered_fs.cpp + file_sys/layered_fs.h file_sys/ncch_container.cpp file_sys/ncch_container.h file_sys/patch.cpp diff --git a/src/core/file_sys/layered_fs.cpp b/src/core/file_sys/layered_fs.cpp new file mode 100644 index 000000000..9a194569a --- /dev/null +++ b/src/core/file_sys/layered_fs.cpp @@ -0,0 +1,551 @@ +// Copyright 2020 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include +#include +#include "common/alignment.h" +#include "common/assert.h" +#include "common/common_paths.h" +#include "common/file_util.h" +#include "common/string_util.h" +#include "common/swap.h" +#include "core/file_sys/layered_fs.h" +#include "core/file_sys/patch.h" + +namespace FileSys { + +struct FileRelocationInfo { + int type; // 0 - none, 1 - replaced / created, 2 - patched, 3 - removed + u64 original_offset; // Type 0. Offset is absolute + FileUtil::IOFile replace_file; // Type 1 + std::vector patched_file; // Type 2 + u64 size; // Relocated file size +}; +struct LayeredFS::File { + std::string name; + std::string path; + FileRelocationInfo relocation{}; + Directory* parent; +}; + +struct DirectoryMetadata { + u32_le parent_directory_offset; + u32_le next_sibling_offset; + u32_le first_child_directory_offset; + u32_le first_file_offset; + u32_le hash_bucket_next; + u32_le name_length; + // Followed by a name of name length (aligned up to 4) +}; +static_assert(sizeof(DirectoryMetadata) == 0x18, "Size of DirectoryMetadata is not correct"); + +struct FileMetadata { + u32_le parent_directory_offset; + u32_le next_sibling_offset; + u64_le file_data_offset; + u64_le file_data_length; + u32_le hash_bucket_next; + u32_le name_length; + // Followed by a name of name length (aligned up to 4) +}; +static_assert(sizeof(FileMetadata) == 0x20, "Size of FileMetadata is not correct"); + +LayeredFS::LayeredFS(std::shared_ptr romfs_, std::string patch_path_, + std::string patch_ext_path_) + : romfs(std::move(romfs_)), patch_path(std::move(patch_path_)), + patch_ext_path(std::move(patch_ext_path_)) { + + romfs->ReadFile(0, sizeof(header), reinterpret_cast(&header)); + + ASSERT_MSG(header.header_length == sizeof(header), "Header size is incorrect"); + + // TODO: is root always the first directory in table? + root.parent = &root; + LoadDirectory(root, 0); + + LoadRelocations(); + LoadExtRelocations(); + + RebuildMetadata(); +} + +LayeredFS::~LayeredFS() = default; + +void LayeredFS::LoadDirectory(Directory& current, u32 offset) { + DirectoryMetadata metadata; + romfs->ReadFile(header.directory_metadata_table.offset + offset, sizeof(metadata), + reinterpret_cast(&metadata)); + + current.name = ReadName(header.directory_metadata_table.offset + offset + sizeof(metadata), + metadata.name_length); + current.path = current.parent->path + current.name + DIR_SEP; + directory_path_map.emplace(current.path, ¤t); + + if (metadata.first_file_offset != 0xFFFFFFFF) { + LoadFile(current, metadata.first_file_offset); + } + + if (metadata.first_child_directory_offset != 0xFFFFFFFF) { + auto child = std::make_unique(); + auto& directory = *child; + directory.parent = ¤t; + current.directories.emplace_back(std::move(child)); + LoadDirectory(directory, metadata.first_child_directory_offset); + } + + if (metadata.next_sibling_offset != 0xFFFFFFFF) { + auto sibling = std::make_unique(); + auto& directory = *sibling; + directory.parent = current.parent; + current.parent->directories.emplace_back(std::move(sibling)); + LoadDirectory(directory, metadata.next_sibling_offset); + } +} + +void LayeredFS::LoadFile(Directory& parent, u32 offset) { + FileMetadata metadata; + romfs->ReadFile(header.file_metadata_table.offset + offset, sizeof(metadata), + reinterpret_cast(&metadata)); + + auto file = std::make_unique(); + file->name = ReadName(header.file_metadata_table.offset + offset + sizeof(metadata), + metadata.name_length); + file->path = parent.path + file->name; + file->relocation.original_offset = header.file_data_offset + metadata.file_data_offset; + file->relocation.size = metadata.file_data_length; + file->parent = &parent; + + file_path_map.emplace(file->path, file.get()); + parent.files.emplace_back(std::move(file)); + + if (metadata.next_sibling_offset != 0xFFFFFFFF) { + LoadFile(parent, metadata.next_sibling_offset); + } +} + +std::string LayeredFS::ReadName(u32 offset, u32 name_length) { + std::vector buffer(name_length / sizeof(u16_le)); + romfs->ReadFile(offset, name_length, reinterpret_cast(buffer.data())); + + std::u16string name(buffer.size(), 0); + std::transform(buffer.begin(), buffer.end(), name.begin(), [](u16_le character) { + return static_cast(static_cast(character)); + }); + return Common::UTF16ToUTF8(name); +} + +void LayeredFS::LoadRelocations() { + if (!FileUtil::Exists(patch_path)) { + return; + } + + const FileUtil::DirectoryEntryCallable callback = [this, + &callback](u64* /*num_entries_out*/, + const std::string& directory, + const std::string& virtual_name) { + auto* parent = directory_path_map.at(directory.substr(patch_path.size() - 1)); + + if (FileUtil::IsDirectory(directory + virtual_name + DIR_SEP)) { + const auto path = (directory + virtual_name + DIR_SEP).substr(patch_path.size() - 1); + if (!directory_path_map.count(path)) { // Add this directory + auto directory = std::make_unique(); + directory->name = virtual_name; + directory->path = path; + directory->parent = parent; + directory_path_map.emplace(path, directory.get()); + parent->directories.emplace_back(std::move(directory)); + LOG_INFO(Service_FS, "LayeredFS created directory {}", path); + } + return FileUtil::ForeachDirectoryEntry(nullptr, directory + virtual_name + DIR_SEP, + callback); + } + + const auto path = (directory + virtual_name).substr(patch_path.size() - 1); + if (!file_path_map.count(path)) { // Newly created file + auto file = std::make_unique(); + file->name = virtual_name; + file->path = path; + file->parent = parent; + file_path_map.emplace(path, file.get()); + parent->files.emplace_back(std::move(file)); + LOG_INFO(Service_FS, "LayeredFS created file {}", path); + } + + auto* file = file_path_map.at(path); + file->relocation.replace_file = FileUtil::IOFile(directory + virtual_name, "rb"); + if (file->relocation.replace_file) { + file->relocation.type = 1; + file->relocation.size = file->relocation.replace_file.GetSize(); + LOG_INFO(Service_FS, "LayeredFS replacement file in use for {}", path); + } else { + LOG_ERROR(Service_FS, "Could not open replacement file for {}", path); + } + return true; + }; + + FileUtil::ForeachDirectoryEntry(nullptr, patch_path, callback); +} + +void LayeredFS::LoadExtRelocations() { + if (!FileUtil::Exists(patch_ext_path)) { + return; + } + + if (patch_ext_path.back() == '/' || patch_ext_path.back() == '\\') { + // ScanDirectoryTree expects a path without trailing '/' + patch_ext_path.erase(patch_ext_path.size() - 1, 1); + } + + FileUtil::FSTEntry result; + FileUtil::ScanDirectoryTree(patch_ext_path, result, 256); + + for (const auto& entry : result.children) { + if (FileUtil::IsDirectory(entry.physicalName)) { + continue; + } + + const auto path = entry.physicalName.substr(patch_ext_path.size()); + if (path.size() >= 5 && path.substr(path.size() - 5) == ".stub") { + // Remove the corresponding file if exists + const auto file_path = path.substr(0, path.size() - 5); + if (file_path_map.count(file_path)) { + auto& file = *file_path_map[file_path]; + file.relocation.type = 3; + file.relocation.size = 0; + file_path_map.erase(file_path); + LOG_INFO(Service_FS, "LayeredFS removed file {}", file_path); + } else { + LOG_WARNING(Service_FS, "LayeredFS file for stub {} not found", path); + } + } else if (path.size() >= 4) { + const auto extension = path.substr(path.size() - 4); + if (extension != ".ips" && extension != ".bps") { + LOG_WARNING(Service_FS, "LayeredFS unknown ext file {}", path); + } + + const auto file_path = path.substr(0, path.size() - 4); + if (!file_path_map.count(file_path)) { + LOG_WARNING(Service_FS, "LayeredFS original file for patch {} not found", path); + continue; + } + + FileUtil::IOFile patch_file(entry.physicalName, "rb"); + if (!patch_file) { + LOG_ERROR(Service_FS, "LayeredFS Could not open file {}", entry.physicalName); + continue; + } + + const auto size = patch_file.GetSize(); + std::vector patch(size); + if (patch_file.ReadBytes(patch.data(), size) != size) { + LOG_ERROR(Service_FS, "LayeredFS Could not read file {}", entry.physicalName); + continue; + } + + auto& file = *file_path_map[file_path]; + std::vector buffer(file.relocation.size); // Original size + romfs->ReadFile(file.relocation.original_offset, buffer.size(), buffer.data()); + + bool ret = false; + if (extension == ".ips") { + ret = Patch::ApplyIpsPatch(patch, buffer); + } else { + ret = Patch::ApplyBpsPatch(patch, buffer); + } + + if (ret) { + LOG_INFO(Service_FS, "LayeredFS patched file {}", file_path); + + file.relocation.type = 2; + file.relocation.size = buffer.size(); + file.relocation.patched_file = std::move(buffer); + } else { + LOG_ERROR(Service_FS, "LayeredFS failed to patch file {}", file_path); + } + } else { + LOG_WARNING(Service_FS, "LayeredFS unknown ext file {}", path); + } + } +} + +std::size_t GetNameSize(const std::string& name) { + std::u16string u16name = Common::UTF8ToUTF16(name); + return Common::AlignUp(u16name.size() * 2, 4); +} + +void LayeredFS::PrepareBuildDirectory(Directory& current) { + directory_metadata_offset_map.emplace(¤t, current_directory_offset); + directory_list.emplace_back(¤t); + current_directory_offset += sizeof(DirectoryMetadata) + GetNameSize(current.name); +} + +void LayeredFS::PrepareBuildFile(File& current) { + if (current.relocation.type == 3) { // Deleted files are not counted + return; + } + file_metadata_offset_map.emplace(¤t, current_file_offset); + file_list.emplace_back(¤t); + current_file_offset += sizeof(FileMetadata) + GetNameSize(current.name); +} + +void LayeredFS::PrepareBuild(Directory& current) { + for (const auto& child : current.files) { + PrepareBuildFile(*child); + } + + for (const auto& child : current.directories) { + PrepareBuildDirectory(*child); + } + + for (const auto& child : current.directories) { + PrepareBuild(*child); + } +} + +// Implementation from 3dbrew +u32 CalcHash(const std::string& name, u32 parent_offset) { + u32 hash = parent_offset ^ 123456789; + + std::u16string u16name = Common::UTF8ToUTF16(name); + std::vector tmp_buffer(u16name.size()); + std::transform(u16name.begin(), u16name.end(), tmp_buffer.begin(), [](char16_t character) { + return static_cast(static_cast(character)); + }); + + std::vector buffer(tmp_buffer.size() * 2); + std::memcpy(buffer.data(), tmp_buffer.data(), buffer.size()); + for (std::size_t i = 0; i < buffer.size(); i += 2) { + hash = (hash >> 5) | (hash << 27); + hash ^= static_cast((buffer[i]) | (buffer[i + 1] << 8)); + } + return hash; +} + +std::size_t WriteName(u8* dest, std::u16string name) { + const auto buffer_size = Common::AlignUp(name.size() * 2, 4); + std::vector buffer(buffer_size / 2); + std::transform(name.begin(), name.end(), buffer.begin(), [](char16_t character) { + return static_cast(static_cast(character)); + }); + std::memcpy(dest, buffer.data(), buffer_size); + + return buffer_size; +} + +void LayeredFS::BuildDirectories() { + directory_metadata_table.resize(current_directory_offset, 0xFF); + + std::size_t written = 0; + for (const auto& directory : directory_list) { + DirectoryMetadata metadata; + std::memset(&metadata, 0xFF, sizeof(metadata)); + metadata.parent_directory_offset = directory_metadata_offset_map.at(directory->parent); + + if (directory->parent != directory) { + bool flag = false; + for (const auto& sibling : directory->parent->directories) { + if (flag) { + metadata.next_sibling_offset = directory_metadata_offset_map.at(sibling.get()); + break; + } else if (sibling.get() == directory) { + flag = true; + } + } + } + + if (!directory->directories.empty()) { + metadata.first_child_directory_offset = + directory_metadata_offset_map.at(directory->directories.front().get()); + } + + if (!directory->files.empty()) { + metadata.first_file_offset = + file_metadata_offset_map.at(directory->files.front().get()); + } + + const auto bucket = CalcHash(directory->name, metadata.parent_directory_offset) % + directory_hash_table.size(); + metadata.hash_bucket_next = directory_hash_table[bucket]; + directory_hash_table[bucket] = directory_metadata_offset_map.at(directory); + + // Write metadata and name + std::u16string u16name = Common::UTF8ToUTF16(directory->name); + metadata.name_length = u16name.size() * 2; + + std::memcpy(directory_metadata_table.data() + written, &metadata, sizeof(metadata)); + written += sizeof(metadata); + + written += WriteName(directory_metadata_table.data() + written, u16name); + } + + ASSERT_MSG(written == directory_metadata_table.size(), + "Calculated size for directory metadata table is wrong"); +} + +void LayeredFS::BuildFiles() { + file_metadata_table.resize(current_file_offset, 0xFF); + + std::size_t written = 0; + for (const auto& file : file_list) { + FileMetadata metadata; + std::memset(&metadata, 0xFF, sizeof(metadata)); + + metadata.parent_directory_offset = directory_metadata_offset_map.at(file->parent); + + bool flag = false; + for (const auto& sibling : file->parent->files) { + if (sibling->relocation.type == 3) { // removed file + continue; + } + if (flag) { + metadata.next_sibling_offset = file_metadata_offset_map.at(sibling.get()); + break; + } else if (sibling.get() == file) { + flag = true; + } + } + + metadata.file_data_offset = current_data_offset; + metadata.file_data_length = file->relocation.size; + current_data_offset += Common::AlignUp(metadata.file_data_length, 16); + if (metadata.file_data_length != 0) { + data_offset_map.emplace(metadata.file_data_offset, file); + } + + const auto bucket = + CalcHash(file->name, metadata.parent_directory_offset) % file_hash_table.size(); + metadata.hash_bucket_next = file_hash_table[bucket]; + file_hash_table[bucket] = file_metadata_offset_map.at(file); + + // Write metadata and name + std::u16string u16name = Common::UTF8ToUTF16(file->name); + metadata.name_length = u16name.size() * 2; + + std::memcpy(file_metadata_table.data() + written, &metadata, sizeof(metadata)); + written += sizeof(metadata); + + written += WriteName(file_metadata_table.data() + written, u16name); + } + + ASSERT_MSG(written == file_metadata_table.size(), + "Calculated size for file metadata table is wrong"); +} + +// Implementation from 3dbrew +std::size_t GetHashTableSize(std::size_t entry_count) { + if (entry_count < 3) { + return 3; + } else if (entry_count < 19) { + return entry_count | 1; + } else { + std::size_t count = entry_count; + while (count % 2 == 0 || count % 3 == 0 || count % 5 == 0 || count % 7 == 0 || + count % 11 == 0 || count % 13 == 0 || count % 17 == 0) { + count++; + } + return count; + } +} + +void LayeredFS::RebuildMetadata() { + PrepareBuildDirectory(root); + PrepareBuild(root); + + directory_hash_table.resize(GetHashTableSize(directory_list.size()), 0xFFFFFFFF); + file_hash_table.resize(GetHashTableSize(file_list.size()), 0xFFFFFFFF); + + BuildDirectories(); + BuildFiles(); + + // Create header + RomFSHeader header; + header.header_length = sizeof(header); + header.directory_hash_table = { + /*offset*/ sizeof(header), + /*length*/ static_cast(directory_hash_table.size() * sizeof(u32_le))}; + header.directory_metadata_table = { + /*offset*/ + header.directory_hash_table.offset + header.directory_hash_table.length, + /*length*/ static_cast(directory_metadata_table.size())}; + header.file_hash_table = { + /*offset*/ + header.directory_metadata_table.offset + header.directory_metadata_table.length, + /*length*/ static_cast(file_hash_table.size() * sizeof(u32_le))}; + header.file_metadata_table = {/*offset*/ header.file_hash_table.offset + + header.file_hash_table.length, + /*length*/ static_cast(file_metadata_table.size())}; + header.file_data_offset = + Common::AlignUp(header.file_metadata_table.offset + header.file_metadata_table.length, 16); + + // Write hash table and metadata table + metadata.resize(header.file_data_offset); + std::memcpy(metadata.data(), &header, header.header_length); + std::memcpy(metadata.data() + header.directory_hash_table.offset, directory_hash_table.data(), + header.directory_hash_table.length); + std::memcpy(metadata.data() + header.directory_metadata_table.offset, + directory_metadata_table.data(), header.directory_metadata_table.length); + std::memcpy(metadata.data() + header.file_hash_table.offset, file_hash_table.data(), + header.file_hash_table.length); + std::memcpy(metadata.data() + header.file_metadata_table.offset, file_metadata_table.data(), + header.file_metadata_table.length); +} + +std::size_t LayeredFS::GetSize() const { + return metadata.size() + current_data_offset; +} + +std::size_t LayeredFS::ReadFile(std::size_t offset, std::size_t length, u8* buffer) { + ASSERT_MSG(offset + length <= GetSize(), "Out of bound"); + + std::size_t read_size = 0; + if (offset < metadata.size()) { + // First read the metadata + const auto to_read = std::min(metadata.size() - offset, length); + std::memcpy(buffer, metadata.data() + offset, to_read); + read_size += to_read; + offset = 0; + } else { + offset -= metadata.size(); + } + + // Read files + auto current = (--data_offset_map.upper_bound(offset)); + while (read_size < length) { + const auto relative_offset = offset - current->first; + std::size_t to_read{}; + if (current->second->relocation.size > relative_offset) { + to_read = + std::min(current->second->relocation.size - relative_offset, length - read_size); + } + const auto alignment = + std::min(Common::AlignUp(current->second->relocation.size, 16) - relative_offset, + length - read_size) - + to_read; + + // Read the file in different ways depending on relocation type + auto& relocation = current->second->relocation; + if (relocation.type == 0) { // none + romfs->ReadFile(relocation.original_offset + relative_offset, to_read, + buffer + read_size); + } else if (relocation.type == 1) { // replace + relocation.replace_file.Seek(relative_offset, SEEK_SET); + relocation.replace_file.ReadBytes(buffer + read_size, to_read); + } else if (relocation.type == 2) { // patch + std::memcpy(buffer + read_size, relocation.patched_file.data() + relative_offset, + to_read); + } else { + UNREACHABLE(); + } + + std::memset(buffer + read_size + to_read, 0, alignment); + + read_size += to_read + alignment; + offset += to_read + alignment; + current++; + } + + return read_size; +} + +} // namespace FileSys diff --git a/src/core/file_sys/layered_fs.h b/src/core/file_sys/layered_fs.h new file mode 100644 index 000000000..b9dcb831f --- /dev/null +++ b/src/core/file_sys/layered_fs.h @@ -0,0 +1,117 @@ +// Copyright 2020 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include +#include +#include +#include +#include "common/common_types.h" +#include "common/swap.h" +#include "core/file_sys/romfs_reader.h" + +namespace FileSys { + +struct RomFSHeader { + struct Descriptor { + u32_le offset; + u32_le length; + }; + u32_le header_length; + Descriptor directory_hash_table; + Descriptor directory_metadata_table; + Descriptor file_hash_table; + Descriptor file_metadata_table; + u32_le file_data_offset; +}; +static_assert(sizeof(RomFSHeader) == 0x28, "Size of RomFSHeader is not correct"); + +/** + * LayeredFS implementation. This basically adds a layer to another RomFSReader. + * + * patch_path: Path for RomFS replacements. Files present in this path replace or create + * corresponding files in RomFS. + * patch_ext_path: Path for RomFS extensions. Files present in this path: + * - When with an extension of ".stub", remove the corresponding file in the RomFS. + * - When with an extension of ".ips" or ".bps", patch the file in the RomFS. + */ +class LayeredFS : public RomFSReader { +public: + explicit LayeredFS(std::shared_ptr romfs, std::string patch_path, + std::string patch_ext_path); + ~LayeredFS(); + + std::size_t GetSize() const override; + std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) override; + +private: + struct File; + struct Directory { + std::string name; + std::string path; // with trailing '/' + std::vector> files; + std::vector> directories; + Directory* parent; + }; + + std::string ReadName(u32 offset, u32 name_length); + + // Loads the current directory, then its siblings, and then its children. + void LoadDirectory(Directory& current, u32 offset); + + // Load the file at offset, and then its siblings. + void LoadFile(Directory& parent, u32 offset); + + // Load replace/create relocations + void LoadRelocations(); + + // Load patch/remove relocations + void LoadExtRelocations(); + + // Calculate the offset of a single directory add it to the map and list of directories + void PrepareBuildDirectory(Directory& current); + + // Calculate the offset of a single file add it to the map and list of files + void PrepareBuildFile(File& current); + + // Recursively generate a sequence of files and directories and their offsets for all + // children of current. (The current directory itself is not handled.) + void PrepareBuild(Directory& current); + + void BuildDirectories(); + void BuildFiles(); + + void RebuildMetadata(); + + std::shared_ptr romfs; + std::string patch_path; + std::string patch_ext_path; + + RomFSHeader header; + Directory root; + std::unordered_map file_path_map; + std::unordered_map directory_path_map; + std::map data_offset_map; // assigned data offset -> file + std::vector metadata; // Includes header, hash table and metadata + + // Used for rebuilding header + std::vector directory_hash_table; + std::vector file_hash_table; + + std::unordered_map + directory_metadata_offset_map; // directory -> metadata offset + std::vector directory_list; // sequence of directories to be written to metadata + u64 current_directory_offset{}; // current directory metadata offset + std::vector directory_metadata_table; // rebuilt directory metadata table + + std::unordered_map file_metadata_offset_map; // file -> metadata offset + std::vector file_list; // sequence of files to be written to metadata + u64 current_file_offset{}; // current file metadata offset + std::vector file_metadata_table; // rebuilt file metadata table + u64 current_data_offset{}; // current assigned data offset +}; + +} // namespace FileSys From 8a570bf00c7b0e709a73f64b78ad4daeb9765d17 Mon Sep 17 00:00:00 2001 From: zhupengfei Date: Fri, 7 Feb 2020 13:45:10 +0800 Subject: [PATCH 03/41] core: Use LayeredFS while reading RomFS Only enabled for NCCHs that do not have an override romfs. LayeredFS files should be put in the `load` directory in User Directory. The directory structure is similar to yuzu's but currently does not allow named mods yet. Replacement files should be put in `load/mods//romfs` while patches/stubs should be put in `load/mods/<Title ID>/romfs_ext`. --- src/core/file_sys/ncch_container.cpp | 28 +++++++++++++++++++++------- src/core/loader/3dsx.cpp | 4 ++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp index f0687fa9e..53432f945 100644 --- a/src/core/file_sys/ncch_container.cpp +++ b/src/core/file_sys/ncch_container.cpp @@ -11,6 +11,7 @@ #include "common/common_types.h" #include "common/logging/log.h" #include "core/core.h" +#include "core/file_sys/layered_fs.h" #include "core/file_sys/ncch_container.h" #include "core/file_sys/patch.h" #include "core/file_sys/seed_db.h" @@ -597,12 +598,24 @@ Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr<RomFSReader>& romf if (!romfs_file_inner.IsOpen()) return Loader::ResultStatus::Error; + std::shared_ptr<RomFSReader> direct_romfs; if (is_encrypted) { - romfs_file = std::make_shared<RomFSReader>(std::move(romfs_file_inner), romfs_offset, - romfs_size, secondary_key, romfs_ctr, 0x1000); + direct_romfs = + std::make_shared<DirectRomFSReader>(std::move(romfs_file_inner), romfs_offset, + romfs_size, secondary_key, romfs_ctr, 0x1000); } else { - romfs_file = - std::make_shared<RomFSReader>(std::move(romfs_file_inner), romfs_offset, romfs_size); + direct_romfs = std::make_shared<DirectRomFSReader>(std::move(romfs_file_inner), + romfs_offset, romfs_size); + } + + const auto path = + fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + ncch_header.program_id); + if (FileUtil::Exists(path + "romfs/") || FileUtil::Exists(path + "romfs_ext/")) { + romfs_file = std::make_shared<LayeredFS>(std::move(direct_romfs), path + "romfs/", + path + "romfs_ext/"); + } else { + romfs_file = std::move(direct_romfs); } return Loader::ResultStatus::Success; @@ -614,9 +627,10 @@ Loader::ResultStatus NCCHContainer::ReadOverrideRomFS(std::shared_ptr<RomFSReade if (FileUtil::Exists(split_filepath)) { FileUtil::IOFile romfs_file_inner(split_filepath, "rb"); if (romfs_file_inner.IsOpen()) { - LOG_WARNING(Service_FS, "File {} overriding built-in RomFS", split_filepath); - romfs_file = std::make_shared<RomFSReader>(std::move(romfs_file_inner), 0, - romfs_file_inner.GetSize()); + LOG_WARNING(Service_FS, "File {} overriding built-in RomFS; LayeredFS not enabled", + split_filepath); + romfs_file = std::make_shared<DirectRomFSReader>(std::move(romfs_file_inner), 0, + romfs_file_inner.GetSize()); return Loader::ResultStatus::Success; } } diff --git a/src/core/loader/3dsx.cpp b/src/core/loader/3dsx.cpp index 7629ad376..84321011b 100644 --- a/src/core/loader/3dsx.cpp +++ b/src/core/loader/3dsx.cpp @@ -309,8 +309,8 @@ ResultStatus AppLoader_THREEDSX::ReadRomFS(std::shared_ptr<FileSys::RomFSReader> if (!romfs_file_inner.IsOpen()) return ResultStatus::Error; - romfs_file = std::make_shared<FileSys::RomFSReader>(std::move(romfs_file_inner), - romfs_offset, romfs_size); + romfs_file = std::make_shared<FileSys::DirectRomFSReader>(std::move(romfs_file_inner), + romfs_offset, romfs_size); return ResultStatus::Success; } From 91e5a39a08adbb7d149fae29046a48a5c7329ff3 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Fri, 7 Feb 2020 13:50:29 +0800 Subject: [PATCH 04/41] core/file_sys: Allow exefs mods to be read from mods path The original path (file_name.exefsdir) is still supported, but alternatively users can choose to put exefs patches in the same place as LayeredFS files (`load/mods/<Title ID>/exefs`). --- src/core/file_sys/ncch_container.cpp | 35 ++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp index 53432f945..dcc8dd57b 100644 --- a/src/core/file_sys/ncch_container.cpp +++ b/src/core/file_sys/ncch_container.cpp @@ -513,7 +513,13 @@ Loader::ResultStatus NCCHContainer::ApplyCodePatch(std::vector<u8>& code) const std::string path; bool (*patch_fn)(const std::vector<u8>& patch, std::vector<u8>& code); }; - const std::array<PatchLocation, 2> patch_paths{{ + + const auto mods_path = + fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + ncch_header.program_id); + const std::array<PatchLocation, 4> patch_paths{{ + {mods_path + "exefs/code.ips", Patch::ApplyIpsPatch}, + {mods_path + "exefs/code.bps", Patch::ApplyBpsPatch}, {filepath + ".exefsdir/code.ips", Patch::ApplyIpsPatch}, {filepath + ".exefsdir/code.bps", Patch::ApplyBpsPatch}, }}; @@ -552,17 +558,26 @@ Loader::ResultStatus NCCHContainer::LoadOverrideExeFSSection(const char* name, else return Loader::ResultStatus::Error; - std::string section_override = filepath + ".exefsdir/" + override_name; - FileUtil::IOFile section_file(section_override, "rb"); + const auto mods_path = + fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + ncch_header.program_id); + std::array<std::string, 2> override_paths{{ + mods_path + "exefs/" + override_name, + filepath + ".exefsdir/" + override_name, + }}; - if (section_file.IsOpen()) { - auto section_size = section_file.GetSize(); - buffer.resize(section_size); + for (const auto& path : override_paths) { + FileUtil::IOFile section_file(path, "rb"); - section_file.Seek(0, SEEK_SET); - if (section_file.ReadBytes(&buffer[0], section_size) == section_size) { - LOG_WARNING(Service_FS, "File {} overriding built-in ExeFS file", section_override); - return Loader::ResultStatus::Success; + if (section_file.IsOpen()) { + auto section_size = section_file.GetSize(); + buffer.resize(section_size); + + section_file.Seek(0, SEEK_SET); + if (section_file.ReadBytes(&buffer[0], section_size) == section_size) { + LOG_WARNING(Service_FS, "File {} overriding built-in ExeFS file", path); + return Loader::ResultStatus::Success; + } } } return Loader::ResultStatus::ErrorNotUsed; From 7c652a0479406eac115b503b5edc67658a825448 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Fri, 7 Feb 2020 14:44:34 +0800 Subject: [PATCH 05/41] citra_qt: Add 'Open Mods Location' --- src/citra_qt/game_list.cpp | 9 +++++++++ src/citra_qt/game_list.h | 3 ++- src/citra_qt/main.cpp | 5 +++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/citra_qt/game_list.cpp b/src/citra_qt/game_list.cpp index 2d5af1b0c..4ba7e99d6 100644 --- a/src/citra_qt/game_list.cpp +++ b/src/citra_qt/game_list.cpp @@ -468,6 +468,7 @@ void GameList::AddGamePopup(QMenu& context_menu, const QString& path, u64 progra QAction* open_texture_dump_location = context_menu.addAction(tr("Open Texture Dump Location")); QAction* open_texture_load_location = context_menu.addAction(tr("Open Custom Texture Location")); + QAction* open_mods_location = context_menu.addAction(tr("Open Mods Location")); QAction* navigate_to_gamedb_entry = context_menu.addAction(tr("Navigate to GameDB entry")); const bool is_application = @@ -497,6 +498,7 @@ void GameList::AddGamePopup(QMenu& context_menu, const QString& path, u64 progra open_texture_dump_location->setVisible(is_application); open_texture_load_location->setVisible(is_application); + open_mods_location->setVisible(is_application); navigate_to_gamedb_entry->setVisible(it != compatibility_list.end()); @@ -526,6 +528,13 @@ void GameList::AddGamePopup(QMenu& context_menu, const QString& path, u64 progra emit OpenFolderRequested(program_id, GameListOpenTarget::TEXTURE_LOAD); } }); + connect(open_mods_location, &QAction::triggered, [this, program_id] { + if (FileUtil::CreateFullPath(fmt::format("{}mods/{:016X}/", + FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + program_id))) { + emit OpenFolderRequested(program_id, GameListOpenTarget::MODS); + } + }); connect(navigate_to_gamedb_entry, &QAction::triggered, [this, program_id]() { emit NavigateToGamedbEntryRequested(program_id, compatibility_list); }); diff --git a/src/citra_qt/game_list.h b/src/citra_qt/game_list.h index ef280ef04..635fbb39b 100644 --- a/src/citra_qt/game_list.h +++ b/src/citra_qt/game_list.h @@ -35,7 +35,8 @@ enum class GameListOpenTarget { APPLICATION = 2, UPDATE_DATA = 3, TEXTURE_DUMP = 4, - TEXTURE_LOAD = 5 + TEXTURE_LOAD = 5, + MODS = 6, }; class GameList : public QWidget { diff --git a/src/citra_qt/main.cpp b/src/citra_qt/main.cpp index 0a47c7728..cbb29e6d7 100644 --- a/src/citra_qt/main.cpp +++ b/src/citra_qt/main.cpp @@ -1141,6 +1141,11 @@ void GMainWindow::OnGameListOpenFolder(u64 data_id, GameListOpenTarget target) { path = fmt::format("{}textures/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), data_id); break; + case GameListOpenTarget::MODS: + open_target = "Mods"; + path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + data_id); + break; default: LOG_ERROR(Frontend, "Unexpected target {}", static_cast<int>(target)); return; From 53d0c618a03cfa472001d456ddabccf7ab33eba2 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Fri, 7 Feb 2020 14:57:32 +0800 Subject: [PATCH 06/41] core/file_sys: Read mods for the original title for updates Updates can override RomFS and ExeFS, therefore we should apply the mods to them as well. --- src/core/file_sys/ncch_container.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp index dcc8dd57b..8ddde9e60 100644 --- a/src/core/file_sys/ncch_container.cpp +++ b/src/core/file_sys/ncch_container.cpp @@ -516,7 +516,7 @@ Loader::ResultStatus NCCHContainer::ApplyCodePatch(std::vector<u8>& code) const const auto mods_path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), - ncch_header.program_id); + ncch_header.program_id & 0x00040000'FFFFFFFF); const std::array<PatchLocation, 4> patch_paths{{ {mods_path + "exefs/code.ips", Patch::ApplyIpsPatch}, {mods_path + "exefs/code.bps", Patch::ApplyBpsPatch}, @@ -560,7 +560,7 @@ Loader::ResultStatus NCCHContainer::LoadOverrideExeFSSection(const char* name, const auto mods_path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), - ncch_header.program_id); + ncch_header.program_id & 0x00040000'FFFFFFFF); std::array<std::string, 2> override_paths{{ mods_path + "exefs/" + override_name, filepath + ".exefsdir/" + override_name, @@ -625,7 +625,7 @@ Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr<RomFSReader>& romf const auto path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), - ncch_header.program_id); + ncch_header.program_id & 0x00040000'FFFFFFFF); if (FileUtil::Exists(path + "romfs/") || FileUtil::Exists(path + "romfs_ext/")) { romfs_file = std::make_shared<LayeredFS>(std::move(direct_romfs), path + "romfs/", path + "romfs_ext/"); From eed9de23369a170e4ab841409edb03cb1ffa1550 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Fri, 7 Feb 2020 15:55:35 +0800 Subject: [PATCH 07/41] core/file_sys: Allow exheader replacement to be read from mods path The previous method (filename.exheader) can still be used. --- src/core/file_sys/ncch_container.cpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp index 8ddde9e60..11be43c0e 100644 --- a/src/core/file_sys/ncch_container.cpp +++ b/src/core/file_sys/ncch_container.cpp @@ -304,8 +304,22 @@ Loader::ResultStatus NCCHContainer::Load() { } } - FileUtil::IOFile exheader_override_file{filepath + ".exheader", "rb"}; - const bool has_exheader_override = read_exheader(exheader_override_file); + const auto mods_path = + fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + ncch_header.program_id & 0x00040000'FFFFFFFF); + std::array<std::string, 2> exheader_override_paths{{ + mods_path + "exheader.bin", + filepath + ".exheader", + }}; + + bool has_exheader_override = false; + for (const auto& path : exheader_override_paths) { + FileUtil::IOFile exheader_override_file{path, "rb"}; + if (read_exheader(exheader_override_file)) { + has_exheader_override = true; + break; + } + } if (has_exheader_override) { if (exheader_header.system_info.jump_id != exheader_header.arm11_system_local_caps.program_id) { From 6e0afbaa19d646fb0b6295c88c3656793c471704 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Fri, 7 Feb 2020 16:26:33 +0800 Subject: [PATCH 08/41] Fix build Explicitly use `std::min<std::size_t>` Added virtual destructor --- src/core/file_sys/layered_fs.cpp | 9 +++++---- src/core/file_sys/layered_fs.h | 2 +- src/core/file_sys/romfs_reader.h | 4 ++++ 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/core/file_sys/layered_fs.cpp b/src/core/file_sys/layered_fs.cpp index 9a194569a..80719d3fa 100644 --- a/src/core/file_sys/layered_fs.cpp +++ b/src/core/file_sys/layered_fs.cpp @@ -515,12 +515,13 @@ std::size_t LayeredFS::ReadFile(std::size_t offset, std::size_t length, u8* buff const auto relative_offset = offset - current->first; std::size_t to_read{}; if (current->second->relocation.size > relative_offset) { - to_read = - std::min(current->second->relocation.size - relative_offset, length - read_size); + to_read = std::min<std::size_t>(current->second->relocation.size - relative_offset, + length - read_size); } const auto alignment = - std::min(Common::AlignUp(current->second->relocation.size, 16) - relative_offset, - length - read_size) - + std::min<std::size_t>(Common::AlignUp(current->second->relocation.size, 16) - + relative_offset, + length - read_size) - to_read; // Read the file in different ways depending on relocation type diff --git a/src/core/file_sys/layered_fs.h b/src/core/file_sys/layered_fs.h index b9dcb831f..4f8844d98 100644 --- a/src/core/file_sys/layered_fs.h +++ b/src/core/file_sys/layered_fs.h @@ -42,7 +42,7 @@ class LayeredFS : public RomFSReader { public: explicit LayeredFS(std::shared_ptr<RomFSReader> romfs, std::string patch_path, std::string patch_ext_path); - ~LayeredFS(); + ~LayeredFS() override; std::size_t GetSize() const override; std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) override; diff --git a/src/core/file_sys/romfs_reader.h b/src/core/file_sys/romfs_reader.h index 5ee39015b..df0318c99 100644 --- a/src/core/file_sys/romfs_reader.h +++ b/src/core/file_sys/romfs_reader.h @@ -11,6 +11,8 @@ namespace FileSys { */ class RomFSReader { public: + virtual ~RomFSReader() = default; + virtual std::size_t GetSize() const = 0; virtual std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) = 0; }; @@ -30,6 +32,8 @@ public: : is_encrypted(true), file(std::move(file)), key(key), ctr(ctr), file_offset(file_offset), crypto_offset(crypto_offset), data_size(data_size) {} + ~DirectRomFSReader() override = default; + std::size_t GetSize() const override { return data_size; } From 2ec99b83aa4efa75019ccdf1094307c2869fd74b Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Fri, 7 Feb 2020 23:45:02 +0800 Subject: [PATCH 09/41] core: Reset archive_manager on shutdown. This holds the archives which include the SelfNCCH archive which holds the RomFS files. If we don't reset it the LayeredFS class can't get destructed and mods files won't be released. --- src/core/core.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core/core.cpp b/src/core/core.cpp index ebaee4f87..d217ad003 100644 --- a/src/core/core.cpp +++ b/src/core/core.cpp @@ -360,6 +360,7 @@ void System::Shutdown() { perf_stats.reset(); rpc_server.reset(); cheat_engine.reset(); + archive_manager.reset(); service_manager.reset(); dsp_core.reset(); cpu_core.reset(); From db18f6c79af679fcd07c75e3355639ae52da78f7 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Fri, 7 Feb 2020 23:53:00 +0800 Subject: [PATCH 10/41] Address review simplify code --- src/core/file_sys/layered_fs.cpp | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/core/file_sys/layered_fs.cpp b/src/core/file_sys/layered_fs.cpp index 80719d3fa..65ececd31 100644 --- a/src/core/file_sys/layered_fs.cpp +++ b/src/core/file_sys/layered_fs.cpp @@ -306,18 +306,10 @@ void LayeredFS::PrepareBuild(Directory& current) { // Implementation from 3dbrew u32 CalcHash(const std::string& name, u32 parent_offset) { u32 hash = parent_offset ^ 123456789; - std::u16string u16name = Common::UTF8ToUTF16(name); - std::vector<u16_le> tmp_buffer(u16name.size()); - std::transform(u16name.begin(), u16name.end(), tmp_buffer.begin(), [](char16_t character) { - return static_cast<u16_le>(static_cast<u16>(character)); - }); - - std::vector<u8> buffer(tmp_buffer.size() * 2); - std::memcpy(buffer.data(), tmp_buffer.data(), buffer.size()); - for (std::size_t i = 0; i < buffer.size(); i += 2) { + for (char16_t c : u16name) { hash = (hash >> 5) | (hash << 27); - hash ^= static_cast<u16>((buffer[i]) | (buffer[i + 1] << 8)); + hash ^= static_cast<u16>(c); } return hash; } From 13e2d534e9dc9ec6208df09aa91fb272003844a4 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Sun, 9 Feb 2020 20:59:31 +0800 Subject: [PATCH 11/41] core: Add dump RomFS support This is added to LayeredFS, then the NCCH container and then the loader interface. --- src/core/file_sys/layered_fs.cpp | 64 ++++++++++++++++++++++++++-- src/core/file_sys/layered_fs.h | 8 +++- src/core/file_sys/ncch_container.cpp | 22 +++++++++- src/core/file_sys/ncch_container.h | 10 ++++- src/core/loader/loader.h | 18 ++++++++ src/core/loader/ncch.cpp | 12 ++++++ src/core/loader/ncch.h | 4 ++ 7 files changed, 131 insertions(+), 7 deletions(-) diff --git a/src/core/file_sys/layered_fs.cpp b/src/core/file_sys/layered_fs.cpp index 65ececd31..8cfe73846 100644 --- a/src/core/file_sys/layered_fs.cpp +++ b/src/core/file_sys/layered_fs.cpp @@ -52,7 +52,7 @@ struct FileMetadata { static_assert(sizeof(FileMetadata) == 0x20, "Size of FileMetadata is not correct"); LayeredFS::LayeredFS(std::shared_ptr<RomFSReader> romfs_, std::string patch_path_, - std::string patch_ext_path_) + std::string patch_ext_path_, bool load_relocations) : romfs(std::move(romfs_)), patch_path(std::move(patch_path_)), patch_ext_path(std::move(patch_ext_path_)) { @@ -64,8 +64,10 @@ LayeredFS::LayeredFS(std::shared_ptr<RomFSReader> romfs_, std::string patch_path root.parent = &root; LoadDirectory(root, 0); - LoadRelocations(); - LoadExtRelocations(); + if (load_relocations) { + LoadRelocations(); + LoadExtRelocations(); + } RebuildMetadata(); } @@ -541,4 +543,60 @@ std::size_t LayeredFS::ReadFile(std::size_t offset, std::size_t length, u8* buff return read_size; } +bool LayeredFS::ExtractDirectory(Directory& current, const std::string& target_path) { + if (!FileUtil::CreateFullPath(target_path + current.path)) { + LOG_ERROR(Service_FS, "Could not create path {}", target_path + current.path); + return false; + } + + constexpr std::size_t BufferSize = 0x10000; + std::array<u8, BufferSize> buffer; + for (const auto& file : current.files) { + // Extract file + const auto path = target_path + file->path; + LOG_INFO(Service_FS, "Extracting {} to {}", file->path, path); + + FileUtil::IOFile target_file(path, "wb"); + if (!target_file) { + LOG_ERROR(Service_FS, "Could not open file {}", path); + return false; + } + + std::size_t written = 0; + while (written < file->relocation.size) { + const auto to_read = + std::min<std::size_t>(buffer.size(), file->relocation.size - written); + if (romfs->ReadFile(file->relocation.original_offset + written, to_read, + buffer.data()) != to_read) { + LOG_ERROR(Service_FS, "Could not read from RomFS"); + return false; + } + + if (target_file.WriteBytes(buffer.data(), to_read) != to_read) { + LOG_ERROR(Service_FS, "Could not write to file {}", path); + return false; + } + + written += to_read; + } + } + + for (const auto& directory : current.directories) { + if (!ExtractDirectory(*directory, target_path)) { + return false; + } + } + + return true; +} + +bool LayeredFS::DumpRomFS(const std::string& target_path) { + std::string path = target_path; + if (path.back() == '/' || path.back() == '\\') { + path.erase(path.size() - 1, 1); + } + + return ExtractDirectory(root, path); +} + } // namespace FileSys diff --git a/src/core/file_sys/layered_fs.h b/src/core/file_sys/layered_fs.h index 4f8844d98..956eedcfa 100644 --- a/src/core/file_sys/layered_fs.h +++ b/src/core/file_sys/layered_fs.h @@ -41,12 +41,14 @@ static_assert(sizeof(RomFSHeader) == 0x28, "Size of RomFSHeader is not correct") class LayeredFS : public RomFSReader { public: explicit LayeredFS(std::shared_ptr<RomFSReader> romfs, std::string patch_path, - std::string patch_ext_path); + std::string patch_ext_path, bool load_relocations = true); ~LayeredFS() override; std::size_t GetSize() const override; std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) override; + bool DumpRomFS(const std::string& target_path); + private: struct File; struct Directory { @@ -84,6 +86,10 @@ private: void BuildDirectories(); void BuildFiles(); + // Recursively extract a directory and all its contents to target_path + // target_path should be without trailing '/'. + bool ExtractDirectory(Directory& current, const std::string& target_path); + void RebuildMetadata(); std::shared_ptr<RomFSReader> romfs; diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp index 11be43c0e..2e549a894 100644 --- a/src/core/file_sys/ncch_container.cpp +++ b/src/core/file_sys/ncch_container.cpp @@ -597,7 +597,8 @@ Loader::ResultStatus NCCHContainer::LoadOverrideExeFSSection(const char* name, return Loader::ResultStatus::ErrorNotUsed; } -Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr<RomFSReader>& romfs_file) { +Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr<RomFSReader>& romfs_file, + bool use_layered_fs) { Loader::ResultStatus result = Load(); if (result != Loader::ResultStatus::Success) return result; @@ -640,7 +641,9 @@ Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr<RomFSReader>& romf const auto path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), ncch_header.program_id & 0x00040000'FFFFFFFF); - if (FileUtil::Exists(path + "romfs/") || FileUtil::Exists(path + "romfs_ext/")) { + if (use_layered_fs && + (FileUtil::Exists(path + "romfs/") || FileUtil::Exists(path + "romfs_ext/"))) { + romfs_file = std::make_shared<LayeredFS>(std::move(direct_romfs), path + "romfs/", path + "romfs_ext/"); } else { @@ -650,6 +653,21 @@ Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr<RomFSReader>& romf return Loader::ResultStatus::Success; } +Loader::ResultStatus NCCHContainer::DumpRomFS(const std::string& target_path) { + std::shared_ptr<RomFSReader> direct_romfs; + Loader::ResultStatus result = ReadRomFS(direct_romfs, false); + if (result != Loader::ResultStatus::Success) + return result; + + std::shared_ptr<LayeredFS> layered_fs = + std::make_shared<LayeredFS>(std::move(direct_romfs), "", "", false); + + if (!layered_fs->DumpRomFS(target_path)) { + return Loader::ResultStatus::Error; + } + return Loader::ResultStatus::Success; +} + Loader::ResultStatus NCCHContainer::ReadOverrideRomFS(std::shared_ptr<RomFSReader>& romfs_file) { // Check for RomFS overrides std::string split_filepath = filepath + ".romfs"; diff --git a/src/core/file_sys/ncch_container.h b/src/core/file_sys/ncch_container.h index f06ee8ef6..8deda1fff 100644 --- a/src/core/file_sys/ncch_container.h +++ b/src/core/file_sys/ncch_container.h @@ -247,7 +247,15 @@ public: * @param size The size of the romfs * @return ResultStatus result of function */ - Loader::ResultStatus ReadRomFS(std::shared_ptr<RomFSReader>& romfs_file); + Loader::ResultStatus ReadRomFS(std::shared_ptr<RomFSReader>& romfs_file, + bool use_layered_fs = true); + + /** + * Dump the RomFS of the NCCH container to the user folder. + * @param target_path target path to dump to + * @return ResultStatus result of function. + */ + Loader::ResultStatus DumpRomFS(const std::string& target_path); /** * Get the override RomFS of the NCCH container diff --git a/src/core/loader/loader.h b/src/core/loader/loader.h index 20e84c6a9..0414f181c 100644 --- a/src/core/loader/loader.h +++ b/src/core/loader/loader.h @@ -186,6 +186,15 @@ public: return ResultStatus::ErrorNotImplemented; } + /** + * Dump the RomFS of the applciation + * @param target_path The target path to dump to + * @return ResultStatus result of function + */ + virtual ResultStatus DumpRomFS(const std::string& target_path) { + return ResultStatus::ErrorNotImplemented; + } + /** * Get the update RomFS of the application * Since the RomFS can be huge, we return a file reference instead of copying to a buffer @@ -196,6 +205,15 @@ public: return ResultStatus::ErrorNotImplemented; } + /** + * Dump the update RomFS of the applciation + * @param target_path The target path to dump to + * @return ResultStatus result of function + */ + virtual ResultStatus DumpUpdateRomFS(const std::string& target_path) { + return ResultStatus::ErrorNotImplemented; + } + /** * Get the title of the application * @param title Reference to store the application title into diff --git a/src/core/loader/ncch.cpp b/src/core/loader/ncch.cpp index 2e688e011..21e607ad5 100644 --- a/src/core/loader/ncch.cpp +++ b/src/core/loader/ncch.cpp @@ -254,6 +254,18 @@ ResultStatus AppLoader_NCCH::ReadUpdateRomFS(std::shared_ptr<FileSys::RomFSReade return ResultStatus::Success; } +ResultStatus AppLoader_NCCH::DumpRomFS(const std::string& target_path) { + return base_ncch.DumpRomFS(target_path); +} + +ResultStatus AppLoader_NCCH::DumpUpdateRomFS(const std::string& target_path) { + u64 program_id; + ReadProgramId(program_id); + update_ncch.OpenFile(Service::AM::GetTitleContentPath(Service::FS::MediaType::SDMC, + program_id | UPDATE_MASK)); + return update_ncch.DumpRomFS(target_path); +} + ResultStatus AppLoader_NCCH::ReadTitle(std::string& title) { std::vector<u8> data; Loader::SMDH smdh; diff --git a/src/core/loader/ncch.h b/src/core/loader/ncch.h index 7c86f85d8..041cfddbd 100644 --- a/src/core/loader/ncch.h +++ b/src/core/loader/ncch.h @@ -59,6 +59,10 @@ public: ResultStatus ReadUpdateRomFS(std::shared_ptr<FileSys::RomFSReader>& romfs_file) override; + ResultStatus DumpRomFS(const std::string& target_path) override; + + ResultStatus DumpUpdateRomFS(const std::string& target_path) override; + ResultStatus ReadTitle(std::string& title) override; private: From b87bc5d35157cc70c0284cc521d6d4a9ec4f636b Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Sun, 9 Feb 2020 21:01:56 +0800 Subject: [PATCH 12/41] citra_qt: Add 'Dump RomFS' menu action A progress dialog will be displayed. However no progress is reported and the user also cannot cancel it. --- src/citra_qt/game_list.cpp | 4 ++++ src/citra_qt/game_list.h | 1 + src/citra_qt/main.cpp | 41 ++++++++++++++++++++++++++++++++++++++ src/citra_qt/main.h | 1 + 4 files changed, 47 insertions(+) diff --git a/src/citra_qt/game_list.cpp b/src/citra_qt/game_list.cpp index 4ba7e99d6..8a5df5172 100644 --- a/src/citra_qt/game_list.cpp +++ b/src/citra_qt/game_list.cpp @@ -469,6 +469,7 @@ void GameList::AddGamePopup(QMenu& context_menu, const QString& path, u64 progra QAction* open_texture_load_location = context_menu.addAction(tr("Open Custom Texture Location")); QAction* open_mods_location = context_menu.addAction(tr("Open Mods Location")); + QAction* dump_romfs = context_menu.addAction(tr("Dump RomFS")); QAction* navigate_to_gamedb_entry = context_menu.addAction(tr("Navigate to GameDB entry")); const bool is_application = @@ -499,6 +500,7 @@ void GameList::AddGamePopup(QMenu& context_menu, const QString& path, u64 progra open_texture_dump_location->setVisible(is_application); open_texture_load_location->setVisible(is_application); open_mods_location->setVisible(is_application); + dump_romfs->setVisible(is_application); navigate_to_gamedb_entry->setVisible(it != compatibility_list.end()); @@ -535,6 +537,8 @@ void GameList::AddGamePopup(QMenu& context_menu, const QString& path, u64 progra emit OpenFolderRequested(program_id, GameListOpenTarget::MODS); } }); + connect(dump_romfs, &QAction::triggered, + [this, path, program_id] { emit DumpRomFSRequested(path, program_id); }); connect(navigate_to_gamedb_entry, &QAction::triggered, [this, program_id]() { emit NavigateToGamedbEntryRequested(program_id, compatibility_list); }); diff --git a/src/citra_qt/game_list.h b/src/citra_qt/game_list.h index 635fbb39b..334089037 100644 --- a/src/citra_qt/game_list.h +++ b/src/citra_qt/game_list.h @@ -82,6 +82,7 @@ signals: void OpenFolderRequested(u64 program_id, GameListOpenTarget target); void NavigateToGamedbEntryRequested(u64 program_id, const CompatibilityList& compatibility_list); + void DumpRomFSRequested(QString game_path, u64 program_id); void OpenDirectory(const QString& directory); void AddDirectory(); void ShowList(bool show); diff --git a/src/citra_qt/main.cpp b/src/citra_qt/main.cpp index cbb29e6d7..8534007f6 100644 --- a/src/citra_qt/main.cpp +++ b/src/citra_qt/main.cpp @@ -568,6 +568,7 @@ void GMainWindow::ConnectWidgetEvents() { connect(game_list, &GameList::OpenFolderRequested, this, &GMainWindow::OnGameListOpenFolder); connect(game_list, &GameList::NavigateToGamedbEntryRequested, this, &GMainWindow::OnGameListNavigateToGamedbEntry); + connect(game_list, &GameList::DumpRomFSRequested, this, &GMainWindow::OnGameListDumpRomFS); connect(game_list, &GameList::AddDirectory, this, &GMainWindow::OnGameListAddDirectory); connect(game_list_placeholder, &GameListPlaceholder::AddDirectory, this, &GMainWindow::OnGameListAddDirectory); @@ -1177,6 +1178,46 @@ void GMainWindow::OnGameListNavigateToGamedbEntry(u64 program_id, QDesktopServices::openUrl(QUrl("https://citra-emu.org/game/" + directory)); } +void GMainWindow::OnGameListDumpRomFS(QString game_path, u64 program_id) { + auto* dialog = new QProgressDialog(tr("Dumping..."), tr("Cancel"), 0, 0, this); + dialog->setWindowModality(Qt::WindowModal); + dialog->setWindowFlags(dialog->windowFlags() & + ~(Qt::WindowCloseButtonHint | Qt::WindowContextHelpButtonHint)); + dialog->setCancelButton(nullptr); + dialog->setMinimumDuration(0); + dialog->setValue(0); + + const auto base_path = fmt::format( + "{}romfs/{:016X}", FileUtil::GetUserPath(FileUtil::UserPath::DumpDir), program_id); + const auto update_path = + fmt::format("{}romfs/{:016X}", FileUtil::GetUserPath(FileUtil::UserPath::DumpDir), + program_id | 0x0004000e00000000); + using FutureWatcher = QFutureWatcher<std::pair<Loader::ResultStatus, Loader::ResultStatus>>; + auto* future_watcher = new FutureWatcher(this); + connect(future_watcher, &FutureWatcher::finished, + [this, program_id, dialog, base_path, update_path, future_watcher] { + dialog->hide(); + const auto& [base, update] = future_watcher->result(); + if (base != Loader::ResultStatus::Success) { + QMessageBox::critical( + this, tr("Citra"), + tr("Could not dump base RomFS.\nRefer to the log for details.")); + return; + } + QDesktopServices::openUrl(QUrl::fromLocalFile(QString::fromStdString(base_path))); + if (update == Loader::ResultStatus::Success) { + QDesktopServices::openUrl( + QUrl::fromLocalFile(QString::fromStdString(update_path))); + } + }); + + auto future = QtConcurrent::run([game_path, base_path, update_path] { + std::unique_ptr<Loader::AppLoader> loader = Loader::GetLoader(game_path.toStdString()); + return std::make_pair(loader->DumpRomFS(base_path), loader->DumpUpdateRomFS(update_path)); + }); + future_watcher->setFuture(future); +} + void GMainWindow::OnGameListOpenDirectory(const QString& directory) { QString path; if (directory == QStringLiteral("INSTALLED")) { diff --git a/src/citra_qt/main.h b/src/citra_qt/main.h index 91d7eed1b..75a1bbb3d 100644 --- a/src/citra_qt/main.h +++ b/src/citra_qt/main.h @@ -169,6 +169,7 @@ private slots: void OnGameListOpenFolder(u64 program_id, GameListOpenTarget target); void OnGameListNavigateToGamedbEntry(u64 program_id, const CompatibilityList& compatibility_list); + void OnGameListDumpRomFS(QString game_path, u64 program_id); void OnGameListOpenDirectory(const QString& directory); void OnGameListAddDirectory(); void OnGameListShowList(bool show); From d9ae4c332d11466351dc231531bf2a4f424ae8c1 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Sun, 9 Feb 2020 21:48:42 +0800 Subject: [PATCH 13/41] layered_fs: Do not open all replacement files on load Instead open them when we want to read them. This is because the standard library has a limit on the number of opened files. --- src/core/file_sys/layered_fs.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/core/file_sys/layered_fs.cpp b/src/core/file_sys/layered_fs.cpp index 8cfe73846..a77ab38f1 100644 --- a/src/core/file_sys/layered_fs.cpp +++ b/src/core/file_sys/layered_fs.cpp @@ -18,7 +18,7 @@ namespace FileSys { struct FileRelocationInfo { int type; // 0 - none, 1 - replaced / created, 2 - patched, 3 - removed u64 original_offset; // Type 0. Offset is absolute - FileUtil::IOFile replace_file; // Type 1 + std::string replace_file_path; // Type 1 std::vector<u8> patched_file; // Type 2 u64 size; // Relocated file size }; @@ -175,14 +175,9 @@ void LayeredFS::LoadRelocations() { } auto* file = file_path_map.at(path); - file->relocation.replace_file = FileUtil::IOFile(directory + virtual_name, "rb"); - if (file->relocation.replace_file) { - file->relocation.type = 1; - file->relocation.size = file->relocation.replace_file.GetSize(); - LOG_INFO(Service_FS, "LayeredFS replacement file in use for {}", path); - } else { - LOG_ERROR(Service_FS, "Could not open replacement file for {}", path); - } + file->relocation.type = 1; + file->relocation.replace_file_path = directory + virtual_name; + LOG_INFO(Service_FS, "LayeredFS replacement file in use for {}", path); return true; }; @@ -524,8 +519,14 @@ std::size_t LayeredFS::ReadFile(std::size_t offset, std::size_t length, u8* buff romfs->ReadFile(relocation.original_offset + relative_offset, to_read, buffer + read_size); } else if (relocation.type == 1) { // replace - relocation.replace_file.Seek(relative_offset, SEEK_SET); - relocation.replace_file.ReadBytes(buffer + read_size, to_read); + FileUtil::IOFile replace_file(relocation.replace_file_path, "rb"); + if (replace_file) { + replace_file.Seek(relative_offset, SEEK_SET); + replace_file.ReadBytes(buffer + read_size, to_read); + } else { + LOG_ERROR(Service_FS, "Could not open replacement file for {}", + current->second->path); + } } else if (relocation.type == 2) { // patch std::memcpy(buffer + read_size, relocation.patched_file.data() + relative_offset, to_read); From b39a611a3d79912476161da4efffa259c2dd0223 Mon Sep 17 00:00:00 2001 From: Lioncash <mathew1800@gmail.com> Date: Mon, 3 Feb 2020 09:09:06 -0500 Subject: [PATCH 14/41] input_common/udp: Add missing header guard --- src/input_common/udp/udp.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/input_common/udp/udp.h b/src/input_common/udp/udp.h index ea3de60bb..0e56e431b 100644 --- a/src/input_common/udp/udp.h +++ b/src/input_common/udp/udp.h @@ -2,6 +2,8 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#pragma once + #include <memory> #include <unordered_map> #include "input_common/main.h" From d7a58fe24dd0170a8920df3d8bf20e7c9745f2e8 Mon Sep 17 00:00:00 2001 From: Lioncash <mathew1800@gmail.com> Date: Mon, 3 Feb 2020 09:10:59 -0500 Subject: [PATCH 15/41] input_common/udp: Remove unnecessary inclusions --- src/input_common/udp/client.h | 1 - src/input_common/udp/protocol.h | 1 - src/input_common/udp/udp.cpp | 4 +++- src/input_common/udp/udp.h | 6 +----- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/input_common/udp/client.h b/src/input_common/udp/client.h index 5177f46be..82bdb68d9 100644 --- a/src/input_common/udp/client.h +++ b/src/input_common/udp/client.h @@ -11,7 +11,6 @@ #include <string> #include <thread> #include <tuple> -#include <vector> #include "common/common_types.h" #include "common/thread.h" #include "common/vector_math.h" diff --git a/src/input_common/udp/protocol.h b/src/input_common/udp/protocol.h index d31bbeb89..5b1852d55 100644 --- a/src/input_common/udp/protocol.h +++ b/src/input_common/udp/protocol.h @@ -7,7 +7,6 @@ #include <array> #include <optional> #include <type_traits> -#include <vector> #include <boost/crc.hpp> #include "common/bit_field.h" #include "common/swap.h" diff --git a/src/input_common/udp/udp.cpp b/src/input_common/udp/udp.cpp index 43691ae2c..08aee5f02 100644 --- a/src/input_common/udp/udp.cpp +++ b/src/input_common/udp/udp.cpp @@ -2,7 +2,9 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. -#include "common/logging/log.h" +#include <mutex> +#include <tuple> + #include "common/param_package.h" #include "core/frontend/input.h" #include "core/settings.h" diff --git a/src/input_common/udp/udp.h b/src/input_common/udp/udp.h index 0e56e431b..4f83f0441 100644 --- a/src/input_common/udp/udp.h +++ b/src/input_common/udp/udp.h @@ -5,14 +5,10 @@ #pragma once #include <memory> -#include <unordered_map> -#include "input_common/main.h" -#include "input_common/udp/client.h" namespace InputCommon::CemuhookUDP { -class UDPTouchDevice; -class UDPMotionDevice; +class Client; class State { public: From 7d45fdc1df23b4417d5787a191292c2055d87f81 Mon Sep 17 00:00:00 2001 From: Lioncash <mathew1800@gmail.com> Date: Mon, 3 Feb 2020 09:15:30 -0500 Subject: [PATCH 16/41] input_common/udp: Silence -Wreorder warning for Socket Amends the constructor initializer list to specify the order of its elements in the same order that initialization would occur. --- src/input_common/udp/client.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/input_common/udp/client.cpp b/src/input_common/udp/client.cpp index 887436550..ba46a4370 100644 --- a/src/input_common/udp/client.cpp +++ b/src/input_common/udp/client.cpp @@ -31,10 +31,9 @@ public: explicit Socket(const std::string& host, u16 port, u8 pad_index, u32 client_id, SocketCallback callback) - : client_id(client_id), timer(io_service), - send_endpoint(udp::endpoint(address_v4::from_string(host), port)), - socket(io_service, udp::endpoint(udp::v4(), 0)), pad_index(pad_index), - callback(std::move(callback)) {} + : callback(std::move(callback)), timer(io_service), + socket(io_service, udp::endpoint(udp::v4(), 0)), client_id(client_id), + pad_index(pad_index), send_endpoint(udp::endpoint(address_v4::from_string(host), port)) {} void Stop() { io_service.stop(); From 8a0f8c3a4f74b0e552c0d312360ef1b9b57e85ee Mon Sep 17 00:00:00 2001 From: Lioncash <mathew1800@gmail.com> Date: Mon, 3 Feb 2020 09:19:03 -0500 Subject: [PATCH 17/41] udp/client: Replace deprecated from_string() call with make_address_v4() Future-proofs code if boost is ever updated. --- src/input_common/udp/client.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/input_common/udp/client.cpp b/src/input_common/udp/client.cpp index ba46a4370..5f5d27124 100644 --- a/src/input_common/udp/client.cpp +++ b/src/input_common/udp/client.cpp @@ -14,7 +14,6 @@ #include "input_common/udp/client.h" #include "input_common/udp/protocol.h" -using boost::asio::ip::address_v4; using boost::asio::ip::udp; namespace InputCommon::CemuhookUDP { @@ -33,7 +32,8 @@ public: SocketCallback callback) : callback(std::move(callback)), timer(io_service), socket(io_service, udp::endpoint(udp::v4(), 0)), client_id(client_id), - pad_index(pad_index), send_endpoint(udp::endpoint(address_v4::from_string(host), port)) {} + pad_index(pad_index), + send_endpoint(udp::endpoint(boost::asio::ip::make_address_v4(host), port)) {} void Stop() { io_service.stop(); From fcdc1911076f3f6e478f6c579600ebea57e6c481 Mon Sep 17 00:00:00 2001 From: Lioncash <mathew1800@gmail.com> Date: Mon, 3 Feb 2020 09:21:44 -0500 Subject: [PATCH 18/41] input_common/udp: std::move shared_ptr within Client constructor Gets rid of a trivially avoidable atomic reference count increment and decrement. --- src/input_common/udp/client.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/input_common/udp/client.cpp b/src/input_common/udp/client.cpp index 5f5d27124..b6946d9dc 100644 --- a/src/input_common/udp/client.cpp +++ b/src/input_common/udp/client.cpp @@ -125,7 +125,7 @@ static void SocketLoop(Socket* socket) { Client::Client(std::shared_ptr<DeviceStatus> status, const std::string& host, u16 port, u8 pad_index, u32 client_id) - : status(status) { + : status(std::move(status)) { StartCommunication(host, port, pad_index, client_id); } From 575ab92a760f453111b91646049872feef9862e8 Mon Sep 17 00:00:00 2001 From: Lioncash <mathew1800@gmail.com> Date: Mon, 3 Feb 2020 09:24:03 -0500 Subject: [PATCH 19/41] input_common/udp: std::move SocketCallback instances where applicable std::function is allowed to heap allocate if the size of the captures associated with each lambda exceed a certain threshold. This prevents potentially unnecessary reallocations from occurring. --- src/input_common/udp/client.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/input_common/udp/client.cpp b/src/input_common/udp/client.cpp index b6946d9dc..c9ffe899c 100644 --- a/src/input_common/udp/client.cpp +++ b/src/input_common/udp/client.cpp @@ -207,7 +207,7 @@ void TestCommunication(const std::string& host, u16 port, u8 pad_index, u32 clie Common::Event success_event; SocketCallback callback{[](Response::Version version) {}, [](Response::PortInfo info) {}, [&](Response::PadData data) { success_event.Set(); }}; - Socket socket{host, port, pad_index, client_id, callback}; + Socket socket{host, port, pad_index, client_id, std::move(callback)}; std::thread worker_thread{SocketLoop, &socket}; bool result = success_event.WaitFor(std::chrono::seconds(8)); socket.Stop(); @@ -263,7 +263,7 @@ CalibrationConfigurationJob::CalibrationConfigurationJob( complete_event.Set(); } }}; - Socket socket{host, port, pad_index, client_id, callback}; + Socket socket{host, port, pad_index, client_id, std::move(callback)}; std::thread worker_thread{SocketLoop, &socket}; complete_event.Wait(); socket.Stop(); From 7362fe48acca810786251fe77eca4283fd3e81cb Mon Sep 17 00:00:00 2001 From: Lioncash <mathew1800@gmail.com> Date: Mon, 3 Feb 2020 09:26:50 -0500 Subject: [PATCH 20/41] input_common/udp: Add missing override specifiers Prevents trivial warnings and ensures interfaces are properly maintained between the base class. --- src/input_common/udp/udp.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/input_common/udp/udp.cpp b/src/input_common/udp/udp.cpp index 08aee5f02..1b83de373 100644 --- a/src/input_common/udp/udp.cpp +++ b/src/input_common/udp/udp.cpp @@ -16,7 +16,7 @@ namespace InputCommon::CemuhookUDP { class UDPTouchDevice final : public Input::TouchDevice { public: explicit UDPTouchDevice(std::shared_ptr<DeviceStatus> status_) : status(std::move(status_)) {} - std::tuple<float, float, bool> GetStatus() const { + std::tuple<float, float, bool> GetStatus() const override { std::lock_guard guard(status->update_mutex); return status->touch_status; } @@ -28,7 +28,7 @@ private: class UDPMotionDevice final : public Input::MotionDevice { public: explicit UDPMotionDevice(std::shared_ptr<DeviceStatus> status_) : status(std::move(status_)) {} - std::tuple<Common::Vec3<float>, Common::Vec3<float>> GetStatus() const { + std::tuple<Common::Vec3<float>, Common::Vec3<float>> GetStatus() const override { std::lock_guard guard(status->update_mutex); return status->motion_status; } From b81c9bd73853f281b2f223a4469fa9dfd264d380 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Mon, 10 Feb 2020 07:41:31 +0800 Subject: [PATCH 21/41] fix clang format --- src/core/loader/ncch.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/loader/ncch.cpp b/src/core/loader/ncch.cpp index 21e607ad5..1a966da5e 100644 --- a/src/core/loader/ncch.cpp +++ b/src/core/loader/ncch.cpp @@ -261,8 +261,8 @@ ResultStatus AppLoader_NCCH::DumpRomFS(const std::string& target_path) { ResultStatus AppLoader_NCCH::DumpUpdateRomFS(const std::string& target_path) { u64 program_id; ReadProgramId(program_id); - update_ncch.OpenFile(Service::AM::GetTitleContentPath(Service::FS::MediaType::SDMC, - program_id | UPDATE_MASK)); + update_ncch.OpenFile( + Service::AM::GetTitleContentPath(Service::FS::MediaType::SDMC, program_id | UPDATE_MASK)); return update_ncch.DumpRomFS(target_path); } From 4c2c27046dae9c98a8a45a4130cc4a7a941c8310 Mon Sep 17 00:00:00 2001 From: FearlessTobi <thm.frey@gmail.com> Date: Sun, 9 Feb 2020 23:09:19 +0100 Subject: [PATCH 22/41] Fix compilation --- src/input_common/udp/udp.cpp | 1 - src/input_common/udp/udp.h | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/input_common/udp/udp.cpp b/src/input_common/udp/udp.cpp index 1b83de373..c4d5121b9 100644 --- a/src/input_common/udp/udp.cpp +++ b/src/input_common/udp/udp.cpp @@ -4,7 +4,6 @@ #include <mutex> #include <tuple> - #include "common/param_package.h" #include "core/frontend/input.h" #include "core/settings.h" diff --git a/src/input_common/udp/udp.h b/src/input_common/udp/udp.h index 4f83f0441..3eac8c7ea 100644 --- a/src/input_common/udp/udp.h +++ b/src/input_common/udp/udp.h @@ -5,11 +5,10 @@ #pragma once #include <memory> +#include "input_common/udp/client.h" namespace InputCommon::CemuhookUDP { -class Client; - class State { public: State(); From 4273b967b5a0b07df18ebcc0131eb11dda497bc8 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Tue, 11 Feb 2020 14:03:07 +0800 Subject: [PATCH 23/41] core/file_sys: Do not apply the same mods to DLCs Now you can apply separate mods to DLCs and mods for the original title won't be applied. --- src/core/file_sys/ncch_container.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp index 2e549a894..81525786f 100644 --- a/src/core/file_sys/ncch_container.cpp +++ b/src/core/file_sys/ncch_container.cpp @@ -26,6 +26,14 @@ namespace FileSys { static const int kMaxSections = 8; ///< Maximum number of sections (files) in an ExeFs static const int kBlockSize = 0x200; ///< Size of ExeFS blocks (in bytes) +u64 GetModId(u64 program_id) { + constexpr u64 UPDATE_MASK = 0x0000000e'00000000; + if ((program_id & 0x000000ff'00000000) == UPDATE_MASK) { // Apply the mods to updates + return program_id & ~UPDATE_MASK; + } + return program_id; +} + /** * Get the decompressed size of an LZSS compressed ExeFS file * @param buffer Buffer of compressed file @@ -306,7 +314,7 @@ Loader::ResultStatus NCCHContainer::Load() { const auto mods_path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), - ncch_header.program_id & 0x00040000'FFFFFFFF); + GetModId(ncch_header.program_id)); std::array<std::string, 2> exheader_override_paths{{ mods_path + "exheader.bin", filepath + ".exheader", @@ -530,7 +538,7 @@ Loader::ResultStatus NCCHContainer::ApplyCodePatch(std::vector<u8>& code) const const auto mods_path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), - ncch_header.program_id & 0x00040000'FFFFFFFF); + GetModId(ncch_header.program_id)); const std::array<PatchLocation, 4> patch_paths{{ {mods_path + "exefs/code.ips", Patch::ApplyIpsPatch}, {mods_path + "exefs/code.bps", Patch::ApplyBpsPatch}, @@ -574,7 +582,7 @@ Loader::ResultStatus NCCHContainer::LoadOverrideExeFSSection(const char* name, const auto mods_path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), - ncch_header.program_id & 0x00040000'FFFFFFFF); + GetModId(ncch_header.program_id)); std::array<std::string, 2> override_paths{{ mods_path + "exefs/" + override_name, filepath + ".exefsdir/" + override_name, @@ -640,7 +648,7 @@ Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr<RomFSReader>& romf const auto path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), - ncch_header.program_id & 0x00040000'FFFFFFFF); + GetModId(ncch_header.program_id)); if (use_layered_fs && (FileUtil::Exists(path + "romfs/") || FileUtil::Exists(path + "romfs_ext/"))) { From 4991c0121a1aeb0930f2318a4800325574391e8e Mon Sep 17 00:00:00 2001 From: Vitor K <vitor-kiguchi@hotmail.com> Date: Sat, 15 Feb 2020 10:38:20 -0300 Subject: [PATCH 24/41] =?UTF-8?q?Remove=20duplicate=20code=20from=20the=20?= =?UTF-8?q?migration=20of=20frame=20limit=20to=20gene=E2=80=A6=20(#5091)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/citra_qt/configuration/configure_general.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/citra_qt/configuration/configure_general.cpp b/src/citra_qt/configuration/configure_general.cpp index e4b6be073..47d559e3d 100644 --- a/src/citra_qt/configuration/configure_general.cpp +++ b/src/citra_qt/configuration/configure_general.cpp @@ -25,10 +25,6 @@ ConfigureGeneral::ConfigureGeneral(QWidget* parent) ConfigureGeneral::~ConfigureGeneral() = default; void ConfigureGeneral::SetConfiguration() { - ui->toggle_frame_limit->setChecked(Settings::values.use_frame_limit); - ui->frame_limit->setEnabled(ui->toggle_frame_limit->isChecked()); - ui->frame_limit->setValue(Settings::values.frame_limit); - ui->toggle_check_exit->setChecked(UISettings::values.confirm_before_closing); ui->toggle_background_pause->setChecked(UISettings::values.pause_when_in_background); @@ -57,9 +53,6 @@ void ConfigureGeneral::ResetDefaults() { } void ConfigureGeneral::ApplyConfiguration() { - Settings::values.use_frame_limit = ui->toggle_frame_limit->isChecked(); - Settings::values.frame_limit = ui->frame_limit->value(); - UISettings::values.confirm_before_closing = ui->toggle_check_exit->isChecked(); UISettings::values.pause_when_in_background = ui->toggle_background_pause->isChecked(); From 996f1546b26b6137ff62474051c53406b9932208 Mon Sep 17 00:00:00 2001 From: Marshall Mohror <mohror64@gmail.com> Date: Thu, 20 Feb 2020 13:40:21 -0600 Subject: [PATCH 25/41] core: Remove outdated MSVC workarounds (#5099) * core/hw/gpu: Remove outdated MSVC workarounds * core/hle/service/hid: Remove MSVC workaround --- src/core/hle/service/hid/hid.h | 7 ------- src/core/hw/gpu.cpp | 4 ++-- src/core/hw/gpu.h | 31 ------------------------------- 3 files changed, 2 insertions(+), 40 deletions(-) diff --git a/src/core/hle/service/hid/hid.h b/src/core/hle/service/hid/hid.h index 8d217f835..7e8a8c527 100644 --- a/src/core/hle/service/hid/hid.h +++ b/src/core/hle/service/hid/hid.h @@ -6,9 +6,7 @@ #include <array> #include <atomic> -#ifndef _MSC_VER #include <cstddef> -#endif #include <memory> #include "common/bit_field.h" #include "common/common_funcs.h" @@ -177,10 +175,6 @@ struct GyroscopeCalibrateParam { } x, y, z; }; -// TODO: MSVC does not support using offsetof() on non-static data members even though this -// is technically allowed since C++11. This macro should be enabled once MSVC adds -// support for that. -#ifndef _MSC_VER #define ASSERT_REG_POSITION(field_name, position) \ static_assert(offsetof(SharedMem, field_name) == position * 4, \ "Field " #field_name " has invalid position") @@ -189,7 +183,6 @@ ASSERT_REG_POSITION(pad.index_reset_ticks, 0x0); ASSERT_REG_POSITION(touch.index_reset_ticks, 0x2A); #undef ASSERT_REG_POSITION -#endif // !defined(_MSC_VER) struct DirectionState { bool up; diff --git a/src/core/hw/gpu.cpp b/src/core/hw/gpu.cpp index fcb60e6fe..f641afae4 100644 --- a/src/core/hw/gpu.cpp +++ b/src/core/hw/gpu.cpp @@ -402,8 +402,8 @@ inline void Write(u32 addr, const T data) { switch (index) { // Memory fills are triggered once the fill value is written. - case GPU_REG_INDEX_WORKAROUND(memory_fill_config[0].trigger, 0x00004 + 0x3): - case GPU_REG_INDEX_WORKAROUND(memory_fill_config[1].trigger, 0x00008 + 0x3): { + case GPU_REG_INDEX(memory_fill_config[0].trigger): + case GPU_REG_INDEX(memory_fill_config[1].trigger): { const bool is_second_filler = (index != GPU_REG_INDEX(memory_fill_config[0].trigger)); auto& config = g_regs.memory_fill_config[is_second_filler]; diff --git a/src/core/hw/gpu.h b/src/core/hw/gpu.h index 606ab9504..ac30bc22e 100644 --- a/src/core/hw/gpu.h +++ b/src/core/hw/gpu.h @@ -20,41 +20,15 @@ namespace GPU { constexpr float SCREEN_REFRESH_RATE = 60; // Returns index corresponding to the Regs member labeled by field_name -// TODO: Due to Visual studio bug 209229, offsetof does not return constant expressions -// when used with array elements (e.g. GPU_REG_INDEX(memory_fill_config[0])). -// For details cf. -// https://connect.microsoft.com/VisualStudio/feedback/details/209229/offsetof-does-not-produce-a-constant-expression-for-array-members -// Hopefully, this will be fixed sometime in the future. -// For lack of better alternatives, we currently hardcode the offsets when constant -// expressions are needed via GPU_REG_INDEX_WORKAROUND (on sane compilers, static_asserts -// will then make sure the offsets indeed match the automatically calculated ones). #define GPU_REG_INDEX(field_name) (offsetof(GPU::Regs, field_name) / sizeof(u32)) -#if defined(_MSC_VER) -#define GPU_REG_INDEX_WORKAROUND(field_name, backup_workaround_index) (backup_workaround_index) -#else -// NOTE: Yeah, hacking in a static_assert here just to workaround the lacking MSVC compiler -// really is this annoying. This macro just forwards its first argument to GPU_REG_INDEX -// and then performs a (no-op) cast to std::size_t iff the second argument matches the -// expected field offset. Otherwise, the compiler will fail to compile this code. -#define GPU_REG_INDEX_WORKAROUND(field_name, backup_workaround_index) \ - ((typename std::enable_if<backup_workaround_index == GPU_REG_INDEX(field_name), \ - std::size_t>::type) GPU_REG_INDEX(field_name)) -#endif // MMIO region 0x1EFxxxxx struct Regs { // helper macro to make sure the defined structures are of the expected size. -#if defined(_MSC_VER) -// TODO: MSVC does not support using sizeof() on non-static data members even though this -// is technically allowed since C++11. This macro should be enabled once MSVC adds -// support for that. -#define ASSERT_MEMBER_SIZE(name, size_in_bytes) -#else #define ASSERT_MEMBER_SIZE(name, size_in_bytes) \ static_assert(sizeof(name) == size_in_bytes, \ "Structure size and register block length don't match") -#endif // Components are laid out in reverse byte order, most significant bits first. enum class PixelFormat : u32 { @@ -299,10 +273,6 @@ private: }; static_assert(std::is_standard_layout<Regs>::value, "Structure does not use standard layout"); -// TODO: MSVC does not support using offsetof() on non-static data members even though this -// is technically allowed since C++11. This macro should be enabled once MSVC adds -// support for that. -#ifndef _MSC_VER #define ASSERT_REG_POSITION(field_name, position) \ static_assert(offsetof(Regs, field_name) == position * 4, \ "Field " #field_name " has invalid position") @@ -315,7 +285,6 @@ ASSERT_REG_POSITION(display_transfer_config, 0x00300); ASSERT_REG_POSITION(command_processor_config, 0x00638); #undef ASSERT_REG_POSITION -#endif // !defined(_MSC_VER) // The total number of registers is chosen arbitrarily, but let's make sure it's not some odd value // anyway. From e3dbdcbdff5ef03c41bef1ec3997256703d4621a Mon Sep 17 00:00:00 2001 From: Ben <benediktthomas@gmail.com> Date: Fri, 21 Feb 2020 19:04:04 +0100 Subject: [PATCH 26/41] HTTP_C::Implement Context::MakeRequest (#4754) * HTTP_C::Implement Context::MakeRequest * httplib: Add add_client_cert_ASN1 and set_verify * HTTP_C: Fix request methode strings case in MakeRequest * HTTP_C: clang-format and cleanups * HTTP_C: Add comment about async in BeginRequest and BeginRequestAsync * Update httplib to contain all the changes we need; adapt http_c and web_services to the changes in httplib; addressed minor review comments * Add android-ifaddrs --- externals/CMakeLists.txt | 5 + externals/android-ifaddrs/CMakeLists.txt | 8 + externals/android-ifaddrs/ifaddrs.c | 600 +++ externals/android-ifaddrs/ifaddrs.h | 54 + externals/httplib/README.md | 3 +- externals/httplib/httplib.h | 6186 +++++++++++++++------- src/core/CMakeLists.txt | 12 +- src/core/hle/service/http_c.cpp | 106 +- src/core/hle/service/http_c.h | 19 +- src/web_service/CMakeLists.txt | 3 + src/web_service/web_backend.cpp | 11 +- 11 files changed, 5103 insertions(+), 1904 deletions(-) create mode 100644 externals/android-ifaddrs/CMakeLists.txt create mode 100644 externals/android-ifaddrs/ifaddrs.c create mode 100644 externals/android-ifaddrs/ifaddrs.h diff --git a/externals/CMakeLists.txt b/externals/CMakeLists.txt index 63159a52b..2aeaa5688 100644 --- a/externals/CMakeLists.txt +++ b/externals/CMakeLists.txt @@ -96,9 +96,14 @@ if (ENABLE_WEB_SERVICE) # lurlparser add_subdirectory(lurlparser EXCLUDE_FROM_ALL) + if(ANDROID) + add_subdirectory(android-ifaddrs) + endif() + # httplib add_library(httplib INTERFACE) target_include_directories(httplib INTERFACE ./httplib) + target_compile_options(httplib INTERFACE -DCPPHTTPLIB_OPENSSL_SUPPORT) # cpp-jwt add_library(cpp-jwt INTERFACE) diff --git a/externals/android-ifaddrs/CMakeLists.txt b/externals/android-ifaddrs/CMakeLists.txt new file mode 100644 index 000000000..25f243679 --- /dev/null +++ b/externals/android-ifaddrs/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(ifaddrs + ifaddrs.c + ifaddrs.h +) + +create_target_directory_groups(ifaddrs) + +target_include_directories(ifaddrs INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/externals/android-ifaddrs/ifaddrs.c b/externals/android-ifaddrs/ifaddrs.c new file mode 100644 index 000000000..3a3878a68 --- /dev/null +++ b/externals/android-ifaddrs/ifaddrs.c @@ -0,0 +1,600 @@ +/* +Copyright (c) 2013, Kenneth MacKay +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include "ifaddrs.h" + +#include <string.h> +#include <stdlib.h> +#include <errno.h> +#include <unistd.h> +#include <sys/socket.h> +#include <net/if_arp.h> +#include <netinet/in.h> +#include <linux/netlink.h> +#include <linux/rtnetlink.h> + +typedef struct NetlinkList +{ + struct NetlinkList *m_next; + struct nlmsghdr *m_data; + unsigned int m_size; +} NetlinkList; + +static int netlink_socket(void) +{ + int l_socket = socket(PF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + if(l_socket < 0) + { + return -1; + } + + struct sockaddr_nl l_addr; + memset(&l_addr, 0, sizeof(l_addr)); + l_addr.nl_family = AF_NETLINK; + if(bind(l_socket, (struct sockaddr *)&l_addr, sizeof(l_addr)) < 0) + { + close(l_socket); + return -1; + } + + return l_socket; +} + +static int netlink_send(int p_socket, int p_request) +{ + char l_buffer[NLMSG_ALIGN(sizeof(struct nlmsghdr)) + NLMSG_ALIGN(sizeof(struct rtgenmsg))]; + memset(l_buffer, 0, sizeof(l_buffer)); + struct nlmsghdr *l_hdr = (struct nlmsghdr *)l_buffer; + struct rtgenmsg *l_msg = (struct rtgenmsg *)NLMSG_DATA(l_hdr); + + l_hdr->nlmsg_len = NLMSG_LENGTH(sizeof(*l_msg)); + l_hdr->nlmsg_type = p_request; + l_hdr->nlmsg_flags = NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST; + l_hdr->nlmsg_pid = 0; + l_hdr->nlmsg_seq = p_socket; + l_msg->rtgen_family = AF_UNSPEC; + + struct sockaddr_nl l_addr; + memset(&l_addr, 0, sizeof(l_addr)); + l_addr.nl_family = AF_NETLINK; + return (sendto(p_socket, l_hdr, l_hdr->nlmsg_len, 0, (struct sockaddr *)&l_addr, sizeof(l_addr))); +} + +static int netlink_recv(int p_socket, void *p_buffer, size_t p_len) +{ + struct msghdr l_msg; + struct iovec l_iov = { p_buffer, p_len }; + struct sockaddr_nl l_addr; + int l_result; + + for(;;) + { + l_msg.msg_name = (void *)&l_addr; + l_msg.msg_namelen = sizeof(l_addr); + l_msg.msg_iov = &l_iov; + l_msg.msg_iovlen = 1; + l_msg.msg_control = NULL; + l_msg.msg_controllen = 0; + l_msg.msg_flags = 0; + int l_result = recvmsg(p_socket, &l_msg, 0); + + if(l_result < 0) + { + if(errno == EINTR) + { + continue; + } + return -2; + } + + if(l_msg.msg_flags & MSG_TRUNC) + { // buffer was too small + return -1; + } + return l_result; + } +} + +static struct nlmsghdr *getNetlinkResponse(int p_socket, int *p_size, int *p_done) +{ + size_t l_size = 4096; + void *l_buffer = NULL; + + for(;;) + { + free(l_buffer); + l_buffer = malloc(l_size); + + int l_read = netlink_recv(p_socket, l_buffer, l_size); + *p_size = l_read; + if(l_read == -2) + { + free(l_buffer); + return NULL; + } + if(l_read >= 0) + { + pid_t l_pid = getpid(); + struct nlmsghdr *l_hdr; + for(l_hdr = (struct nlmsghdr *)l_buffer; NLMSG_OK(l_hdr, (unsigned int)l_read); l_hdr = (struct nlmsghdr *)NLMSG_NEXT(l_hdr, l_read)) + { + if((pid_t)l_hdr->nlmsg_pid != l_pid || (int)l_hdr->nlmsg_seq != p_socket) + { + continue; + } + + if(l_hdr->nlmsg_type == NLMSG_DONE) + { + *p_done = 1; + break; + } + + if(l_hdr->nlmsg_type == NLMSG_ERROR) + { + free(l_buffer); + return NULL; + } + } + return l_buffer; + } + + l_size *= 2; + } +} + +static NetlinkList *newListItem(struct nlmsghdr *p_data, unsigned int p_size) +{ + NetlinkList *l_item = malloc(sizeof(NetlinkList)); + l_item->m_next = NULL; + l_item->m_data = p_data; + l_item->m_size = p_size; + return l_item; +} + +static void freeResultList(NetlinkList *p_list) +{ + NetlinkList *l_cur; + while(p_list) + { + l_cur = p_list; + p_list = p_list->m_next; + free(l_cur->m_data); + free(l_cur); + } +} + +static NetlinkList *getResultList(int p_socket, int p_request) +{ + if(netlink_send(p_socket, p_request) < 0) + { + return NULL; + } + + NetlinkList *l_list = NULL; + NetlinkList *l_end = NULL; + int l_size; + int l_done = 0; + while(!l_done) + { + struct nlmsghdr *l_hdr = getNetlinkResponse(p_socket, &l_size, &l_done); + if(!l_hdr) + { // error + freeResultList(l_list); + return NULL; + } + + NetlinkList *l_item = newListItem(l_hdr, l_size); + if(!l_list) + { + l_list = l_item; + } + else + { + l_end->m_next = l_item; + } + l_end = l_item; + } + return l_list; +} + +static size_t maxSize(size_t a, size_t b) +{ + return (a > b ? a : b); +} + +static size_t calcAddrLen(sa_family_t p_family, int p_dataSize) +{ + switch(p_family) + { + case AF_INET: + return sizeof(struct sockaddr_in); + case AF_INET6: + return sizeof(struct sockaddr_in6); + case AF_PACKET: + return maxSize(sizeof(struct sockaddr_ll), offsetof(struct sockaddr_ll, sll_addr) + p_dataSize); + default: + return maxSize(sizeof(struct sockaddr), offsetof(struct sockaddr, sa_data) + p_dataSize); + } +} + +static void makeSockaddr(sa_family_t p_family, struct sockaddr *p_dest, void *p_data, size_t p_size) +{ + switch(p_family) + { + case AF_INET: + memcpy(&((struct sockaddr_in*)p_dest)->sin_addr, p_data, p_size); + break; + case AF_INET6: + memcpy(&((struct sockaddr_in6*)p_dest)->sin6_addr, p_data, p_size); + break; + case AF_PACKET: + memcpy(((struct sockaddr_ll*)p_dest)->sll_addr, p_data, p_size); + ((struct sockaddr_ll*)p_dest)->sll_halen = p_size; + break; + default: + memcpy(p_dest->sa_data, p_data, p_size); + break; + } + p_dest->sa_family = p_family; +} + +static void addToEnd(struct ifaddrs **p_resultList, struct ifaddrs *p_entry) +{ + if(!*p_resultList) + { + *p_resultList = p_entry; + } + else + { + struct ifaddrs *l_cur = *p_resultList; + while(l_cur->ifa_next) + { + l_cur = l_cur->ifa_next; + } + l_cur->ifa_next = p_entry; + } +} + +static void interpretLink(struct nlmsghdr *p_hdr, struct ifaddrs **p_links, struct ifaddrs **p_resultList) +{ + struct ifinfomsg *l_info = (struct ifinfomsg *)NLMSG_DATA(p_hdr); + + size_t l_nameSize = 0; + size_t l_addrSize = 0; + size_t l_dataSize = 0; + + size_t l_rtaSize = NLMSG_PAYLOAD(p_hdr, sizeof(struct ifinfomsg)); + struct rtattr *l_rta; + for(l_rta = (struct rtattr *)(((char *)l_info) + NLMSG_ALIGN(sizeof(struct ifinfomsg))); RTA_OK(l_rta, l_rtaSize); l_rta = RTA_NEXT(l_rta, l_rtaSize)) + { + void *l_rtaData = RTA_DATA(l_rta); + size_t l_rtaDataSize = RTA_PAYLOAD(l_rta); + switch(l_rta->rta_type) + { + case IFLA_ADDRESS: + case IFLA_BROADCAST: + l_addrSize += NLMSG_ALIGN(calcAddrLen(AF_PACKET, l_rtaDataSize)); + break; + case IFLA_IFNAME: + l_nameSize += NLMSG_ALIGN(l_rtaSize + 1); + break; + case IFLA_STATS: + l_dataSize += NLMSG_ALIGN(l_rtaSize); + break; + default: + break; + } + } + + struct ifaddrs *l_entry = malloc(sizeof(struct ifaddrs) + l_nameSize + l_addrSize + l_dataSize); + memset(l_entry, 0, sizeof(struct ifaddrs)); + l_entry->ifa_name = ""; + + char *l_name = ((char *)l_entry) + sizeof(struct ifaddrs); + char *l_addr = l_name + l_nameSize; + char *l_data = l_addr + l_addrSize; + + l_entry->ifa_flags = l_info->ifi_flags; + + l_rtaSize = NLMSG_PAYLOAD(p_hdr, sizeof(struct ifinfomsg)); + for(l_rta = (struct rtattr *)(((char *)l_info) + NLMSG_ALIGN(sizeof(struct ifinfomsg))); RTA_OK(l_rta, l_rtaSize); l_rta = RTA_NEXT(l_rta, l_rtaSize)) + { + void *l_rtaData = RTA_DATA(l_rta); + size_t l_rtaDataSize = RTA_PAYLOAD(l_rta); + switch(l_rta->rta_type) + { + case IFLA_ADDRESS: + case IFLA_BROADCAST: + { + size_t l_addrLen = calcAddrLen(AF_PACKET, l_rtaDataSize); + makeSockaddr(AF_PACKET, (struct sockaddr *)l_addr, l_rtaData, l_rtaDataSize); + ((struct sockaddr_ll *)l_addr)->sll_ifindex = l_info->ifi_index; + ((struct sockaddr_ll *)l_addr)->sll_hatype = l_info->ifi_type; + if(l_rta->rta_type == IFLA_ADDRESS) + { + l_entry->ifa_addr = (struct sockaddr *)l_addr; + } + else + { + l_entry->ifa_broadaddr = (struct sockaddr *)l_addr; + } + l_addr += NLMSG_ALIGN(l_addrLen); + break; + } + case IFLA_IFNAME: + strncpy(l_name, l_rtaData, l_rtaDataSize); + l_name[l_rtaDataSize] = '\0'; + l_entry->ifa_name = l_name; + break; + case IFLA_STATS: + memcpy(l_data, l_rtaData, l_rtaDataSize); + l_entry->ifa_data = l_data; + break; + default: + break; + } + } + + addToEnd(p_resultList, l_entry); + p_links[l_info->ifi_index - 1] = l_entry; +} + +static void interpretAddr(struct nlmsghdr *p_hdr, struct ifaddrs **p_links, struct ifaddrs **p_resultList) +{ + struct ifaddrmsg *l_info = (struct ifaddrmsg *)NLMSG_DATA(p_hdr); + + size_t l_nameSize = 0; + size_t l_addrSize = 0; + + int l_addedNetmask = 0; + + size_t l_rtaSize = NLMSG_PAYLOAD(p_hdr, sizeof(struct ifaddrmsg)); + struct rtattr *l_rta; + for(l_rta = (struct rtattr *)(((char *)l_info) + NLMSG_ALIGN(sizeof(struct ifaddrmsg))); RTA_OK(l_rta, l_rtaSize); l_rta = RTA_NEXT(l_rta, l_rtaSize)) + { + void *l_rtaData = RTA_DATA(l_rta); + size_t l_rtaDataSize = RTA_PAYLOAD(l_rta); + if(l_info->ifa_family == AF_PACKET) + { + continue; + } + + switch(l_rta->rta_type) + { + case IFA_ADDRESS: + case IFA_LOCAL: + if((l_info->ifa_family == AF_INET || l_info->ifa_family == AF_INET6) && !l_addedNetmask) + { // make room for netmask + l_addrSize += NLMSG_ALIGN(calcAddrLen(l_info->ifa_family, l_rtaDataSize)); + l_addedNetmask = 1; + } + case IFA_BROADCAST: + l_addrSize += NLMSG_ALIGN(calcAddrLen(l_info->ifa_family, l_rtaDataSize)); + break; + case IFA_LABEL: + l_nameSize += NLMSG_ALIGN(l_rtaSize + 1); + break; + default: + break; + } + } + + struct ifaddrs *l_entry = malloc(sizeof(struct ifaddrs) + l_nameSize + l_addrSize); + memset(l_entry, 0, sizeof(struct ifaddrs)); + l_entry->ifa_name = p_links[l_info->ifa_index - 1]->ifa_name; + + char *l_name = ((char *)l_entry) + sizeof(struct ifaddrs); + char *l_addr = l_name + l_nameSize; + + l_entry->ifa_flags = l_info->ifa_flags | p_links[l_info->ifa_index - 1]->ifa_flags; + + l_rtaSize = NLMSG_PAYLOAD(p_hdr, sizeof(struct ifaddrmsg)); + for(l_rta = (struct rtattr *)(((char *)l_info) + NLMSG_ALIGN(sizeof(struct ifaddrmsg))); RTA_OK(l_rta, l_rtaSize); l_rta = RTA_NEXT(l_rta, l_rtaSize)) + { + void *l_rtaData = RTA_DATA(l_rta); + size_t l_rtaDataSize = RTA_PAYLOAD(l_rta); + switch(l_rta->rta_type) + { + case IFA_ADDRESS: + case IFA_BROADCAST: + case IFA_LOCAL: + { + size_t l_addrLen = calcAddrLen(l_info->ifa_family, l_rtaDataSize); + makeSockaddr(l_info->ifa_family, (struct sockaddr *)l_addr, l_rtaData, l_rtaDataSize); + if(l_info->ifa_family == AF_INET6) + { + if(IN6_IS_ADDR_LINKLOCAL((struct in6_addr *)l_rtaData) || IN6_IS_ADDR_MC_LINKLOCAL((struct in6_addr *)l_rtaData)) + { + ((struct sockaddr_in6 *)l_addr)->sin6_scope_id = l_info->ifa_index; + } + } + + if(l_rta->rta_type == IFA_ADDRESS) + { // apparently in a point-to-point network IFA_ADDRESS contains the dest address and IFA_LOCAL contains the local address + if(l_entry->ifa_addr) + { + l_entry->ifa_dstaddr = (struct sockaddr *)l_addr; + } + else + { + l_entry->ifa_addr = (struct sockaddr *)l_addr; + } + } + else if(l_rta->rta_type == IFA_LOCAL) + { + if(l_entry->ifa_addr) + { + l_entry->ifa_dstaddr = l_entry->ifa_addr; + } + l_entry->ifa_addr = (struct sockaddr *)l_addr; + } + else + { + l_entry->ifa_broadaddr = (struct sockaddr *)l_addr; + } + l_addr += NLMSG_ALIGN(l_addrLen); + break; + } + case IFA_LABEL: + strncpy(l_name, l_rtaData, l_rtaDataSize); + l_name[l_rtaDataSize] = '\0'; + l_entry->ifa_name = l_name; + break; + default: + break; + } + } + + if(l_entry->ifa_addr && (l_entry->ifa_addr->sa_family == AF_INET || l_entry->ifa_addr->sa_family == AF_INET6)) + { + unsigned l_maxPrefix = (l_entry->ifa_addr->sa_family == AF_INET ? 32 : 128); + unsigned l_prefix = (l_info->ifa_prefixlen > l_maxPrefix ? l_maxPrefix : l_info->ifa_prefixlen); + char l_mask[16] = {0}; + unsigned i; + for(i=0; i<(l_prefix/8); ++i) + { + l_mask[i] = 0xff; + } + l_mask[i] = 0xff << (8 - (l_prefix % 8)); + + makeSockaddr(l_entry->ifa_addr->sa_family, (struct sockaddr *)l_addr, l_mask, l_maxPrefix / 8); + l_entry->ifa_netmask = (struct sockaddr *)l_addr; + } + + addToEnd(p_resultList, l_entry); +} + +static void interpret(int p_socket, NetlinkList *p_netlinkList, struct ifaddrs **p_links, struct ifaddrs **p_resultList) +{ + pid_t l_pid = getpid(); + for(; p_netlinkList; p_netlinkList = p_netlinkList->m_next) + { + unsigned int l_nlsize = p_netlinkList->m_size; + struct nlmsghdr *l_hdr; + for(l_hdr = p_netlinkList->m_data; NLMSG_OK(l_hdr, l_nlsize); l_hdr = NLMSG_NEXT(l_hdr, l_nlsize)) + { + if((pid_t)l_hdr->nlmsg_pid != l_pid || (int)l_hdr->nlmsg_seq != p_socket) + { + continue; + } + + if(l_hdr->nlmsg_type == NLMSG_DONE) + { + break; + } + + if(l_hdr->nlmsg_type == RTM_NEWLINK) + { + interpretLink(l_hdr, p_links, p_resultList); + } + else if(l_hdr->nlmsg_type == RTM_NEWADDR) + { + interpretAddr(l_hdr, p_links, p_resultList); + } + } + } +} + +static unsigned countLinks(int p_socket, NetlinkList *p_netlinkList) +{ + unsigned l_links = 0; + pid_t l_pid = getpid(); + for(; p_netlinkList; p_netlinkList = p_netlinkList->m_next) + { + unsigned int l_nlsize = p_netlinkList->m_size; + struct nlmsghdr *l_hdr; + for(l_hdr = p_netlinkList->m_data; NLMSG_OK(l_hdr, l_nlsize); l_hdr = NLMSG_NEXT(l_hdr, l_nlsize)) + { + if((pid_t)l_hdr->nlmsg_pid != l_pid || (int)l_hdr->nlmsg_seq != p_socket) + { + continue; + } + + if(l_hdr->nlmsg_type == NLMSG_DONE) + { + break; + } + + if(l_hdr->nlmsg_type == RTM_NEWLINK) + { + ++l_links; + } + } + } + + return l_links; +} + +int getifaddrs(struct ifaddrs **ifap) +{ + if(!ifap) + { + return -1; + } + *ifap = NULL; + + int l_socket = netlink_socket(); + if(l_socket < 0) + { + return -1; + } + + NetlinkList *l_linkResults = getResultList(l_socket, RTM_GETLINK); + if(!l_linkResults) + { + close(l_socket); + return -1; + } + + NetlinkList *l_addrResults = getResultList(l_socket, RTM_GETADDR); + if(!l_addrResults) + { + close(l_socket); + freeResultList(l_linkResults); + return -1; + } + + unsigned l_numLinks = countLinks(l_socket, l_linkResults) + countLinks(l_socket, l_addrResults); + struct ifaddrs *l_links[l_numLinks]; + memset(l_links, 0, l_numLinks * sizeof(struct ifaddrs *)); + + interpret(l_socket, l_linkResults, l_links, ifap); + interpret(l_socket, l_addrResults, l_links, ifap); + + freeResultList(l_linkResults); + freeResultList(l_addrResults); + close(l_socket); + return 0; +} + +void freeifaddrs(struct ifaddrs *ifa) +{ + struct ifaddrs *l_cur; + while(ifa) + { + l_cur = ifa; + ifa = ifa->ifa_next; + free(l_cur); + } +} \ No newline at end of file diff --git a/externals/android-ifaddrs/ifaddrs.h b/externals/android-ifaddrs/ifaddrs.h new file mode 100644 index 000000000..42d1b37e8 --- /dev/null +++ b/externals/android-ifaddrs/ifaddrs.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 1995, 1999 + * Berkeley Software Design, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * THIS SOFTWARE IS PROVIDED BY Berkeley Software Design, Inc. ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design, Inc. BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * + * BSDI ifaddrs.h,v 2.5 2000/02/23 14:51:59 dab Exp + */ + +#ifndef _IFADDRS_H_ +#define _IFADDRS_H_ + +struct ifaddrs { + struct ifaddrs *ifa_next; + char *ifa_name; + unsigned int ifa_flags; + struct sockaddr *ifa_addr; + struct sockaddr *ifa_netmask; + struct sockaddr *ifa_dstaddr; + void *ifa_data; +}; + +/* + * This may have been defined in <net/if.h>. Note that if <net/if.h> is + * to be included it must be included before this header file. + */ +#ifndef ifa_broadaddr +#define ifa_broadaddr ifa_dstaddr /* broadcast address interface */ +#endif + +#include <sys/cdefs.h> + +__BEGIN_DECLS +extern int getifaddrs(struct ifaddrs **ifap); +extern void freeifaddrs(struct ifaddrs *ifa); +__END_DECLS + +#endif \ No newline at end of file diff --git a/externals/httplib/README.md b/externals/httplib/README.md index 0e26522b5..33cbaf183 100644 --- a/externals/httplib/README.md +++ b/externals/httplib/README.md @@ -1,4 +1,4 @@ -From https://github.com/yhirose/cpp-httplib/commit/d9479bc0b12e8a1e8bce2d34da4feeef488581f3 +From https://github.com/yhirose/cpp-httplib/commit/b251668522dd459d2c6a75c10390a11b640be708 MIT License @@ -13,3 +13,4 @@ It's extremely easy to setup. Just include httplib.h file in your code! Inspired by Sinatra and express. © 2017 Yuji Hirose + diff --git a/externals/httplib/httplib.h b/externals/httplib/httplib.h index dd9afe693..ab087b184 100644 --- a/externals/httplib/httplib.h +++ b/externals/httplib/httplib.h @@ -1,73 +1,177 @@ // // httplib.h // -// Copyright (c) 2017 Yuji Hirose. All rights reserved. +// Copyright (c) 2020 Yuji Hirose. All rights reserved. // MIT License // -#ifndef _CPPHTTPLIB_HTTPLIB_H_ -#define _CPPHTTPLIB_HTTPLIB_H_ +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (std::numeric_limits<size_t>::max)() +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +// if hardware_concurrency() outputs 0 we still wants to use threads for this. +// -1 because we have one thread already in the main function. +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + (std::thread::hardware_concurrency() \ + ? std::thread::hardware_concurrency() - 1 \ + : 2) +#endif + +/* + * Headers + */ #ifdef _WIN32 #ifndef _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS -#endif +#endif //_CRT_SECURE_NO_WARNINGS + #ifndef _CRT_NONSTDC_NO_DEPRECATE #define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = int; #endif -#if defined(_MSC_VER) && _MSC_VER < 1900 +#if _MSC_VER < 1900 #define snprintf _snprintf_s #endif +#endif // _MSC_VER #ifndef S_ISREG -#define S_ISREG(m) (((m)&S_IFREG)==S_IFREG) -#endif +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#endif // S_ISREG + #ifndef S_ISDIR -#define S_ISDIR(m) (((m)&S_IFDIR)==S_IFDIR) -#endif +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX #include <io.h> #include <winsock2.h> #include <ws2tcpip.h> -#undef min -#undef max +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +#ifdef _MSC_VER +#pragma comment(lib, "ws2_32.lib") +#endif #ifndef strcasecmp #define strcasecmp _stricmp +#endif // strcasecmp + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) #endif -typedef SOCKET socket_t; -#else -#include <pthread.h> -#include <unistd.h> -#include <netdb.h> -#include <cstring> -#include <netinet/in.h> +#else // not _WIN32 + #include <arpa/inet.h> -#include <signal.h> -#include <sys/socket.h> -#include <sys/select.h> - -typedef int socket_t; -#define INVALID_SOCKET (-1) +#include <cstring> +#include <ifaddrs.h> +#include <netdb.h> +#include <netinet/in.h> +#ifdef CPPHTTPLIB_USE_POLL +#include <poll.h> #endif +#include <csignal> +#include <pthread.h> +#include <sys/select.h> +#include <sys/socket.h> +#include <unistd.h> +using socket_t = int; +#define INVALID_SOCKET (-1) +#endif //_WIN32 + +#include <array> +#include <atomic> +#include <cassert> +#include <condition_variable> +#include <errno.h> +#include <fcntl.h> #include <fstream> #include <functional> +#include <list> #include <map> #include <memory> #include <mutex> +#include <random> #include <regex> #include <string> -#include <thread> #include <sys/stat.h> -#include <fcntl.h> -#include <assert.h> +#include <thread> #ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#include <openssl/err.h> +#include <openssl/md5.h> #include <openssl/ssl.h> +#include <openssl/x509v3.h> + +#include <iomanip> +#include <sstream> + +// #if OPENSSL_VERSION_NUMBER < 0x1010100fL +// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported +// #endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#include <openssl/crypto.h> +inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { + return M_ASN1_STRING_data(asn1); +} +#endif #endif #ifdef CPPHTTPLIB_ZLIB_SUPPORT @@ -75,1198 +179,2321 @@ typedef int socket_t; #endif /* - * Configuration + * Declaration */ -#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 -#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 - -namespace httplib -{ +namespace httplib { namespace detail { struct ci { - bool operator() (const std::string & s1, const std::string & s2) const { - return std::lexicographical_compare( - s1.begin(), s1.end(), - s2.begin(), s2.end(), - [](char c1, char c2) { - return ::tolower(c1) < ::tolower(c2); - }); - } + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } }; } // namespace detail -enum class HttpVersion { v1_0 = 0, v1_1 }; +using Headers = std::multimap<std::string, std::string, detail::ci>; -typedef std::multimap<std::string, std::string, detail::ci> Headers; +using Params = std::multimap<std::string, std::string>; +using Match = std::smatch; -template<typename uint64_t, typename... Args> -std::pair<std::string, std::string> make_range_header(uint64_t value, Args... args); +using Progress = std::function<bool(uint64_t current, uint64_t total)>; -typedef std::multimap<std::string, std::string> Params; -typedef std::smatch Match; -typedef std::function<void (uint64_t current, uint64_t total)> Progress; +struct Response; +using ResponseHandler = std::function<bool(const Response &response)>; -struct MultipartFile { - std::string filename; - std::string content_type; - size_t offset = 0; - size_t length = 0; +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; }; -typedef std::multimap<std::string, MultipartFile> MultipartFiles; +using MultipartFormDataItems = std::vector<MultipartFormData>; +using MultipartFormDataMap = std::multimap<std::string, MultipartFormData>; + +class DataSink { +public: + DataSink() = default; + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function<void(const char *data, size_t data_len)> write; + std::function<void()> done; + std::function<bool()> is_writable; +}; + +using ContentProvider = + std::function<void(size_t offset, size_t length, DataSink &sink)>; + +using ContentReceiver = + std::function<bool(const char *data, size_t data_length)>; + +using MultipartContentHeader = + std::function<bool(const MultipartFormData &file)>; + +class ContentReader { +public: + using Reader = std::function<bool(ContentReceiver receiver)>; + using MultipartReader = std::function<bool(MultipartContentHeader header, + ContentReceiver receiver)>; + + ContentReader(Reader reader, MultipartReader muitlpart_reader) + : reader_(reader), muitlpart_reader_(muitlpart_reader) {} + + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return muitlpart_reader_(header, receiver); + } + + bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + + Reader reader_; + MultipartReader muitlpart_reader_; +}; + +using Range = std::pair<ssize_t, ssize_t>; +using Ranges = std::vector<Range>; struct Request { - std::string version; - std::string method; - std::string target; - std::string path; - Headers headers; - std::string body; - Params params; - MultipartFiles files; - Match matches; + std::string method; + std::string path; + Headers headers; + std::string body; - Progress progress; + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; - bool has_header(const char* key) const; - std::string get_header_value(const char* key) const; - void set_header(const char* key, const char* val); + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + Progress progress; - bool has_param(const char* key) const; - std::string get_param_value(const char* key) const; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl; +#endif - bool has_file(const char* key) const; - MultipartFile get_file_value(const char* key) const; + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; + + bool is_multipart_form_data() const; + + bool has_file(const char *key) const; + MultipartFormData get_file_value(const char *key) const; + + // private members... + size_t content_length; + ContentProvider content_provider; }; struct Response { - std::string version; - int status; - Headers headers; - std::string body; + std::string version; + int status = -1; + Headers headers; + std::string body; - bool has_header(const char* key) const; - std::string get_header_value(const char* key) const; - void set_header(const char* key, const char* val); + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); - void set_redirect(const char* uri); - void set_content(const char* s, size_t n, const char* content_type); - void set_content(const std::string& s, const char* content_type); + void set_redirect(const char *url); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(const std::string &s, const char *content_type); - Response() : status(-1) {} + void set_content_provider( + size_t length, + std::function<void(size_t offset, size_t length, DataSink &sink)> + provider, + std::function<void()> resource_releaser = [] {}); + + void set_chunked_content_provider( + std::function<void(size_t offset, DataSink &sink)> provider, + std::function<void()> resource_releaser = [] {}); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser) { + content_provider_resource_releaser(); + } + } + + // private members... + size_t content_length = 0; + ContentProvider content_provider; + std::function<void()> content_provider_resource_releaser; }; class Stream { public: - virtual ~Stream() {} - virtual int read(char* ptr, size_t size) = 0; - virtual int write(const char* ptr, size_t size1) = 0; - virtual int write(const char* ptr) = 0; - virtual std::string get_remote_addr() = 0; + virtual ~Stream() = default; - template <typename ...Args> - void write_format(const char* fmt, const Args& ...args); + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual int read(char *ptr, size_t size) = 0; + virtual int write(const char *ptr, size_t size) = 0; + virtual std::string get_remote_addr() const = 0; + + template <typename... Args> + int write_format(const char *fmt, const Args &... args); + int write(const char *ptr); + int write(const std::string &s); }; -class SocketStream : public Stream { +class TaskQueue { public: - SocketStream(socket_t sock); - virtual ~SocketStream(); + TaskQueue() = default; + virtual ~TaskQueue() = default; + virtual void enqueue(std::function<void()> fn) = 0; + virtual void shutdown() = 0; +}; - virtual int read(char* ptr, size_t size); - virtual int write(const char* ptr, size_t size); - virtual int write(const char* ptr); - virtual std::string get_remote_addr(); +class ThreadPool : public TaskQueue { +public: + explicit ThreadPool(size_t n) : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + void enqueue(std::function<void()> fn) override { + std::unique_lock<std::mutex> lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock<std::mutex> lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } private: - socket_t sock_; + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function<void()> fn; + { + std::unique_lock<std::mutex> lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast<bool>(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector<std::thread> threads_; + std::list<std::function<void()>> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; }; +using Logger = std::function<void(const Request &, const Response &)>; + class Server { public: - typedef std::function<void (const Request&, Response&)> Handler; - typedef std::function<void (const Request&, const Response&)> Logger; + using Handler = std::function<void(const Request &, Response &)>; + using HandlerWithContentReader = std::function<void( + const Request &, Response &, const ContentReader &content_reader)>; - Server(); + Server(); - virtual ~Server(); + virtual ~Server(); - virtual bool is_valid() const; + virtual bool is_valid() const; - Server& Get(const char* pattern, Handler handler); - Server& Post(const char* pattern, Handler handler); + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); + Server &Post(const char *pattern, HandlerWithContentReader handler); + Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); + Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); + Server &Delete(const char *pattern, Handler handler); + Server &Options(const char *pattern, Handler handler); - Server& Put(const char* pattern, Handler handler); - Server& Delete(const char* pattern, Handler handler); - Server& Options(const char* pattern, Handler handler); + [[deprecated]] bool set_base_dir(const char *dir, + const char *mount_point = nullptr); + bool set_mount_point(const char *mount_point, const char *dir); + bool remove_mount_point(const char *mount_point); + void set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime); + void set_file_request_handler(Handler handler); - bool set_base_dir(const char* path); + void set_error_handler(Handler handler); + void set_logger(Logger logger); - void set_error_handler(Handler handler); - void set_logger(Logger logger); + void set_keep_alive_max_count(size_t count); + void set_read_timeout(time_t sec, time_t usec); + void set_payload_max_length(size_t length); - void set_keep_alive_max_count(size_t count); + bool bind_to_port(const char *host, int port, int socket_flags = 0); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); - int bind_to_any_port(const char* host, int socket_flags = 0); - bool listen_after_bind(); + bool listen(const char *host, int port, int socket_flags = 0); - bool listen(const char* host, int port, int socket_flags = 0); + bool is_running() const; + void stop(); - bool is_running() const; - void stop(); + std::function<TaskQueue *(void)> new_task_queue; protected: - bool process_request(Stream& strm, bool last_connection, bool& connection_close); + bool process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function<void(Request &)> &setup_request); - size_t keep_alive_max_count_; + size_t keep_alive_max_count_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + size_t payload_max_length_; private: - typedef std::vector<std::pair<std::regex, Handler>> Handlers; + using Handlers = std::vector<std::pair<std::regex, Handler>>; + using HandlersForContentReader = + std::vector<std::pair<std::regex, HandlerWithContentReader>>; - socket_t create_server_socket(const char* host, int port, int socket_flags) const; - int bind_internal(const char* host, int port, int socket_flags); - bool listen_internal(); + socket_t create_server_socket(const char *host, int port, + int socket_flags) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); - bool routing(Request& req, Response& res); - bool handle_file_request(Request& req, Response& res); - bool dispatch_request(Request& req, Response& res, Handlers& handlers); + bool routing(Request &req, Response &res, Stream &strm, bool last_connection); + bool handle_file_request(Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, Handlers &handlers); + bool dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + HandlersForContentReader &handlers); - bool parse_request_line(const char* s, Request& req); - void write_response(Stream& strm, bool last_connection, const Request& req, Response& res); + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool last_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, bool last_connection, Request &req, + Response &res); + bool read_content_with_content_receiver( + Stream &strm, bool last_connection, Request &req, Response &res, + ContentReceiver receiver, MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, bool last_connection, Request &req, + Response &res, ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver); - virtual bool read_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); - bool is_running_; - socket_t svr_sock_; - std::string base_dir_; - Handlers get_handlers_; - Handlers post_handlers_; - Handlers put_handlers_; - Handlers delete_handlers_; - Handlers options_handlers_; - Handler error_handler_; - Logger logger_; - - // TODO: Use thread pool... - std::mutex running_threads_mutex_; - int running_threads_; + std::atomic<bool> is_running_; + std::atomic<socket_t> svr_sock_; + std::vector<std::pair<std::string, std::string>> base_dirs_; + std::map<std::string, std::string> file_extension_and_mimetype_map_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; }; class Client { public: - Client( - const char* host, - int port = 80, - size_t timeout_sec = 300); + explicit Client(const std::string &host, int port = 80, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); - virtual ~Client(); + virtual ~Client(); - virtual bool is_valid() const; + virtual bool is_valid() const; - std::shared_ptr<Response> Get(const char* path, Progress progress = nullptr); - std::shared_ptr<Response> Get(const char* path, const Headers& headers, Progress progress = nullptr); + std::shared_ptr<Response> Get(const char *path); - std::shared_ptr<Response> Head(const char* path); - std::shared_ptr<Response> Head(const char* path, const Headers& headers); + std::shared_ptr<Response> Get(const char *path, const Headers &headers); - std::shared_ptr<Response> Post(const char* path, const std::string& body, const char* content_type); - std::shared_ptr<Response> Post(const char* path, const Headers& headers, const std::string& body, const char* content_type); + std::shared_ptr<Response> Get(const char *path, Progress progress); - std::shared_ptr<Response> Post(const char* path, const Params& params); - std::shared_ptr<Response> Post(const char* path, const Headers& headers, const Params& params); + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + Progress progress); - std::shared_ptr<Response> Put(const char* path, const std::string& body, const char* content_type); - std::shared_ptr<Response> Put(const char* path, const Headers& headers, const std::string& body, const char* content_type); + std::shared_ptr<Response> Get(const char *path, + ContentReceiver content_receiver); - std::shared_ptr<Response> Delete(const char* path); - std::shared_ptr<Response> Delete(const char* path, const Headers& headers); + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); - std::shared_ptr<Response> Options(const char* path); - std::shared_ptr<Response> Options(const char* path, const Headers& headers); + std::shared_ptr<Response> + Get(const char *path, ContentReceiver content_receiver, Progress progress); - bool send(Request& req, Response& res); + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress); -protected: - bool process_request(Stream& strm, Request& req, Response& res, bool& connection_close); + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); - const std::string host_; - const int port_; - size_t timeout_sec_; - const std::string host_and_port_; + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); -private: - socket_t create_client_socket() const; - bool read_response_line(Stream& strm, Response& res); - void write_request(Stream& strm, Request& req); + std::shared_ptr<Response> Head(const char *path); - virtual bool read_and_close_socket(socket_t sock, Request& req, Response& res); -}; + std::shared_ptr<Response> Head(const char *path, const Headers &headers); + + std::shared_ptr<Response> Post(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Post(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Post(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr<Response> Post(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr<Response> Post(const char *path, const Params ¶ms); + + std::shared_ptr<Response> Post(const char *path, const Headers &headers, + const Params ¶ms); + + std::shared_ptr<Response> Post(const char *path, + const MultipartFormDataItems &items); + + std::shared_ptr<Response> Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + + std::shared_ptr<Response> Put(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Put(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Put(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr<Response> Put(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr<Response> Put(const char *path, const Params ¶ms); + + std::shared_ptr<Response> Put(const char *path, const Headers &headers, + const Params ¶ms); + + std::shared_ptr<Response> Patch(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Patch(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Patch(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr<Response> Patch(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr<Response> Delete(const char *path); + + std::shared_ptr<Response> Delete(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Delete(const char *path, const Headers &headers); + + std::shared_ptr<Response> Delete(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr<Response> Options(const char *path); + + std::shared_ptr<Response> Options(const char *path, const Headers &headers); + + bool send(const Request &req, Response &res); + + bool send(const std::vector<Request> &requests, + std::vector<Response> &responses); + + void set_timeout_sec(time_t timeout_sec); + + void set_read_timeout(time_t sec, time_t usec); + + void set_keep_alive_max_count(size_t count); + + void set_basic_auth(const char *username, const char *password); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -class SSLSocketStream : public Stream { -public: - SSLSocketStream(socket_t sock, SSL* ssl); - virtual ~SSLSocketStream(); + void set_digest_auth(const char *username, const char *password); +#endif - virtual int read(char* ptr, size_t size); - virtual int write(const char* ptr, size_t size); - virtual int write(const char* ptr); - virtual std::string get_remote_addr(); + void set_follow_location(bool on); + + void set_compress(bool on); + + void set_interface(const char *intf); + + void set_proxy(const char *host, int port); + + void set_proxy_basic_auth(const char *username, const char *password); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const char *username, const char *password); +#endif + + void set_logger(Logger logger); + +protected: + bool process_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); + + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t timeout_sec_ = 300; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + + std::string basic_auth_username_; + std::string basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool follow_location_ = false; + + bool compress_ = false; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + + Logger logger_; + + void copy_settings(const Client &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + timeout_sec_ = rhs.timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + keep_alive_max_count_ = rhs.keep_alive_max_count_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + follow_location_ = rhs.follow_location_; + compress_ = rhs.compress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif + logger_ = rhs.logger_; + } private: - socket_t sock_; - SSL* ssl_; + socket_t create_client_socket() const; + bool read_response_line(Stream &strm, Response &res); + bool write_request(Stream &strm, const Request &req, bool last_connection); + bool redirect(const Request &req, Response &res); + bool handle_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool connect(socket_t sock, Response &res, bool &error); +#endif + + std::shared_ptr<Response> send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type); + + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function<bool(Stream &strm, bool last_connection, + bool &connection_close)> + callback); + + virtual bool is_ssl() const; }; +inline void Get(std::vector<Request> &requests, const char *path, + const Headers &headers) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + requests.emplace_back(std::move(req)); +} + +inline void Get(std::vector<Request> &requests, const char *path) { + Get(requests, path, Headers()); +} + +inline void Post(std::vector<Request> &requests, const char *path, + const Headers &headers, const std::string &body, + const char *content_type) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + req.headers.emplace("Content-Type", content_type); + req.body = body; + requests.emplace_back(std::move(req)); +} + +inline void Post(std::vector<Request> &requests, const char *path, + const std::string &body, const char *content_type) { + Post(requests, path, Headers(), body, content_type); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLServer : public Server { public: - SSLServer( - const char* cert_path, const char* private_key_path); + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); - virtual ~SSLServer(); + virtual ~SSLServer(); - virtual bool is_valid() const; + virtual bool is_valid() const; private: - virtual bool read_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); - SSL_CTX* ctx_; - std::mutex ctx_mutex_; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; }; class SSLClient : public Client { public: - SSLClient( - const char* host, - int port = 80, - size_t timeout_sec = 300); + SSLClient(const std::string &host, int port = 443, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); - virtual ~SSLClient(); + virtual ~SSLClient(); - virtual bool is_valid() const; + virtual bool is_valid() const; + + void set_ca_cert_path(const char *ca_ceert_file_path, + const char *ca_cert_dir_path = nullptr); + + void enable_server_certificate_verification(bool enabled); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const noexcept; private: - virtual bool read_and_close_socket(socket_t sock, Request& req, Response& res); + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function<bool(Stream &strm, bool last_connection, + bool &connection_close)> + callback); + virtual bool is_ssl() const; - SSL_CTX* ctx_; - std::mutex ctx_mutex_; + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::vector<std::string> host_components_; + + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + bool server_certificate_verification_ = false; + long verify_result_ = 0; }; #endif +// ---------------------------------------------------------------------------- + /* * Implementation */ + namespace detail { -template <class Fn> -void split(const char* b, const char* e, char d, Fn fn) -{ - int i = 0; - int beg = 0; +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} - while (e ? (b + i != e) : (b[i] != '\0')) { - if (b[i] == d) { - fn(&b[beg], &b[i]); - beg = i + 1; - } - i++; +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + const char *charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = (0xC0 | ((code >> 6) & 0x1F)); + buff[1] = (0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = (0xF0 | ((code >> 18) & 0x7)); + buff[1] = (0x80 | ((code >> 12) & 0x3F)); + buff[2] = (0x80 | ((code >> 6) & 0x3F)); + buff[3] = (0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + int val = 0; + int valb = -6; + + for (uint8_t c : in) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_file(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +} + +inline bool is_dir(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; } - if (i) { - fn(&b[beg], &b[i]); + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast<size_t>(size)); + fs.read(&out[0], size); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); +} + +template <class Fn> void split(const char *b, const char *e, char d, Fn fn) { + int i = 0; + int beg = 0; + + while (e ? (b + i != e) : (b[i] != '\0')) { + if (b[i] == d) { + fn(&b[beg], &b[i]); + beg = i + 1; + } + i++; + } + + if (i) { fn(&b[beg], &b[i]); } } // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` // to store data. The call can set memory on stack for performance. class stream_line_reader { public: - stream_line_reader(Stream& strm, char* fixed_buffer, size_t fixed_buffer_size) - : strm_(strm) - , fixed_buffer_(fixed_buffer) - , fixed_buffer_size_(fixed_buffer_size) { - } + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} - const char* ptr() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_; - } else { - return glowable_buffer_.data(); - } - } - - bool getline() { - fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); - - for (size_t i = 0; ; i++) { - char byte; - auto n = strm_.read(&byte, 1); - - if (n < 0) { - return false; - } else if (n == 0) { - if (i == 0) { - return false; - } else { - break; - } - } - - append(byte); - - if (byte == '\n') { - break; - } - } - - return true; - } - -private: - void append(char c) { - if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { - fixed_buffer_[fixed_buffer_used_size_++] = c; - fixed_buffer_[fixed_buffer_used_size_] = '\0'; - } else { - if (glowable_buffer_.empty()) { - assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); - } - glowable_buffer_ += c; - } - } - - Stream& strm_; - char* fixed_buffer_; - const size_t fixed_buffer_size_; - size_t fixed_buffer_used_size_; - std::string glowable_buffer_; -}; - -inline int close_socket(socket_t sock) -{ -#ifdef _WIN32 - return closesocket(sock); -#else - return close(sock); -#endif -} - -inline int select_read(socket_t sock, size_t sec, size_t usec) -{ - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); - - timeval tv; - tv.tv_sec = sec; - tv.tv_usec = usec; - - return select(sock + 1, &fds, NULL, NULL, &tv); -} - -inline bool wait_until_socket_is_ready(socket_t sock, size_t sec, size_t usec) -{ - fd_set fdsr; - FD_ZERO(&fdsr); - FD_SET(sock, &fdsr); - - auto fdsw = fdsr; - auto fdse = fdsr; - - timeval tv; - tv.tv_sec = sec; - tv.tv_usec = usec; - - if (select(sock + 1, &fdsr, &fdsw, &fdse, &tv) < 0) { - return false; - } else if (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw)) { - int error = 0; - socklen_t len = sizeof(error); - if (getsockopt(sock, SOL_SOCKET, SO_ERROR, (char*)&error, &len) < 0 || error) { - return false; - } + const char *ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; } else { + return glowable_buffer_.data(); + } + } + + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } + } + + bool end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; + } + + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + + if (byte == '\n') { break; } } return true; + } + +private: + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; +}; + +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +inline int select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast<int>(sec * 1000 + usec / 1000); + + return poll(&pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast<long>(sec); + tv.tv_usec = static_cast<long>(usec); + + return select(static_cast<int>(sock + 1), &fds, nullptr, nullptr, &tv); +#endif +} + +inline int select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast<int>(sec * 1000 + usec / 1000); + + return poll(&pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast<long>(sec); + tv.tv_usec = static_cast<long>(usec); + + return select(static_cast<int>(sock + 1), nullptr, &fds, nullptr, &tv); +#endif +} + +inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast<int>(sec * 1000 + usec / 1000); + + if (poll(&pfd_read, 1, timeout) > 0 && + pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast<char *>(&error), &len) >= 0 && + !error; + } + return false; +#else + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast<long>(sec); + tv.tv_usec = static_cast<long>(usec); + + if (select(static_cast<int>(sock + 1), &fdsr, &fdsw, &fdse, &tv) > 0 && + (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast<char *>(&error), &len) >= 0 && + !error; + } + return false; +#endif +} + +class SocketStream : public Stream { +public: + SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + int read(char *ptr, size_t size) override; + int write(const char *ptr, size_t size) override; + std::string get_remote_addr() const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream : public Stream { +public: + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec); + virtual ~SSLSocketStream(); + + bool is_readable() const override; + bool is_writable() const override; + int read(char *ptr, size_t size) override; + int write(const char *ptr, size_t size) override; + std::string get_remote_addr() const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; +}; +#endif + +class BufferStream : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + int read(char *ptr, size_t size) override; + int write(const char *ptr, size_t size) override; + std::string get_remote_addr() const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + int position = 0; +}; + +template <typename T> +inline bool process_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + assert(keep_alive_max_count > 0); + + auto ret = false; + + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { // keep_alive_max_count is 0 or 1 + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(strm, true, dummy_connection_close); + } + + return ret; } template <typename T> -inline bool read_and_close_socket(socket_t sock, size_t keep_alive_max_count, T callback) -{ - bool ret = false; - - if (keep_alive_max_count > 0) { - auto count = keep_alive_max_count; - while (count > 0 && - detail::select_read(sock, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { - SocketStream strm(sock); - auto last_connection = count == 1; - auto connection_close = false; - - ret = callback(strm, last_connection, connection_close); - if (!ret || connection_close) { - break; - } - - count--; - } - } else { - SocketStream strm(sock); - auto dummy_connection_close = false; - ret = callback(strm, true, dummy_connection_close); - } - - close_socket(sock); - return ret; +inline bool process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, + time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + auto ret = process_socket(is_client_request, sock, keep_alive_max_count, + read_timeout_sec, read_timeout_usec, callback); + close_socket(sock); + return ret; } -inline int shutdown_socket(socket_t sock) -{ +inline int shutdown_socket(socket_t sock) { #ifdef _WIN32 - return shutdown(sock, SD_BOTH); + return shutdown(sock, SD_BOTH); #else - return shutdown(sock, SHUT_RDWR); + return shutdown(sock, SHUT_RDWR); #endif } template <typename Fn> -socket_t create_socket(const char* host, int port, Fn fn, int socket_flags = 0) -{ +socket_t create_socket(const char *host, int port, Fn fn, + int socket_flags = 0) { #ifdef _WIN32 #define SO_SYNCHRONOUS_NONALERT 0x20 #define SO_OPENTYPE 0x7008 - int opt = SO_SYNCHRONOUS_NONALERT; - setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char*)&opt, sizeof(opt)); + int opt = SO_SYNCHRONOUS_NONALERT; + setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char *)&opt, + sizeof(opt)); #endif - // Get address info - struct addrinfo hints; - struct addrinfo *result; + // Get address info + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = socket_flags; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; - auto service = std::to_string(port); + auto service = std::to_string(port); - if (getaddrinfo(host, service.c_str(), &hints, &result)) { - return INVALID_SOCKET; - } - - for (auto rp = result; rp; rp = rp->ai_next) { - // Create a socket - auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - if (sock == INVALID_SOCKET) { - continue; - } - - // Make 'reuse address' option available - int yes = 1; - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)&yes, sizeof(yes)); - - // bind or connect - if (fn(sock, *rp)) { - freeaddrinfo(result); - return sock; - } - - close_socket(sock); - } - - freeaddrinfo(result); + if (getaddrinfo(host, service.c_str(), &hints, &result)) { return INVALID_SOCKET; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + if (sock == INVALID_SOCKET) { continue; } + +#ifndef _WIN32 + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } +#endif + + // Make 'reuse address' option available + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char *>(&yes), + sizeof(yes)); +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast<char *>(&yes), + sizeof(yes)); +#endif + + // bind or connect + if (fn(sock, *rp)) { + freeaddrinfo(result); + return sock; + } + + close_socket(sock); + } + + freeaddrinfo(result); + return INVALID_SOCKET; } -inline void set_nonblocking(socket_t sock, bool nonblocking) -{ +inline void set_nonblocking(socket_t sock, bool nonblocking) { #ifdef _WIN32 - auto flags = nonblocking ? 1UL : 0UL; - ioctlsocket(sock, FIONBIO, &flags); + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); #else - auto flags = fcntl(sock, F_GETFL, 0); - fcntl(sock, F_SETFL, nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); #endif } -inline bool is_connection_error() -{ +inline bool is_connection_error() { #ifdef _WIN32 - return WSAGetLastError() != WSAEWOULDBLOCK; + return WSAGetLastError() != WSAEWOULDBLOCK; #else - return errno != EINPROGRESS; + return errno != EINPROGRESS; #endif } +inline bool bind_ip_address(socket_t sock, const char *host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host, "0", &hints, &result)) { return false; } + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast<int>(ai.ai_addrlen))) { + ret = true; + break; + } + } + + freeaddrinfo(result); + return ret; +} + +inline std::string if2ip(const std::string &ifn) { +#ifndef _WIN32 + struct ifaddrs *ifap; + getifaddrs(&ifap); + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } + } + } + freeifaddrs(ifap); +#endif + return std::string(); +} + +inline socket_t create_client_socket(const char *host, int port, + time_t timeout_sec, + const std::string &intf) { + return create_socket( + host, port, [&](socket_t sock, struct addrinfo &ai) -> bool { + if (!intf.empty()) { + auto ip = if2ip(intf); + if (ip.empty()) { ip = intf; } + if (!bind_ip_address(sock, ip.c_str())) { return false; } + } + + set_nonblocking(sock, true); + + auto ret = ::connect(sock, ai.ai_addr, static_cast<int>(ai.ai_addrlen)); + if (ret < 0) { + if (is_connection_error() || + !wait_until_socket_is_ready(sock, timeout_sec, 0)) { + close_socket(sock); + return false; + } + } + + set_nonblocking(sock, false); + return true; + }); +} + inline std::string get_remote_addr(socket_t sock) { - struct sockaddr_storage addr; - socklen_t len = sizeof(addr); + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); - if (!getpeername(sock, (struct sockaddr*)&addr, &len)) { - char ipstr[NI_MAXHOST]; + if (!getpeername(sock, reinterpret_cast<struct sockaddr *>(&addr), &len)) { + std::array<char, NI_MAXHOST> ipstr{}; - if (!getnameinfo((struct sockaddr*)&addr, len, - ipstr, sizeof(ipstr), nullptr, 0, NI_NUMERICHOST)) { - return ipstr; - } + if (!getnameinfo(reinterpret_cast<struct sockaddr *>(&addr), len, + ipstr.data(), ipstr.size(), nullptr, 0, NI_NUMERICHOST)) { + return ipstr.data(); } + } - return std::string(); + return std::string(); } -inline bool is_file(const std::string& path) -{ - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +inline const char * +find_content_type(const std::string &path, + const std::map<std::string, std::string> &user_data) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second.c_str(); } + + if (ext == "txt") { + return "text/plain"; + } else if (ext == "html" || ext == "htm") { + return "text/html"; + } else if (ext == "css") { + return "text/css"; + } else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } else if (ext == "png") { + return "image/png"; + } else if (ext == "gif") { + return "image/gif"; + } else if (ext == "svg") { + return "image/svg+xml"; + } else if (ext == "ico") { + return "image/x-icon"; + } else if (ext == "json") { + return "application/json"; + } else if (ext == "pdf") { + return "application/pdf"; + } else if (ext == "js") { + return "application/javascript"; + } else if (ext == "wasm") { + return "application/wasm"; + } else if (ext == "xml") { + return "application/xml"; + } else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; } -inline bool is_dir(const std::string& path) -{ - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); -} - -inline bool is_valid_path(const std::string& path) { - size_t level = 0; - size_t i = 0; - - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } - - while (i < path.size()) { - // Read component - auto beg = i; - while (i < path.size() && path[i] != '/') { - i++; - } - - auto len = i - beg; - assert(len > 0); - - if (!path.compare(beg, len, ".")) { - ; - } else if (!path.compare(beg, len, "..")) { - if (level == 0) { - return false; - } - level--; - } else { - level++; - } - - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } - } - - return true; -} - -inline void read_file(const std::string& path, std::string& out) -{ - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast<size_t>(size)); - fs.read(&out[0], size); -} - -inline std::string file_extension(const std::string& path) -{ - std::smatch m; - auto pat = std::regex("\\.([a-zA-Z0-9]+)$"); - if (std::regex_search(path, m, pat)) { - return m[1].str(); - } - return std::string(); -} - -inline const char* find_content_type(const std::string& path) -{ - auto ext = file_extension(path); - if (ext == "txt") { - return "text/plain"; - } else if (ext == "html") { - return "text/html"; - } else if (ext == "css") { - return "text/css"; - } else if (ext == "jpeg" || ext == "jpg") { - return "image/jpg"; - } else if (ext == "png") { - return "image/png"; - } else if (ext == "gif") { - return "image/gif"; - } else if (ext == "svg") { - return "image/svg+xml"; - } else if (ext == "ico") { - return "image/x-icon"; - } else if (ext == "json") { - return "application/json"; - } else if (ext == "pdf") { - return "application/pdf"; - } else if (ext == "js") { - return "application/javascript"; - } else if (ext == "xml") { - return "application/xml"; - } else if (ext == "xhtml") { - return "application/xhtml+xml"; - } - return nullptr; -} - -inline const char* status_message(int status) -{ - switch (status) { - case 200: return "OK"; - case 301: return "Moved Permanently"; - case 302: return "Found"; - case 303: return "See Other"; - case 304: return "Not Modified"; - case 400: return "Bad Request"; - case 403: return "Forbidden"; - case 404: return "Not Found"; - case 415: return "Unsupported Media Type"; - default: - case 500: return "Internal Server Error"; - } -} - -inline const char* get_header_value(const Headers& headers, const char* key, const char* def) -{ - auto it = headers.find(key); - if (it != headers.end()) { - return it->second.c_str(); - } - return def; -} - -inline int get_header_value_int(const Headers& headers, const char* key, int def) -{ - auto it = headers.find(key); - if (it != headers.end()) { - return std::stoi(it->second); - } - return def; -} - -inline bool read_headers(Stream& strm, Headers& headers) -{ - static std::regex re(R"((.+?):\s*(.+?)\s*\r\n)"); - - const auto bufsiz = 2048; - char buf[bufsiz]; - - stream_line_reader reader(strm, buf, bufsiz); - - for (;;) { - if (!reader.getline()) { - return false; - } - if (!strcmp(reader.ptr(), "\r\n")) { - break; - } - std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { - auto key = std::string(m[1]); - auto val = std::string(m[2]); - headers.emplace(key, val); - } - } - - return true; -} - -inline bool read_content_with_length(Stream& strm, std::string& out, size_t len, Progress progress) -{ - out.assign(len, 0); - size_t r = 0; - while (r < len){ - auto n = strm.read(&out[r], len - r); - if (n <= 0) { - return false; - } - - r += n; - - if (progress) { - progress(r, len); - } - } - - return true; -} - -inline bool read_content_without_length(Stream& strm, std::string& out) -{ - for (;;) { - char byte; - auto n = strm.read(&byte, 1); - if (n < 0) { - return false; - } else if (n == 0) { - return true; - } - out += byte; - } - - return true; -} - -inline bool read_content_chunked(Stream& strm, std::string& out) -{ - const auto bufsiz = 16; - char buf[bufsiz]; - - stream_line_reader reader(strm, buf, bufsiz); - - if (!reader.getline()) { - return false; - } - - auto chunk_len = std::stoi(reader.ptr(), 0, 16); - - while (chunk_len > 0){ - std::string chunk; - if (!read_content_with_length(strm, chunk, chunk_len, nullptr)) { - return false; - } - - if (!reader.getline()) { - return false; - } - - if (strcmp(reader.ptr(), "\r\n")) { - break; - } - - out += chunk; - - if (!reader.getline()) { - return false; - } - - chunk_len = std::stoi(reader.ptr(), 0, 16); - } - - if (chunk_len == 0) { - // Reader terminator after chunks - if (!reader.getline() || strcmp(reader.ptr(), "\r\n")) - return false; - } - - return true; -} - -template <typename T> -bool read_content(Stream& strm, T& x, Progress progress = Progress()) -{ - auto len = get_header_value_int(x.headers, "Content-Length", 0); - - if (len) { - return read_content_with_length(strm, x.body, len, progress); - } else { - const auto& encoding = get_header_value(x.headers, "Transfer-Encoding", ""); - - if (!strcasecmp(encoding, "chunked")) { - return read_content_chunked(strm, x.body); - } else { - return read_content_without_length(strm, x.body); - } - } - - return true; -} - -template <typename T> -inline void write_headers(Stream& strm, const T& info) -{ - for (const auto& x: info.headers) { - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - } - strm.write("\r\n"); -} - -inline std::string encode_url(const std::string& s) -{ - std::string result; - - for (auto i = 0; s[i]; i++) { - switch (s[i]) { - case ' ': result += "+"; break; - case '\'': result += "%27"; break; - case ',': result += "%2C"; break; - case ':': result += "%3A"; break; - case ';': result += "%3B"; break; - default: - if (s[i] < 0) { - result += '%'; - char hex[4]; - size_t len = snprintf(hex, sizeof(hex) - 1, "%02X", (unsigned char)s[i]); - assert(len == 2); - result.append(hex, len); - } else { - result += s[i]; - } - break; - } - } - - return result; -} - -inline bool is_hex(char c, int& v) -{ - if (0x20 <= c && isdigit(c)) { - v = c - '0'; - return true; - } else if ('A' <= c && c <= 'F') { - v = c - 'A' + 10; - return true; - } else if ('a' <= c && c <= 'f') { - v = c - 'a' + 10; - return true; - } - return false; -} - -inline bool from_hex_to_i(const std::string& s, size_t i, size_t cnt, int& val) -{ - if (i >= s.size()) { - return false; - } - - val = 0; - for (; cnt; i++, cnt--) { - if (!s[i]) { - return false; - } - int v = 0; - if (is_hex(s[i], v)) { - val = val * 16 + v; - } else { - return false; - } - } - return true; -} - -inline size_t to_utf8(int code, char* buff) -{ - if (code < 0x0080) { - buff[0] = (code & 0x7F); - return 1; - } else if (code < 0x0800) { - buff[0] = (0xC0 | ((code >> 6) & 0x1F)); - buff[1] = (0x80 | (code & 0x3F)); - return 2; - } else if (code < 0xD800) { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } else if (code < 0xE000) { // D800 - DFFF is invalid... - return 0; - } else if (code < 0x10000) { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } else if (code < 0x110000) { - buff[0] = (0xF0 | ((code >> 18) & 0x7)); - buff[1] = (0x80 | ((code >> 12) & 0x3F)); - buff[2] = (0x80 | ((code >> 6) & 0x3F)); - buff[3] = (0x80 | (code & 0x3F)); - return 4; - } - - // NOTREACHED - return 0; -} - -inline std::string decode_url(const std::string& s) -{ - std::string result; - - for (size_t i = 0; i < s.size(); i++) { - if (s[i] == '%' && i + 1 < s.size()) { - if (s[i + 1] == 'u') { - int val = 0; - if (from_hex_to_i(s, i + 2, 4, val)) { - // 4 digits Unicode codes - char buff[4]; - size_t len = to_utf8(val, buff); - if (len > 0) { - result.append(buff, len); - } - i += 5; // 'u0000' - } else { - result += s[i]; - } - } else { - int val = 0; - if (from_hex_to_i(s, i + 1, 2, val)) { - // 2 digits hex codes - result += val; - i += 2; // '00' - } else { - result += s[i]; - } - } - } else if (s[i] == '+') { - result += ' '; - } else { - result += s[i]; - } - } - - return result; -} - -inline void parse_query_text(const std::string& s, Params& params) -{ - split(&s[0], &s[s.size()], '&', [&](const char* b, const char* e) { - std::string key; - std::string val; - split(b, e, '=', [&](const char* b, const char* e) { - if (key.empty()) { - key.assign(b, e); - } else { - val.assign(b, e); - } - }); - params.emplace(key, decode_url(val)); - }); -} - -inline bool parse_multipart_boundary(const std::string& content_type, std::string& boundary) -{ - auto pos = content_type.find("boundary="); - if (pos == std::string::npos) { - return false; - } - - boundary = content_type.substr(pos + 9); - return true; -} - -inline bool parse_multipart_formdata( - const std::string& boundary, const std::string& body, MultipartFiles& files) -{ - static std::string dash = "--"; - static std::string crlf = "\r\n"; - - static std::regex re_content_type( - "Content-Type: (.*?)", std::regex_constants::icase); - - static std::regex re_content_disposition( - "Content-Disposition: form-data; name=\"(.*?)\"(?:; filename=\"(.*?)\")?", - std::regex_constants::icase); - - auto dash_boundary = dash + boundary; - - auto pos = body.find(dash_boundary); - if (pos != 0) { - return false; - } - - pos += dash_boundary.size(); - - auto next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } - - pos = next_pos + crlf.size(); - - while (pos < body.size()) { - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } - - std::string name; - MultipartFile file; - - auto header = body.substr(pos, (next_pos - pos)); - - while (pos != next_pos) { - std::smatch m; - if (std::regex_match(header, m, re_content_type)) { - file.content_type = m[1]; - } else if (std::regex_match(header, m, re_content_disposition)) { - name = m[1]; - file.filename = m[2]; - } - - pos = next_pos + crlf.size(); - - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } - - header = body.substr(pos, (next_pos - pos)); - } - - pos = next_pos + crlf.size(); - - next_pos = body.find(crlf + dash_boundary, pos); - - if (next_pos == std::string::npos) { - return false; - } - - file.offset = pos; - file.length = next_pos - pos; - - pos = next_pos + crlf.size() + dash_boundary.size(); - - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } - - files.emplace(name, file); - - pos = next_pos + crlf.size(); - } - - return true; -} - -inline std::string to_lower(const char* beg, const char* end) -{ - std::string out; - auto it = beg; - while (it != end) { - out += ::tolower(*it); - it++; - } - return out; -} - -inline void make_range_header_core(std::string&) {} - -template<typename uint64_t> -inline void make_range_header_core(std::string& field, uint64_t value) -{ - if (!field.empty()) { - field += ", "; - } - field += std::to_string(value) + "-"; -} - -template<typename uint64_t, typename... Args> -inline void make_range_header_core(std::string& field, uint64_t value1, uint64_t value2, Args... args) -{ - if (!field.empty()) { - field += ", "; - } - field += std::to_string(value1) + "-" + std::to_string(value2); - make_range_header_core(field, args...); +inline const char *status_message(int status) { + switch (status) { + case 200: return "OK"; + case 202: return "Accepted"; + case 204: return "No Content"; + case 206: return "Partial Content"; + case 301: return "Moved Permanently"; + case 302: return "Found"; + case 303: return "See Other"; + case 304: return "Not Modified"; + case 400: return "Bad Request"; + case 401: return "Unauthorized"; + case 403: return "Forbidden"; + case 404: return "Not Found"; + case 413: return "Payload Too Large"; + case 414: return "Request-URI Too Long"; + case 415: return "Unsupported Media Type"; + case 416: return "Range Not Satisfiable"; + case 503: return "Service Unavailable"; + + default: + case 500: return "Internal Server Error"; + } } #ifdef CPPHTTPLIB_ZLIB_SUPPORT -inline bool can_compress(const std::string& content_type) { - return !content_type.find("text/") || - content_type == "image/svg+xml" || - content_type == "application/javascript" || - content_type == "application/json" || - content_type == "application/xml" || - content_type == "application/xhtml+xml"; +inline bool can_compress(const std::string &content_type) { + return !content_type.find("text/") || content_type == "image/svg+xml" || + content_type == "application/javascript" || + content_type == "application/json" || + content_type == "application/xml" || + content_type == "application/xhtml+xml"; } -inline void compress(std::string& content) -{ - z_stream strm; +inline bool compress(std::string &content) { + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY); + if (ret != Z_OK) { return false; } + + strm.avail_in = content.size(); + strm.next_in = + const_cast<Bytef *>(reinterpret_cast<const Bytef *>(content.data())); + + std::string compressed; + + std::array<char, 16384> buff{}; + do { + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast<Bytef *>(buff.data()); + ret = deflate(&strm, Z_FINISH); + assert(ret != Z_STREAM_ERROR); + compressed.append(buff.data(), buff.size() - strm.avail_out); + } while (strm.avail_out == 0); + + assert(ret == Z_STREAM_END); + assert(strm.avail_in == 0); + + content.swap(compressed); + + deflateEnd(&strm); + return true; +} + +class decompressor { +public: + decompressor() { strm.zalloc = Z_NULL; strm.zfree = Z_NULL; strm.opaque = Z_NULL; - auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, Z_DEFAULT_STRATEGY); - if (ret != Z_OK) { - return; - } + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 16 specifies + // that the stream to decompress will be formatted with a gzip wrapper. + is_valid_ = inflateInit2(&strm, 16 + 15) == Z_OK; + } - strm.avail_in = content.size(); - strm.next_in = (Bytef *)content.data(); + ~decompressor() { inflateEnd(&strm); } - std::string compressed; + bool is_valid() const { return is_valid_; } - const auto bufsiz = 16384; - char buff[bufsiz]; + template <typename T> + bool decompress(const char *data, size_t data_length, T callback) { + int ret = Z_OK; + + strm.avail_in = data_length; + strm.next_in = const_cast<Bytef *>(reinterpret_cast<const Bytef *>(data)); + + std::array<char, 16384> buff{}; do { - strm.avail_out = bufsiz; - strm.next_out = (Bytef *)buff; - deflate(&strm, Z_FINISH); - compressed.append(buff, bufsiz - strm.avail_out); + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast<Bytef *>(buff.data()); + + ret = inflate(&strm, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm); return false; + } + + if (!callback(buff.data(), buff.size() - strm.avail_out)) { + return false; + } } while (strm.avail_out == 0); - content.swap(compressed); + return ret == Z_OK || ret == Z_STREAM_END; + } - deflateEnd(&strm); +private: + bool is_valid_; + z_stream strm; +}; +#endif + +inline bool has_header(const Headers &headers, const char *key) { + return headers.find(key) != headers.end(); } -inline void decompress(std::string& content) -{ - z_stream strm; - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; +inline const char *get_header_value(const Headers &headers, const char *key, + size_t id = 0, const char *def = nullptr) { + auto it = headers.find(key); + std::advance(it, id); + if (it != headers.end()) { return it->second.c_str(); } + return def; +} - // 15 is the value of wbits, which should be at the maximum possible value to ensure - // that any gzip stream can be decoded. The offset of 16 specifies that the stream - // to decompress will be formatted with a gzip wrapper. - auto ret = inflateInit2(&strm, 16 + 15); - if (ret != Z_OK) { - return; +inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, + int def = 0) { + auto it = headers.find(key); + if (it != headers.end()) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } else { + continue; // Skip invalid line. } - strm.avail_in = content.size(); - strm.next_in = (Bytef *)content.data(); + // Skip trailing spaces and tabs. + auto end = line_reader.ptr() + line_reader.size() - 2; + while (line_reader.ptr() < end && (end[-1] == ' ' || end[-1] == '\t')) { + end--; + } - std::string decompressed; + // Horizontal tab and ' ' are considered whitespace and are ignored when on + // the left or right side of the header value: + // - https://stackoverflow.com/questions/50179659/ + // - https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html + static const std::regex re(R"((.+?):[\t ]*(.+))"); - const auto bufsiz = 16384; - char buff[bufsiz]; - do { - strm.avail_out = bufsiz; - strm.next_out = (Bytef *)buff; - inflate(&strm, Z_NO_FLUSH); - decompressed.append(buff, bufsiz - strm.avail_out); - } while (strm.avail_out == 0); + std::cmatch m; + if (std::regex_match(line_reader.ptr(), end, m, re)) { + auto key = std::string(m[1]); + auto val = std::string(m[2]); + headers.emplace(key, val); + } + } - content.swap(decompressed); + return true; +} - inflateEnd(&strm); +inline bool read_content_with_length(Stream &strm, uint64_t len, + Progress progress, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast<size_t>(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, n)) { return false; } + + r += n; + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast<size_t>(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += n; + } +} + +inline bool read_content_without_length(Stream &strm, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; + } + if (!out(buf, n)) { return false; } + } + + return true; +} + +inline bool read_content_chunked(Stream &strm, ContentReceiver out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return false; } + + auto chunk_len = std::stoi(line_reader.ptr(), 0, 16); + + while (chunk_len > 0) { + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { return false; } + + if (strcmp(line_reader.ptr(), "\r\n")) { break; } + + if (!line_reader.getline()) { return false; } + + chunk_len = std::stoi(line_reader.ptr(), 0, 16); + } + + if (chunk_len == 0) { + // Reader terminator after chunks + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); +} + +template <typename T> +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + Progress progress, ContentReceiver receiver) { + + ContentReceiver out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor decompressor; + + if (!decompressor.is_valid()) { + status = 500; + return false; + } + + if (x.get_header_value("Content-Encoding") == "gzip") { + out = [&](const char *buf, size_t n) { + return decompressor.decompress( + buf, n, [&](const char *buf, size_t n) { return receiver(buf, n); }); + }; + } +#else + if (x.get_header_value("Content-Encoding") == "gzip") { + status = 415; + return false; + } +#endif + + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value_uint64(x.headers, "Content-Length", 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); + } + } + + if (!ret) { status = exceed_payload_max_length ? 413 : 400; } + + return ret; +} + +template <typename T> +inline int write_headers(Stream &strm, const T &info, const Headers &headers) { + auto write_len = 0; + for (const auto &x : info.headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline ssize_t write_content(Stream &strm, ContentProvider content_provider, + size_t offset, size_t length) { + size_t begin_offset = offset; + size_t end_offset = offset + length; + while (offset < end_offset) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + offset += l; + written_length = strm.write(d, l); + }; + data_sink.done = [&](void) { written_length = -1; }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, end_offset - offset, data_sink); + if (written_length < 0) { return written_length; } + } + return static_cast<ssize_t>(offset - begin_offset); +} + +template <typename T> +inline ssize_t write_content_chunked(Stream &strm, + ContentProvider content_provider, + T is_shutting_down) { + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + while (data_available && !is_shutting_down()) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + data_available = l > 0; + offset += l; + + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n"; + written_length = strm.write(chunk); + }; + data_sink.done = [&](void) { + data_available = false; + written_length = strm.write("0\r\n\r\n"); + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, 0, data_sink); + + if (written_length < 0) { return written_length; } + total_written_length += written_length; + } + return total_written_length; +} + +template <typename T> +inline bool redirect(T &cli, const Request &req, Response &res, + const std::string &path) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count -= 1; + + Response new_res; + + auto ret = cli.send(new_req, new_res); + if (ret) { res = new_res; } + return ret; +} + +inline std::string encode_url(const std::string &s) { + std::string result; + + for (auto i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast<uint8_t>(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + size_t len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, len); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast<char>(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + split(&s[0], &s[s.size()], '&', [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b, const char *e) { + if (key.empty()) { + key.assign(b, e); + } else { + val.assign(b, e); + } + }); + params.emplace(key, decode_url(val)); + }); +} + +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { return false; } + + boundary = content_type.substr(pos + 9); + return true; +} + +inline bool parse_range_header(const std::string &s, Ranges &ranges) { + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = m.position(1); + auto len = m.length(1); + bool all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) return; + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch m; + if (std::regex_match(b, e, m, re_another_range)) { + ssize_t first = -1; + if (!m.str(1).empty()) { + first = static_cast<ssize_t>(std::stoll(m.str(1))); + } + + ssize_t last = -1; + if (!m.str(2).empty()) { + last = static_cast<ssize_t>(std::stoll(m.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; +} + +class MultipartFormDataParser { +public: + MultipartFormDataParser() {} + + void set_boundary(const std::string &boundary) { boundary_ = boundary; } + + bool is_valid() const { return is_valid_; } + + template <typename T, typename U> + bool parse(const char *buf, size_t n, T content_callback, U header_callback) { + static const std::regex re_content_type(R"(^Content-Type:\s*(.*?)\s*$)", + std::regex_constants::icase); + + static const std::regex re_content_disposition( + "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" + "\"(.*?)\")?\\s*$", + std::regex_constants::icase); + + buf_.append(buf, n); // TODO: performance improvement + + while (!buf_.empty()) { + switch (state_) { + case 0: { // Initial boundary + auto pattern = dash_ + boundary_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + auto pos = buf_.find(pattern); + if (pos != 0) { + is_done_ = true; + return false; + } + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_.find(crlf_); + while (pos != std::string::npos) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + is_done_ = false; + return false; + } + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 3; + break; + } + + auto header = buf_.substr(0, pos); + { + std::smatch m; + if (std::regex_match(header, m, re_content_type)) { + file_.content_type = m[1]; + } else if (std::regex_match(header, m, re_content_disposition)) { + file_.name = m[1]; + file_.filename = m[2]; + } + } + + buf_.erase(0, pos + crlf_.size()); + off_ += pos + crlf_.size(); + pos = buf_.find(crlf_); + } + break; + } + case 3: { // Body + { + auto pattern = crlf_ + dash_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos == std::string::npos) { pos = buf_.size(); } + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos; + buf_.erase(0, pos); + } + + { + auto pattern = crlf_ + dash_ + boundary_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos != std::string::npos) { + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos + pattern.size(); + buf_.erase(0, pos + pattern.size()); + state_ = 4; + } else { + if (!content_callback(buf_.data(), pattern.size())) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pattern.size(); + buf_.erase(0, pattern.size()); + } + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_.size()) { return true; } + if (buf_.find(crlf_) == 0) { + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 1; + } else { + auto pattern = dash_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + if (buf_.find(pattern) == 0) { + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + is_valid_ = true; + state_ = 5; + } else { + is_done_ = true; + return true; + } + } + break; + } + case 5: { // Done + is_valid_ = false; + return false; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + + std::string buf_; + size_t state_ = 0; + size_t is_valid_ = false; + size_t is_done_ = false; + size_t off_ = 0; + MultipartFormData file_; +}; + +inline std::string to_lower(const char *beg, const char *end) { + std::string out; + auto it = beg; + while (it != end) { + out += static_cast<char>(::tolower(*it)); + it++; + } + return out; +} + +inline std::string make_multipart_data_boundary() { + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); + + std::string result = "--cpp-httplib-multipart-data-"; + + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + + return result; +} + +inline std::pair<size_t, size_t> +get_range_offset_and_length(const Request &req, size_t content_length, + size_t index) { + auto r = req.ranges[index]; + + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } + + if (r.first == -1) { + r.first = content_length - r.second; + r.second = content_length - 1; + } + + if (r.second == -1) { r.second = content_length - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline std::string make_content_range_header_field(size_t offset, size_t length, + size_t content_length) { + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template <typename SToken, typename CToken, typename Content> +bool process_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + SToken stoken, CToken ctoken, + Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offsets = get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { return false; } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); + + return true; +} + +inline std::string make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); + + return data; +} + +inline size_t +get_multipart_ranges_data_length(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +inline bool write_multipart_ranges_data(Stream &strm, const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type) { + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider, offset, length) >= 0; + }); +} + +inline std::pair<size_t, size_t> +get_range_offset_and_length(const Request &req, const Response &res, + size_t index) { + auto r = req.ranges[index]; + + if (r.second == -1) { r.second = res.content_length - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI") { + return true; + } + // TODO: check if Content-Length is set + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +template <typename CTX, typename Init, typename Update, typename Final> +inline std::string message_digest(const std::string &s, Init init, + Update update, Final final, + size_t digest_length) { + using namespace std; + + std::vector<unsigned char> md(digest_length, 0); + CTX ctx; + init(&ctx); + update(&ctx, s.data(), s.size()); + final(md.data(), &ctx); + + stringstream ss; + for (auto c : md) { + ss << setfill('0') << setw(2) << hex << (unsigned int)c; + } + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest<MD5_CTX>(s, MD5_Init, MD5_Update, MD5_Final, + MD5_DIGEST_LENGTH); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest<SHA256_CTX>(s, SHA256_Init, SHA256_Update, SHA256_Final, + SHA256_DIGEST_LENGTH); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest<SHA512_CTX>(s, SHA512_Init, SHA512_Update, SHA512_Final, + SHA512_DIGEST_LENGTH); } #endif #ifdef _WIN32 class WSInit { public: - WSInit() { - WSADATA wsaData; - WSAStartup(0x0002, &wsaData); - } + WSInit() { + WSADATA wsaData; + WSAStartup(0x0002, &wsaData); + } - ~WSInit() { - WSACleanup(); - } + ~WSInit() { WSACleanup(); } }; static WSInit wsinit_; @@ -1275,876 +2502,1804 @@ static WSInit wsinit_; } // namespace detail // Header utilities -template<typename uint64_t, typename... Args> -inline std::pair<std::string, std::string> make_range_header(uint64_t value, Args... args) -{ - std::string field; - detail::make_range_header_core(field, value, args...); - field.insert(0, "bytes="); - return std::make_pair("Range", field); +inline std::pair<std::string, std::string> make_range_header(Ranges ranges) { + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", field); +} + +inline std::pair<std::string, std::string> +make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair<std::string, std::string> make_digest_authentication_header( + const Request &req, const std::map<std::string, std::string> &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + using namespace std; + + string nc; + { + stringstream ss; + ss << setfill('0') << setw(8) << hex << cnonce_count; + nc = ss.str(); + } + + auto qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else { + qop = "auth"; + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + string response; + { + auto H = algo == "SHA-256" + ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + + auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") + + "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path + + "\", algorithm=" + algo + ", qop=" + qop + ", nc=\"" + nc + + "\", cnonce=\"" + cnonce + "\", response=\"" + response + "\""; + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} +#endif + +inline bool parse_www_authenticate(const httplib::Response &res, + std::map<std::string, std::string> &auth, + bool is_proxy) { + auto key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(m.position(1), m.length(1)); + auto val = m.length(2) > 0 ? s.substr(m.position(2), m.length(2)) + : s.substr(m.position(3), m.length(3)); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 +inline std::string random_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; } // Request implementation -inline bool Request::has_header(const char* key) const -{ - return headers.find(key) != headers.end(); +inline bool Request::has_header(const char *key) const { + return detail::has_header(headers, key); } -inline std::string Request::get_header_value(const char* key) const -{ - return detail::get_header_value(headers, key, ""); +inline std::string Request::get_header_value(const char *key, size_t id) const { + return detail::get_header_value(headers, key, id, ""); } -inline void Request::set_header(const char* key, const char* val) -{ - headers.emplace(key, val); +inline size_t Request::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return std::distance(r.first, r.second); } -inline bool Request::has_param(const char* key) const -{ - return params.find(key) != params.end(); +inline void Request::set_header(const char *key, const char *val) { + headers.emplace(key, val); } -inline std::string Request::get_param_value(const char* key) const -{ - auto it = params.find(key); - if (it != params.end()) { - return it->second; - } - return std::string(); +inline void Request::set_header(const char *key, const std::string &val) { + headers.emplace(key, val); } -inline bool Request::has_file(const char* key) const -{ - return files.find(key) != files.end(); +inline bool Request::has_param(const char *key) const { + return params.find(key) != params.end(); } -inline MultipartFile Request::get_file_value(const char* key) const -{ - auto it = files.find(key); - if (it != files.end()) { - return it->second; - } - return MultipartFile(); +inline std::string Request::get_param_value(const char *key, size_t id) const { + auto it = params.find(key); + std::advance(it, id); + if (it != params.end()) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const char *key) const { + auto r = params.equal_range(key); + return std::distance(r.first, r.second); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.find("multipart/form-data"); +} + +inline bool Request::has_file(const char *key) const { + return files.find(key) != files.end(); +} + +inline MultipartFormData Request::get_file_value(const char *key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFormData(); } // Response implementation -inline bool Response::has_header(const char* key) const -{ - return headers.find(key) != headers.end(); +inline bool Response::has_header(const char *key) const { + return headers.find(key) != headers.end(); } -inline std::string Response::get_header_value(const char* key) const -{ - return detail::get_header_value(headers, key, ""); +inline std::string Response::get_header_value(const char *key, + size_t id) const { + return detail::get_header_value(headers, key, id, ""); } -inline void Response::set_header(const char* key, const char* val) -{ - headers.emplace(key, val); +inline size_t Response::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return std::distance(r.first, r.second); } -inline void Response::set_redirect(const char* url) -{ - set_header("Location", url); - status = 302; +inline void Response::set_header(const char *key, const char *val) { + headers.emplace(key, val); } -inline void Response::set_content(const char* s, size_t n, const char* content_type) -{ - body.assign(s, n); - set_header("Content-Type", content_type); +inline void Response::set_header(const char *key, const std::string &val) { + headers.emplace(key, val); } -inline void Response::set_content(const std::string& s, const char* content_type) -{ - body = s; - set_header("Content-Type", content_type); +inline void Response::set_redirect(const char *url) { + set_header("Location", url); + status = 302; +} + +inline void Response::set_content(const char *s, size_t n, + const char *content_type) { + body.assign(s, n); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, + const char *content_type) { + body = s; + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t length, + std::function<void(size_t offset, size_t length, DataSink &sink)> provider, + std::function<void()> resource_releaser) { + assert(length > 0); + content_length = length; + content_provider = [provider](size_t offset, size_t length, DataSink &sink) { + provider(offset, length, sink); + }; + content_provider_resource_releaser = resource_releaser; +} + +inline void Response::set_chunked_content_provider( + std::function<void(size_t offset, DataSink &sink)> provider, + std::function<void()> resource_releaser) { + content_length = 0; + content_provider = [provider](size_t offset, size_t, DataSink &sink) { + provider(offset, sink); + }; + content_provider_resource_releaser = resource_releaser; } // Rstream implementation -template <typename ...Args> -inline void Stream::write_format(const char* fmt, const Args& ...args) -{ - const auto bufsiz = 2048; - char buf[bufsiz]; +inline int Stream::write(const char *ptr) { return write(ptr, strlen(ptr)); } -#if defined(_MSC_VER) && _MSC_VER < 1900 - auto n = _snprintf_s(buf, bufsiz, bufsiz - 1, fmt, args...); -#else - auto n = snprintf(buf, bufsiz - 1, fmt, args...); -#endif - if (n > 0) { - if (n >= bufsiz - 1) { - std::vector<char> glowable_buf(bufsiz); - - while (n >= static_cast<int>(glowable_buf.size() - 1)) { - glowable_buf.resize(glowable_buf.size() * 2); -#if defined(_MSC_VER) && _MSC_VER < 1900 - n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), glowable_buf.size() - 1, fmt, args...); -#else - n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); -#endif - } - write(&glowable_buf[0], n); - } else { - write(buf, n); - } - } +inline int Stream::write(const std::string &s) { + return write(s.data(), s.size()); } +template <typename... Args> +inline int Stream::write_format(const char *fmt, const Args &... args) { + std::array<char, 2048> buf; + +#if defined(_MSC_VER) && _MSC_VER < 1900 + auto n = _snprintf_s(buf, bufsiz, buf.size() - 1, fmt, args...); +#else + auto n = snprintf(buf.data(), buf.size() - 1, fmt, args...); +#endif + if (n <= 0) { return n; } + + if (n >= static_cast<int>(buf.size()) - 1) { + std::vector<char> glowable_buf(buf.size()); + + while (n >= static_cast<int>(glowable_buf.size() - 1)) { + glowable_buf.resize(glowable_buf.size() * 2); +#if defined(_MSC_VER) && _MSC_VER < 1900 + n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, args...); +#else + n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); +#endif + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); + } +} + +namespace detail { + // Socket stream implementation -inline SocketStream::SocketStream(socket_t sock): sock_(sock) -{ +inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + +inline SocketStream::~SocketStream() {} + +inline bool SocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } -inline SocketStream::~SocketStream() -{ +inline bool SocketStream::is_writable() const { + return detail::select_write(sock_, 0, 0) > 0; } -inline int SocketStream::read(char* ptr, size_t size) -{ - return recv(sock_, ptr, size, 0); +inline int SocketStream::read(char *ptr, size_t size) { + if (is_readable()) { return recv(sock_, ptr, static_cast<int>(size), 0); } + return -1; } -inline int SocketStream::write(const char* ptr, size_t size) -{ - return send(sock_, ptr, size, 0); +inline int SocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { return send(sock_, ptr, static_cast<int>(size), 0); } + return -1; } -inline int SocketStream::write(const char* ptr) -{ - return write(ptr, strlen(ptr)); +inline std::string SocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); } -inline std::string SocketStream::get_remote_addr() { - return detail::get_remote_addr(sock_); +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::is_writable() const { return true; } + +inline int BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1900 + int len_read = static_cast<int>(buffer._Copy_s(ptr, size, size, position)); +#else + int len_read = static_cast<int>(buffer.copy(ptr, size, position)); +#endif + position += len_read; + return len_read; } +inline int BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast<int>(size); +} + +inline std::string BufferStream::get_remote_addr() const { return ""; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +} // namespace detail + // HTTP server implementation inline Server::Server() - : keep_alive_max_count_(5) - , is_running_(false) - , svr_sock_(INVALID_SOCKET) - , running_threads_(0) -{ + : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + read_timeout_sec_(CPPHTTPLIB_READ_TIMEOUT_SECOND), + read_timeout_usec_(CPPHTTPLIB_READ_TIMEOUT_USECOND), + payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false), + svr_sock_(INVALID_SOCKET) { #ifndef _WIN32 - signal(SIGPIPE, SIG_IGN); + signal(SIGPIPE, SIG_IGN); #endif + new_task_queue = [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }; } -inline Server::~Server() -{ +inline Server::~Server() {} + +inline Server &Server::Get(const char *pattern, Handler handler) { + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Get(const char* pattern, Handler handler) -{ - get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Post(const char *pattern, Handler handler) { + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Post(const char* pattern, Handler handler) -{ - post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Post(const char *pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Put(const char* pattern, Handler handler) -{ - put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Put(const char *pattern, Handler handler) { + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Delete(const char* pattern, Handler handler) -{ - delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Put(const char *pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Options(const char* pattern, Handler handler) -{ - options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Patch(const char *pattern, Handler handler) { + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline bool Server::set_base_dir(const char* path) -{ - if (detail::is_dir(path)) { - base_dir_ = path; - return true; +inline Server &Server::Patch(const char *pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Delete(const char *pattern, Handler handler) { + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Options(const char *pattern, Handler handler) { + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline bool Server::set_base_dir(const char *dir, const char *mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const char *mount_point, const char *dir) { + if (detail::is_dir(dir)) { + std::string mnt = mount_point ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.emplace_back(mnt, dir); + return true; } + } + return false; +} + +inline bool Server::remove_mount_point(const char *mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->first == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline void Server::set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime) { + file_extension_and_mimetype_map_[ext] = mime; +} + +inline void Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); +} + +inline void Server::set_error_handler(Handler handler) { + error_handler_ = std::move(handler); +} + +inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } + +inline void Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + +inline void Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; +} + +inline bool Server::bind_to_port(const char *host, int port, int socket_flags) { + if (bind_internal(host, port, socket_flags) < 0) return false; + return true; +} +inline int Server::bind_to_any_port(const char *host, int socket_flags) { + return bind_internal(host, 0, socket_flags); +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const char *host, int port, int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic<socket_t> sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } +} + +inline bool Server::parse_request_line(const char *s, Request &req) { + const static std::regex re( + "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); + + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3]); + + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { detail::parse_query_text(m[4], req.params); } + + return true; + } + + return false; +} + +inline bool Server::write_response(Stream &strm, bool last_connection, + const Request &req, Response &res) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_) { error_handler_(req, res); } + + // Response line + if (!strm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { return false; -} + } -inline void Server::set_error_handler(Handler handler) -{ - error_handler_ = handler; -} + // Headers + if (last_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } -inline void Server::set_logger(Logger logger) -{ - logger_ = logger; -} + if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") { + res.set_header("Connection", "Keep-Alive"); + } -inline void Server::set_keep_alive_max_count(size_t count) -{ - keep_alive_max_count_ = count; -} + if (!res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } -inline int Server::bind_to_any_port(const char* host, int socket_flags) -{ - return bind_internal(host, 0, socket_flags); -} + if (!res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } -inline bool Server::listen_after_bind() { - return listen_internal(); -} + std::string content_type; + std::string boundary; -inline bool Server::listen(const char* host, int port, int socket_flags) -{ - if (bind_internal(host, port, socket_flags) < 0) - return false; - return listen_internal(); -} + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); -inline bool Server::is_running() const -{ - return is_running_; -} - -inline void Server::stop() -{ - if (is_running_) { - assert(svr_sock_ != INVALID_SOCKET); - detail::shutdown_socket(svr_sock_); - detail::close_socket(svr_sock_); - svr_sock_ = INVALID_SOCKET; - } -} - -inline bool Server::parse_request_line(const char* s, Request& req) -{ - static std::regex re("(GET|HEAD|POST|PUT|DELETE|OPTIONS) (([^?]+)(?:\\?(.+?))?) (HTTP/1\\.[01])\r\n"); - - std::cmatch m; - if (std::regex_match(s, m, re)) { - req.version = std::string(m[4]); - req.method = std::string(m[1]); - req.target = std::string(m[2]); - req.path = detail::decode_url(m[3]); - - // Parse query text - auto len = std::distance(m[4].first, m[4].second); - if (len > 0) { - detail::parse_query_text(m[4], req.params); - } - - return true; + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); } - return false; -} + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } -inline void Server::write_response(Stream& strm, bool last_connection, const Request& req, Response& res) -{ - assert(res.status != -1); - - if (400 <= res.status && error_handler_) { - error_handler_(req, res); + if (res.body.empty()) { + if (res.content_length > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_length); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + } else { + res.set_header("Content-Length", "0"); + } + } + } else { + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } else { + res.body = + detail::make_multipart_ranges_data(req, res, boundary, content_type); } - // Response line - strm.write_format("HTTP/1.1 %d %s\r\n", - res.status, - detail::status_message(res.status)); - - // Headers - if (last_connection || - req.version == "HTTP/1.0" || - req.get_header_value("Connection") == "close") { - res.set_header("Connection", "close"); - } - - if (!res.body.empty()) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - // TODO: 'Accpet-Encoding' has gzip, not gzip;q=0 - const auto& encodings = req.get_header_value("Accept-Encoding"); - if (encodings.find("gzip") != std::string::npos && - detail::can_compress(res.get_header_value("Content-Type"))) { - detail::compress(res.body); - res.set_header("Content-Encoding", "gzip"); - } + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + const auto &encodings = req.get_header_value("Accept-Encoding"); + if (encodings.find("gzip") != std::string::npos && + detail::can_compress(res.get_header_value("Content-Type"))) { + if (detail::compress(res.body)) { + res.set_header("Content-Encoding", "gzip"); + } + } #endif - if (!res.has_header("Content-Type")) { - res.set_header("Content-Type", "text/plain"); - } + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } - auto length = std::to_string(res.body.size()); - res.set_header("Content-Length", length.c_str()); + if (!detail::write_headers(strm, res, Headers())) { return false; } + + // Body + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { return false; } + } else if (res.content_provider) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + return false; + } } + } - detail::write_headers(strm, res); + // Log + if (logger_) { logger_(req, res); } - // Body - if (!res.body.empty() && req.method != "HEAD") { - strm.write(res.body.c_str(), res.body.size()); - } - - // Log - if (logger_) { - logger_(req, res); - } + return true; } -inline bool Server::handle_file_request(Request& req, Response& res) -{ - if (!base_dir_.empty() && detail::is_valid_path(req.path)) { - std::string path = base_dir_ + req.path; +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + if (res.content_length) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider, 0, + res.content_length) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider, offset, length) < + 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data(strm, req, res, boundary, + content_type)) { + return false; + } + } + } else { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + if (detail::write_content_chunked(strm, res.content_provider, + is_shutting_down) < 0) { + return false; + } + } + return true; +} - if (!path.empty() && path.back() == '/') { - path += "index.html"; - } +inline bool Server::read_content(Stream &strm, bool last_connection, + Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + auto ret = read_content_core( + strm, last_connection, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + return true; + }); + + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } + + return ret; +} + +inline bool Server::read_content_with_content_receiver( + Stream &strm, bool last_connection, Request &req, Response &res, + ContentReceiver receiver, MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, last_connection, req, res, receiver, + multipart_header, multipart_receiver); +} + +inline bool Server::read_content_core(Stream &strm, bool last_connection, + Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver) { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiver out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + multipart_form_data_parser.set_boundary(boundary); + out = [&](const char *buf, size_t n) { + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + mulitpart_header); + }; + } else { + out = receiver; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), out)) { + return write_response(strm, last_connection, req, res); + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + } + + return true; +} + +inline bool Server::handle_file_request(Request &req, Response &res, + bool head) { + for (const auto &kv : base_dirs_) { + const auto &mount_point = kv.first; + const auto &base_dir = kv.second; + + // Prefix match + if (!req.path.find(mount_point)) { + std::string sub_path = "/" + req.path.substr(mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } if (detail::is_file(path)) { - detail::read_file(path, res.body); - auto type = detail::find_content_type(path); - if (type) { - res.set_header("Content-Type", type); - } - res.status = 200; - return true; + detail::read_file(path, res.body); + auto type = + detail::find_content_type(path, file_extension_and_mimetype_map_); + if (type) { res.set_header("Content-Type", type); } + res.status = 200; + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + return true; } + } } - - return false; + } + return false; } -inline socket_t Server::create_server_socket(const char* host, int port, int socket_flags) const -{ - return detail::create_socket(host, port, - [](socket_t sock, struct addrinfo& ai) -> bool { - if (::bind(sock, ai.ai_addr, ai.ai_addrlen)) { - return false; - } - if (::listen(sock, 5)) { // Listen through 5 channels - return false; - } - return true; - }, socket_flags); +inline socket_t Server::create_server_socket(const char *host, int port, + int socket_flags) const { + return detail::create_socket( + host, port, + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast<int>(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, 5)) { // Listen through 5 channels + return false; + } + return true; + }, + socket_flags); } -inline int Server::bind_internal(const char* host, int port, int socket_flags) -{ - if (!is_valid()) { - return -1; - } +inline int Server::bind_internal(const char *host, int port, int socket_flags) { + if (!is_valid()) { return -1; } - svr_sock_ = create_server_socket(host, port, socket_flags); - if (svr_sock_ == INVALID_SOCKET) { - return -1; - } + svr_sock_ = create_server_socket(host, port, socket_flags); + if (svr_sock_ == INVALID_SOCKET) { return -1; } - if (port == 0) { - struct sockaddr_storage address; - socklen_t len = sizeof(address); - if (getsockname(svr_sock_, reinterpret_cast<struct sockaddr *>(&address), &len) == -1) { - return -1; + if (port == 0) { + struct sockaddr_storage address; + socklen_t len = sizeof(address); + if (getsockname(svr_sock_, reinterpret_cast<struct sockaddr *>(&address), + &len) == -1) { + return -1; + } + if (address.ss_family == AF_INET) { + return ntohs(reinterpret_cast<struct sockaddr_in *>(&address)->sin_port); + } else if (address.ss_family == AF_INET6) { + return ntohs( + reinterpret_cast<struct sockaddr_in6 *>(&address)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + auto ret = true; + is_running_ = true; + + { + std::unique_ptr<TaskQueue> task_queue(new_task_queue()); + + for (;;) { + if (svr_sock_ == INVALID_SOCKET) { + // The server socket was closed by 'stop' method. + break; + } + + auto val = detail::select_read(svr_sock_, 0, 100000); + + if (val == 0) { // Timeout + continue; + } + + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; } - if (address.ss_family == AF_INET) { - return ntohs(reinterpret_cast<struct sockaddr_in*>(&address)->sin_port); - } else if (address.ss_family == AF_INET6) { - return ntohs(reinterpret_cast<struct sockaddr_in6*>(&address)->sin6_port); + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; } else { - return -1; + ; // The server socket was closed by user. } - } else { - return port; + break; + } + + task_queue->enqueue([=]() { process_and_close_socket(sock); }); } + + task_queue->shutdown(); + } + + is_running_ = false; + return ret; } -inline bool Server::listen_internal() -{ - auto ret = true; - - is_running_ = true; - - for (;;) { - auto val = detail::select_read(svr_sock_, 0, 100000); - - if (val == 0) { // Timeout - if (svr_sock_ == INVALID_SOCKET) { - // The server socket was closed by 'stop' method. - break; - } - continue; - } - - socket_t sock = accept(svr_sock_, NULL, NULL); - - if (sock == INVALID_SOCKET) { - if (svr_sock_ != INVALID_SOCKET) { - detail::close_socket(svr_sock_); - ret = false; - } else { - ; // The server socket was closed by user. - } - break; - } - - // TODO: Use thread pool... - std::thread([=]() { - { - std::lock_guard<std::mutex> guard(running_threads_mutex_); - running_threads_++; - } - - read_and_close_socket(sock); - - { - std::lock_guard<std::mutex> guard(running_threads_mutex_); - running_threads_--; - } - }).detach(); - } - - // TODO: Use thread pool... - for (;;) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - std::lock_guard<std::mutex> guard(running_threads_mutex_); - if (!running_threads_) { - break; - } - } - - is_running_ = false; - - return ret; -} - -inline bool Server::routing(Request& req, Response& res) -{ - if (req.method == "GET" && handle_file_request(req, res)) { - return true; - } - - if (req.method == "GET" || req.method == "HEAD") { - return dispatch_request(req, res, get_handlers_); - } else if (req.method == "POST") { - return dispatch_request(req, res, post_handlers_); - } else if (req.method == "PUT") { - return dispatch_request(req, res, put_handlers_); - } else if (req.method == "DELETE") { - return dispatch_request(req, res, delete_handlers_); - } else if (req.method == "OPTIONS") { - return dispatch_request(req, res, options_handlers_); - } - return false; -} - -inline bool Server::dispatch_request(Request& req, Response& res, Handlers& handlers) -{ - for (const auto& x: handlers) { - const auto& pattern = x.first; - const auto& handler = x.second; - - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res); - return true; - } - } - return false; -} - -inline bool Server::process_request(Stream& strm, bool last_connection, bool& connection_close) -{ - const auto bufsiz = 2048; - char buf[bufsiz]; - - detail::stream_line_reader reader(strm, buf, bufsiz); - - // Connection has been closed on client - if (!reader.getline()) { - return false; - } - - Request req; - Response res; - - res.version = "HTTP/1.1"; - - // Request line and headers - if (!parse_request_line(reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { - res.status = 400; - write_response(strm, last_connection, req, res); - return true; - } - - auto ret = true; - if (req.get_header_value("Connection") == "close") { - // ret = false; - connection_close = true; - } - - req.set_header("REMOTE_ADDR", strm.get_remote_addr().c_str()); - - // Body - if (req.method == "POST" || req.method == "PUT") { - if (!detail::read_content(strm, req)) { - res.status = 400; - write_response(strm, last_connection, req, res); - return ret; - } - - const auto& content_type = req.get_header_value("Content-Type"); - - if (req.get_header_value("Content-Encoding") == "gzip") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - detail::decompress(req.body); -#else - res.status = 415; - write_response(strm, last_connection, req, res); - return ret; -#endif - } - - if (!content_type.find("application/x-www-form-urlencoded")) { - detail::parse_query_text(req.body, req.params); - } else if(!content_type.find("multipart/form-data")) { - std::string boundary; - if (!detail::parse_multipart_boundary(content_type, boundary) || - !detail::parse_multipart_formdata(boundary, req.body, req.files)) { - res.status = 400; - write_response(strm, last_connection, req, res); - return ret; - } - } - } - - if (routing(req, res)) { - if (res.status == -1) { - res.status = 200; - } - } else { - res.status = 404; - } - - write_response(strm, last_connection, req, res); - return ret; -} - -inline bool Server::is_valid() const -{ +inline bool Server::routing(Request &req, Response &res, Stream &strm, + bool last_connection) { + // File handler + bool is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, last_connection, req, res, receiver, nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, last_connection, req, res, nullptr, header, receiver); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, reader, post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, reader, put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, reader, patch_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, last_connection, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; } -inline bool Server::read_and_close_socket(socket_t sock) -{ - return detail::read_and_close_socket( - sock, - keep_alive_max_count_, - [this](Stream& strm, bool last_connection, bool& connection_close) { - return process_request(strm, last_connection, connection_close); - }); +inline bool Server::dispatch_request(Request &req, Response &res, + Handlers &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + return false; +} + +inline bool Server::dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + HandlersForContentReader &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function<void(Request &)> &setup_request) { + std::array<char, 2048> buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + Response res; + + res.version = "HTTP/1.1"; + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, last_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_close = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_close = true; + } + + req.set_header("REMOTE_ADDR", strm.get_remote_addr()); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error + } + } + + if (setup_request) { setup_request(req); } + + // Rounting + if (routing(req, res, strm, last_connection)) { + if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } + } else { + if (res.status == -1) { res.status = 404; } + } + + return write_response(strm, last_connection, req, res); +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + [this](Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, last_connection, connection_close, + nullptr); + }); } // HTTP client implementation -inline Client::Client( - const char* host, int port, size_t timeout_sec) - : host_(host) - , port_(port) - , timeout_sec_(timeout_sec) - , host_and_port_(host_ + ":" + std::to_string(port_)) -{ +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : host_(host), port_(port), + host_and_port_(host_ + ":" + std::to_string(port_)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + +inline Client::~Client() {} + +inline bool Client::is_valid() const { return true; } + +inline socket_t Client::create_client_socket() const { + if (!proxy_host_.empty()) { + return detail::create_client_socket(proxy_host_.c_str(), proxy_port_, + timeout_sec_, interface_); + } + return detail::create_client_socket(host_.c_str(), port_, timeout_sec_, + interface_); } -inline Client::~Client() -{ +inline bool Client::read_response_line(Stream &strm, Response &res) { + std::array<char, 2048> buf; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + + const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), m, re)) { + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + } + + return true; } -inline bool Client::is_valid() const -{ - return true; -} +inline bool Client::send(const Request &req, Response &res) { + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } -inline socket_t Client::create_client_socket() const -{ - return detail::create_socket(host_.c_str(), port_, - [=](socket_t sock, struct addrinfo& ai) -> bool { - detail::set_nonblocking(sock, true); - - auto ret = connect(sock, ai.ai_addr, ai.ai_addrlen); - if (ret < 0) { - if (detail::is_connection_error() || - !detail::wait_until_socket_is_ready(sock, timeout_sec_, 0)) { - detail::close_socket(sock); - return false; - } - } - - detail::set_nonblocking(sock, false); - return true; - }); -} - -inline bool Client::read_response_line(Stream& strm, Response& res) -{ - const auto bufsiz = 2048; - char buf[bufsiz]; - - detail::stream_line_reader reader(strm, buf, bufsiz); - - if (!reader.getline()) { - return false; - } - - const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .+\r\n"); - - std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { - res.version = std::string(m[1]); - res.status = std::stoi(std::string(m[2])); - } - - return true; -} - -inline bool Client::send(Request& req, Response& res) -{ - if (req.path.empty()) { - return false; - } - - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { - return false; - } - - return read_and_close_socket(sock, req, res); -} - -inline void Client::write_request(Stream& strm, Request& req) -{ - auto path = detail::encode_url(req.path); - - // Request line - strm.write_format("%s %s HTTP/1.1\r\n", - req.method.c_str(), - path.c_str()); - - // Headers - req.set_header("Host", host_and_port_.c_str()); - - if (!req.has_header("Accept")) { - req.set_header("Accept", "*/*"); - } - - if (!req.has_header("User-Agent")) { - req.set_header("User-Agent", "cpp-httplib/0.2"); - } - - // TODO: Support KeepAlive connection - // if (!req.has_header("Connection")) { - req.set_header("Connection", "close"); - // } - - if (!req.body.empty()) { - if (!req.has_header("Content-Type")) { - req.set_header("Content-Type", "text/plain"); - } - - auto length = std::to_string(req.body.size()); - req.set_header("Content-Length", length.c_str()); - } - - detail::write_headers(strm, req); - - // Body - if (!req.body.empty()) { - if (req.get_header_value("Content-Type") == "application/x-www-form-urlencoded") { - auto str = detail::encode_url(req.body); - strm.write(str.c_str(), str.size()); - } else { - strm.write(req.body.c_str(), req.body.size()); - } - } -} - -inline bool Client::process_request(Stream& strm, Request& req, Response& res, bool& connection_close) -{ - // Send request - write_request(strm, req); - - // Receive response and headers - if (!read_response_line(strm, res) || !detail::read_headers(strm, res.headers)) { - return false; - } - - if (res.get_header_value("Connection") == "close" || res.version == "HTTP/1.0") { - connection_close = true; - } - - // Body - if (req.method != "HEAD") { - if (!detail::read_content(strm, res, req.progress)) { - return false; - } - - if (res.get_header_value("Content-Encoding") == "gzip") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - detail::decompress(res.body); -#else - return false; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + bool error; + if (!connect(sock, res, error)) { return error; } + } #endif + + return process_and_close_socket( + sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { + return handle_request(strm, req, res, last_connection, + connection_close); + }); +} + +inline bool Client::send(const std::vector<Request> &requests, + std::vector<Response> &responses) { + size_t i = 0; + while (i < requests.size()) { + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + Response res; + bool error; + if (!connect(sock, res, error)) { return false; } + } +#endif + + if (!process_and_close_socket(sock, requests.size() - i, + [&](Stream &strm, bool last_connection, + bool &connection_close) -> bool { + auto &req = requests[i++]; + auto res = Response(); + auto ret = handle_request(strm, req, res, + last_connection, + connection_close); + if (ret) { + responses.emplace_back(std::move(res)); + } + return ret; + })) { + return false; + } + } + + return true; +} + +inline bool Client::handle_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + if (req.path.empty()) { return false; } + + bool ret; + + if (!is_ssl() && !proxy_host_.empty()) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, last_connection, connection_close); + } else { + ret = process_request(strm, req, res, last_connection, connection_close); + } + + if (!ret) { return false; } + + if (300 < res.status && res.status < 400 && follow_location_) { + ret = redirect(req, res); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (res.status == 401 || res.status == 407) { + auto is_proxy = res.status == 407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map<std::string, std::string> auth; + if (parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + auto key = is_proxy ? "Proxy-Authorization" : "WWW-Authorization"; + new_req.headers.erase(key); + new_req.headers.insert(make_digest_authentication_header( + req, auth, 1, random_string(10), username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline bool Client::connect(socket_t sock, Response &res, bool &error) { + error = true; + Response res2; + + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false, connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; + } + + if (res2.status == 407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map<std::string, std::string> auth; + if (parse_www_authenticate(res2, auth, true)) { + Response res3; + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, + bool &connection_close) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(make_digest_authentication_header( + req3, auth, 1, random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, false, + connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; } + } + } else { + res = res2; + return false; + } + } + + return true; +} +#endif + +inline bool Client::redirect(const Request &req, Response &res) { + if (req.redirect_count == 0) { return false; } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + const static std::regex re( + R"(^(?:([^:/?#]+):)?(?://([^/?#]*))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + + std::smatch m; + if (!regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto next_path = m[3].str(); + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + if (next_scheme == scheme && next_host == host_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str()); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); +#else + return false; +#endif + } else { + Client cli(next_host.c_str()); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); + } + } +} + +inline bool Client::write_request(Stream &strm, const Request &req, + bool last_connection) { + detail::BufferStream bstrm; + + // Request line + const auto &path = detail::encode_url(req.path); + + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + + // Additonal headers + Headers headers; + if (last_connection) { headers.emplace("Connection", "close"); } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } else { + if (port_ == 80) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } + + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.5"); + } + + if (req.body.empty()) { + if (req.content_provider) { + auto length = std::to_string(req.content_length); + headers.emplace("Content-Length", length); + } else { + headers.emplace("Content-Length", "0"); + } + } else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); } - return true; + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); + } + } + + if (!basic_auth_username_.empty() && !basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + + detail::write_headers(bstrm, req, headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); + + // Body + if (req.body.empty()) { + if (req.content_provider) { + size_t offset = 0; + size_t end_offset = req.content_length; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + auto written_length = strm.write(d, l); + offset += written_length; + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + while (offset < end_offset) { + req.content_provider(offset, end_offset - offset, data_sink); + } + } + } else { + strm.write(req.body); + } + + return true; } -inline bool Client::read_and_close_socket(socket_t sock, Request& req, Response& res) -{ - return detail::read_and_close_socket( - sock, - 0, - [&](Stream& strm, bool /*last_connection*/, bool& connection_close) { - return process_request(strm, req, res, connection_close); - }); -} +inline std::shared_ptr<Response> Client::send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; -inline std::shared_ptr<Response> Client::Get(const char* path, Progress progress) -{ - return Get(path, Headers(), progress); -} + req.headers.emplace("Content-Type", content_type); -inline std::shared_ptr<Response> Client::Get(const char* path, const Headers& headers, Progress progress) -{ - Request req; - req.method = "GET"; - req.path = path; - req.headers = headers; - req.progress = progress; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + if (content_provider) { + size_t offset = 0; - auto res = std::make_shared<Response>(); + DataSink data_sink; + data_sink.write = [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + offset += data_len; + }; + data_sink.is_writable = [&](void) { return true; }; - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr<Response> Client::Head(const char* path) -{ - return Head(path, Headers()); -} - -inline std::shared_ptr<Response> Client::Head(const char* path, const Headers& headers) -{ - Request req; - req.method = "HEAD"; - req.headers = headers; - req.path = path; - - auto res = std::make_shared<Response>(); - - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr<Response> Client::Post( - const char* path, const std::string& body, const char* content_type) -{ - return Post(path, Headers(), body, content_type); -} - -inline std::shared_ptr<Response> Client::Post( - const char* path, const Headers& headers, const std::string& body, const char* content_type) -{ - Request req; - req.method = "POST"; - req.headers = headers; - req.path = path; - - req.headers.emplace("Content-Type", content_type); - req.body = body; - - auto res = std::make_shared<Response>(); - - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr<Response> Client::Post(const char* path, const Params& params) -{ - return Post(path, Headers(), params); -} - -inline std::shared_ptr<Response> Client::Post(const char* path, const Headers& headers, const Params& params) -{ - std::string query; - for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { - query += "&"; - } - query += it->first; - query += "="; - query += it->second; + while (offset < content_length) { + content_provider(offset, content_length - offset, data_sink); + } + } else { + req.body = body; } - return Post(path, headers, query, "application/x-www-form-urlencoded"); + if (!detail::compress(req.body)) { return nullptr; } + req.headers.emplace("Content-Encoding", "gzip"); + } else +#endif + { + if (content_provider) { + req.content_length = content_length; + req.content_provider = content_provider; + } else { + req.body = body; + } + } + + auto res = std::make_shared<Response>(); + + return send(req, *res) ? res : nullptr; } -inline std::shared_ptr<Response> Client::Put( - const char* path, const std::string& body, const char* content_type) -{ - return Put(path, Headers(), body, content_type); +inline bool Client::process_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + // Send request + if (!write_request(strm, req, last_connection)) { return false; } + + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + return false; + } + + if (res.get_header_value("Connection") == "close" || + res.version == "HTTP/1.0") { + connection_close = true; + } + + if (req.response_handler) { + if (!req.response_handler(res)) { return false; } + } + + // Body + if (req.method != "HEAD" && req.method != "CONNECT") { + ContentReceiver out = [&](const char *buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { return false; } + res.body.append(buf, n); + return true; + }; + + if (req.content_receiver) { + out = [&](const char *buf, size_t n) { + return req.content_receiver(buf, n); + }; + } + + int dummy_status; + if (!detail::read_content(strm, res, std::numeric_limits<size_t>::max(), + dummy_status, req.progress, out)) { + return false; + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; } -inline std::shared_ptr<Response> Client::Put( - const char* path, const Headers& headers, const std::string& body, const char* content_type) -{ - Request req; - req.method = "PUT"; - req.headers = headers; - req.path = path; - - req.headers.emplace("Content-Type", content_type); - req.body = body; - - auto res = std::make_shared<Response>(); - - return send(req, *res) ? res : nullptr; +inline bool Client::process_and_close_socket( + socket_t sock, size_t request_count, + std::function<bool(Stream &strm, bool last_connection, + bool &connection_close)> + callback) { + request_count = std::min(request_count, keep_alive_max_count_); + return detail::process_and_close_socket(true, sock, request_count, + read_timeout_sec_, read_timeout_usec_, + callback); } -inline std::shared_ptr<Response> Client::Delete(const char* path) -{ - return Delete(path, Headers()); +inline bool Client::is_ssl() const { return false; } + +inline std::shared_ptr<Response> Client::Get(const char *path) { + return Get(path, Headers(), Progress()); } -inline std::shared_ptr<Response> Client::Delete(const char* path, const Headers& headers) -{ - Request req; - req.method = "DELETE"; - req.path = path; - req.headers = headers; - - auto res = std::make_shared<Response>(); - - return send(req, *res) ? res : nullptr; +inline std::shared_ptr<Response> Client::Get(const char *path, + Progress progress) { + return Get(path, Headers(), std::move(progress)); } -inline std::shared_ptr<Response> Client::Options(const char* path) -{ - return Options(path, Headers()); +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers) { + return Get(path, headers, Progress()); } -inline std::shared_ptr<Response> Client::Options(const char* path, const Headers& headers) -{ - Request req; - req.method = "OPTIONS"; - req.path = path; - req.headers = headers; +inline std::shared_ptr<Response> +Client::Get(const char *path, const Headers &headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); - auto res = std::make_shared<Response>(); - - return send(req, *res) ? res : nullptr; + auto res = std::make_shared<Response>(); + return send(req, *res) ? res : nullptr; } +inline std::shared_ptr<Response> Client::Get(const char *path, + ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), Progress()); +} + +inline std::shared_ptr<Response> Client::Get(const char *path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), Progress()); +} + +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), content_receiver, + Progress()); +} + +inline std::shared_ptr<Response> Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = std::move(content_receiver); + req.progress = std::move(progress); + + auto res = std::make_shared<Response>(); + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr<Response> Client::Head(const char *path) { + return Head(path, Headers()); +} + +inline std::shared_ptr<Response> Client::Head(const char *path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + auto res = std::make_shared<Response>(); + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr<Response> Client::Post(const char *path, + const std::string &body, + const char *content_type) { + return Post(path, Headers(), body, content_type); +} + +inline std::shared_ptr<Response> Client::Post(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("POST", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr<Response> Client::Post(const char *path, + const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline std::shared_ptr<Response> Client::Post(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Post(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr<Response> +Client::Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("POST", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr<Response> +Client::Post(const char *path, const Headers &headers, const Params ¶ms) { + std::string query; + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += detail::encode_url(it->second); + } + + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline std::shared_ptr<Response> +Client::Post(const char *path, const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline std::shared_ptr<Response> +Client::Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items) { + auto boundary = detail::make_multipart_data_boundary(); + + std::string body; + + for (const auto &item : items) { + body += "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + body += item.content + "\r\n"; + } + + body += "--" + boundary + "--\r\n"; + + std::string content_type = "multipart/form-data; boundary=" + boundary; + return Post(path, headers, body, content_type.c_str()); +} + +inline std::shared_ptr<Response> Client::Put(const char *path, + const std::string &body, + const char *content_type) { + return Put(path, Headers(), body, content_type); +} + +inline std::shared_ptr<Response> Client::Put(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("PUT", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr<Response> Client::Put(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Put(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr<Response> +Client::Put(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("PUT", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr<Response> Client::Put(const char *path, + const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline std::shared_ptr<Response> +Client::Put(const char *path, const Headers &headers, const Params ¶ms) { + std::string query; + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += detail::encode_url(it->second); + } + + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline std::shared_ptr<Response> Client::Patch(const char *path, + const std::string &body, + const char *content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline std::shared_ptr<Response> Client::Patch(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("PATCH", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr<Response> Client::Patch(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Patch(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr<Response> +Client::Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("PATCH", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr<Response> Client::Delete(const char *path) { + return Delete(path, Headers(), std::string(), nullptr); +} + +inline std::shared_ptr<Response> Client::Delete(const char *path, + const std::string &body, + const char *content_type) { + return Delete(path, Headers(), body, content_type); +} + +inline std::shared_ptr<Response> Client::Delete(const char *path, + const Headers &headers) { + return Delete(path, headers, std::string(), nullptr); +} + +inline std::shared_ptr<Response> Client::Delete(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + + if (content_type) { req.headers.emplace("Content-Type", content_type); } + req.body = body; + + auto res = std::make_shared<Response>(); + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr<Response> Client::Options(const char *path) { + return Options(path, Headers()); +} + +inline std::shared_ptr<Response> Client::Options(const char *path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.path = path; + req.headers = headers; + + auto res = std::make_shared<Response>(); + + return send(req, *res) ? res : nullptr; +} + +inline void Client::set_timeout_sec(time_t timeout_sec) { + timeout_sec_ = timeout_sec; +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void Client::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + +inline void Client::set_basic_auth(const char *username, const char *password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const char *username, + const char *password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void Client::set_follow_location(bool on) { follow_location_ = on; } + +inline void Client::set_compress(bool on) { compress_ = on; } + +inline void Client::set_interface(const char *intf) { interface_ = intf; } + +inline void Client::set_proxy(const char *host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void Client::set_proxy_basic_auth(const char *username, + const char *password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const char *username, + const char *password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} +#endif + +inline void Client::set_logger(Logger logger) { logger_ = std::move(logger); } + /* * SSL Implementation */ @@ -2152,193 +4307,446 @@ inline std::shared_ptr<Response> Client::Options(const char* path, const Headers namespace detail { template <typename U, typename V, typename T> -inline bool read_and_close_socket_ssl( - socket_t sock, size_t keep_alive_max_count, - // TODO: OpenSSL 1.0.2 occasionally crashes... - // The upcoming 1.1.0 is going to be thread safe. - SSL_CTX* ctx, std::mutex& ctx_mutex, - U SSL_connect_or_accept, V setup, - T callback) -{ - SSL* ssl = nullptr; - { - std::lock_guard<std::mutex> guard(ctx_mutex); +inline bool process_and_close_socket_ssl( + bool is_client_request, socket_t sock, size_t keep_alive_max_count, + time_t read_timeout_sec, time_t read_timeout_usec, SSL_CTX *ctx, + std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup, T callback) { + assert(keep_alive_max_count > 0); - ssl = SSL_new(ctx); - if (!ssl) { - return false; - } - } + SSL *ssl = nullptr; + { + std::lock_guard<std::mutex> guard(ctx_mutex); + ssl = SSL_new(ctx); + } - auto bio = BIO_new_socket(sock, BIO_NOCLOSE); - SSL_set_bio(ssl, bio, bio); + if (!ssl) { + close_socket(sock); + return false; + } - setup(ssl); - - SSL_connect_or_accept(ssl); - - bool ret = false; - - if (keep_alive_max_count > 0) { - auto count = keep_alive_max_count; - while (count > 0 && - detail::select_read(sock, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { - SSLSocketStream strm(sock, ssl); - auto last_connection = count == 1; - auto connection_close = false; - - ret = callback(strm, last_connection, connection_close); - if (!ret || connection_close) { - break; - } - - count--; - } - } else { - SSLSocketStream strm(sock, ssl); - auto dummy_connection_close = false; - ret = callback(strm, true, dummy_connection_close); - } + auto bio = BIO_new_socket(static_cast<int>(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); + if (!setup(ssl)) { SSL_shutdown(ssl); - { - std::lock_guard<std::mutex> guard(ctx_mutex); - SSL_free(ssl); + std::lock_guard<std::mutex> guard(ctx_mutex); + SSL_free(ssl); } close_socket(sock); + return false; + } - return ret; + auto ret = false; + + if (SSL_connect_or_accept(ssl) == 1) { + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(ssl, strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(ssl, strm, true, dummy_connection_close); + } + } + + SSL_shutdown(ssl); + { + std::lock_guard<std::mutex> guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + + return ret; } +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static std::shared_ptr<std::vector<std::mutex>> openSSL_locks_; + +class SSLThreadLocks { +public: + SSLThreadLocks() { + openSSL_locks_ = + std::make_shared<std::vector<std::mutex>>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } + + ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); } + +private: + static void locking_callback(int mode, int type, const char * /*file*/, + int /*line*/) { + auto &locks = *openSSL_locks_; + if (mode & CRYPTO_LOCK) { + locks[type].lock(); + } else { + locks[type].unlock(); + } + } +}; + +#endif + class SSLInit { public: - SSLInit() { - SSL_load_error_strings(); - SSL_library_init(); - } + SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + SSL_load_error_strings(); + SSL_library_init(); +#else + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); +#endif + } + + ~SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + ERR_free_strings(); +#endif + } + +private: +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSLThreadLocks thread_init_; +#endif }; +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, + time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + +inline SSLSocketStream::~SSLSocketStream() {} + +inline bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SSLSocketStream::is_writable() const { + return detail::select_write(sock_, 0, 0) > 0; +} + +inline int SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0 || + select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) { + return SSL_read(ssl_, ptr, static_cast<int>(size)); + } + return -1; +} + +inline int SSLSocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { return SSL_write(ssl_, ptr, static_cast<int>(size)); } + return -1; +} + +inline std::string SSLSocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); +} + static SSLInit sslinit_; } // namespace detail -// SSL socket stream implementation -inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL* ssl) - : sock_(sock), ssl_(ssl) -{ -} - -inline SSLSocketStream::~SSLSocketStream() -{ -} - -inline int SSLSocketStream::read(char* ptr, size_t size) -{ - return SSL_read(ssl_, ptr, size); -} - -inline int SSLSocketStream::write(const char* ptr, size_t size) -{ - return SSL_write(ssl_, ptr, size); -} - -inline int SSLSocketStream::write(const char* ptr) -{ - return write(ptr, strlen(ptr)); -} - -inline std::string SSLSocketStream::get_remote_addr() { - return detail::get_remote_addr(sock_); -} - // SSL HTTP server implementation -inline SSLServer::SSLServer(const char* cert_path, const char* private_key_path) -{ - ctx_ = SSL_CTX_new(SSLv23_server_method()); +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path) { + ctx_ = SSL_CTX_new(SSLv23_server_method()); - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); - // EC_KEY_free(ecdh); + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); - if (SSL_CTX_use_certificate_file(ctx_, cert_path, SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); } + } } -inline SSLServer::~SSLServer() -{ - if (ctx_) { - SSL_CTX_free(ctx_); - } +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } } -inline bool SSLServer::is_valid() const -{ - return ctx_; -} +inline bool SSLServer::is_valid() const { return ctx_; } -inline bool SSLServer::read_and_close_socket(socket_t sock) -{ - return detail::read_and_close_socket_ssl( - sock, - keep_alive_max_count_, - ctx_, ctx_mutex_, - SSL_accept, - [](SSL* /*ssl*/) {}, - [this](Stream& strm, bool last_connection, bool& connection_close) { - return process_request(strm, last_connection, connection_close); - }); +inline bool SSLServer::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket_ssl( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; }, + [this](SSL *ssl, Stream &strm, bool last_connection, + bool &connection_close) { + return process_request(strm, last_connection, connection_close, + [&](Request &req) { req.ssl = ssl; }); + }); } // SSL HTTP client implementation -inline SSLClient::SSLClient(const char* host, int port, size_t timeout_sec) - : Client(host, port, timeout_sec) -{ - ctx_ = SSL_CTX_new(SSLv23_client_method()); -} +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : Client(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(SSLv23_client_method()); -inline SSLClient::~SSLClient() -{ - if (ctx_) { - SSL_CTX_free(ctx_); + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; } + } } -inline bool SSLClient::is_valid() const -{ - return ctx_; +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } } -inline bool SSLClient::read_and_close_socket(socket_t sock, Request& req, Response& res) -{ - return is_valid() && detail::read_and_close_socket_ssl( - sock, 0, - ctx_, ctx_mutex_, - SSL_connect, - [&](SSL* ssl) { - SSL_set_tlsext_host_name(ssl, host_.c_str()); - }, - [&](Stream& strm, bool /*last_connection*/, bool& connection_close) { - return process_request(strm, req, res, connection_close); - }); +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path) { + if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } + if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } +} + +inline void SSLClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const noexcept { return ctx_; } + +inline bool SSLClient::process_and_close_socket( + socket_t sock, size_t request_count, + std::function<bool(Stream &strm, bool last_connection, + bool &connection_close)> + callback) { + + request_count = std::min(request_count, keep_alive_max_count_); + + return is_valid() && + detail::process_and_close_socket_ssl( + true, sock, request_count, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, + [&](SSL *ssl) { + if (ca_cert_file_path_.empty()) { + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); + } else { + if (!SSL_CTX_load_verify_locations( + ctx_, ca_cert_file_path_.c_str(), nullptr)) { + return false; + } + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + } + + if (SSL_connect(ssl) != 1) { return false; } + + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl); + + if (verify_result_ != X509_V_OK) { return false; } + + auto server_cert = SSL_get_peer_certificate(ssl); + + if (server_cert == nullptr) { return false; } + + if (!verify_host(server_cert)) { + X509_free(server_cert); + return false; + } + X509_free(server_cert); + } + + return true; + }, + [&](SSL *ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }, + [&](SSL * /*ssl*/, Stream &strm, bool last_connection, + bool &connection_close) { + return callback(strm, last_connection, connection_close); + }); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast<const struct stack_st_GENERAL_NAME *>( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (auto i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; + } + } + } + } + + if (dsn_matched || ip_mached) { ret = true; } + } + + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { return check_host_name(name, name_len); } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector<std::string> pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; } #endif +// ---------------------------------------------------------------------------- + } // namespace httplib -#endif - -// vim: et ts=4 sw=4 cin cino={1s ff=unix +#endif // CPPHTTPLIB_HTTPLIB_H \ No newline at end of file diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 064e44f94..423d13630 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -466,9 +466,17 @@ create_target_directory_groups(core) target_link_libraries(core PUBLIC common PRIVATE audio_core network video_core) target_link_libraries(core PUBLIC Boost::boost PRIVATE cryptopp fmt open_source_archives) + if (ENABLE_WEB_SERVICE) - target_compile_definitions(core PRIVATE -DENABLE_WEB_SERVICE) - target_link_libraries(core PRIVATE web_service) + get_directory_property(OPENSSL_LIBS + DIRECTORY ${PROJECT_SOURCE_DIR}/externals/libressl + DEFINITION OPENSSL_LIBS) + + target_compile_definitions(core PRIVATE -DENABLE_WEB_SERVICE -DCPPHTTPLIB_OPENSSL_SUPPORT) + target_link_libraries(core PRIVATE web_service ${OPENSSL_LIBS} httplib lurlparser) + if (ANDROID) + target_link_libraries(core PRIVATE ifaddrs) + endif() endif() if (ARCHITECTURE_x86_64) diff --git a/src/core/hle/service/http_c.cpp b/src/core/hle/service/http_c.cpp index 221f3cc20..f0e805e51 100644 --- a/src/core/hle/service/http_c.cpp +++ b/src/core/hle/service/http_c.cpp @@ -2,8 +2,13 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#include <atomic> +#ifdef ENABLE_WEB_SERVICE +#include <LUrlParser.h> +#endif #include <cryptopp/aes.h> #include <cryptopp/modes.h> +#include "common/assert.h" #include "core/core.h" #include "core/file_sys/archive_ncch.h" #include "core/file_sys/file_backend.h" @@ -48,6 +53,82 @@ const ResultCode ERROR_WRONG_CERT_HANDLE = // 0xD8A0A0C9 const ResultCode ERROR_CERT_ALREADY_SET = // 0xD8A0A03D ResultCode(61, ErrorModule::HTTP, ErrorSummary::InvalidState, ErrorLevel::Permanent); +void Context::MakeRequest() { + ASSERT(state == RequestState::NotStarted); + +#ifdef ENABLE_WEB_SERVICE + LUrlParser::clParseURL parsedUrl = LUrlParser::clParseURL::ParseURL(url); + int port; + std::unique_ptr<httplib::Client> client; + if (parsedUrl.m_Scheme == "http") { + if (!parsedUrl.GetPort(&port)) { + port = 80; + } + // TODO(B3N30): Support for setting timeout + // Figure out what the default timeout on 3DS is + client = std::make_unique<httplib::Client>(parsedUrl.m_Host.c_str(), port); + } else { + if (!parsedUrl.GetPort(&port)) { + port = 443; + } + // TODO(B3N30): Support for setting timeout + // Figure out what the default timeout on 3DS is + + auto ssl_client = std::make_unique<httplib::SSLClient>(parsedUrl.m_Host, port); + SSL_CTX* ctx = ssl_client->ssl_context(); + client = std::move(ssl_client); + + if (auto client_cert = ssl_config.client_cert_ctx.lock()) { + SSL_CTX_use_certificate_ASN1(ctx, client_cert->certificate.size(), + client_cert->certificate.data()); + SSL_CTX_use_PrivateKey_ASN1(EVP_PKEY_RSA, ctx, client_cert->private_key.data(), + client_cert->private_key.size()); + } + + // TODO(B3N30): Check for SSLOptions-Bits and set the verify method accordingly + // https://www.3dbrew.org/wiki/SSL_Services#SSLOpt + // Hack: Since for now RootCerts are not implemented we set the VerifyMode to None. + SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL); + } + + state = RequestState::InProgress; + + static const std::unordered_map<RequestMethod, std::string> request_method_strings{ + {RequestMethod::Get, "GET"}, {RequestMethod::Post, "POST"}, + {RequestMethod::Head, "HEAD"}, {RequestMethod::Put, "PUT"}, + {RequestMethod::Delete, "DELETE"}, {RequestMethod::PostEmpty, "POST"}, + {RequestMethod::PutEmpty, "PUT"}, + }; + + httplib::Request request; + request.method = request_method_strings.at(method); + request.path = url; + // TODO(B3N30): Add post data body + request.progress = [this](u64 current, u64 total) -> bool { + // TODO(B3N30): Is there a state that shows response header are available + current_download_size_bytes = current; + total_download_size_bytes = total; + return true; + }; + + for (const auto& header : headers) { + request.headers.emplace(header.name, header.value); + } + + if (!client->send(request, response)) { + LOG_ERROR(Service_HTTP, "Request failed"); + state = RequestState::TimedOut; + } else { + LOG_DEBUG(Service_HTTP, "Request successful"); + // TODO(B3N30): Verify this state on HW + state = RequestState::ReadyToDownloadContent; + } +#else + LOG_ERROR(Service_HTTP, "Tried to make request but WebServices is not enabled in this build"); + state = RequestState::TimedOut; +#endif +} + void HTTP_C::Initialize(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x1, 1, 4); const u32 shmem_size = rp.Pop<u32>(); @@ -152,7 +233,15 @@ void HTTP_C::BeginRequest(Kernel::HLERequestContext& ctx) { auto itr = contexts.find(context_handle); ASSERT(itr != contexts.end()); - // TODO(B3N30): Make the request + // On a 3DS BeginRequest and BeginRequestAsync will push the Request to a worker queue. + // You can only enqueue 8 requests at the same time. + // trying to enqueue any more will either fail (BeginRequestAsync), or block (BeginRequest) + // Note that you only can have 8 Contexts at a time. So this difference shouldn't matter + // Then there are 3? worker threads that pop the requests from the queue and send them + // For now make every request async in it's own thread. + + itr->second.request_future = + std::async(std::launch::async, &Context::MakeRequest, std::ref(itr->second)); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); rb.Push(RESULT_SUCCESS); @@ -197,7 +286,15 @@ void HTTP_C::BeginRequestAsync(Kernel::HLERequestContext& ctx) { auto itr = contexts.find(context_handle); ASSERT(itr != contexts.end()); - // TODO(B3N30): Make the request + // On a 3DS BeginRequest and BeginRequestAsync will push the Request to a worker queue. + // You can only enqueue 8 requests at the same time. + // trying to enqueue any more will either fail (BeginRequestAsync), or block (BeginRequest) + // Note that you only can have 8 Contexts at a time. So this difference shouldn't matter + // Then there are 3? worker threads that pop the requests from the queue and send them + // For now make every request async in it's own thread. + + itr->second.request_future = + std::async(std::launch::async, &Context::MakeRequest, std::ref(itr->second)); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); rb.Push(RESULT_SUCCESS); @@ -260,7 +357,7 @@ void HTTP_C::CreateContext(Kernel::HLERequestContext& ctx) { return; } - contexts.emplace(++context_counter, Context()); + contexts.try_emplace(++context_counter); contexts[context_counter].url = std::move(url); contexts[context_counter].method = method; contexts[context_counter].state = RequestState::NotStarted; @@ -307,10 +404,9 @@ void HTTP_C::CloseContext(Kernel::HLERequestContext& ctx) { } // TODO(Subv): What happens if you try to close a context that's currently being used? - ASSERT(itr->second.state == RequestState::NotStarted); - // TODO(Subv): Make sure that only the session that created the context can close it. + // Note that this will block if a request is still in progress contexts.erase(itr); session_data->num_http_contexts--; diff --git a/src/core/hle/service/http_c.h b/src/core/hle/service/http_c.h index 91a455820..0742de8a7 100644 --- a/src/core/hle/service/http_c.h +++ b/src/core/hle/service/http_c.h @@ -4,11 +4,18 @@ #pragma once +#include <future> #include <memory> #include <optional> #include <string> #include <unordered_map> #include <vector> +#ifdef ENABLE_WEB_SERVICE +#if defined(__ANDROID__) +#include <ifaddrs.h> +#endif +#include <httplib.h> +#endif #include "core/hle/kernel/shared_memory.h" #include "core/hle/service/service.h" @@ -78,8 +85,7 @@ public: Context(const Context&) = delete; Context& operator=(const Context&) = delete; - Context(Context&& other) = default; - Context& operator=(Context&&) = default; + void MakeRequest(); struct Proxy { std::string url; @@ -116,13 +122,20 @@ public: u32 session_id; std::string url; RequestMethod method; - RequestState state = RequestState::NotStarted; + std::atomic<RequestState> state = RequestState::NotStarted; std::optional<Proxy> proxy; std::optional<BasicAuth> basic_auth; SSLConfig ssl_config{}; u32 socket_buffer_size; std::vector<RequestHeader> headers; std::vector<PostData> post_data; + + std::future<void> request_future; + std::atomic<u64> current_download_size_bytes; + std::atomic<u64> total_download_size_bytes; +#ifdef ENABLE_WEB_SERVICE + httplib::Response response; +#endif }; struct SessionData : public Kernel::SessionRequestHandler::SessionDataBase { diff --git a/src/web_service/CMakeLists.txt b/src/web_service/CMakeLists.txt index c3d42fe8b..5695e25f0 100644 --- a/src/web_service/CMakeLists.txt +++ b/src/web_service/CMakeLists.txt @@ -18,3 +18,6 @@ get_directory_property(OPENSSL_LIBS DEFINITION OPENSSL_LIBS) target_compile_definitions(web_service PRIVATE -DCPPHTTPLIB_OPENSSL_SUPPORT) target_link_libraries(web_service PRIVATE common network json-headers ${OPENSSL_LIBS} httplib lurlparser cpp-jwt) +if (ANDROID) + target_link_libraries(web_service PRIVATE ifaddrs) +endif() diff --git a/src/web_service/web_backend.cpp b/src/web_service/web_backend.cpp index 6683f459f..c047677f9 100644 --- a/src/web_service/web_backend.cpp +++ b/src/web_service/web_backend.cpp @@ -8,6 +8,9 @@ #include <string> #include <LUrlParser.h> #include <fmt/format.h> +#if defined(__ANDROID__) +#include <ifaddrs.h> +#endif #include <httplib.h> #include "common/common_types.h" #include "common/logging/log.h" @@ -73,14 +76,14 @@ struct Client::Impl { if (!parsedUrl.GetPort(&port)) { port = HTTP_PORT; } - cli = std::make_unique<httplib::Client>(parsedUrl.m_Host.c_str(), port, - TIMEOUT_SECONDS); + cli = std::make_unique<httplib::Client>(parsedUrl.m_Host.c_str(), port); + cli->set_timeout_sec(TIMEOUT_SECONDS); } else if (parsedUrl.m_Scheme == "https") { if (!parsedUrl.GetPort(&port)) { port = HTTPS_PORT; } - cli = std::make_unique<httplib::SSLClient>(parsedUrl.m_Host.c_str(), port, - TIMEOUT_SECONDS); + cli = std::make_unique<httplib::SSLClient>(parsedUrl.m_Host.c_str(), port); + cli->set_timeout_sec(TIMEOUT_SECONDS); } else { LOG_ERROR(WebService, "Bad URL scheme {}", parsedUrl.m_Scheme); return Common::WebResult{Common::WebResult::Code::InvalidURL, "Bad URL scheme"}; From 55ec7031ccb2943c2c507620cf4613a86d160670 Mon Sep 17 00:00:00 2001 From: Ben <benediktthomas@gmail.com> Date: Fri, 21 Feb 2020 19:31:32 +0100 Subject: [PATCH 27/41] Core timing 2.0 (#4913) * Core::Timing: Add multiple timer, one for each core * revert clang-format; work on tests for CoreTiming * Kernel:: Add support for multiple cores, asserts in HandleSyncRequest because Thread->status == WaitIPC * Add some TRACE_LOGs * fix tests * make some adjustments to qt-debugger, cheats and gdbstub(probably still broken) * Make ARM_Interface::id private, rework ARM_Interface ctor * ReRename TimingManager to Timing for smaler diff * addressed review comments --- src/citra_qt/debugger/registers.cpp | 17 +- src/citra_qt/debugger/wait_tree.cpp | 15 +- src/core/arm/arm_interface.h | 17 ++ src/core/arm/dynarmic/arm_dynarmic.cpp | 16 +- src/core/arm/dynarmic/arm_dynarmic.h | 3 +- src/core/arm/dyncom/arm_dyncom.cpp | 9 +- src/core/arm/dyncom/arm_dyncom.h | 3 +- .../arm/dyncom/arm_dyncom_interpreter.cpp | 2 +- src/core/arm/skyeye_common/armstate.cpp | 4 +- src/core/cheats/gateway_cheat.cpp | 7 +- src/core/core.cpp | 121 +++++++-- src/core/core.h | 47 +++- src/core/core_timing.cpp | 183 ++++++++------ src/core/core_timing.h | 158 ++++++------ src/core/gdbstub/gdbstub.cpp | 61 +++-- src/core/hle/kernel/handle_table.cpp | 2 +- src/core/hle/kernel/kernel.cpp | 65 ++++- src/core/hle/kernel/kernel.h | 24 +- src/core/hle/kernel/mutex.cpp | 2 +- src/core/hle/kernel/shared_page.cpp | 2 +- src/core/hle/kernel/svc.cpp | 60 +++-- src/core/hle/kernel/thread.cpp | 42 ++-- src/core/hle/kernel/thread.h | 15 +- src/core/hle/service/ldr_ro/cro_helper.cpp | 230 +++++++++--------- src/core/hle/service/ldr_ro/cro_helper.h | 21 +- src/core/hle/service/ldr_ro/ldr_ro.cpp | 16 +- src/core/rpc/rpc_server.cpp | 4 +- src/tests/core/arm/arm_test_common.cpp | 4 +- .../core/arm/dyncom/arm_dyncom_vfp_tests.cpp | 2 +- src/tests/core/core_timing.cpp | 131 ++++------ src/tests/core/hle/kernel/hle_ipc.cpp | 8 +- src/tests/core/memory/memory.cpp | 4 +- 32 files changed, 760 insertions(+), 535 deletions(-) diff --git a/src/citra_qt/debugger/registers.cpp b/src/citra_qt/debugger/registers.cpp index f689708ad..eb4d3d6dc 100644 --- a/src/citra_qt/debugger/registers.cpp +++ b/src/citra_qt/debugger/registers.cpp @@ -61,13 +61,14 @@ void RegistersWidget::OnDebugModeEntered() { if (!Core::System::GetInstance().IsPoweredOn()) return; + // Todo: Handle all cores for (int i = 0; i < core_registers->childCount(); ++i) core_registers->child(i)->setText( - 1, QStringLiteral("0x%1").arg(Core::CPU().GetReg(i), 8, 16, QLatin1Char('0'))); + 1, QStringLiteral("0x%1").arg(Core::GetCore(0).GetReg(i), 8, 16, QLatin1Char('0'))); for (int i = 0; i < vfp_registers->childCount(); ++i) vfp_registers->child(i)->setText( - 1, QStringLiteral("0x%1").arg(Core::CPU().GetVFPReg(i), 8, 16, QLatin1Char('0'))); + 1, QStringLiteral("0x%1").arg(Core::GetCore(0).GetVFPReg(i), 8, 16, QLatin1Char('0'))); UpdateCPSRValues(); UpdateVFPSystemRegisterValues(); @@ -127,7 +128,8 @@ void RegistersWidget::CreateCPSRChildren() { } void RegistersWidget::UpdateCPSRValues() { - const u32 cpsr_val = Core::CPU().GetCPSR(); + // Todo: Handle all cores + const u32 cpsr_val = Core::GetCore(0).GetCPSR(); cpsr->setText(1, QStringLiteral("0x%1").arg(cpsr_val, 8, 16, QLatin1Char('0'))); cpsr->child(0)->setText( @@ -191,10 +193,11 @@ void RegistersWidget::CreateVFPSystemRegisterChildren() { } void RegistersWidget::UpdateVFPSystemRegisterValues() { - const u32 fpscr_val = Core::CPU().GetVFPSystemReg(VFP_FPSCR); - const u32 fpexc_val = Core::CPU().GetVFPSystemReg(VFP_FPEXC); - const u32 fpinst_val = Core::CPU().GetVFPSystemReg(VFP_FPINST); - const u32 fpinst2_val = Core::CPU().GetVFPSystemReg(VFP_FPINST2); + // Todo: handle all cores + const u32 fpscr_val = Core::GetCore(0).GetVFPSystemReg(VFP_FPSCR); + const u32 fpexc_val = Core::GetCore(0).GetVFPSystemReg(VFP_FPEXC); + const u32 fpinst_val = Core::GetCore(0).GetVFPSystemReg(VFP_FPINST); + const u32 fpinst2_val = Core::GetCore(0).GetVFPSystemReg(VFP_FPINST2); QTreeWidgetItem* const fpscr = vfp_system_registers->child(0); fpscr->setText(1, QStringLiteral("0x%1").arg(fpscr_val, 8, 16, QLatin1Char('0'))); diff --git a/src/citra_qt/debugger/wait_tree.cpp b/src/citra_qt/debugger/wait_tree.cpp index 019153576..15a7ada03 100644 --- a/src/citra_qt/debugger/wait_tree.cpp +++ b/src/citra_qt/debugger/wait_tree.cpp @@ -12,6 +12,7 @@ #include "core/hle/kernel/thread.h" #include "core/hle/kernel/timer.h" #include "core/hle/kernel/wait_object.h" +#include "core/settings.h" WaitTreeItem::~WaitTreeItem() = default; @@ -51,12 +52,16 @@ std::size_t WaitTreeItem::Row() const { } std::vector<std::unique_ptr<WaitTreeThread>> WaitTreeItem::MakeThreadItemList() { - const auto& threads = Core::System::GetInstance().Kernel().GetThreadManager().GetThreadList(); + u32 num_cores = Core::GetNumCores(); std::vector<std::unique_ptr<WaitTreeThread>> item_list; - item_list.reserve(threads.size()); - for (std::size_t i = 0; i < threads.size(); ++i) { - item_list.push_back(std::make_unique<WaitTreeThread>(*threads[i])); - item_list.back()->row = i; + for (u32 i = 0; i < num_cores; ++i) { + const auto& threads = + Core::System::GetInstance().Kernel().GetThreadManager(i).GetThreadList(); + item_list.reserve(item_list.size() + threads.size()); + for (std::size_t i = 0; i < threads.size(); ++i) { + item_list.push_back(std::make_unique<WaitTreeThread>(*threads[i])); + item_list.back()->row = i; + } } return item_list; } diff --git a/src/core/arm/arm_interface.h b/src/core/arm/arm_interface.h index 6e6da8626..22443295b 100644 --- a/src/core/arm/arm_interface.h +++ b/src/core/arm/arm_interface.h @@ -9,10 +9,13 @@ #include "common/common_types.h" #include "core/arm/skyeye_common/arm_regformat.h" #include "core/arm/skyeye_common/vfp/asm_vfp.h" +#include "core/core_timing.h" /// Generic ARM11 CPU interface class ARM_Interface : NonCopyable { public: + explicit ARM_Interface(u32 id, std::shared_ptr<Core::Timing::Timer> timer) + : timer(timer), id(id){}; virtual ~ARM_Interface() {} class ThreadContext { @@ -172,4 +175,18 @@ public: /// Prepare core for thread reschedule (if needed to correctly handle state) virtual void PrepareReschedule() = 0; + + std::shared_ptr<Core::Timing::Timer> GetTimer() { + return timer; + } + + u32 GetID() const { + return id; + } + +protected: + std::shared_ptr<Core::Timing::Timer> timer; + +private: + u32 id; }; diff --git a/src/core/arm/dynarmic/arm_dynarmic.cpp b/src/core/arm/dynarmic/arm_dynarmic.cpp index f494b5228..b39a4a24e 100644 --- a/src/core/arm/dynarmic/arm_dynarmic.cpp +++ b/src/core/arm/dynarmic/arm_dynarmic.cpp @@ -72,8 +72,7 @@ private: class DynarmicUserCallbacks final : public Dynarmic::A32::UserCallbacks { public: explicit DynarmicUserCallbacks(ARM_Dynarmic& parent) - : parent(parent), timing(parent.system.CoreTiming()), svc_context(parent.system), - memory(parent.memory) {} + : parent(parent), svc_context(parent.system), memory(parent.memory) {} ~DynarmicUserCallbacks() = default; std::uint8_t MemoryRead8(VAddr vaddr) override { @@ -137,7 +136,7 @@ public: parent.jit->HaltExecution(); parent.SetPC(pc); Kernel::Thread* thread = - parent.system.Kernel().GetThreadManager().GetCurrentThread(); + parent.system.Kernel().GetCurrentThreadManager().GetCurrentThread(); parent.SaveContext(thread->context); GDBStub::Break(); GDBStub::SendTrap(thread, 5); @@ -150,22 +149,23 @@ public: } void AddTicks(std::uint64_t ticks) override { - timing.AddTicks(ticks); + parent.GetTimer()->AddTicks(ticks); } std::uint64_t GetTicksRemaining() override { - s64 ticks = timing.GetDowncount(); + s64 ticks = parent.GetTimer()->GetDowncount(); return static_cast<u64>(ticks <= 0 ? 0 : ticks); } ARM_Dynarmic& parent; - Core::Timing& timing; Kernel::SVCContext svc_context; Memory::MemorySystem& memory; }; ARM_Dynarmic::ARM_Dynarmic(Core::System* system, Memory::MemorySystem& memory, - PrivilegeMode initial_mode) - : system(*system), memory(memory), cb(std::make_unique<DynarmicUserCallbacks>(*this)) { + PrivilegeMode initial_mode, u32 id, + std::shared_ptr<Core::Timing::Timer> timer) + : ARM_Interface(id, timer), system(*system), memory(memory), + cb(std::make_unique<DynarmicUserCallbacks>(*this)) { interpreter_state = std::make_shared<ARMul_State>(system, memory, initial_mode); PageTableChanged(); } diff --git a/src/core/arm/dynarmic/arm_dynarmic.h b/src/core/arm/dynarmic/arm_dynarmic.h index 559dbf5a8..4aaec1bf1 100644 --- a/src/core/arm/dynarmic/arm_dynarmic.h +++ b/src/core/arm/dynarmic/arm_dynarmic.h @@ -24,7 +24,8 @@ class DynarmicUserCallbacks; class ARM_Dynarmic final : public ARM_Interface { public: - ARM_Dynarmic(Core::System* system, Memory::MemorySystem& memory, PrivilegeMode initial_mode); + ARM_Dynarmic(Core::System* system, Memory::MemorySystem& memory, PrivilegeMode initial_mode, + u32 id, std::shared_ptr<Core::Timing::Timer> timer); ~ARM_Dynarmic() override; void Run() override; diff --git a/src/core/arm/dyncom/arm_dyncom.cpp b/src/core/arm/dyncom/arm_dyncom.cpp index d54b0cb95..fa1aa598d 100644 --- a/src/core/arm/dyncom/arm_dyncom.cpp +++ b/src/core/arm/dyncom/arm_dyncom.cpp @@ -69,8 +69,9 @@ private: }; ARM_DynCom::ARM_DynCom(Core::System* system, Memory::MemorySystem& memory, - PrivilegeMode initial_mode) - : system(system) { + PrivilegeMode initial_mode, u32 id, + std::shared_ptr<Core::Timing::Timer> timer) + : ARM_Interface(id, timer), system(system) { state = std::make_unique<ARMul_State>(system, memory, initial_mode); } @@ -78,7 +79,7 @@ ARM_DynCom::~ARM_DynCom() {} void ARM_DynCom::Run() { DEBUG_ASSERT(system != nullptr); - ExecuteInstructions(std::max<s64>(system->CoreTiming().GetDowncount(), 0)); + ExecuteInstructions(std::max<s64>(timer->GetDowncount(), 0)); } void ARM_DynCom::Step() { @@ -150,7 +151,7 @@ void ARM_DynCom::ExecuteInstructions(u64 num_instructions) { state->NumInstrsToExecute = num_instructions; unsigned ticks_executed = InterpreterMainLoop(state.get()); if (system != nullptr) { - system->CoreTiming().AddTicks(ticks_executed); + timer->AddTicks(ticks_executed); } state->ServeBreak(); } diff --git a/src/core/arm/dyncom/arm_dyncom.h b/src/core/arm/dyncom/arm_dyncom.h index 99c6ab460..f5360b307 100644 --- a/src/core/arm/dyncom/arm_dyncom.h +++ b/src/core/arm/dyncom/arm_dyncom.h @@ -21,7 +21,8 @@ class MemorySystem; class ARM_DynCom final : public ARM_Interface { public: explicit ARM_DynCom(Core::System* system, Memory::MemorySystem& memory, - PrivilegeMode initial_mode); + PrivilegeMode initial_mode, u32 id, + std::shared_ptr<Core::Timing::Timer> timer); ~ARM_DynCom() override; void Run() override; diff --git a/src/core/arm/dyncom/arm_dyncom_interpreter.cpp b/src/core/arm/dyncom/arm_dyncom_interpreter.cpp index ba4a15b0c..706e0092b 100644 --- a/src/core/arm/dyncom/arm_dyncom_interpreter.cpp +++ b/src/core/arm/dyncom/arm_dyncom_interpreter.cpp @@ -3865,7 +3865,7 @@ SWI_INST : { if (inst_base->cond == ConditionCode::AL || CondPassed(cpu, inst_base->cond)) { DEBUG_ASSERT(cpu->system != nullptr); swi_inst* const inst_cream = (swi_inst*)inst_base->component; - cpu->system->CoreTiming().AddTicks(num_instrs); + cpu->system->GetRunningCore().GetTimer()->AddTicks(num_instrs); cpu->NumInstrsToExecute = num_instrs >= cpu->NumInstrsToExecute ? 0 : cpu->NumInstrsToExecute - num_instrs; num_instrs = 0; diff --git a/src/core/arm/skyeye_common/armstate.cpp b/src/core/arm/skyeye_common/armstate.cpp index 26520da76..775618a8b 100644 --- a/src/core/arm/skyeye_common/armstate.cpp +++ b/src/core/arm/skyeye_common/armstate.cpp @@ -607,8 +607,8 @@ void ARMul_State::ServeBreak() { } DEBUG_ASSERT(system != nullptr); - Kernel::Thread* thread = system->Kernel().GetThreadManager().GetCurrentThread(); - system->CPU().SaveContext(thread->context); + Kernel::Thread* thread = system->Kernel().GetCurrentThreadManager().GetCurrentThread(); + system->GetRunningCore().SaveContext(thread->context); if (last_bkpt_hit || GDBStub::IsMemoryBreak() || GDBStub::GetCpuStepFlag()) { last_bkpt_hit = false; diff --git a/src/core/cheats/gateway_cheat.cpp b/src/core/cheats/gateway_cheat.cpp index 15a0e8ca9..cba74bb7c 100644 --- a/src/core/cheats/gateway_cheat.cpp +++ b/src/core/cheats/gateway_cheat.cpp @@ -35,7 +35,7 @@ static inline std::enable_if_t<std::is_integral_v<T>> WriteOp(const GatewayCheat Core::System& system) { u32 addr = line.address + state.offset; write_func(addr, static_cast<T>(line.value)); - system.CPU().InvalidateCacheRange(addr, sizeof(T)); + system.InvalidateCacheRange(addr, sizeof(T)); } template <typename T, typename ReadFunction, typename CompareFunc> @@ -105,7 +105,7 @@ static inline std::enable_if_t<std::is_integral_v<T>> IncrementiveWriteOp( Core::System& system) { u32 addr = line.value + state.offset; write_func(addr, static_cast<T>(state.reg)); - system.CPU().InvalidateCacheRange(addr, sizeof(T)); + system.InvalidateCacheRange(addr, sizeof(T)); state.offset += sizeof(T); } @@ -143,7 +143,8 @@ static inline void PatchOp(const GatewayCheat::CheatLine& line, State& state, Co } u32 num_bytes = line.value; u32 addr = line.address + state.offset; - system.CPU().InvalidateCacheRange(addr, num_bytes); + system.InvalidateCacheRange(addr, num_bytes); + bool first = true; u32 bit_offset = 0; if (num_bytes > 0) diff --git a/src/core/core.cpp b/src/core/core.cpp index ebaee4f87..cd1799e42 100644 --- a/src/core/core.cpp +++ b/src/core/core.cpp @@ -44,7 +44,8 @@ namespace Core { System::ResultStatus System::RunLoop(bool tight_loop) { status = ResultStatus::Success; - if (!cpu_core) { + if (std::any_of(cpu_cores.begin(), cpu_cores.end(), + [](std::shared_ptr<ARM_Interface> ptr) { return ptr == nullptr; })) { return ResultStatus::ErrorNotInitialized; } @@ -62,22 +63,73 @@ System::ResultStatus System::RunLoop(bool tight_loop) { } } - // If we don't have a currently active thread then don't execute instructions, - // instead advance to the next event and try to yield to the next thread - if (kernel->GetThreadManager().GetCurrentThread() == nullptr) { - LOG_TRACE(Core_ARM11, "Idling"); - timing->Idle(); - timing->Advance(); - PrepareReschedule(); - } else { - timing->Advance(); - if (tight_loop) { - cpu_core->Run(); - } else { - cpu_core->Step(); + // All cores should have executed the same amount of ticks. If this is not the case an event was + // scheduled with a cycles_into_future smaller then the current downcount. + // So we have to get those cores to the same global time first + u64 global_ticks = timing->GetGlobalTicks(); + s64 max_delay = 0; + std::shared_ptr<ARM_Interface> current_core_to_execute = nullptr; + for (auto& cpu_core : cpu_cores) { + if (cpu_core->GetTimer()->GetTicks() < global_ticks) { + s64 delay = global_ticks - cpu_core->GetTimer()->GetTicks(); + cpu_core->GetTimer()->Advance(delay); + if (max_delay < delay) { + max_delay = delay; + current_core_to_execute = cpu_core; + } } } + if (max_delay > 0) { + LOG_TRACE(Core_ARM11, "Core {} running (delayed) for {} ticks", + current_core_to_execute->GetID(), + current_core_to_execute->GetTimer()->GetDowncount()); + running_core = current_core_to_execute.get(); + kernel->SetRunningCPU(current_core_to_execute); + if (kernel->GetCurrentThreadManager().GetCurrentThread() == nullptr) { + LOG_TRACE(Core_ARM11, "Core {} idling", current_core_to_execute->GetID()); + current_core_to_execute->GetTimer()->Idle(); + PrepareReschedule(); + } else { + if (tight_loop) { + current_core_to_execute->Run(); + } else { + current_core_to_execute->Step(); + } + } + } else { + // Now all cores are at the same global time. So we will run them one after the other + // with a max slice that is the minimum of all max slices of all cores + // TODO: Make special check for idle since we can easily revert the time of idle cores + s64 max_slice = Timing::MAX_SLICE_LENGTH; + for (const auto& cpu_core : cpu_cores) { + max_slice = std::min(max_slice, cpu_core->GetTimer()->GetMaxSliceLength()); + } + for (auto& cpu_core : cpu_cores) { + cpu_core->GetTimer()->Advance(max_slice); + } + for (auto& cpu_core : cpu_cores) { + LOG_TRACE(Core_ARM11, "Core {} running for {} ticks", cpu_core->GetID(), + cpu_core->GetTimer()->GetDowncount()); + running_core = cpu_core.get(); + kernel->SetRunningCPU(cpu_core); + // If we don't have a currently active thread then don't execute instructions, + // instead advance to the next event and try to yield to the next thread + if (kernel->GetCurrentThreadManager().GetCurrentThread() == nullptr) { + LOG_TRACE(Core_ARM11, "Core {} idling", cpu_core->GetID()); + cpu_core->GetTimer()->Idle(); + PrepareReschedule(); + } else { + if (tight_loop) { + cpu_core->Run(); + } else { + cpu_core->Step(); + } + } + } + timing->AddToGlobalTicks(max_slice); + } + if (GDBStub::IsServerEnabled()) { GDBStub::SetCpuStepFlag(false); } @@ -174,7 +226,7 @@ System::ResultStatus System::Load(Frontend::EmuWindow& emu_window, const std::st } void System::PrepareReschedule() { - cpu_core->PrepareReschedule(); + running_core->PrepareReschedule(); reschedule_pending = true; } @@ -188,31 +240,50 @@ void System::Reschedule() { } reschedule_pending = false; - kernel->GetThreadManager().Reschedule(); + for (const auto& core : cpu_cores) { + LOG_TRACE(Core_ARM11, "Reschedule core {}", core->GetID()); + kernel->GetThreadManager(core->GetID()).Reschedule(); + } } System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mode) { LOG_DEBUG(HW_Memory, "initialized OK"); + std::size_t num_cores = 2; + if (Settings::values.is_new_3ds) { + num_cores = 4; + } + memory = std::make_unique<Memory::MemorySystem>(); - timing = std::make_unique<Timing>(); + timing = std::make_unique<Timing>(num_cores); - kernel = std::make_unique<Kernel::KernelSystem>(*memory, *timing, - [this] { PrepareReschedule(); }, system_mode); + kernel = std::make_unique<Kernel::KernelSystem>( + *memory, *timing, [this] { PrepareReschedule(); }, system_mode, num_cores); if (Settings::values.use_cpu_jit) { #ifdef ARCHITECTURE_x86_64 - cpu_core = std::make_shared<ARM_Dynarmic>(this, *memory, USER32MODE); + for (std::size_t i = 0; i < num_cores; ++i) { + cpu_cores.push_back( + std::make_shared<ARM_Dynarmic>(this, *memory, USER32MODE, i, timing->GetTimer(i))); + } #else - cpu_core = std::make_shared<ARM_DynCom>(this, *memory, USER32MODE); + for (std::size_t i = 0; i < num_cores; ++i) { + cpu_cores.push_back( + std::make_shared<ARM_DynCom>(this, *memory, USER32MODE, i, timing->GetTimer(i))); + } LOG_WARNING(Core, "CPU JIT requested, but Dynarmic not available"); #endif } else { - cpu_core = std::make_shared<ARM_DynCom>(this, *memory, USER32MODE); + for (std::size_t i = 0; i < num_cores; ++i) { + cpu_cores.push_back( + std::make_shared<ARM_DynCom>(this, *memory, USER32MODE, i, timing->GetTimer(i))); + } } + running_core = cpu_cores[0].get(); - kernel->SetCPU(cpu_core); + kernel->SetCPUs(cpu_cores); + kernel->SetRunningCPU(cpu_cores[0]); if (Settings::values.enable_dsp_lle) { dsp_core = std::make_unique<AudioCore::DspLle>(*memory, @@ -257,6 +328,8 @@ System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mo LOG_DEBUG(Core, "Initialized OK"); + initalized = true; + return ResultStatus::Success; } @@ -362,7 +435,7 @@ void System::Shutdown() { cheat_engine.reset(); service_manager.reset(); dsp_core.reset(); - cpu_core.reset(); + cpu_cores.clear(); kernel.reset(); timing.reset(); app_loader.reset(); diff --git a/src/core/core.h b/src/core/core.h index 5b7965453..2727ea78c 100644 --- a/src/core/core.h +++ b/src/core/core.h @@ -140,7 +140,10 @@ public: * @returns True if the emulated system is powered on, otherwise false. */ bool IsPoweredOn() const { - return cpu_core != nullptr; + return cpu_cores.size() > 0 && + std::all_of(cpu_cores.begin(), cpu_cores.end(), + [](std::shared_ptr<ARM_Interface> ptr) { return ptr != nullptr; }); + ; } /** @@ -160,8 +163,29 @@ public: * Gets a reference to the emulated CPU. * @returns A reference to the emulated CPU. */ - ARM_Interface& CPU() { - return *cpu_core; + + ARM_Interface& GetRunningCore() { + return *running_core; + }; + + /** + * Gets a reference to the emulated CPU. + * @param core_id The id of the core requested. + * @returns A reference to the emulated CPU. + */ + + ARM_Interface& GetCore(u32 core_id) { + return *cpu_cores[core_id]; + }; + + u32 GetNumCores() const { + return cpu_cores.size(); + } + + void InvalidateCacheRange(u32 start_address, std::size_t length) { + for (const auto& cpu : cpu_cores) { + cpu->InvalidateCacheRange(start_address, length); + } } /** @@ -288,7 +312,8 @@ private: std::unique_ptr<Loader::AppLoader> app_loader; /// ARM11 CPU core - std::shared_ptr<ARM_Interface> cpu_core; + std::vector<std::shared_ptr<ARM_Interface>> cpu_cores; + ARM_Interface* running_core = nullptr; /// DSP core std::unique_ptr<AudioCore::DspInterface> dsp_core; @@ -330,6 +355,8 @@ private: private: static System s_instance; + bool initalized = false; + ResultStatus status = ResultStatus::Success; std::string status_details = ""; /// Saved variables for reset @@ -340,8 +367,16 @@ private: std::atomic<bool> shutdown_requested; }; -inline ARM_Interface& CPU() { - return System::GetInstance().CPU(); +inline ARM_Interface& GetRunningCore() { + return System::GetInstance().GetRunningCore(); +} + +inline ARM_Interface& GetCore(u32 core_id) { + return System::GetInstance().GetCore(core_id); +} + +inline u32 GetNumCores() { + return System::GetInstance().GetNumCores(); } inline AudioCore::DspInterface& DSP() { diff --git a/src/core/core_timing.cpp b/src/core/core_timing.cpp index df355ab27..8966bc55b 100644 --- a/src/core/core_timing.cpp +++ b/src/core/core_timing.cpp @@ -12,14 +12,22 @@ namespace Core { // Sort by time, unless the times are the same, in which case sort by the order added to the queue -bool Timing::Event::operator>(const Event& right) const { +bool Timing::Event::operator>(const Timing::Event& right) const { return std::tie(time, fifo_order) > std::tie(right.time, right.fifo_order); } -bool Timing::Event::operator<(const Event& right) const { +bool Timing::Event::operator<(const Timing::Event& right) const { return std::tie(time, fifo_order) < std::tie(right.time, right.fifo_order); } +Timing::Timing(std::size_t num_cores) { + timers.resize(num_cores); + for (std::size_t i = 0; i < num_cores; ++i) { + timers[i] = std::make_shared<Timer>(); + } + current_timer = timers[0]; +} + TimingEventType* Timing::RegisterEvent(const std::string& name, TimedCallback callback) { // check for existing type with same name. // we want event type names to remain unique so that we can use them for serialization. @@ -34,73 +42,102 @@ TimingEventType* Timing::RegisterEvent(const std::string& name, TimedCallback ca return event_type; } -Timing::~Timing() { +void Timing::ScheduleEvent(s64 cycles_into_future, const TimingEventType* event_type, u64 userdata, + std::size_t core_id) { + ASSERT(event_type != nullptr); + std::shared_ptr<Timing::Timer> timer; + if (core_id == std::numeric_limits<std::size_t>::max()) { + timer = current_timer; + } else { + ASSERT(core_id < timers.size()); + timer = timers.at(core_id); + } + + s64 timeout = timer->GetTicks() + cycles_into_future; + if (current_timer == timer) { + // If this event needs to be scheduled before the next advance(), force one early + if (!timer->is_timer_sane) + timer->ForceExceptionCheck(cycles_into_future); + + timer->event_queue.emplace_back( + Event{timeout, timer->event_fifo_id++, userdata, event_type}); + std::push_heap(timer->event_queue.begin(), timer->event_queue.end(), std::greater<>()); + } else { + timer->ts_queue.Push(Event{static_cast<s64>(timer->GetTicks() + cycles_into_future), 0, + userdata, event_type}); + } +} + +void Timing::UnscheduleEvent(const TimingEventType* event_type, u64 userdata) { + for (auto timer : timers) { + auto itr = std::remove_if( + timer->event_queue.begin(), timer->event_queue.end(), + [&](const Event& e) { return e.type == event_type && e.userdata == userdata; }); + + // Removing random items breaks the invariant so we have to re-establish it. + if (itr != timer->event_queue.end()) { + timer->event_queue.erase(itr, timer->event_queue.end()); + std::make_heap(timer->event_queue.begin(), timer->event_queue.end(), std::greater<>()); + } + } + // TODO:remove events from ts_queue +} + +void Timing::RemoveEvent(const TimingEventType* event_type) { + for (auto timer : timers) { + auto itr = std::remove_if(timer->event_queue.begin(), timer->event_queue.end(), + [&](const Event& e) { return e.type == event_type; }); + + // Removing random items breaks the invariant so we have to re-establish it. + if (itr != timer->event_queue.end()) { + timer->event_queue.erase(itr, timer->event_queue.end()); + std::make_heap(timer->event_queue.begin(), timer->event_queue.end(), std::greater<>()); + } + } + // TODO:remove events from ts_queue +} + +void Timing::SetCurrentTimer(std::size_t core_id) { + current_timer = timers[core_id]; +} + +s64 Timing::GetTicks() const { + return current_timer->GetTicks(); +} + +s64 Timing::GetGlobalTicks() const { + return global_timer; +} + +std::chrono::microseconds Timing::GetGlobalTimeUs() const { + return std::chrono::microseconds{GetTicks() * 1000000 / BASE_CLOCK_RATE_ARM11}; +} + +std::shared_ptr<Timing::Timer> Timing::GetTimer(std::size_t cpu_id) { + return timers[cpu_id]; +} + +Timing::Timer::~Timer() { MoveEvents(); } -u64 Timing::GetTicks() const { - u64 ticks = static_cast<u64>(global_timer); - if (!is_global_timer_sane) { +u64 Timing::Timer::GetTicks() const { + u64 ticks = static_cast<u64>(executed_ticks); + if (!is_timer_sane) { ticks += slice_length - downcount; } return ticks; } -void Timing::AddTicks(u64 ticks) { +void Timing::Timer::AddTicks(u64 ticks) { downcount -= ticks; } -u64 Timing::GetIdleTicks() const { +u64 Timing::Timer::GetIdleTicks() const { return static_cast<u64>(idled_cycles); } -void Timing::ScheduleEvent(s64 cycles_into_future, const TimingEventType* event_type, - u64 userdata) { - ASSERT(event_type != nullptr); - s64 timeout = GetTicks() + cycles_into_future; - - // If this event needs to be scheduled before the next advance(), force one early - if (!is_global_timer_sane) - ForceExceptionCheck(cycles_into_future); - - event_queue.emplace_back(Event{timeout, event_fifo_id++, userdata, event_type}); - std::push_heap(event_queue.begin(), event_queue.end(), std::greater<>()); -} - -void Timing::ScheduleEventThreadsafe(s64 cycles_into_future, const TimingEventType* event_type, - u64 userdata) { - ts_queue.Push(Event{global_timer + cycles_into_future, 0, userdata, event_type}); -} - -void Timing::UnscheduleEvent(const TimingEventType* event_type, u64 userdata) { - auto itr = std::remove_if(event_queue.begin(), event_queue.end(), [&](const Event& e) { - return e.type == event_type && e.userdata == userdata; - }); - - // Removing random items breaks the invariant so we have to re-establish it. - if (itr != event_queue.end()) { - event_queue.erase(itr, event_queue.end()); - std::make_heap(event_queue.begin(), event_queue.end(), std::greater<>()); - } -} - -void Timing::RemoveEvent(const TimingEventType* event_type) { - auto itr = std::remove_if(event_queue.begin(), event_queue.end(), - [&](const Event& e) { return e.type == event_type; }); - - // Removing random items breaks the invariant so we have to re-establish it. - if (itr != event_queue.end()) { - event_queue.erase(itr, event_queue.end()); - std::make_heap(event_queue.begin(), event_queue.end(), std::greater<>()); - } -} - -void Timing::RemoveNormalAndThreadsafeEvent(const TimingEventType* event_type) { - MoveEvents(); - RemoveEvent(event_type); -} - -void Timing::ForceExceptionCheck(s64 cycles) { +void Timing::Timer::ForceExceptionCheck(s64 cycles) { cycles = std::max<s64>(0, cycles); if (downcount > cycles) { slice_length -= downcount - cycles; @@ -108,7 +145,7 @@ void Timing::ForceExceptionCheck(s64 cycles) { } } -void Timing::MoveEvents() { +void Timing::Timer::MoveEvents() { for (Event ev; ts_queue.Pop(ev);) { ev.fifo_order = event_fifo_id++; event_queue.emplace_back(std::move(ev)); @@ -116,43 +153,49 @@ void Timing::MoveEvents() { } } -void Timing::Advance() { +s64 Timing::Timer::GetMaxSliceLength() const { + auto next_event = std::find_if(event_queue.begin(), event_queue.end(), + [&](const Event& e) { return e.time - executed_ticks > 0; }); + if (next_event != event_queue.end()) { + return next_event->time - executed_ticks; + } + return MAX_SLICE_LENGTH; +} + +void Timing::Timer::Advance(s64 max_slice_length) { MoveEvents(); s64 cycles_executed = slice_length - downcount; - global_timer += cycles_executed; - slice_length = MAX_SLICE_LENGTH; + idled_cycles = 0; + executed_ticks += cycles_executed; + slice_length = max_slice_length; - is_global_timer_sane = true; + is_timer_sane = true; - while (!event_queue.empty() && event_queue.front().time <= global_timer) { + while (!event_queue.empty() && event_queue.front().time <= executed_ticks) { Event evt = std::move(event_queue.front()); std::pop_heap(event_queue.begin(), event_queue.end(), std::greater<>()); event_queue.pop_back(); - evt.type->callback(evt.userdata, global_timer - evt.time); + evt.type->callback(evt.userdata, executed_ticks - evt.time); } - is_global_timer_sane = false; + is_timer_sane = false; // Still events left (scheduled in the future) if (!event_queue.empty()) { slice_length = static_cast<int>( - std::min<s64>(event_queue.front().time - global_timer, MAX_SLICE_LENGTH)); + std::min<s64>(event_queue.front().time - executed_ticks, max_slice_length)); } downcount = slice_length; } -void Timing::Idle() { +void Timing::Timer::Idle() { idled_cycles += downcount; downcount = 0; } -std::chrono::microseconds Timing::GetGlobalTimeUs() const { - return std::chrono::microseconds{GetTicks() * 1000000 / BASE_CLOCK_RATE_ARM11}; -} - -s64 Timing::GetDowncount() const { +s64 Timing::Timer::GetDowncount() const { return downcount; } diff --git a/src/core/core_timing.h b/src/core/core_timing.h index 229fc37f4..30c1106bb 100644 --- a/src/core/core_timing.h +++ b/src/core/core_timing.h @@ -134,62 +134,6 @@ struct TimingEventType { class Timing { public: - ~Timing(); - - /** - * This should only be called from the emu thread, if you are calling it any other thread, you - * are doing something evil - */ - u64 GetTicks() const; - u64 GetIdleTicks() const; - void AddTicks(u64 ticks); - - /** - * Returns the event_type identifier. if name is not unique, it will assert. - */ - TimingEventType* RegisterEvent(const std::string& name, TimedCallback callback); - - /** - * After the first Advance, the slice lengths and the downcount will be reduced whenever an - * event is scheduled earlier than the current values. Scheduling from a callback will not - * update the downcount until the Advance() completes. - */ - void ScheduleEvent(s64 cycles_into_future, const TimingEventType* event_type, u64 userdata = 0); - - /** - * This is to be called when outside of hle threads, such as the graphics thread, wants to - * schedule things to be executed on the main thread. - * Not that this doesn't change slice_length and thus events scheduled by this might be called - * with a delay of up to MAX_SLICE_LENGTH - */ - void ScheduleEventThreadsafe(s64 cycles_into_future, const TimingEventType* event_type, - u64 userdata); - - void UnscheduleEvent(const TimingEventType* event_type, u64 userdata); - - /// We only permit one event of each type in the queue at a time. - void RemoveEvent(const TimingEventType* event_type); - void RemoveNormalAndThreadsafeEvent(const TimingEventType* event_type); - - /** Advance must be called at the beginning of dispatcher loops, not the end. Advance() ends - * the previous timing slice and begins the next one, you must Advance from the previous - * slice to the current one before executing any cycles. CoreTiming starts in slice -1 so an - * Advance() is required to initialize the slice length before the first cycle of emulated - * instructions is executed. - */ - void Advance(); - void MoveEvents(); - - /// Pretend that the main CPU has executed enough cycles to reach the next event. - void Idle(); - - void ForceExceptionCheck(s64 cycles); - - std::chrono::microseconds GetGlobalTimeUs() const; - - s64 GetDowncount() const; - -private: struct Event { s64 time; u64 fifo_order; @@ -202,33 +146,93 @@ private: static constexpr int MAX_SLICE_LENGTH = 20000; + class Timer { + public: + ~Timer(); + + s64 GetMaxSliceLength() const; + + void Advance(s64 max_slice_length = MAX_SLICE_LENGTH); + + void Idle(); + + u64 GetTicks() const; + u64 GetIdleTicks() const; + + void AddTicks(u64 ticks); + + s64 GetDowncount() const; + + void ForceExceptionCheck(s64 cycles); + + void MoveEvents(); + + private: + friend class Timing; + // The queue is a min-heap using std::make_heap/push_heap/pop_heap. + // We don't use std::priority_queue because we need to be able to serialize, unserialize and + // erase arbitrary events (RemoveEvent()) regardless of the queue order. These aren't + // accomodated by the standard adaptor class. + std::vector<Event> event_queue; + u64 event_fifo_id = 0; + // the queue for storing the events from other threads threadsafe until they will be added + // to the event_queue by the emu thread + Common::MPSCQueue<Event> ts_queue; + // Are we in a function that has been called from Advance() + // If events are sheduled from a function that gets called from Advance(), + // don't change slice_length and downcount. + // The time between CoreTiming being intialized and the first call to Advance() is + // considered the slice boundary between slice -1 and slice 0. Dispatcher loops must call + // Advance() before executing the first cycle of each slice to prepare the slice length and + // downcount for that slice. + bool is_timer_sane = true; + + s64 slice_length = MAX_SLICE_LENGTH; + s64 downcount = MAX_SLICE_LENGTH; + s64 executed_ticks = 0; + u64 idled_cycles; + }; + + explicit Timing(std::size_t num_cores); + + ~Timing(){}; + + /** + * Returns the event_type identifier. if name is not unique, it will assert. + */ + TimingEventType* RegisterEvent(const std::string& name, TimedCallback callback); + + void ScheduleEvent(s64 cycles_into_future, const TimingEventType* event_type, u64 userdata = 0, + std::size_t core_id = std::numeric_limits<std::size_t>::max()); + + void UnscheduleEvent(const TimingEventType* event_type, u64 userdata); + + /// We only permit one event of each type in the queue at a time. + void RemoveEvent(const TimingEventType* event_type); + + void SetCurrentTimer(std::size_t core_id); + + s64 GetTicks() const; + + s64 GetGlobalTicks() const; + + void AddToGlobalTicks(s64 ticks) { + global_timer += ticks; + } + + std::chrono::microseconds GetGlobalTimeUs() const; + + std::shared_ptr<Timer> GetTimer(std::size_t cpu_id); + +private: s64 global_timer = 0; - s64 slice_length = MAX_SLICE_LENGTH; - s64 downcount = MAX_SLICE_LENGTH; // unordered_map stores each element separately as a linked list node so pointers to // elements remain stable regardless of rehashes/resizing. std::unordered_map<std::string, TimingEventType> event_types; - // The queue is a min-heap using std::make_heap/push_heap/pop_heap. - // We don't use std::priority_queue because we need to be able to serialize, unserialize and - // erase arbitrary events (RemoveEvent()) regardless of the queue order. These aren't - // accomodated by the standard adaptor class. - std::vector<Event> event_queue; - u64 event_fifo_id = 0; - // the queue for storing the events from other threads threadsafe until they will be added - // to the event_queue by the emu thread - Common::MPSCQueue<Event> ts_queue; - s64 idled_cycles = 0; - - // Are we in a function that has been called from Advance() - // If events are sheduled from a function that gets called from Advance(), - // don't change slice_length and downcount. - // The time between CoreTiming being intialized and the first call to Advance() is considered - // the slice boundary between slice -1 and slice 0. Dispatcher loops must call Advance() before - // executing the first cycle of each slice to prepare the slice length and downcount for - // that slice. - bool is_global_timer_sane = true; + std::vector<std::shared_ptr<Timer>> timers; + std::shared_ptr<Timer> current_timer; }; } // namespace Core diff --git a/src/core/gdbstub/gdbstub.cpp b/src/core/gdbstub/gdbstub.cpp index 7f722ab0f..a7ed44aff 100644 --- a/src/core/gdbstub/gdbstub.cpp +++ b/src/core/gdbstub/gdbstub.cpp @@ -160,10 +160,14 @@ BreakpointMap breakpoints_write; } // Anonymous namespace static Kernel::Thread* FindThreadById(int id) { - const auto& threads = Core::System::GetInstance().Kernel().GetThreadManager().GetThreadList(); - for (auto& thread : threads) { - if (thread->GetThreadId() == static_cast<u32>(id)) { - return thread.get(); + u32 num_cores = Core::GetNumCores(); + for (u32 i = 0; i < num_cores; ++i) { + const auto& threads = + Core::System::GetInstance().Kernel().GetThreadManager(i).GetThreadList(); + for (auto& thread : threads) { + if (thread->GetThreadId() == static_cast<u32>(id)) { + return thread.get(); + } } } return nullptr; @@ -414,7 +418,10 @@ static void RemoveBreakpoint(BreakpointType type, VAddr addr) { Core::System::GetInstance().Memory().WriteBlock( *Core::System::GetInstance().Kernel().GetCurrentProcess(), bp->second.addr, bp->second.inst.data(), bp->second.inst.size()); - Core::CPU().ClearInstructionCache(); + u32 num_cores = Core::GetNumCores(); + for (u32 i = 0; i < num_cores; ++i) { + Core::GetCore(i).ClearInstructionCache(); + } } p.erase(addr); } @@ -540,10 +547,13 @@ static void HandleQuery() { SendReply(target_xml); } else if (strncmp(query, "fThreadInfo", strlen("fThreadInfo")) == 0) { std::string val = "m"; - const auto& threads = - Core::System::GetInstance().Kernel().GetThreadManager().GetThreadList(); - for (const auto& thread : threads) { - val += fmt::format("{:x},", thread->GetThreadId()); + u32 num_cores = Core::GetNumCores(); + for (u32 i = 0; i < num_cores; ++i) { + const auto& threads = + Core::System::GetInstance().Kernel().GetThreadManager(i).GetThreadList(); + for (const auto& thread : threads) { + val += fmt::format("{:x},", thread->GetThreadId()); + } } val.pop_back(); SendReply(val.c_str()); @@ -553,11 +563,14 @@ static void HandleQuery() { std::string buffer; buffer += "l<?xml version=\"1.0\"?>"; buffer += "<threads>"; - const auto& threads = - Core::System::GetInstance().Kernel().GetThreadManager().GetThreadList(); - for (const auto& thread : threads) { - buffer += fmt::format(R"*(<thread id="{:x}" name="Thread {:x}"></thread>)*", - thread->GetThreadId(), thread->GetThreadId()); + u32 num_cores = Core::GetNumCores(); + for (u32 i = 0; i < num_cores; ++i) { + const auto& threads = + Core::System::GetInstance().Kernel().GetThreadManager(i).GetThreadList(); + for (const auto& thread : threads) { + buffer += fmt::format(R"*(<thread id="{:x}" name="Thread {:x}"></thread>)*", + thread->GetThreadId(), thread->GetThreadId()); + } } buffer += "</threads>"; SendReply(buffer.c_str()); @@ -619,9 +632,9 @@ static void SendSignal(Kernel::Thread* thread, u32 signal, bool full = true) { if (full) { buffer = fmt::format("T{:02x}{:02x}:{:08x};{:02x}:{:08x};{:02x}:{:08x}", latest_signal, - PC_REGISTER, htonl(Core::CPU().GetPC()), SP_REGISTER, - htonl(Core::CPU().GetReg(SP_REGISTER)), LR_REGISTER, - htonl(Core::CPU().GetReg(LR_REGISTER))); + PC_REGISTER, htonl(Core::GetRunningCore().GetPC()), SP_REGISTER, + htonl(Core::GetRunningCore().GetReg(SP_REGISTER)), LR_REGISTER, + htonl(Core::GetRunningCore().GetReg(LR_REGISTER))); } else { buffer = fmt::format("T{:02x}", latest_signal); } @@ -782,7 +795,7 @@ static void WriteRegister() { return SendReply("E01"); } - Core::CPU().LoadContext(current_thread->context); + Core::GetRunningCore().LoadContext(current_thread->context); SendReply("OK"); } @@ -812,7 +825,7 @@ static void WriteRegisters() { } } - Core::CPU().LoadContext(current_thread->context); + Core::GetRunningCore().LoadContext(current_thread->context); SendReply("OK"); } @@ -869,7 +882,7 @@ static void WriteMemory() { GdbHexToMem(data.data(), len_pos + 1, len); Core::System::GetInstance().Memory().WriteBlock( *Core::System::GetInstance().Kernel().GetCurrentProcess(), addr, data.data(), len); - Core::CPU().ClearInstructionCache(); + Core::GetRunningCore().ClearInstructionCache(); SendReply("OK"); } @@ -883,12 +896,12 @@ void Break(bool is_memory_break) { static void Step() { if (command_length > 1) { RegWrite(PC_REGISTER, GdbHexToInt(command_buffer + 1), current_thread); - Core::CPU().LoadContext(current_thread->context); + Core::GetRunningCore().LoadContext(current_thread->context); } step_loop = true; halt_loop = true; send_trap = true; - Core::CPU().ClearInstructionCache(); + Core::GetRunningCore().ClearInstructionCache(); } bool IsMemoryBreak() { @@ -904,7 +917,7 @@ static void Continue() { memory_break = false; step_loop = false; halt_loop = false; - Core::CPU().ClearInstructionCache(); + Core::GetRunningCore().ClearInstructionCache(); } /** @@ -930,7 +943,7 @@ static bool CommitBreakpoint(BreakpointType type, VAddr addr, u32 len) { Core::System::GetInstance().Memory().WriteBlock( *Core::System::GetInstance().Kernel().GetCurrentProcess(), addr, btrap.data(), btrap.size()); - Core::CPU().ClearInstructionCache(); + Core::GetRunningCore().ClearInstructionCache(); } p.insert({addr, breakpoint}); diff --git a/src/core/hle/kernel/handle_table.cpp b/src/core/hle/kernel/handle_table.cpp index 71e18eb7c..d717c8399 100644 --- a/src/core/hle/kernel/handle_table.cpp +++ b/src/core/hle/kernel/handle_table.cpp @@ -83,7 +83,7 @@ bool HandleTable::IsValid(Handle handle) const { std::shared_ptr<Object> HandleTable::GetGeneric(Handle handle) const { if (handle == CurrentThread) { - return SharedFrom(kernel.GetThreadManager().GetCurrentThread()); + return SharedFrom(kernel.GetCurrentThreadManager().GetCurrentThread()); } else if (handle == CurrentProcess) { return kernel.GetCurrentProcess(); } diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp index ceb2f14f5..c0b6f8308 100644 --- a/src/core/hle/kernel/kernel.cpp +++ b/src/core/hle/kernel/kernel.cpp @@ -18,19 +18,27 @@ namespace Kernel { /// Initialize the kernel KernelSystem::KernelSystem(Memory::MemorySystem& memory, Core::Timing& timing, - std::function<void()> prepare_reschedule_callback, u32 system_mode) + std::function<void()> prepare_reschedule_callback, u32 system_mode, + u32 num_cores) : memory(memory), timing(timing), prepare_reschedule_callback(std::move(prepare_reschedule_callback)) { MemoryInit(system_mode); resource_limits = std::make_unique<ResourceLimitList>(*this); - thread_manager = std::make_unique<ThreadManager>(*this); + for (u32 core_id = 0; core_id < num_cores; ++core_id) { + thread_managers.push_back(std::make_unique<ThreadManager>(*this, core_id)); + } timer_manager = std::make_unique<TimerManager>(timing); ipc_recorder = std::make_unique<IPCDebugger::Recorder>(); + stored_processes.assign(num_cores, nullptr); + + next_thread_id = 1; } /// Shutdown the kernel -KernelSystem::~KernelSystem() = default; +KernelSystem::~KernelSystem() { + ResetThreadIDs(); +}; ResourceLimitList& KernelSystem::ResourceLimit() { return *resource_limits; @@ -53,6 +61,15 @@ void KernelSystem::SetCurrentProcess(std::shared_ptr<Process> process) { SetCurrentMemoryPageTable(&process->vm_manager.page_table); } +void KernelSystem::SetCurrentProcessForCPU(std::shared_ptr<Process> process, u32 core_id) { + if (current_cpu->GetID() == core_id) { + current_process = process; + SetCurrentMemoryPageTable(&process->vm_manager.page_table); + } else { + stored_processes[core_id] = process; + } +} + void KernelSystem::SetCurrentMemoryPageTable(Memory::PageTable* page_table) { memory.SetCurrentPageTable(page_table); if (current_cpu != nullptr) { @@ -60,17 +77,39 @@ void KernelSystem::SetCurrentMemoryPageTable(Memory::PageTable* page_table) { } } -void KernelSystem::SetCPU(std::shared_ptr<ARM_Interface> cpu) { +void KernelSystem::SetCPUs(std::vector<std::shared_ptr<ARM_Interface>> cpus) { + ASSERT(cpus.size() == thread_managers.size()); + u32 i = 0; + for (const auto& cpu : cpus) { + thread_managers[i++]->SetCPU(*cpu); + } +} + +void KernelSystem::SetRunningCPU(std::shared_ptr<ARM_Interface> cpu) { + if (current_process) { + stored_processes[current_cpu->GetID()] = current_process; + } current_cpu = cpu; - thread_manager->SetCPU(*cpu); + timing.SetCurrentTimer(cpu->GetID()); + if (stored_processes[current_cpu->GetID()]) { + SetCurrentProcess(stored_processes[current_cpu->GetID()]); + } } -ThreadManager& KernelSystem::GetThreadManager() { - return *thread_manager; +ThreadManager& KernelSystem::GetThreadManager(u32 core_id) { + return *thread_managers[core_id]; } -const ThreadManager& KernelSystem::GetThreadManager() const { - return *thread_manager; +const ThreadManager& KernelSystem::GetThreadManager(u32 core_id) const { + return *thread_managers[core_id]; +} + +ThreadManager& KernelSystem::GetCurrentThreadManager() { + return *thread_managers[current_cpu->GetID()]; +} + +const ThreadManager& KernelSystem::GetCurrentThreadManager() const { + return *thread_managers[current_cpu->GetID()]; } TimerManager& KernelSystem::GetTimerManager() { @@ -101,4 +140,12 @@ void KernelSystem::AddNamedPort(std::string name, std::shared_ptr<ClientPort> po named_ports.emplace(std::move(name), std::move(port)); } +u32 KernelSystem::NewThreadId() { + return next_thread_id++; +} + +void KernelSystem::ResetThreadIDs() { + next_thread_id = 0; +} + } // namespace Kernel diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h index 58f63938b..fd68cbf6d 100644 --- a/src/core/hle/kernel/kernel.h +++ b/src/core/hle/kernel/kernel.h @@ -85,7 +85,8 @@ enum class MemoryRegion : u16 { class KernelSystem { public: explicit KernelSystem(Memory::MemorySystem& memory, Core::Timing& timing, - std::function<void()> prepare_reschedule_callback, u32 system_mode); + std::function<void()> prepare_reschedule_callback, u32 system_mode, + u32 num_cores); ~KernelSystem(); using PortPair = std::pair<std::shared_ptr<ServerPort>, std::shared_ptr<ClientPort>>; @@ -210,13 +211,19 @@ public: std::shared_ptr<Process> GetCurrentProcess() const; void SetCurrentProcess(std::shared_ptr<Process> process); + void SetCurrentProcessForCPU(std::shared_ptr<Process> process, u32 core_id); void SetCurrentMemoryPageTable(Memory::PageTable* page_table); - void SetCPU(std::shared_ptr<ARM_Interface> cpu); + void SetCPUs(std::vector<std::shared_ptr<ARM_Interface>> cpu); - ThreadManager& GetThreadManager(); - const ThreadManager& GetThreadManager() const; + void SetRunningCPU(std::shared_ptr<ARM_Interface> cpu); + + ThreadManager& GetThreadManager(u32 core_id); + const ThreadManager& GetThreadManager(u32 core_id) const; + + ThreadManager& GetCurrentThreadManager(); + const ThreadManager& GetCurrentThreadManager() const; TimerManager& GetTimerManager(); const TimerManager& GetTimerManager() const; @@ -242,6 +249,10 @@ public: prepare_reschedule_callback(); } + u32 NewThreadId(); + + void ResetThreadIDs(); + /// Map of named ports managed by the kernel, which can be retrieved using the ConnectToPort std::unordered_map<std::string, std::shared_ptr<ClientPort>> named_ports; @@ -276,13 +287,16 @@ private: std::vector<std::shared_ptr<Process>> process_list; std::shared_ptr<Process> current_process; + std::vector<std::shared_ptr<Process>> stored_processes; - std::unique_ptr<ThreadManager> thread_manager; + std::vector<std::unique_ptr<ThreadManager>> thread_managers; std::unique_ptr<ConfigMem::Handler> config_mem_handler; std::unique_ptr<SharedPage::Handler> shared_page_handler; std::unique_ptr<IPCDebugger::Recorder> ipc_recorder; + + u32 next_thread_id; }; } // namespace Kernel diff --git a/src/core/hle/kernel/mutex.cpp b/src/core/hle/kernel/mutex.cpp index 467b1ae1e..30dd1eb55 100644 --- a/src/core/hle/kernel/mutex.cpp +++ b/src/core/hle/kernel/mutex.cpp @@ -35,7 +35,7 @@ std::shared_ptr<Mutex> KernelSystem::CreateMutex(bool initial_locked, std::strin // Acquire mutex with current thread if initialized as locked if (initial_locked) - mutex->Acquire(thread_manager->GetCurrentThread()); + mutex->Acquire(thread_managers[current_cpu->GetID()]->GetCurrentThread()); return mutex; } diff --git a/src/core/hle/kernel/shared_page.cpp b/src/core/hle/kernel/shared_page.cpp index 7067aace8..30de0ca1f 100644 --- a/src/core/hle/kernel/shared_page.cpp +++ b/src/core/hle/kernel/shared_page.cpp @@ -56,7 +56,7 @@ Handler::Handler(Core::Timing& timing) : timing(timing) { using namespace std::placeholders; update_time_event = timing.RegisterEvent("SharedPage::UpdateTimeCallback", std::bind(&Handler::UpdateTimeCallback, this, _1, _2)); - timing.ScheduleEvent(0, update_time_event); + timing.ScheduleEvent(0, update_time_event, 0, 0); float slidestate = Settings::values.factor_3d / 100.0f; shared_page.sliderstate_3d = static_cast<float_le>(slidestate); diff --git a/src/core/hle/kernel/svc.cpp b/src/core/hle/kernel/svc.cpp index b5ebaf936..d3a7b0626 100644 --- a/src/core/hle/kernel/svc.cpp +++ b/src/core/hle/kernel/svc.cpp @@ -280,12 +280,12 @@ void SVC::ExitProcess() { current_process->status = ProcessStatus::Exited; // Stop all the process threads that are currently waiting for objects. - auto& thread_list = kernel.GetThreadManager().GetThreadList(); + auto& thread_list = kernel.GetCurrentThreadManager().GetThreadList(); for (auto& thread : thread_list) { if (thread->owner_process != current_process.get()) continue; - if (thread.get() == kernel.GetThreadManager().GetCurrentThread()) + if (thread.get() == kernel.GetCurrentThreadManager().GetCurrentThread()) continue; // TODO(Subv): When are the other running/ready threads terminated? @@ -297,7 +297,7 @@ void SVC::ExitProcess() { } // Kill the current thread - kernel.GetThreadManager().GetCurrentThread()->Stop(); + kernel.GetCurrentThreadManager().GetCurrentThread()->Stop(); system.PrepareReschedule(); } @@ -388,7 +388,7 @@ ResultCode SVC::SendSyncRequest(Handle handle) { system.PrepareReschedule(); - auto thread = SharedFrom(kernel.GetThreadManager().GetCurrentThread()); + auto thread = SharedFrom(kernel.GetCurrentThreadManager().GetCurrentThread()); if (kernel.GetIPCRecorder().IsEnabled()) { kernel.GetIPCRecorder().RegisterRequest(session, thread); @@ -406,7 +406,7 @@ ResultCode SVC::CloseHandle(Handle handle) { /// Wait for a handle to synchronize, timeout after the specified nanoseconds ResultCode SVC::WaitSynchronization1(Handle handle, s64 nano_seconds) { auto object = kernel.GetCurrentProcess()->handle_table.Get<WaitObject>(handle); - Thread* thread = kernel.GetThreadManager().GetCurrentThread(); + Thread* thread = kernel.GetCurrentThreadManager().GetCurrentThread(); if (object == nullptr) return ERR_INVALID_HANDLE; @@ -458,7 +458,7 @@ ResultCode SVC::WaitSynchronization1(Handle handle, s64 nano_seconds) { /// Wait for the given handles to synchronize, timeout after the specified nanoseconds ResultCode SVC::WaitSynchronizationN(s32* out, VAddr handles_address, s32 handle_count, bool wait_all, s64 nano_seconds) { - Thread* thread = kernel.GetThreadManager().GetCurrentThread(); + Thread* thread = kernel.GetCurrentThreadManager().GetCurrentThread(); if (!Memory::IsValidVirtualAddress(*kernel.GetCurrentProcess(), handles_address)) return ERR_INVALID_POINTER; @@ -654,7 +654,7 @@ ResultCode SVC::ReplyAndReceive(s32* index, VAddr handles_address, s32 handle_co // We are also sending a command reply. // Do not send a reply if the command id in the command buffer is 0xFFFF. - Thread* thread = kernel.GetThreadManager().GetCurrentThread(); + Thread* thread = kernel.GetCurrentThreadManager().GetCurrentThread(); u32 cmd_buff_header = memory.Read32(thread->GetCommandBufferAddress()); IPC::Header header{cmd_buff_header}; if (reply_target != 0 && header.command_id != 0xFFFF) { @@ -776,7 +776,7 @@ ResultCode SVC::ArbitrateAddress(Handle handle, u32 address, u32 type, u32 value return ERR_INVALID_HANDLE; auto res = - arbiter->ArbitrateAddress(SharedFrom(kernel.GetThreadManager().GetCurrentThread()), + arbiter->ArbitrateAddress(SharedFrom(kernel.GetCurrentThreadManager().GetCurrentThread()), static_cast<ArbitrationType>(type), address, value, nanoseconds); // TODO(Subv): Identify in which specific cases this call should cause a reschedule. @@ -897,14 +897,19 @@ ResultCode SVC::CreateThread(Handle* out_handle, u32 entry_point, u32 arg, VAddr break; case ThreadProcessorIdAll: LOG_INFO(Kernel_SVC, - "Newly created thread is allowed to be run in any Core, unimplemented."); + "Newly created thread is allowed to be run in any Core, for now run in core 0."); + processor_id = ThreadProcessorId0; break; case ThreadProcessorId1: - LOG_ERROR(Kernel_SVC, - "Newly created thread must run in the SysCore (Core1), unimplemented."); + case ThreadProcessorId2: + case ThreadProcessorId3: + // TODO: Check and log for: When processorid==0x2 and the process is not a BASE mem-region + // process, exheader kernel-flags bitmask 0x2000 must be set (otherwise error 0xD9001BEA is + // returned). When processorid==0x3 and the process is not a BASE mem-region process, error + // 0xD9001BEA is returned. These are the only restriction checks done by the kernel for + // processorid. break; default: - // TODO(bunnei): Implement support for other processor IDs ASSERT_MSG(false, "Unsupported thread processor ID: {}", processor_id); break; } @@ -930,9 +935,9 @@ ResultCode SVC::CreateThread(Handle* out_handle, u32 entry_point, u32 arg, VAddr /// Called when a thread exits void SVC::ExitThread() { - LOG_TRACE(Kernel_SVC, "called, pc=0x{:08X}", system.CPU().GetPC()); + LOG_TRACE(Kernel_SVC, "called, pc=0x{:08X}", system.GetRunningCore().GetPC()); - kernel.GetThreadManager().ExitCurrentThread(); + kernel.GetCurrentThreadManager().ExitCurrentThread(); system.PrepareReschedule(); } @@ -978,7 +983,7 @@ ResultCode SVC::SetThreadPriority(Handle handle, u32 priority) { /// Create a mutex ResultCode SVC::CreateMutex(Handle* out_handle, u32 initial_locked) { std::shared_ptr<Mutex> mutex = kernel.CreateMutex(initial_locked != 0); - mutex->name = fmt::format("mutex-{:08x}", system.CPU().GetReg(14)); + mutex->name = fmt::format("mutex-{:08x}", system.GetRunningCore().GetReg(14)); CASCADE_RESULT(*out_handle, kernel.GetCurrentProcess()->handle_table.Create(std::move(mutex))); LOG_TRACE(Kernel_SVC, "called initial_locked={} : created handle=0x{:08X}", @@ -995,7 +1000,7 @@ ResultCode SVC::ReleaseMutex(Handle handle) { if (mutex == nullptr) return ERR_INVALID_HANDLE; - return mutex->Release(kernel.GetThreadManager().GetCurrentThread()); + return mutex->Release(kernel.GetCurrentThreadManager().GetCurrentThread()); } /// Get the ID of the specified process @@ -1045,7 +1050,7 @@ ResultCode SVC::GetThreadId(u32* thread_id, Handle handle) { ResultCode SVC::CreateSemaphore(Handle* out_handle, s32 initial_count, s32 max_count) { CASCADE_RESULT(std::shared_ptr<Semaphore> semaphore, kernel.CreateSemaphore(initial_count, max_count)); - semaphore->name = fmt::format("semaphore-{:08x}", system.CPU().GetReg(14)); + semaphore->name = fmt::format("semaphore-{:08x}", system.GetRunningCore().GetReg(14)); CASCADE_RESULT(*out_handle, kernel.GetCurrentProcess()->handle_table.Create(std::move(semaphore))); @@ -1115,8 +1120,9 @@ ResultCode SVC::QueryMemory(MemoryInfo* memory_info, PageInfo* page_info, u32 ad /// Create an event ResultCode SVC::CreateEvent(Handle* out_handle, u32 reset_type) { - std::shared_ptr<Event> evt = kernel.CreateEvent( - static_cast<ResetType>(reset_type), fmt::format("event-{:08x}", system.CPU().GetReg(14))); + std::shared_ptr<Event> evt = + kernel.CreateEvent(static_cast<ResetType>(reset_type), + fmt::format("event-{:08x}", system.GetRunningCore().GetReg(14))); CASCADE_RESULT(*out_handle, kernel.GetCurrentProcess()->handle_table.Create(std::move(evt))); LOG_TRACE(Kernel_SVC, "called reset_type=0x{:08X} : created handle=0x{:08X}", reset_type, @@ -1158,8 +1164,9 @@ ResultCode SVC::ClearEvent(Handle handle) { /// Creates a timer ResultCode SVC::CreateTimer(Handle* out_handle, u32 reset_type) { - std::shared_ptr<Timer> timer = kernel.CreateTimer( - static_cast<ResetType>(reset_type), fmt ::format("timer-{:08x}", system.CPU().GetReg(14))); + std::shared_ptr<Timer> timer = + kernel.CreateTimer(static_cast<ResetType>(reset_type), + fmt ::format("timer-{:08x}", system.GetRunningCore().GetReg(14))); CASCADE_RESULT(*out_handle, kernel.GetCurrentProcess()->handle_table.Create(std::move(timer))); LOG_TRACE(Kernel_SVC, "called reset_type=0x{:08X} : created handle=0x{:08X}", reset_type, @@ -1213,7 +1220,7 @@ ResultCode SVC::CancelTimer(Handle handle) { void SVC::SleepThread(s64 nanoseconds) { LOG_TRACE(Kernel_SVC, "called nanoseconds={}", nanoseconds); - ThreadManager& thread_manager = kernel.GetThreadManager(); + ThreadManager& thread_manager = kernel.GetCurrentThreadManager(); // Don't attempt to yield execution if there are no available threads to run, // this way we avoid a useless reschedule to the idle thread. @@ -1231,10 +1238,11 @@ void SVC::SleepThread(s64 nanoseconds) { /// This returns the total CPU ticks elapsed since the CPU was powered-on s64 SVC::GetSystemTick() { - s64 result = system.CoreTiming().GetTicks(); + // TODO: Use globalTicks here? + s64 result = system.GetRunningCore().GetTimer()->GetTicks(); // Advance time to defeat dumb games (like Cubic Ninja) that busy-wait for the frame to end. // Measured time between two calls on a 9.2 o3DS with Ninjhax 1.1b - system.CoreTiming().AddTicks(150); + system.GetRunningCore().GetTimer()->AddTicks(150); return result; } @@ -1596,11 +1604,11 @@ void SVC::CallSVC(u32 immediate) { SVC::SVC(Core::System& system) : system(system), kernel(system.Kernel()), memory(system.Memory()) {} u32 SVC::GetReg(std::size_t n) { - return system.CPU().GetReg(static_cast<int>(n)); + return system.GetRunningCore().GetReg(static_cast<int>(n)); } void SVC::SetReg(std::size_t n, u32 value) { - system.CPU().SetReg(static_cast<int>(n), value); + system.GetRunningCore().SetReg(static_cast<int>(n), value); } SVCContext::SVCContext(Core::System& system) : impl(std::make_unique<SVC>(system)) {} diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp index 3b15ec35e..47d8cb1df 100644 --- a/src/core/hle/kernel/thread.cpp +++ b/src/core/hle/kernel/thread.cpp @@ -33,13 +33,9 @@ void Thread::Acquire(Thread* thread) { ASSERT_MSG(!ShouldWait(thread), "object unavailable!"); } -u32 ThreadManager::NewThreadId() { - return next_thread_id++; -} - -Thread::Thread(KernelSystem& kernel) - : WaitObject(kernel), context(kernel.GetThreadManager().NewContext()), - thread_manager(kernel.GetThreadManager()) {} +Thread::Thread(KernelSystem& kernel, u32 core_id) + : WaitObject(kernel), context(kernel.GetThreadManager(core_id).NewContext()), + thread_manager(kernel.GetThreadManager(core_id)) {} Thread::~Thread() {} Thread* ThreadManager::GetCurrentThread() const { @@ -84,7 +80,7 @@ void ThreadManager::SwitchContext(Thread* new_thread) { // Save context for previous thread if (previous_thread) { - previous_thread->last_running_ticks = timing.GetTicks(); + previous_thread->last_running_ticks = timing.GetGlobalTicks(); cpu->SaveContext(previous_thread->context); if (previous_thread->status == ThreadStatus::Running) { @@ -111,7 +107,7 @@ void ThreadManager::SwitchContext(Thread* new_thread) { new_thread->status = ThreadStatus::Running; if (previous_process.get() != current_thread->owner_process) { - kernel.SetCurrentProcess(SharedFrom(current_thread->owner_process)); + kernel.SetCurrentProcessForCPU(SharedFrom(current_thread->owner_process), cpu->GetID()); } cpu->LoadContext(new_thread->context); @@ -124,7 +120,7 @@ void ThreadManager::SwitchContext(Thread* new_thread) { } Thread* ThreadManager::PopNextReadyThread() { - Thread* next; + Thread* next = nullptr; Thread* thread = GetCurrentThread(); if (thread && thread->status == ThreadStatus::Running) { @@ -309,22 +305,22 @@ ResultVal<std::shared_ptr<Thread>> KernelSystem::CreateThread(std::string name, ErrorSummary::InvalidArgument, ErrorLevel::Permanent); } - auto thread{std::make_shared<Thread>(*this)}; + auto thread{std::make_shared<Thread>(*this, processor_id)}; - thread_manager->thread_list.push_back(thread); - thread_manager->ready_queue.prepare(priority); + thread_managers[processor_id]->thread_list.push_back(thread); + thread_managers[processor_id]->ready_queue.prepare(priority); - thread->thread_id = thread_manager->NewThreadId(); + thread->thread_id = NewThreadId(); thread->status = ThreadStatus::Dormant; thread->entry_point = entry_point; thread->stack_top = stack_top; thread->nominal_priority = thread->current_priority = priority; - thread->last_running_ticks = timing.GetTicks(); + thread->last_running_ticks = timing.GetGlobalTicks(); thread->processor_id = processor_id; thread->wait_objects.clear(); thread->wait_address = 0; thread->name = std::move(name); - thread_manager->wakeup_callback_table[thread->thread_id] = thread.get(); + thread_managers[processor_id]->wakeup_callback_table[thread->thread_id] = thread.get(); thread->owner_process = &owner_process; // Find the next available TLS index, and mark it as used @@ -369,7 +365,7 @@ ResultVal<std::shared_ptr<Thread>> KernelSystem::CreateThread(std::string name, // to initialize the context ResetThreadContext(thread->context, stack_top, entry_point, arg); - thread_manager->ready_queue.push_back(thread->current_priority, thread.get()); + thread_managers[processor_id]->ready_queue.push_back(thread->current_priority, thread.get()); thread->status = ThreadStatus::Ready; return MakeResult<std::shared_ptr<Thread>>(std::move(thread)); @@ -435,6 +431,9 @@ void ThreadManager::Reschedule() { LOG_TRACE(Kernel, "context switch {} -> idle", cur->GetObjectId()); } else if (next) { LOG_TRACE(Kernel, "context switch idle -> {}", next->GetObjectId()); + } else { + LOG_TRACE(Kernel, "context switch idle -> idle, do nothing"); + return; } SwitchContext(next); @@ -461,11 +460,10 @@ VAddr Thread::GetCommandBufferAddress() const { return GetTLSAddress() + command_header_offset; } -ThreadManager::ThreadManager(Kernel::KernelSystem& kernel) : kernel(kernel) { - ThreadWakeupEventType = - kernel.timing.RegisterEvent("ThreadWakeupCallback", [this](u64 thread_id, s64 cycle_late) { - ThreadWakeupCallback(thread_id, cycle_late); - }); +ThreadManager::ThreadManager(Kernel::KernelSystem& kernel, u32 core_id) : kernel(kernel) { + ThreadWakeupEventType = kernel.timing.RegisterEvent( + "ThreadWakeupCallback_" + std::to_string(core_id), + [this](u64 thread_id, s64 cycle_late) { ThreadWakeupCallback(thread_id, cycle_late); }); } ThreadManager::~ThreadManager() { diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h index f2ef767ef..cd47448a2 100644 --- a/src/core/hle/kernel/thread.h +++ b/src/core/hle/kernel/thread.h @@ -34,7 +34,9 @@ enum ThreadProcessorId : s32 { ThreadProcessorIdAll = -1, ///< Run thread on either core ThreadProcessorId0 = 0, ///< Run thread on core 0 (AppCore) ThreadProcessorId1 = 1, ///< Run thread on core 1 (SysCore) - ThreadProcessorIdMax = 2, ///< Processor ID must be less than this + ThreadProcessorId2 = 2, ///< Run thread on core 2 (additional n3ds core) + ThreadProcessorId3 = 3, ///< Run thread on core 3 (additional n3ds core) + ThreadProcessorIdMax = 4, ///< Processor ID must be less than this }; enum class ThreadStatus { @@ -57,15 +59,9 @@ enum class ThreadWakeupReason { class ThreadManager { public: - explicit ThreadManager(Kernel::KernelSystem& kernel); + explicit ThreadManager(Kernel::KernelSystem& kernel, u32 core_id); ~ThreadManager(); - /** - * Creates a new thread ID - * @return The new thread ID - */ - u32 NewThreadId(); - /** * Gets the current thread */ @@ -132,7 +128,6 @@ private: Kernel::KernelSystem& kernel; ARM_Interface* cpu; - u32 next_thread_id = 1; std::shared_ptr<Thread> current_thread; Common::ThreadQueueList<Thread*, ThreadPrioLowest + 1> ready_queue; std::unordered_map<u64, Thread*> wakeup_callback_table; @@ -149,7 +144,7 @@ private: class Thread final : public WaitObject { public: - explicit Thread(KernelSystem&); + explicit Thread(KernelSystem&, u32 core_id); ~Thread() override; std::string GetName() const override { diff --git a/src/core/hle/service/ldr_ro/cro_helper.cpp b/src/core/hle/service/ldr_ro/cro_helper.cpp index 86600e7a9..89e64f9d8 100644 --- a/src/core/hle/service/ldr_ro/cro_helper.cpp +++ b/src/core/hle/service/ldr_ro/cro_helper.cpp @@ -55,7 +55,7 @@ VAddr CROHelper::SegmentTagToAddress(SegmentTag segment_tag) const { return 0; SegmentEntry entry; - GetEntry(memory, segment_tag.segment_index, entry); + GetEntry(system.Memory(), segment_tag.segment_index, entry); if (segment_tag.offset_into_segment >= entry.size) return 0; @@ -71,12 +71,12 @@ ResultCode CROHelper::ApplyRelocation(VAddr target_address, RelocationType reloc break; case RelocationType::AbsoluteAddress: case RelocationType::AbsoluteAddress2: - memory.Write32(target_address, symbol_address + addend); - cpu.InvalidateCacheRange(target_address, sizeof(u32)); + system.Memory().Write32(target_address, symbol_address + addend); + system.InvalidateCacheRange(target_address, sizeof(u32)); break; case RelocationType::RelativeAddress: - memory.Write32(target_address, symbol_address + addend - target_future_address); - cpu.InvalidateCacheRange(target_address, sizeof(u32)); + system.Memory().Write32(target_address, symbol_address + addend - target_future_address); + system.InvalidateCacheRange(target_address, sizeof(u32)); break; case RelocationType::ThumbBranch: case RelocationType::ArmBranch: @@ -98,8 +98,8 @@ ResultCode CROHelper::ClearRelocation(VAddr target_address, RelocationType reloc case RelocationType::AbsoluteAddress: case RelocationType::AbsoluteAddress2: case RelocationType::RelativeAddress: - memory.Write32(target_address, 0); - cpu.InvalidateCacheRange(target_address, sizeof(u32)); + system.Memory().Write32(target_address, 0); + system.InvalidateCacheRange(target_address, sizeof(u32)); break; case RelocationType::ThumbBranch: case RelocationType::ArmBranch: @@ -121,7 +121,8 @@ ResultCode CROHelper::ApplyRelocationBatch(VAddr batch, u32 symbol_address, bool VAddr relocation_address = batch; while (true) { RelocationEntry relocation; - memory.ReadBlock(process, relocation_address, &relocation, sizeof(RelocationEntry)); + system.Memory().ReadBlock(process, relocation_address, &relocation, + sizeof(RelocationEntry)); VAddr relocation_target = SegmentTagToAddress(relocation.target_position); if (relocation_target == 0) { @@ -142,9 +143,9 @@ ResultCode CROHelper::ApplyRelocationBatch(VAddr batch, u32 symbol_address, bool } RelocationEntry relocation; - memory.ReadBlock(process, batch, &relocation, sizeof(RelocationEntry)); + system.Memory().ReadBlock(process, batch, &relocation, sizeof(RelocationEntry)); relocation.is_batch_resolved = reset ? 0 : 1; - memory.WriteBlock(process, batch, &relocation, sizeof(RelocationEntry)); + system.Memory().WriteBlock(process, batch, &relocation, sizeof(RelocationEntry)); return RESULT_SUCCESS; } @@ -154,13 +155,13 @@ VAddr CROHelper::FindExportNamedSymbol(const std::string& name) const { std::size_t len = name.size(); ExportTreeEntry entry; - GetEntry(memory, 0, entry); + GetEntry(system.Memory(), 0, entry); ExportTreeEntry::Child next; next.raw = entry.left.raw; u32 found_id; while (true) { - GetEntry(memory, next.next_index, entry); + GetEntry(system.Memory(), next.next_index, entry); if (next.is_end) { found_id = entry.export_table_index; @@ -186,9 +187,9 @@ VAddr CROHelper::FindExportNamedSymbol(const std::string& name) const { u32 export_strings_size = GetField(ExportStringsSize); ExportNamedSymbolEntry symbol_entry; - GetEntry(memory, found_id, symbol_entry); + GetEntry(system.Memory(), found_id, symbol_entry); - if (memory.ReadCString(symbol_entry.name_offset, export_strings_size) != name) + if (system.Memory().ReadCString(symbol_entry.name_offset, export_strings_size) != name) return 0; return SegmentTagToAddress(symbol_entry.symbol_position); @@ -279,7 +280,7 @@ ResultVal<VAddr> CROHelper::RebaseSegmentTable(u32 cro_size, VAddr data_segment_ u32 segment_num = GetField(SegmentNum); for (u32 i = 0; i < segment_num; ++i) { SegmentEntry segment; - GetEntry(memory, i, segment); + GetEntry(system.Memory(), i, segment); if (segment.type == SegmentType::Data) { if (segment.size != 0) { if (segment.size > data_segment_size) @@ -298,7 +299,7 @@ ResultVal<VAddr> CROHelper::RebaseSegmentTable(u32 cro_size, VAddr data_segment_ if (segment.offset > module_address + cro_size) return CROFormatError(0x19); } - SetEntry(memory, i, segment); + SetEntry(system.Memory(), i, segment); } return MakeResult<u32>(prev_data_segment + module_address); } @@ -310,7 +311,7 @@ ResultCode CROHelper::RebaseExportNamedSymbolTable() { u32 export_named_symbol_num = GetField(ExportNamedSymbolNum); for (u32 i = 0; i < export_named_symbol_num; ++i) { ExportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset += module_address; @@ -320,7 +321,7 @@ ResultCode CROHelper::RebaseExportNamedSymbolTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -329,7 +330,7 @@ ResultCode CROHelper::VerifyExportTreeTable() const { u32 tree_num = GetField(ExportTreeNum); for (u32 i = 0; i < tree_num; ++i) { ExportTreeEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.left.next_index >= tree_num || entry.right.next_index >= tree_num) { return CROFormatError(0x11); @@ -353,7 +354,7 @@ ResultCode CROHelper::RebaseImportModuleTable() { u32 module_num = GetField(ImportModuleNum); for (u32 i = 0; i < module_num; ++i) { ImportModuleEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset += module_address; @@ -379,7 +380,7 @@ ResultCode CROHelper::RebaseImportModuleTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -395,7 +396,7 @@ ResultCode CROHelper::RebaseImportNamedSymbolTable() { u32 num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset += module_address; @@ -413,7 +414,7 @@ ResultCode CROHelper::RebaseImportNamedSymbolTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -427,7 +428,7 @@ ResultCode CROHelper::RebaseImportIndexedSymbolTable() { u32 num = GetField(ImportIndexedSymbolNum); for (u32 i = 0; i < num; ++i) { ImportIndexedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.relocation_batch_offset != 0) { entry.relocation_batch_offset += module_address; @@ -437,7 +438,7 @@ ResultCode CROHelper::RebaseImportIndexedSymbolTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -451,7 +452,7 @@ ResultCode CROHelper::RebaseImportAnonymousSymbolTable() { u32 num = GetField(ImportAnonymousSymbolNum); for (u32 i = 0; i < num; ++i) { ImportAnonymousSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.relocation_batch_offset != 0) { entry.relocation_batch_offset += module_address; @@ -461,7 +462,7 @@ ResultCode CROHelper::RebaseImportAnonymousSymbolTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -476,14 +477,14 @@ ResultCode CROHelper::ResetExternalRelocations() { ExternalRelocationEntry relocation; // Verifies that the last relocation is the end of a batch - GetEntry(memory, external_relocation_num - 1, relocation); + GetEntry(system.Memory(), external_relocation_num - 1, relocation); if (!relocation.is_batch_end) { return CROFormatError(0x12); } bool batch_begin = true; for (u32 i = 0; i < external_relocation_num; ++i) { - GetEntry(memory, i, relocation); + GetEntry(system.Memory(), i, relocation); VAddr relocation_target = SegmentTagToAddress(relocation.target_position); if (relocation_target == 0) { @@ -500,7 +501,7 @@ ResultCode CROHelper::ResetExternalRelocations() { if (batch_begin) { // resets to unresolved state relocation.is_batch_resolved = 0; - SetEntry(memory, i, relocation); + SetEntry(system.Memory(), i, relocation); } // if current is an end, then the next is a beginning @@ -516,7 +517,7 @@ ResultCode CROHelper::ClearExternalRelocations() { bool batch_begin = true; for (u32 i = 0; i < external_relocation_num; ++i) { - GetEntry(memory, i, relocation); + GetEntry(system.Memory(), i, relocation); VAddr relocation_target = SegmentTagToAddress(relocation.target_position); if (relocation_target == 0) { @@ -532,7 +533,7 @@ ResultCode CROHelper::ClearExternalRelocations() { if (batch_begin) { // resets to unresolved state relocation.is_batch_resolved = 0; - SetEntry(memory, i, relocation); + SetEntry(system.Memory(), i, relocation); } // if current is an end, then the next is a beginning @@ -548,13 +549,13 @@ ResultCode CROHelper::ApplyStaticAnonymousSymbolToCRS(VAddr crs_address) { static_relocation_table_offset + GetField(StaticRelocationNum) * sizeof(StaticRelocationEntry); - CROHelper crs(crs_address, process, memory, cpu); + CROHelper crs(crs_address, process, system); u32 offset_export_num = GetField(StaticAnonymousSymbolNum); LOG_INFO(Service_LDR, "CRO \"{}\" exports {} static anonymous symbols", ModuleName(), offset_export_num); for (u32 i = 0; i < offset_export_num; ++i) { StaticAnonymousSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); u32 batch_address = entry.relocation_batch_offset + module_address; if (batch_address < static_relocation_table_offset || @@ -579,7 +580,7 @@ ResultCode CROHelper::ApplyInternalRelocations(u32 old_data_segment_address) { u32 internal_relocation_num = GetField(InternalRelocationNum); for (u32 i = 0; i < internal_relocation_num; ++i) { InternalRelocationEntry relocation; - GetEntry(memory, i, relocation); + GetEntry(system.Memory(), i, relocation); VAddr target_addressB = SegmentTagToAddress(relocation.target_position); if (target_addressB == 0) { return CROFormatError(0x15); @@ -587,7 +588,7 @@ ResultCode CROHelper::ApplyInternalRelocations(u32 old_data_segment_address) { VAddr target_address; SegmentEntry target_segment; - GetEntry(memory, relocation.target_position.segment_index, target_segment); + GetEntry(system.Memory(), relocation.target_position.segment_index, target_segment); if (target_segment.type == SegmentType::Data) { // If the relocation is to the .data segment, we need to relocate it in the old buffer @@ -602,7 +603,7 @@ ResultCode CROHelper::ApplyInternalRelocations(u32 old_data_segment_address) { } SegmentEntry symbol_segment; - GetEntry(memory, relocation.symbol_segment, symbol_segment); + GetEntry(system.Memory(), relocation.symbol_segment, symbol_segment); LOG_TRACE(Service_LDR, "Internally relocates 0x{:08X} with 0x{:08X}", target_address, symbol_segment.offset); ResultCode result = ApplyRelocation(target_address, relocation.type, relocation.addend, @@ -619,7 +620,7 @@ ResultCode CROHelper::ClearInternalRelocations() { u32 internal_relocation_num = GetField(InternalRelocationNum); for (u32 i = 0; i < internal_relocation_num; ++i) { InternalRelocationEntry relocation; - GetEntry(memory, i, relocation); + GetEntry(system.Memory(), i, relocation); VAddr target_address = SegmentTagToAddress(relocation.target_position); if (target_address == 0) { @@ -639,13 +640,13 @@ void CROHelper::UnrebaseImportAnonymousSymbolTable() { u32 num = GetField(ImportAnonymousSymbolNum); for (u32 i = 0; i < num; ++i) { ImportAnonymousSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.relocation_batch_offset != 0) { entry.relocation_batch_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -653,13 +654,13 @@ void CROHelper::UnrebaseImportIndexedSymbolTable() { u32 num = GetField(ImportIndexedSymbolNum); for (u32 i = 0; i < num; ++i) { ImportIndexedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.relocation_batch_offset != 0) { entry.relocation_batch_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -667,7 +668,7 @@ void CROHelper::UnrebaseImportNamedSymbolTable() { u32 num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset -= module_address; @@ -677,7 +678,7 @@ void CROHelper::UnrebaseImportNamedSymbolTable() { entry.relocation_batch_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -685,7 +686,7 @@ void CROHelper::UnrebaseImportModuleTable() { u32 module_num = GetField(ImportModuleNum); for (u32 i = 0; i < module_num; ++i) { ImportModuleEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset -= module_address; @@ -699,7 +700,7 @@ void CROHelper::UnrebaseImportModuleTable() { entry.import_anonymous_symbol_table_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -707,13 +708,13 @@ void CROHelper::UnrebaseExportNamedSymbolTable() { u32 export_named_symbol_num = GetField(ExportNamedSymbolNum); for (u32 i = 0; i < export_named_symbol_num; ++i) { ExportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -721,7 +722,7 @@ void CROHelper::UnrebaseSegmentTable() { u32 segment_num = GetField(SegmentNum); for (u32 i = 0; i < segment_num; ++i) { SegmentEntry segment; - GetEntry(memory, i, segment); + GetEntry(system.Memory(), i, segment); if (segment.type == SegmentType::BSS) { segment.offset = 0; @@ -729,7 +730,7 @@ void CROHelper::UnrebaseSegmentTable() { segment.offset -= module_address; } - SetEntry(memory, i, segment); + SetEntry(system.Memory(), i, segment); } } @@ -751,17 +752,17 @@ ResultCode CROHelper::ApplyImportNamedSymbol(VAddr crs_address) { u32 symbol_import_num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); if (!relocation_entry.is_batch_resolved) { ResultCode result = ForEachAutoLinkCRO( - process, memory, cpu, crs_address, [&](CROHelper source) -> ResultVal<bool> { + process, system, crs_address, [&](CROHelper source) -> ResultVal<bool> { std::string symbol_name = - memory.ReadCString(entry.name_offset, import_strings_size); + system.Memory().ReadCString(entry.name_offset, import_strings_size); u32 symbol_address = source.FindExportNamedSymbol(symbol_name); if (symbol_address != 0) { @@ -794,11 +795,11 @@ ResultCode CROHelper::ResetImportNamedSymbol() { u32 symbol_import_num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); ResultCode result = ApplyRelocationBatch(relocation_addr, unresolved_symbol, true); if (result.IsError()) { @@ -815,11 +816,11 @@ ResultCode CROHelper::ResetImportIndexedSymbol() { u32 import_num = GetField(ImportIndexedSymbolNum); for (u32 i = 0; i < import_num; ++i) { ImportIndexedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); ResultCode result = ApplyRelocationBatch(relocation_addr, unresolved_symbol, true); if (result.IsError()) { @@ -836,11 +837,11 @@ ResultCode CROHelper::ResetImportAnonymousSymbol() { u32 import_num = GetField(ImportAnonymousSymbolNum); for (u32 i = 0; i < import_num; ++i) { ImportAnonymousSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); ResultCode result = ApplyRelocationBatch(relocation_addr, unresolved_symbol, true); if (result.IsError()) { @@ -857,19 +858,20 @@ ResultCode CROHelper::ApplyModuleImport(VAddr crs_address) { u32 import_module_num = GetField(ImportModuleNum); for (u32 i = 0; i < import_module_num; ++i) { ImportModuleEntry entry; - GetEntry(memory, i, entry); - std::string want_cro_name = memory.ReadCString(entry.name_offset, import_strings_size); + GetEntry(system.Memory(), i, entry); + std::string want_cro_name = + system.Memory().ReadCString(entry.name_offset, import_strings_size); ResultCode result = ForEachAutoLinkCRO( - process, memory, cpu, crs_address, [&](CROHelper source) -> ResultVal<bool> { + process, system, crs_address, [&](CROHelper source) -> ResultVal<bool> { if (want_cro_name == source.ModuleName()) { LOG_INFO(Service_LDR, "CRO \"{}\" imports {} indexed symbols from \"{}\"", ModuleName(), entry.import_indexed_symbol_num, source.ModuleName()); for (u32 j = 0; j < entry.import_indexed_symbol_num; ++j) { ImportIndexedSymbolEntry im; - entry.GetImportIndexedSymbolEntry(process, memory, j, im); + entry.GetImportIndexedSymbolEntry(process, system.Memory(), j, im); ExportIndexedSymbolEntry ex; - source.GetEntry(memory, im.index, ex); + source.GetEntry(system.Memory(), im.index, ex); u32 symbol_address = source.SegmentTagToAddress(ex.symbol_position); LOG_TRACE(Service_LDR, " Imports 0x{:08X}", symbol_address); ResultCode result = @@ -884,7 +886,7 @@ ResultCode CROHelper::ApplyModuleImport(VAddr crs_address) { ModuleName(), entry.import_anonymous_symbol_num, source.ModuleName()); for (u32 j = 0; j < entry.import_anonymous_symbol_num; ++j) { ImportAnonymousSymbolEntry im; - entry.GetImportAnonymousSymbolEntry(process, memory, j, im); + entry.GetImportAnonymousSymbolEntry(process, system.Memory(), j, im); u32 symbol_address = source.SegmentTagToAddress(im.symbol_position); LOG_TRACE(Service_LDR, " Imports 0x{:08X}", symbol_address); ResultCode result = @@ -913,15 +915,15 @@ ResultCode CROHelper::ApplyExportNamedSymbol(CROHelper target) { u32 target_symbol_import_num = target.GetField(ImportNamedSymbolNum); for (u32 i = 0; i < target_symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - target.GetEntry(memory, i, entry); + target.GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); if (!relocation_entry.is_batch_resolved) { std::string symbol_name = - memory.ReadCString(entry.name_offset, target_import_strings_size); + system.Memory().ReadCString(entry.name_offset, target_import_strings_size); u32 symbol_address = FindExportNamedSymbol(symbol_name); if (symbol_address != 0) { LOG_TRACE(Service_LDR, " exports symbol \"{}\"", symbol_name); @@ -944,15 +946,15 @@ ResultCode CROHelper::ResetExportNamedSymbol(CROHelper target) { u32 target_symbol_import_num = target.GetField(ImportNamedSymbolNum); for (u32 i = 0; i < target_symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - target.GetEntry(memory, i, entry); + target.GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); if (relocation_entry.is_batch_resolved) { std::string symbol_name = - memory.ReadCString(entry.name_offset, target_import_strings_size); + system.Memory().ReadCString(entry.name_offset, target_import_strings_size); u32 symbol_address = FindExportNamedSymbol(symbol_name); if (symbol_address != 0) { LOG_TRACE(Service_LDR, " unexports symbol \"{}\"", symbol_name); @@ -974,18 +976,19 @@ ResultCode CROHelper::ApplyModuleExport(CROHelper target) { u32 target_import_module_num = target.GetField(ImportModuleNum); for (u32 i = 0; i < target_import_module_num; ++i) { ImportModuleEntry entry; - target.GetEntry(memory, i, entry); + target.GetEntry(system.Memory(), i, entry); - if (memory.ReadCString(entry.name_offset, target_import_string_size) != module_name) + if (system.Memory().ReadCString(entry.name_offset, target_import_string_size) != + module_name) continue; LOG_INFO(Service_LDR, "CRO \"{}\" exports {} indexed symbols to \"{}\"", module_name, entry.import_indexed_symbol_num, target.ModuleName()); for (u32 j = 0; j < entry.import_indexed_symbol_num; ++j) { ImportIndexedSymbolEntry im; - entry.GetImportIndexedSymbolEntry(process, memory, j, im); + entry.GetImportIndexedSymbolEntry(process, system.Memory(), j, im); ExportIndexedSymbolEntry ex; - GetEntry(memory, im.index, ex); + GetEntry(system.Memory(), im.index, ex); u32 symbol_address = SegmentTagToAddress(ex.symbol_position); LOG_TRACE(Service_LDR, " exports symbol 0x{:08X}", symbol_address); ResultCode result = @@ -1000,7 +1003,7 @@ ResultCode CROHelper::ApplyModuleExport(CROHelper target) { entry.import_anonymous_symbol_num, target.ModuleName()); for (u32 j = 0; j < entry.import_anonymous_symbol_num; ++j) { ImportAnonymousSymbolEntry im; - entry.GetImportAnonymousSymbolEntry(process, memory, j, im); + entry.GetImportAnonymousSymbolEntry(process, system.Memory(), j, im); u32 symbol_address = SegmentTagToAddress(im.symbol_position); LOG_TRACE(Service_LDR, " exports symbol 0x{:08X}", symbol_address); ResultCode result = @@ -1023,16 +1026,17 @@ ResultCode CROHelper::ResetModuleExport(CROHelper target) { u32 target_import_module_num = target.GetField(ImportModuleNum); for (u32 i = 0; i < target_import_module_num; ++i) { ImportModuleEntry entry; - target.GetEntry(memory, i, entry); + target.GetEntry(system.Memory(), i, entry); - if (memory.ReadCString(entry.name_offset, target_import_string_size) != module_name) + if (system.Memory().ReadCString(entry.name_offset, target_import_string_size) != + module_name) continue; LOG_DEBUG(Service_LDR, "CRO \"{}\" unexports indexed symbols to \"{}\"", module_name, target.ModuleName()); for (u32 j = 0; j < entry.import_indexed_symbol_num; ++j) { ImportIndexedSymbolEntry im; - entry.GetImportIndexedSymbolEntry(process, memory, j, im); + entry.GetImportIndexedSymbolEntry(process, system.Memory(), j, im); ResultCode result = target.ApplyRelocationBatch(im.relocation_batch_offset, unresolved_symbol, true); if (result.IsError()) { @@ -1045,7 +1049,7 @@ ResultCode CROHelper::ResetModuleExport(CROHelper target) { target.ModuleName()); for (u32 j = 0; j < entry.import_anonymous_symbol_num; ++j) { ImportAnonymousSymbolEntry im; - entry.GetImportAnonymousSymbolEntry(process, memory, j, im); + entry.GetImportAnonymousSymbolEntry(process, system.Memory(), j, im); ResultCode result = target.ApplyRelocationBatch(im.relocation_batch_offset, unresolved_symbol, true); if (result.IsError()) { @@ -1063,15 +1067,16 @@ ResultCode CROHelper::ApplyExitRelocations(VAddr crs_address) { u32 symbol_import_num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); - if (memory.ReadCString(entry.name_offset, import_strings_size) == "__aeabi_atexit") { + if (system.Memory().ReadCString(entry.name_offset, import_strings_size) == + "__aeabi_atexit") { ResultCode result = ForEachAutoLinkCRO( - process, memory, cpu, crs_address, [&](CROHelper source) -> ResultVal<bool> { + process, system, crs_address, [&](CROHelper source) -> ResultVal<bool> { u32 symbol_address = source.FindExportNamedSymbol("nnroAeabiAtexit_"); if (symbol_address != 0) { @@ -1126,7 +1131,8 @@ ResultCode CROHelper::Rebase(VAddr crs_address, u32 cro_size, VAddr data_segment return result; } - result = VerifyStringTableLength(memory, GetField(ModuleNameOffset), GetField(ModuleNameSize)); + result = VerifyStringTableLength(system.Memory(), GetField(ModuleNameOffset), + GetField(ModuleNameSize)); if (result.IsError()) { LOG_ERROR(Service_LDR, "Error verifying module name {:08X}", result.raw); return result; @@ -1155,8 +1161,8 @@ ResultCode CROHelper::Rebase(VAddr crs_address, u32 cro_size, VAddr data_segment return result; } - result = - VerifyStringTableLength(memory, GetField(ExportStringsOffset), GetField(ExportStringsSize)); + result = VerifyStringTableLength(system.Memory(), GetField(ExportStringsOffset), + GetField(ExportStringsSize)); if (result.IsError()) { LOG_ERROR(Service_LDR, "Error verifying export strings {:08X}", result.raw); return result; @@ -1192,8 +1198,8 @@ ResultCode CROHelper::Rebase(VAddr crs_address, u32 cro_size, VAddr data_segment return result; } - result = - VerifyStringTableLength(memory, GetField(ImportStringsOffset), GetField(ImportStringsSize)); + result = VerifyStringTableLength(system.Memory(), GetField(ImportStringsOffset), + GetField(ImportStringsSize)); if (result.IsError()) { LOG_ERROR(Service_LDR, "Error verifying import strings {:08X}", result.raw); return result; @@ -1266,11 +1272,11 @@ ResultCode CROHelper::Link(VAddr crs_address, bool link_on_load_bug_fix) { // so we do the same if (GetField(SegmentNum) >= 2) { // means we have .data segment SegmentEntry entry; - GetEntry(memory, 2, entry); + GetEntry(system.Memory(), 2, entry); ASSERT(entry.type == SegmentType::Data); data_segment_address = entry.offset; entry.offset = GetField(DataOffset); - SetEntry(memory, 2, entry); + SetEntry(system.Memory(), 2, entry); } } SCOPE_EXIT({ @@ -1278,9 +1284,9 @@ ResultCode CROHelper::Link(VAddr crs_address, bool link_on_load_bug_fix) { if (link_on_load_bug_fix) { if (GetField(SegmentNum) >= 2) { SegmentEntry entry; - GetEntry(memory, 2, entry); + GetEntry(system.Memory(), 2, entry); entry.offset = data_segment_address; - SetEntry(memory, 2, entry); + SetEntry(system.Memory(), 2, entry); } } }); @@ -1301,7 +1307,7 @@ ResultCode CROHelper::Link(VAddr crs_address, bool link_on_load_bug_fix) { } // Exports symbols to other modules - result = ForEachAutoLinkCRO(process, memory, cpu, crs_address, + result = ForEachAutoLinkCRO(process, system, crs_address, [this](CROHelper target) -> ResultVal<bool> { ResultCode result = ApplyExportNamedSymbol(target); if (result.IsError()) @@ -1346,7 +1352,7 @@ ResultCode CROHelper::Unlink(VAddr crs_address) { // Resets all symbols in other modules imported from this module // Note: the RO service seems only searching in auto-link modules - result = ForEachAutoLinkCRO(process, memory, cpu, crs_address, + result = ForEachAutoLinkCRO(process, system, crs_address, [this](CROHelper target) -> ResultVal<bool> { ResultCode result = ResetExportNamedSymbol(target); if (result.IsError()) @@ -1387,13 +1393,13 @@ void CROHelper::InitCRS() { } void CROHelper::Register(VAddr crs_address, bool auto_link) { - CROHelper crs(crs_address, process, memory, cpu); - CROHelper head(auto_link ? crs.NextModule() : crs.PreviousModule(), process, memory, cpu); + CROHelper crs(crs_address, process, system); + CROHelper head(auto_link ? crs.NextModule() : crs.PreviousModule(), process, system); if (head.module_address) { // there are already CROs registered // register as the new tail - CROHelper tail(head.PreviousModule(), process, memory, cpu); + CROHelper tail(head.PreviousModule(), process, system); // link with the old tail ASSERT(tail.NextModule() == 0); @@ -1419,11 +1425,11 @@ void CROHelper::Register(VAddr crs_address, bool auto_link) { } void CROHelper::Unregister(VAddr crs_address) { - CROHelper crs(crs_address, process, memory, cpu); - CROHelper next_head(crs.NextModule(), process, memory, cpu); - CROHelper previous_head(crs.PreviousModule(), process, memory, cpu); - CROHelper next(NextModule(), process, memory, cpu); - CROHelper previous(PreviousModule(), process, memory, cpu); + CROHelper crs(crs_address, process, system); + CROHelper next_head(crs.NextModule(), process, system); + CROHelper previous_head(crs.PreviousModule(), process, system); + CROHelper next(NextModule(), process, system); + CROHelper previous(PreviousModule(), process, system); if (module_address == next_head.module_address || module_address == previous_head.module_address) { @@ -1517,7 +1523,7 @@ std::tuple<VAddr, u32> CROHelper::GetExecutablePages() const { u32 segment_num = GetField(SegmentNum); for (u32 i = 0; i < segment_num; ++i) { SegmentEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.type == SegmentType::Code && entry.size != 0) { VAddr begin = Common::AlignDown(entry.offset, Memory::PAGE_SIZE); VAddr end = Common::AlignUp(entry.offset + entry.size, Memory::PAGE_SIZE); diff --git a/src/core/hle/service/ldr_ro/cro_helper.h b/src/core/hle/service/ldr_ro/cro_helper.h index 46fbe05a6..265b6971e 100644 --- a/src/core/hle/service/ldr_ro/cro_helper.h +++ b/src/core/hle/service/ldr_ro/cro_helper.h @@ -33,12 +33,11 @@ static constexpr u32 CRO_HASH_SIZE = 0x80; class CROHelper final { public: // TODO (wwylele): pass in the process handle for memory access - explicit CROHelper(VAddr cro_address, Kernel::Process& process, Memory::MemorySystem& memory, - ARM_Interface& cpu) - : module_address(cro_address), process(process), memory(memory), cpu(cpu) {} + explicit CROHelper(VAddr cro_address, Kernel::Process& process, Core::System& system) + : module_address(cro_address), process(process), system(system) {} std::string ModuleName() const { - return memory.ReadCString(GetField(ModuleNameOffset), GetField(ModuleNameSize)); + return system.Memory().ReadCString(GetField(ModuleNameOffset), GetField(ModuleNameSize)); } u32 GetFileSize() const { @@ -144,8 +143,7 @@ public: private: const VAddr module_address; ///< the virtual address of this module Kernel::Process& process; ///< the owner process of this module - Memory::MemorySystem& memory; - ARM_Interface& cpu; + Core::System& system; /** * Each item in this enum represents a u32 field in the header begin from address+0x80, @@ -403,11 +401,11 @@ private: } u32 GetField(HeaderField field) const { - return memory.Read32(Field(field)); + return system.Memory().Read32(Field(field)); } void SetField(HeaderField field, u32 value) { - memory.Write32(Field(field), value); + system.Memory().Write32(Field(field), value); } /** @@ -474,12 +472,11 @@ private: * otherwise error code of the last iteration. */ template <typename FunctionObject> - static ResultCode ForEachAutoLinkCRO(Kernel::Process& process, Memory::MemorySystem& memory, - ARM_Interface& cpu, VAddr crs_address, - FunctionObject func) { + static ResultCode ForEachAutoLinkCRO(Kernel::Process& process, Core::System& system, + VAddr crs_address, FunctionObject func) { VAddr current = crs_address; while (current != 0) { - CROHelper cro(current, process, memory, cpu); + CROHelper cro(current, process, system); CASCADE_RESULT(bool next, func(cro)); if (!next) break; diff --git a/src/core/hle/service/ldr_ro/ldr_ro.cpp b/src/core/hle/service/ldr_ro/ldr_ro.cpp index caa063593..274d36ed5 100644 --- a/src/core/hle/service/ldr_ro/ldr_ro.cpp +++ b/src/core/hle/service/ldr_ro/ldr_ro.cpp @@ -115,7 +115,7 @@ void RO::Initialize(Kernel::HLERequestContext& ctx) { return; } - CROHelper crs(crs_address, *process, system.Memory(), system.CPU()); + CROHelper crs(crs_address, *process, system); crs.InitCRS(); result = crs.Rebase(0, crs_size, 0, 0, 0, 0, true); @@ -249,7 +249,7 @@ void RO::LoadCRO(Kernel::HLERequestContext& ctx, bool link_on_load_bug_fix) { return; } - CROHelper cro(cro_address, *process, system.Memory(), system.CPU()); + CROHelper cro(cro_address, *process, system); result = cro.VerifyHash(cro_size, crr_address); if (result.IsError()) { @@ -313,7 +313,7 @@ void RO::LoadCRO(Kernel::HLERequestContext& ctx, bool link_on_load_bug_fix) { } } - system.CPU().InvalidateCacheRange(cro_address, cro_size); + system.InvalidateCacheRange(cro_address, cro_size); LOG_INFO(Service_LDR, "CRO \"{}\" loaded at 0x{:08X}, fixed_end=0x{:08X}", cro.ModuleName(), cro_address, cro_address + fix_size); @@ -331,7 +331,7 @@ void RO::UnloadCRO(Kernel::HLERequestContext& ctx) { LOG_DEBUG(Service_LDR, "called, cro_address=0x{:08X}, zero={}, cro_buffer_ptr=0x{:08X}", cro_address, zero, cro_buffer_ptr); - CROHelper cro(cro_address, *process, system.Memory(), system.CPU()); + CROHelper cro(cro_address, *process, system); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); @@ -386,7 +386,7 @@ void RO::UnloadCRO(Kernel::HLERequestContext& ctx) { LOG_ERROR(Service_LDR, "Error unmapping CRO {:08X}", result.raw); } - system.CPU().InvalidateCacheRange(cro_address, fixed_size); + system.InvalidateCacheRange(cro_address, fixed_size); rb.Push(result); } @@ -398,7 +398,7 @@ void RO::LinkCRO(Kernel::HLERequestContext& ctx) { LOG_DEBUG(Service_LDR, "called, cro_address=0x{:08X}", cro_address); - CROHelper cro(cro_address, *process, system.Memory(), system.CPU()); + CROHelper cro(cro_address, *process, system); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); @@ -438,7 +438,7 @@ void RO::UnlinkCRO(Kernel::HLERequestContext& ctx) { LOG_DEBUG(Service_LDR, "called, cro_address=0x{:08X}", cro_address); - CROHelper cro(cro_address, *process, system.Memory(), system.CPU()); + CROHelper cro(cro_address, *process, system); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); @@ -487,7 +487,7 @@ void RO::Shutdown(Kernel::HLERequestContext& ctx) { return; } - CROHelper crs(slot->loaded_crs, *process, system.Memory(), system.CPU()); + CROHelper crs(slot->loaded_crs, *process, system); crs.Unrebase(true); ResultCode result = RESULT_SUCCESS; diff --git a/src/core/rpc/rpc_server.cpp b/src/core/rpc/rpc_server.cpp index aec99f273..a95c53fd7 100644 --- a/src/core/rpc/rpc_server.cpp +++ b/src/core/rpc/rpc_server.cpp @@ -46,7 +46,9 @@ void RPCServer::HandleWriteMemory(Packet& packet, u32 address, const u8* data, u Core::System::GetInstance().Memory().WriteBlock( *Core::System::GetInstance().Kernel().GetCurrentProcess(), address, data, data_size); // If the memory happens to be executable code, make sure the changes become visible - Core::CPU().InvalidateCacheRange(address, data_size); + + // Is current core correct here? + Core::System::GetInstance().InvalidateCacheRange(address, data_size); } packet.SetPacketDataSize(0); packet.SendReply(); diff --git a/src/tests/core/arm/arm_test_common.cpp b/src/tests/core/arm/arm_test_common.cpp index dbbc21c8c..0957c7c97 100644 --- a/src/tests/core/arm/arm_test_common.cpp +++ b/src/tests/core/arm/arm_test_common.cpp @@ -15,9 +15,9 @@ static Memory::PageTable* page_table = nullptr; TestEnvironment::TestEnvironment(bool mutable_memory_) : mutable_memory(mutable_memory_), test_memory(std::make_shared<TestMemory>(this)) { - timing = std::make_unique<Core::Timing>(); + timing = std::make_unique<Core::Timing>(1); memory = std::make_unique<Memory::MemorySystem>(); - kernel = std::make_unique<Kernel::KernelSystem>(*memory, *timing, [] {}, 0); + kernel = std::make_unique<Kernel::KernelSystem>(*memory, *timing, [] {}, 0, 1); kernel->SetCurrentProcess(kernel->CreateProcess(kernel->CreateCodeSet("", 0))); page_table = &kernel->GetCurrentProcess()->vm_manager.page_table; diff --git a/src/tests/core/arm/dyncom/arm_dyncom_vfp_tests.cpp b/src/tests/core/arm/dyncom/arm_dyncom_vfp_tests.cpp index 5fadaf85e..f0152eb7d 100644 --- a/src/tests/core/arm/dyncom/arm_dyncom_vfp_tests.cpp +++ b/src/tests/core/arm/dyncom/arm_dyncom_vfp_tests.cpp @@ -23,7 +23,7 @@ TEST_CASE("ARM_DynCom (vfp): vadd", "[arm_dyncom]") { test_env.SetMemory32(0, 0xEE321A03); // vadd.f32 s2, s4, s6 test_env.SetMemory32(4, 0xEAFFFFFE); // b +#0 - ARM_DynCom dyncom(nullptr, test_env.GetMemory(), USER32MODE); + ARM_DynCom dyncom(nullptr, test_env.GetMemory(), USER32MODE, 0, nullptr); std::vector<VfpTestCase> test_cases{{ #include "vfp_vadd_f32.inc" diff --git a/src/tests/core/core_timing.cpp b/src/tests/core/core_timing.cpp index 1cfd9e971..850f13bc5 100644 --- a/src/tests/core/core_timing.cpp +++ b/src/tests/core/core_timing.cpp @@ -34,16 +34,16 @@ static void AdvanceAndCheck(Core::Timing& timing, u32 idx, int downcount, int ex expected_callback = CB_IDS[idx]; lateness = expected_lateness; - timing.AddTicks(timing.GetDowncount() - - cpu_downcount); // Pretend we executed X cycles of instructions. - timing.Advance(); + timing.GetTimer(0)->AddTicks(timing.GetTimer(0)->GetDowncount() - + cpu_downcount); // Pretend we executed X cycles of instructions. + timing.GetTimer(0)->Advance(); REQUIRE(decltype(callbacks_ran_flags)().set(idx) == callbacks_ran_flags); - REQUIRE(downcount == timing.GetDowncount()); + REQUIRE(downcount == timing.GetTimer(0)->GetDowncount()); } TEST_CASE("CoreTiming[BasicOrder]", "[core]") { - Core::Timing timing; + Core::Timing timing(1); Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", CallbackTemplate<0>); Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", CallbackTemplate<1>); @@ -52,60 +52,19 @@ TEST_CASE("CoreTiming[BasicOrder]", "[core]") { Core::TimingEventType* cb_e = timing.RegisterEvent("callbackE", CallbackTemplate<4>); // Enter slice 0 - timing.Advance(); + timing.GetTimer(0)->Advance(); // D -> B -> C -> A -> E - timing.ScheduleEvent(1000, cb_a, CB_IDS[0]); - REQUIRE(1000 == timing.GetDowncount()); - timing.ScheduleEvent(500, cb_b, CB_IDS[1]); - REQUIRE(500 == timing.GetDowncount()); - timing.ScheduleEvent(800, cb_c, CB_IDS[2]); - REQUIRE(500 == timing.GetDowncount()); - timing.ScheduleEvent(100, cb_d, CB_IDS[3]); - REQUIRE(100 == timing.GetDowncount()); - timing.ScheduleEvent(1200, cb_e, CB_IDS[4]); - REQUIRE(100 == timing.GetDowncount()); - - AdvanceAndCheck(timing, 3, 400); - AdvanceAndCheck(timing, 1, 300); - AdvanceAndCheck(timing, 2, 200); - AdvanceAndCheck(timing, 0, 200); - AdvanceAndCheck(timing, 4, MAX_SLICE_LENGTH); -} - -TEST_CASE("CoreTiming[Threadsave]", "[core]") { - Core::Timing timing; - - Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", CallbackTemplate<0>); - Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", CallbackTemplate<1>); - Core::TimingEventType* cb_c = timing.RegisterEvent("callbackC", CallbackTemplate<2>); - Core::TimingEventType* cb_d = timing.RegisterEvent("callbackD", CallbackTemplate<3>); - Core::TimingEventType* cb_e = timing.RegisterEvent("callbackE", CallbackTemplate<4>); - - // Enter slice 0 - timing.Advance(); - - // D -> B -> C -> A -> E - timing.ScheduleEventThreadsafe(1000, cb_a, CB_IDS[0]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(1000); - REQUIRE(1000 == timing.GetDowncount()); - timing.ScheduleEventThreadsafe(500, cb_b, CB_IDS[1]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(500); - REQUIRE(500 == timing.GetDowncount()); - timing.ScheduleEventThreadsafe(800, cb_c, CB_IDS[2]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(800); - REQUIRE(500 == timing.GetDowncount()); - timing.ScheduleEventThreadsafe(100, cb_d, CB_IDS[3]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(100); - REQUIRE(100 == timing.GetDowncount()); - timing.ScheduleEventThreadsafe(1200, cb_e, CB_IDS[4]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(1200); - REQUIRE(100 == timing.GetDowncount()); + timing.ScheduleEvent(1000, cb_a, CB_IDS[0], 0); + REQUIRE(1000 == timing.GetTimer(0)->GetDowncount()); + timing.ScheduleEvent(500, cb_b, CB_IDS[1], 0); + REQUIRE(500 == timing.GetTimer(0)->GetDowncount()); + timing.ScheduleEvent(800, cb_c, CB_IDS[2], 0); + REQUIRE(500 == timing.GetTimer(0)->GetDowncount()); + timing.ScheduleEvent(100, cb_d, CB_IDS[3], 0); + REQUIRE(100 == timing.GetTimer(0)->GetDowncount()); + timing.ScheduleEvent(1200, cb_e, CB_IDS[4], 0); + REQUIRE(100 == timing.GetTimer(0)->GetDowncount()); AdvanceAndCheck(timing, 3, 400); AdvanceAndCheck(timing, 1, 300); @@ -131,7 +90,7 @@ void FifoCallback(u64 userdata, s64 cycles_late) { TEST_CASE("CoreTiming[SharedSlot]", "[core]") { using namespace SharedSlotTest; - Core::Timing timing; + Core::Timing timing(1); Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", FifoCallback<0>); Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", FifoCallback<1>); @@ -139,36 +98,36 @@ TEST_CASE("CoreTiming[SharedSlot]", "[core]") { Core::TimingEventType* cb_d = timing.RegisterEvent("callbackD", FifoCallback<3>); Core::TimingEventType* cb_e = timing.RegisterEvent("callbackE", FifoCallback<4>); - timing.ScheduleEvent(1000, cb_a, CB_IDS[0]); - timing.ScheduleEvent(1000, cb_b, CB_IDS[1]); - timing.ScheduleEvent(1000, cb_c, CB_IDS[2]); - timing.ScheduleEvent(1000, cb_d, CB_IDS[3]); - timing.ScheduleEvent(1000, cb_e, CB_IDS[4]); + timing.ScheduleEvent(1000, cb_a, CB_IDS[0], 0); + timing.ScheduleEvent(1000, cb_b, CB_IDS[1], 0); + timing.ScheduleEvent(1000, cb_c, CB_IDS[2], 0); + timing.ScheduleEvent(1000, cb_d, CB_IDS[3], 0); + timing.ScheduleEvent(1000, cb_e, CB_IDS[4], 0); // Enter slice 0 - timing.Advance(); - REQUIRE(1000 == timing.GetDowncount()); + timing.GetTimer(0)->Advance(); + REQUIRE(1000 == timing.GetTimer(0)->GetDowncount()); callbacks_ran_flags = 0; counter = 0; lateness = 0; - timing.AddTicks(timing.GetDowncount()); - timing.Advance(); - REQUIRE(MAX_SLICE_LENGTH == timing.GetDowncount()); + timing.GetTimer(0)->AddTicks(timing.GetTimer(0)->GetDowncount()); + timing.GetTimer(0)->Advance(); + REQUIRE(MAX_SLICE_LENGTH == timing.GetTimer(0)->GetDowncount()); REQUIRE(0x1FULL == callbacks_ran_flags.to_ullong()); } TEST_CASE("CoreTiming[PredictableLateness]", "[core]") { - Core::Timing timing; + Core::Timing timing(1); Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", CallbackTemplate<0>); Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", CallbackTemplate<1>); // Enter slice 0 - timing.Advance(); + timing.GetTimer(0)->Advance(); - timing.ScheduleEvent(100, cb_a, CB_IDS[0]); - timing.ScheduleEvent(200, cb_b, CB_IDS[1]); + timing.ScheduleEvent(100, cb_a, CB_IDS[0], 0); + timing.ScheduleEvent(200, cb_b, CB_IDS[1], 0); AdvanceAndCheck(timing, 0, 90, 10, -10); // (100 - 10) AdvanceAndCheck(timing, 1, MAX_SLICE_LENGTH, 50, -50); @@ -190,7 +149,7 @@ static void RescheduleCallback(Core::Timing& timing, u64 userdata, s64 cycles_la TEST_CASE("CoreTiming[ChainScheduling]", "[core]") { using namespace ChainSchedulingTest; - Core::Timing timing; + Core::Timing timing(1); Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", CallbackTemplate<0>); Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", CallbackTemplate<1>); @@ -201,28 +160,30 @@ TEST_CASE("CoreTiming[ChainScheduling]", "[core]") { }); // Enter slice 0 - timing.Advance(); + timing.GetTimer(0)->Advance(); - timing.ScheduleEvent(800, cb_a, CB_IDS[0]); - timing.ScheduleEvent(1000, cb_b, CB_IDS[1]); - timing.ScheduleEvent(2200, cb_c, CB_IDS[2]); - timing.ScheduleEvent(1000, cb_rs, reinterpret_cast<u64>(cb_rs)); - REQUIRE(800 == timing.GetDowncount()); + timing.ScheduleEvent(800, cb_a, CB_IDS[0], 0); + timing.ScheduleEvent(1000, cb_b, CB_IDS[1], 0); + timing.ScheduleEvent(2200, cb_c, CB_IDS[2], 0); + timing.ScheduleEvent(1000, cb_rs, reinterpret_cast<u64>(cb_rs), 0); + REQUIRE(800 == timing.GetTimer(0)->GetDowncount()); reschedules = 3; AdvanceAndCheck(timing, 0, 200); // cb_a AdvanceAndCheck(timing, 1, 1000); // cb_b, cb_rs REQUIRE(2 == reschedules); - timing.AddTicks(timing.GetDowncount()); - timing.Advance(); // cb_rs + timing.GetTimer(0)->AddTicks(timing.GetTimer(0)->GetDowncount()); + timing.GetTimer(0)->Advance(); // cb_rs REQUIRE(1 == reschedules); - REQUIRE(200 == timing.GetDowncount()); + REQUIRE(200 == timing.GetTimer(0)->GetDowncount()); AdvanceAndCheck(timing, 2, 800); // cb_c - timing.AddTicks(timing.GetDowncount()); - timing.Advance(); // cb_rs + timing.GetTimer(0)->AddTicks(timing.GetTimer(0)->GetDowncount()); + timing.GetTimer(0)->Advance(); // cb_rs REQUIRE(0 == reschedules); - REQUIRE(MAX_SLICE_LENGTH == timing.GetDowncount()); + REQUIRE(MAX_SLICE_LENGTH == timing.GetTimer(0)->GetDowncount()); } + +// TODO: Add tests for multiple timers diff --git a/src/tests/core/hle/kernel/hle_ipc.cpp b/src/tests/core/hle/kernel/hle_ipc.cpp index fb549f829..59026afd6 100644 --- a/src/tests/core/hle/kernel/hle_ipc.cpp +++ b/src/tests/core/hle/kernel/hle_ipc.cpp @@ -21,9 +21,9 @@ static std::shared_ptr<Object> MakeObject(Kernel::KernelSystem& kernel) { } TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel]") { - Core::Timing timing; + Core::Timing timing(1); Memory::MemorySystem memory; - Kernel::KernelSystem kernel(memory, timing, [] {}, 0); + Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1); auto [server, client] = kernel.CreateSessionPair(); HLERequestContext context(kernel, std::move(server), nullptr); @@ -233,9 +233,9 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel } TEST_CASE("HLERequestContext::WriteToOutgoingCommandBuffer", "[core][kernel]") { - Core::Timing timing; + Core::Timing timing(1); Memory::MemorySystem memory; - Kernel::KernelSystem kernel(memory, timing, [] {}, 0); + Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1); auto [server, client] = kernel.CreateSessionPair(); HLERequestContext context(kernel, std::move(server), nullptr); diff --git a/src/tests/core/memory/memory.cpp b/src/tests/core/memory/memory.cpp index 4a6d54bf7..2e7c71434 100644 --- a/src/tests/core/memory/memory.cpp +++ b/src/tests/core/memory/memory.cpp @@ -11,9 +11,9 @@ #include "core/memory.h" TEST_CASE("Memory::IsValidVirtualAddress", "[core][memory]") { - Core::Timing timing; + Core::Timing timing(1); Memory::MemorySystem memory; - Kernel::KernelSystem kernel(memory, timing, [] {}, 0); + Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1); SECTION("these regions should not be mapped on an empty process") { auto process = kernel.CreateProcess(kernel.CreateCodeSet("", 0)); CHECK(Memory::IsValidVirtualAddress(*process, Memory::PROCESS_IMAGE_VADDR) == false); From 990d27f4f94e4254d143819f8a492c4abdab7a17 Mon Sep 17 00:00:00 2001 From: Marshall Mohror <mohror64@gmail.com> Date: Thu, 20 Feb 2020 09:04:37 -0600 Subject: [PATCH 28/41] Remove C++ standard flag --- src/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4a4c39f17..33d7400fa 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -38,7 +38,6 @@ if (MSVC) /Zo /permissive- /EHsc - /std:c++latest /volatile:iso /Zc:externConstexpr /Zc:inline From 670119ef86a5053b3a25a17b5c369d3cd2bd5dfd Mon Sep 17 00:00:00 2001 From: BreadFish64 <mohror64@gmail.com> Date: Fri, 21 Feb 2020 16:47:04 -0600 Subject: [PATCH 29/41] android: use cmake 3.10 bitrise still doesn't have 3.10 despite it being part of the NDK now --- bitrise.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitrise.yml b/bitrise.yml index e6ab3d22d..d5a5cb1ab 100644 --- a/bitrise.yml +++ b/bitrise.yml @@ -52,7 +52,7 @@ workflows: sudo apt remove cmake -y sudo apt purge --auto-remove cmake -y sudo apt install ninja-build -y - version=3.8 + version=3.10 build=2 mkdir ~/temp cd ~/temp @@ -97,7 +97,7 @@ workflows: sudo apt remove cmake -y sudo apt purge --auto-remove cmake -y sudo apt install ninja-build -y - version=3.8 + version=3.10 build=2 mkdir ~/temp cd ~/temp From 688e44bc8b07a81b4b39a86485dc6e3dcc7fff27 Mon Sep 17 00:00:00 2001 From: Marshall Mohror <mohror64@gmail.com> Date: Sat, 22 Feb 2020 15:37:42 -0600 Subject: [PATCH 30/41] videocore/renderer_opengl/gl_rasterizer_cache: Move bits per pixel table out of function (#5101) * videocore/renderer_opengl/gl_rasterizer_cache: Move bits per pixel table out of function GCC and MSVC copy the table at runtime with the old implementation, which is wasteful and prevents inlining. Unfortunately, static constexpr variables are not legal in constexpr functions, so the table has to be external. Also replaced non-standard assert with DEBUG_ASSERT_MSG. * fix case of table name in assert * set table to private --- .../renderer_opengl/gl_rasterizer_cache.h | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/video_core/renderer_opengl/gl_rasterizer_cache.h b/src/video_core/renderer_opengl/gl_rasterizer_cache.h index cd601ef29..0073ef1d2 100644 --- a/src/video_core/renderer_opengl/gl_rasterizer_cache.h +++ b/src/video_core/renderer_opengl/gl_rasterizer_cache.h @@ -101,6 +101,29 @@ enum class ScaleMatch { }; struct SurfaceParams { +private: + static constexpr std::array<unsigned int, 18> BPP_TABLE = { + 32, // RGBA8 + 24, // RGB8 + 16, // RGB5A1 + 16, // RGB565 + 16, // RGBA4 + 16, // IA8 + 16, // RG8 + 8, // I8 + 8, // A8 + 8, // IA4 + 4, // I4 + 4, // A4 + 4, // ETC1 + 8, // ETC1A4 + 16, // D16 + 0, + 24, // D24 + 32, // D24S8 + }; + +public: enum class PixelFormat { // First 5 formats are shared between textures and color buffers RGBA8 = 0, @@ -139,30 +162,11 @@ struct SurfaceParams { }; static constexpr unsigned int GetFormatBpp(PixelFormat format) { - constexpr std::array<unsigned int, 18> bpp_table = { - 32, // RGBA8 - 24, // RGB8 - 16, // RGB5A1 - 16, // RGB565 - 16, // RGBA4 - 16, // IA8 - 16, // RG8 - 8, // I8 - 8, // A8 - 8, // IA4 - 4, // I4 - 4, // A4 - 4, // ETC1 - 8, // ETC1A4 - 16, // D16 - 0, - 24, // D24 - 32, // D24S8 - }; - - assert(static_cast<std::size_t>(format) < bpp_table.size()); - return bpp_table[static_cast<std::size_t>(format)]; + const auto format_idx = static_cast<std::size_t>(format); + DEBUG_ASSERT_MSG(format_idx < BPP_TABLE.size(), "Invalid pixel format {}", format_idx); + return BPP_TABLE[format_idx]; } + unsigned int GetFormatBpp() const { return GetFormatBpp(pixel_format); } From 2a616fcc5e2df683c09df7999374d324c76d0f68 Mon Sep 17 00:00:00 2001 From: "Gauvain \"GovanifY\" Roussel-Tarbouriech" <gauvain@govanify.com> Date: Sun, 23 Feb 2020 03:54:29 +0100 Subject: [PATCH 31/41] ipc_debugger: Fixing NULL ptr call on multiple clear --- src/citra_qt/debugger/ipc/recorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/citra_qt/debugger/ipc/recorder.cpp b/src/citra_qt/debugger/ipc/recorder.cpp index 24128e38a..ef634d750 100644 --- a/src/citra_qt/debugger/ipc/recorder.cpp +++ b/src/citra_qt/debugger/ipc/recorder.cpp @@ -114,7 +114,7 @@ void IPCRecorderWidget::SetEnabled(bool enabled) { } void IPCRecorderWidget::Clear() { - id_offset = records.size() + 1; + id_offset += records.size(); records.clear(); ui->main->invisibleRootItem()->takeChildren(); From 8eacfceb6a14a9862eae8e82fbfc06a0a8d7fad3 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Sun, 23 Feb 2020 15:22:41 +0800 Subject: [PATCH 32/41] layered_fs: Fix missing file size update This was a silly typo from a previous change. --- src/core/file_sys/layered_fs.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core/file_sys/layered_fs.cpp b/src/core/file_sys/layered_fs.cpp index a77ab38f1..9d5fbf7c2 100644 --- a/src/core/file_sys/layered_fs.cpp +++ b/src/core/file_sys/layered_fs.cpp @@ -177,6 +177,7 @@ void LayeredFS::LoadRelocations() { auto* file = file_path_map.at(path); file->relocation.type = 1; file->relocation.replace_file_path = directory + virtual_name; + file->relocation.size = FileUtil::GetSize(directory + virtual_name); LOG_INFO(Service_FS, "LayeredFS replacement file in use for {}", path); return true; }; From cff00f38c57e841add0db14125c048caae7ce2fa Mon Sep 17 00:00:00 2001 From: liushuyu <liushuyu011@gmail.com> Date: Sun, 23 Feb 2020 03:01:21 -0700 Subject: [PATCH 33/41] Implements fdk_aac decoder (#4764) * audio_core: dsp_hle: implements fdk_aac decoder * audio_core: dsp_hle: clean up and add comments * audio_core: dsp_hle: move fdk include to cpp file * audio_core: dsp_hle: detects broken fdk_aac... ... and refuses to initialize if that's the case * audio_core: dsp_hle: fdk_aac: address comments... ... and rebase commits * fdk_decoder: move fdk header to cpp file --- CMakeLists.txt | 8 + src/audio_core/CMakeLists.txt | 7 + src/audio_core/hle/fdk_decoder.cpp | 233 +++++++++++++++++++++++++++++ src/audio_core/hle/fdk_decoder.h | 23 +++ src/audio_core/hle/hle.cpp | 4 + 5 files changed, 275 insertions(+) create mode 100644 src/audio_core/hle/fdk_decoder.cpp create mode 100644 src/audio_core/hle/fdk_decoder.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 41d55f375..16c7cceb1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,6 +35,8 @@ CMAKE_DEPENDENT_OPTION(ENABLE_MF "Use Media Foundation decoder (preferred over F CMAKE_DEPENDENT_OPTION(COMPILE_WITH_DWARF "Add DWARF debugging information" ON "MINGW" OFF) +CMAKE_DEPENDENT_OPTION(ENABLE_FDK "Use FDK AAC decoder" OFF "NOT ENABLE_FFMPEG_AUDIO_DECODER;NOT ENABLE_MF" OFF) + if(NOT EXISTS ${PROJECT_SOURCE_DIR}/.git/hooks/pre-commit) message(STATUS "Copying pre-commit hook") file(COPY hooks/pre-commit @@ -223,6 +225,12 @@ if (ENABLE_FFMPEG_VIDEO_DUMPER) add_definitions(-DENABLE_FFMPEG_VIDEO_DUMPER) endif() +if (ENABLE_FDK) + find_library(FDK_AAC fdk-aac DOC "The path to fdk_aac library") + if(FDK_AAC STREQUAL "FDK_AAC-NOTFOUND") + message(FATAL_ERROR "fdk_aac library not found.") + endif() +endif() # Platform-specific library requirements # ====================================== diff --git a/src/audio_core/CMakeLists.txt b/src/audio_core/CMakeLists.txt index f2b3e1f3b..2caed1233 100644 --- a/src/audio_core/CMakeLists.txt +++ b/src/audio_core/CMakeLists.txt @@ -62,6 +62,13 @@ elseif(ENABLE_FFMPEG_AUDIO_DECODER) target_include_directories(audio_core PRIVATE ${FFMPEG_DIR}/include) endif() target_compile_definitions(audio_core PUBLIC HAVE_FFMPEG) +elseif(ENABLE_FDK) + target_sources(audio_core PRIVATE + hle/fdk_decoder.cpp + hle/fdk_decoder.h + ) + target_link_libraries(audio_core PRIVATE ${FDK_AAC}) + target_compile_definitions(audio_core PUBLIC HAVE_FDK) endif() if(SDL2_FOUND) diff --git a/src/audio_core/hle/fdk_decoder.cpp b/src/audio_core/hle/fdk_decoder.cpp new file mode 100644 index 000000000..c99e3d43c --- /dev/null +++ b/src/audio_core/hle/fdk_decoder.cpp @@ -0,0 +1,233 @@ +// Copyright 2019 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include <fdk-aac/aacdecoder_lib.h> +#include "audio_core/hle/fdk_decoder.h" + +namespace AudioCore::HLE { + +class FDKDecoder::Impl { +public: + explicit Impl(Memory::MemorySystem& memory); + ~Impl(); + std::optional<BinaryResponse> ProcessRequest(const BinaryRequest& request); + bool IsValid() const { + return decoder != nullptr; + } + +private: + std::optional<BinaryResponse> Initalize(const BinaryRequest& request); + + std::optional<BinaryResponse> Decode(const BinaryRequest& request); + + void Clear(); + + Memory::MemorySystem& memory; + + HANDLE_AACDECODER decoder = nullptr; +}; + +FDKDecoder::Impl::Impl(Memory::MemorySystem& memory) : memory(memory) { + // allocate an array of LIB_INFO structures + // if we don't pre-fill the whole segment with zeros, when we call `aacDecoder_GetLibInfo` + // it will segfault, upon investigation, there is some code in fdk_aac depends on your initial + // values in this array + LIB_INFO decoder_info[FDK_MODULE_LAST] = {}; + // get library information and fill the struct + if (aacDecoder_GetLibInfo(decoder_info) != 0) { + LOG_ERROR(Audio_DSP, "Failed to retrieve fdk_aac library information!"); + return; + } + // This segment: identify the broken fdk_aac implementation + // and refuse to initialize if identified as broken (check for module IDs) + // although our AAC samples do not contain SBC feature, this is a way to detect + // watered down version of fdk_aac implementations + if (FDKlibInfo_getCapabilities(decoder_info, FDK_SBRDEC) == 0) { + LOG_ERROR(Audio_DSP, "Bad fdk_aac library found! Initialization aborted!"); + return; + } + + LOG_INFO(Audio_DSP, "Using fdk_aac version {} (build date: {})", decoder_info[0].versionStr, + decoder_info[0].build_date); + + // choose the input format when initializing: 1 layer of ADTS + decoder = aacDecoder_Open(TRANSPORT_TYPE::TT_MP4_ADTS, 1); + // set maximum output channel to two (stereo) + // if the input samples have more channels, fdk_aac will perform a downmix + AAC_DECODER_ERROR ret = aacDecoder_SetParam(decoder, AAC_PCM_MAX_OUTPUT_CHANNELS, 2); + if (ret != AAC_DEC_OK) { + // unable to set this parameter reflects the decoder implementation might be broken + // we'd better shuts down everything + aacDecoder_Close(decoder); + decoder = nullptr; + LOG_ERROR(Audio_DSP, "Unable to set downmix parameter: {}", ret); + return; + } +} + +std::optional<BinaryResponse> FDKDecoder::Impl::Initalize(const BinaryRequest& request) { + BinaryResponse response; + std::memcpy(&response, &request, sizeof(response)); + response.unknown1 = 0x0; + + if (decoder) { + LOG_INFO(Audio_DSP, "FDK Decoder initialized"); + Clear(); + } else { + LOG_ERROR(Audio_DSP, "Decoder not initialized"); + } + + return response; +} + +FDKDecoder::Impl::~Impl() { + if (decoder) + aacDecoder_Close(decoder); +} + +void FDKDecoder::Impl::Clear() { + s16 decoder_output[8192]; + // flush and re-sync the decoder, discarding the internal buffer + // we actually don't care if this succeeds or not + // FLUSH - flush internal buffer + // INTR - treat the current internal buffer as discontinuous + // CONCEAL - try to interpolate and smooth out the samples + if (decoder) + aacDecoder_DecodeFrame(decoder, decoder_output, 8192, + AACDEC_FLUSH & AACDEC_INTR & AACDEC_CONCEAL); +} + +std::optional<BinaryResponse> FDKDecoder::Impl::ProcessRequest(const BinaryRequest& request) { + if (request.codec != DecoderCodec::AAC) { + LOG_ERROR(Audio_DSP, "FDK AAC Decoder cannot handle such codec: {}", + static_cast<u16>(request.codec)); + return {}; + } + + switch (request.cmd) { + case DecoderCommand::Init: { + return Initalize(request); + } + case DecoderCommand::Decode: { + return Decode(request); + } + case DecoderCommand::Unknown: { + BinaryResponse response; + std::memcpy(&response, &request, sizeof(response)); + response.unknown1 = 0x0; + return response; + } + default: + LOG_ERROR(Audio_DSP, "Got unknown binary request: {}", static_cast<u16>(request.cmd)); + return {}; + } +} + +std::optional<BinaryResponse> FDKDecoder::Impl::Decode(const BinaryRequest& request) { + BinaryResponse response; + response.codec = request.codec; + response.cmd = request.cmd; + response.size = request.size; + + if (!decoder) { + LOG_DEBUG(Audio_DSP, "Decoder not initalized"); + // This is a hack to continue games that are not compiled with the aac codec + response.num_channels = 2; + response.num_samples = 1024; + return response; + } + + if (request.src_addr < Memory::FCRAM_PADDR || + request.src_addr + request.size > Memory::FCRAM_PADDR + Memory::FCRAM_SIZE) { + LOG_ERROR(Audio_DSP, "Got out of bounds src_addr {:08x}", request.src_addr); + return {}; + } + u8* data = memory.GetFCRAMPointer(request.src_addr - Memory::FCRAM_PADDR); + + std::array<std::vector<s16>, 2> out_streams; + + std::size_t data_size = request.size; + + // decoding loops + AAC_DECODER_ERROR result = AAC_DEC_OK; + // 8192 units of s16 are enough to hold one frame of AAC-LC or AAC-HE/v2 data + s16 decoder_output[8192]; + // note that we don't free this pointer as it is automatically freed by fdk_aac + CStreamInfo* stream_info; + // how many bytes to be queued into the decoder, decrementing from the buffer size + u32 buffer_remaining = data_size; + // alias the data_size as an u32 + u32 input_size = data_size; + + while (buffer_remaining) { + // queue the input buffer, fdk_aac will automatically slice out the buffer it needs + // from the input buffer + result = aacDecoder_Fill(decoder, &data, &input_size, &buffer_remaining); + if (result != AAC_DEC_OK) { + // there are some issues when queuing the input buffer + LOG_ERROR(Audio_DSP, "Failed to enqueue the input samples"); + return std::nullopt; + } + // get output from decoder + result = aacDecoder_DecodeFrame(decoder, decoder_output, 8192, 0); + if (result == AAC_DEC_OK) { + // get the stream information + stream_info = aacDecoder_GetStreamInfo(decoder); + // fill the stream information for binary response + response.num_channels = stream_info->aacNumChannels; + response.num_samples = stream_info->frameSize; + // fill the output + // the sample size = frame_size * channel_counts + for (int sample = 0; sample < (stream_info->frameSize * 2); sample++) { + for (int ch = 0; ch < stream_info->aacNumChannels; ch++) { + out_streams[ch].push_back(decoder_output[(sample * 2) + 1]); + } + } + } else if (result == AAC_DEC_TRANSPORT_SYNC_ERROR) { + // decoder has some synchronization problems, try again with new samples, + // using old samples might trigger this error again + continue; + } else { + LOG_ERROR(Audio_DSP, "Error decoding the sample: {}", result); + return std::nullopt; + } + } + // transfer the decoded buffer from vector to the FCRAM + if (out_streams[0].size() != 0) { + if (request.dst_addr_ch0 < Memory::FCRAM_PADDR || + request.dst_addr_ch0 + out_streams[0].size() > + Memory::FCRAM_PADDR + Memory::FCRAM_SIZE) { + LOG_ERROR(Audio_DSP, "Got out of bounds dst_addr_ch0 {:08x}", request.dst_addr_ch0); + return {}; + } + std::memcpy(memory.GetFCRAMPointer(request.dst_addr_ch0 - Memory::FCRAM_PADDR), + out_streams[0].data(), out_streams[0].size()); + } + + if (out_streams[1].size() != 0) { + if (request.dst_addr_ch1 < Memory::FCRAM_PADDR || + request.dst_addr_ch1 + out_streams[1].size() > + Memory::FCRAM_PADDR + Memory::FCRAM_SIZE) { + LOG_ERROR(Audio_DSP, "Got out of bounds dst_addr_ch1 {:08x}", request.dst_addr_ch1); + return {}; + } + std::memcpy(memory.GetFCRAMPointer(request.dst_addr_ch1 - Memory::FCRAM_PADDR), + out_streams[1].data(), out_streams[1].size()); + } + return response; +} + +FDKDecoder::FDKDecoder(Memory::MemorySystem& memory) : impl(std::make_unique<Impl>(memory)) {} + +FDKDecoder::~FDKDecoder() = default; + +std::optional<BinaryResponse> FDKDecoder::ProcessRequest(const BinaryRequest& request) { + return impl->ProcessRequest(request); +} + +bool FDKDecoder::IsValid() const { + return impl->IsValid(); +} + +} // namespace AudioCore::HLE diff --git a/src/audio_core/hle/fdk_decoder.h b/src/audio_core/hle/fdk_decoder.h new file mode 100644 index 000000000..337c6054a --- /dev/null +++ b/src/audio_core/hle/fdk_decoder.h @@ -0,0 +1,23 @@ +// Copyright 2019 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include "audio_core/hle/decoder.h" + +namespace AudioCore::HLE { + +class FDKDecoder final : public DecoderBase { +public: + explicit FDKDecoder(Memory::MemorySystem& memory); + ~FDKDecoder() override; + std::optional<BinaryResponse> ProcessRequest(const BinaryRequest& request) override; + bool IsValid() const override; + +private: + class Impl; + std::unique_ptr<Impl> impl; +}; + +} // namespace AudioCore::HLE diff --git a/src/audio_core/hle/hle.cpp b/src/audio_core/hle/hle.cpp index 873d6a72e..052e507c5 100644 --- a/src/audio_core/hle/hle.cpp +++ b/src/audio_core/hle/hle.cpp @@ -7,6 +7,8 @@ #include "audio_core/hle/wmf_decoder.h" #elif HAVE_FFMPEG #include "audio_core/hle/ffmpeg_decoder.h" +#elif HAVE_FDK +#include "audio_core/hle/fdk_decoder.h" #endif #include "audio_core/hle/common.h" #include "audio_core/hle/decoder.h" @@ -97,6 +99,8 @@ DspHle::Impl::Impl(DspHle& parent_, Memory::MemorySystem& memory) : parent(paren decoder = std::make_unique<HLE::WMFDecoder>(memory); #elif defined(HAVE_FFMPEG) decoder = std::make_unique<HLE::FFMPEGDecoder>(memory); +#elif defined(HAVE_FDK) + decoder = std::make_unique<HLE::FDKDecoder>(memory); #else LOG_WARNING(Audio_DSP, "No decoder found, this could lead to missing audio"); decoder = std::make_unique<HLE::NullDecoder>(); From d8bb37fc2fb123992493aa4d22184966523c48c0 Mon Sep 17 00:00:00 2001 From: "Gauvain \"GovanifY\" Roussel-Tarbouriech" <gauvain@govanify.com> Date: Sun, 23 Feb 2020 21:33:49 +0100 Subject: [PATCH 34/41] gdbstub: Ensure gdbstub doesn't drop packets crucial to initialization --- src/core/core.cpp | 2 +- src/core/gdbstub/gdbstub.cpp | 9 ++++++++- src/core/gdbstub/gdbstub.h | 7 +++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/core/core.cpp b/src/core/core.cpp index cd1799e42..401fff01b 100644 --- a/src/core/core.cpp +++ b/src/core/core.cpp @@ -306,7 +306,7 @@ System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mo HW::Init(*memory); Service::Init(*this); - GDBStub::Init(); + GDBStub::DeferStart(); VideoCore::ResultStatus result = VideoCore::Init(emu_window, *memory); if (result != VideoCore::ResultStatus::Success) { diff --git a/src/core/gdbstub/gdbstub.cpp b/src/core/gdbstub/gdbstub.cpp index a7ed44aff..bf15f1011 100644 --- a/src/core/gdbstub/gdbstub.cpp +++ b/src/core/gdbstub/gdbstub.cpp @@ -121,6 +121,7 @@ constexpr char target_xml[] = )"; int gdbserver_socket = -1; +bool defer_start = false; u8 command_buffer[GDB_BUFFER_SIZE]; u32 command_length; @@ -1042,7 +1043,8 @@ static void RemoveBreakpoint() { } void HandlePacket() { - if (!IsConnected()) { + if (!IsConnected() && defer_start) { + ToggleServer(true); return; } @@ -1133,6 +1135,10 @@ void ToggleServer(bool status) { } } +void DeferStart() { + defer_start = true; +} + static void Init(u16 port) { if (!server_enabled) { // Set the halt loop to false in case the user enabled the gdbstub mid-execution. @@ -1216,6 +1222,7 @@ void Shutdown() { if (!server_enabled) { return; } + defer_start = false; LOG_INFO(Debug_GDBStub, "Stopping GDB ..."); if (gdbserver_socket != -1) { diff --git a/src/core/gdbstub/gdbstub.h b/src/core/gdbstub/gdbstub.h index 131b3f823..b878b7957 100644 --- a/src/core/gdbstub/gdbstub.h +++ b/src/core/gdbstub/gdbstub.h @@ -42,6 +42,13 @@ void ToggleServer(bool status); /// Start the gdbstub server. void Init(); +/** + * Defer initialization of the gdbstub to the first packet processing functions. + * This avoids a case where the gdbstub thread is frozen after initialization + * and fails to respond in time to packets. + */ +void DeferStart(); + /// Stop gdbstub server. void Shutdown(); From 8fedd5c240b863800879ed3aa8ba3073e959553b Mon Sep 17 00:00:00 2001 From: "Gauvain \"GovanifY\" Roussel-Tarbouriech" <gauvain@govanify.com> Date: Mon, 24 Feb 2020 14:30:24 +0100 Subject: [PATCH 35/41] gdbstub: small logic bug fix with defer_start --- src/core/gdbstub/gdbstub.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/core/gdbstub/gdbstub.cpp b/src/core/gdbstub/gdbstub.cpp index bf15f1011..1e42ff5e4 100644 --- a/src/core/gdbstub/gdbstub.cpp +++ b/src/core/gdbstub/gdbstub.cpp @@ -1043,8 +1043,10 @@ static void RemoveBreakpoint() { } void HandlePacket() { - if (!IsConnected() && defer_start) { - ToggleServer(true); + if (!IsConnected()) { + if (defer_start) { + ToggleServer(true); + } return; } From 9d57325a8b65cc574d3274272a2589822076cbd6 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Fri, 28 Feb 2020 23:36:17 +0800 Subject: [PATCH 36/41] core/file_sys: Add alternative override pathes for ExeFS files You can now directly place ExeFS overrides/patches inside the mod folder (instead of the exefs subfolder). This allows us to have drop-in compatibility with Luma3DS mods. --- src/core/file_sys/ncch_container.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp index 81525786f..177526aff 100644 --- a/src/core/file_sys/ncch_container.cpp +++ b/src/core/file_sys/ncch_container.cpp @@ -539,9 +539,11 @@ Loader::ResultStatus NCCHContainer::ApplyCodePatch(std::vector<u8>& code) const const auto mods_path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), GetModId(ncch_header.program_id)); - const std::array<PatchLocation, 4> patch_paths{{ + const std::array<PatchLocation, 6> patch_paths{{ {mods_path + "exefs/code.ips", Patch::ApplyIpsPatch}, {mods_path + "exefs/code.bps", Patch::ApplyBpsPatch}, + {mods_path + "code.ips", Patch::ApplyIpsPatch}, + {mods_path + "code.bps", Patch::ApplyBpsPatch}, {filepath + ".exefsdir/code.ips", Patch::ApplyIpsPatch}, {filepath + ".exefsdir/code.bps", Patch::ApplyBpsPatch}, }}; @@ -583,8 +585,9 @@ Loader::ResultStatus NCCHContainer::LoadOverrideExeFSSection(const char* name, const auto mods_path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), GetModId(ncch_header.program_id)); - std::array<std::string, 2> override_paths{{ + std::array<std::string, 3> override_paths{{ mods_path + "exefs/" + override_name, + mods_path + override_name, filepath + ".exefsdir/" + override_name, }}; From cfd2ab61212160c1162549ba257c4a3f0d287910 Mon Sep 17 00:00:00 2001 From: BreadFish64 <mohror64@gmail.com> Date: Fri, 28 Feb 2020 13:45:19 -0600 Subject: [PATCH 37/41] video_core: use explicit interval type in texture cache The default is discrete_interval which has dynamic open-ness. We only use right_open intervals anyway. In theory this could allow some compile-time optimizations. --- .../renderer_opengl/gl_rasterizer_cache.h | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/video_core/renderer_opengl/gl_rasterizer_cache.h b/src/video_core/renderer_opengl/gl_rasterizer_cache.h index cd601ef29..5581adb94 100644 --- a/src/video_core/renderer_opengl/gl_rasterizer_cache.h +++ b/src/video_core/renderer_opengl/gl_rasterizer_cache.h @@ -80,11 +80,15 @@ struct CachedSurface; using Surface = std::shared_ptr<CachedSurface>; using SurfaceSet = std::set<Surface>; -using SurfaceRegions = boost::icl::interval_set<PAddr>; -using SurfaceMap = boost::icl::interval_map<PAddr, Surface>; -using SurfaceCache = boost::icl::interval_map<PAddr, SurfaceSet>; +using SurfaceInterval = boost::icl::right_open_interval<PAddr>; +using SurfaceRegions = boost::icl::interval_set<PAddr, std::less, SurfaceInterval>; +using SurfaceMap = + boost::icl::interval_map<PAddr, Surface, boost::icl::partial_absorber, std::less, + boost::icl::inplace_plus, boost::icl::inter_section, SurfaceInterval>; +using SurfaceCache = + boost::icl::interval_map<PAddr, SurfaceSet, boost::icl::partial_absorber, std::less, + boost::icl::inplace_plus, boost::icl::inter_section, SurfaceInterval>; -using SurfaceInterval = SurfaceCache::interval_type; static_assert(std::is_same<SurfaceRegions::interval_type, SurfaceCache::interval_type>() && std::is_same<SurfaceMap::interval_type, SurfaceCache::interval_type>(), "incorrect interval types"); @@ -245,7 +249,7 @@ struct SurfaceParams { } SurfaceInterval GetInterval() const { - return SurfaceInterval::right_open(addr, end); + return SurfaceInterval(addr, end); } // Returns the outer rectangle containing "interval" From 0fe832bb49eaa4c0b82d52955ddc73fa6c7fd575 Mon Sep 17 00:00:00 2001 From: zhupengfei <zhupf321@gmail.com> Date: Sat, 29 Feb 2020 09:11:34 +0800 Subject: [PATCH 38/41] Make the arrays const --- src/core/file_sys/ncch_container.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp index 177526aff..056f7a901 100644 --- a/src/core/file_sys/ncch_container.cpp +++ b/src/core/file_sys/ncch_container.cpp @@ -315,7 +315,7 @@ Loader::ResultStatus NCCHContainer::Load() { const auto mods_path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), GetModId(ncch_header.program_id)); - std::array<std::string, 2> exheader_override_paths{{ + const std::array<std::string, 2> exheader_override_paths{{ mods_path + "exheader.bin", filepath + ".exheader", }}; @@ -585,7 +585,7 @@ Loader::ResultStatus NCCHContainer::LoadOverrideExeFSSection(const char* name, const auto mods_path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), GetModId(ncch_header.program_id)); - std::array<std::string, 3> override_paths{{ + const std::array<std::string, 3> override_paths{{ mods_path + "exefs/" + override_name, mods_path + override_name, filepath + ".exefsdir/" + override_name, From 9dfb83f1e17cb4abccf2950a588954167cb84df6 Mon Sep 17 00:00:00 2001 From: Marshall Mohror <mohror64@gmail.com> Date: Sat, 29 Feb 2020 11:13:28 -0600 Subject: [PATCH 39/41] Update README.md (#5097) * Update README.md * fix travis link --- README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index dc5c6ce30..522c482f0 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Citra ============== -[![Travis CI Build Status](https://travis-ci.org/citra-emu/citra.svg?branch=master)](https://travis-ci.org/citra-emu/citra) +[![Travis CI Build Status](https://travis-ci.com/citra-emu/citra.svg?branch=master)](https://travis-ci.com/citra-emu/citra) [![AppVeyor CI Build Status](https://ci.appveyor.com/api/projects/status/sdf1o4kh3g1e68m9?svg=true)](https://ci.appveyor.com/project/bunnei/citra) [![Bitrise CI Build Status](https://app.bitrise.io/app/4ccd8e5720f0d13b/status.svg?token=H32TmbCwxb3OQ-M66KbAyw&branch=master)](https://app.bitrise.io/app/4ccd8e5720f0d13b) @@ -16,13 +16,13 @@ Check out our [website](https://citra-emu.org/)! Need help? Check out our [asking for help](https://citra-emu.org/help/reference/asking/) guide. -For development discussion, please join us at #citra-dev on freenode. +For development discussion, please join us on our [Discord server](https://citra-emu.org/discord/) or at #citra-dev on freenode. ### Development Most of the development happens on GitHub. It's also where [our central repository](https://github.com/citra-emu/citra) is hosted. -If you want to contribute please take a look at the [Contributor's Guide](https://github.com/citra-emu/citra/wiki/Contributing) and [Developer Information](https://github.com/citra-emu/citra/wiki/Developer-Information). You should as well contact any of the developers in the forum in order to know about the current state of the emulator because the [TODO list](https://docs.google.com/document/d/1SWIop0uBI9IW8VGg97TAtoT_CHNoP42FzYmvG1F4QDA) isn't maintained anymore. +If you want to contribute please take a look at the [Contributor's Guide](https://github.com/citra-emu/citra/wiki/Contributing) and [Developer Information](https://github.com/citra-emu/citra/wiki/Developer-Information). You should also contact any of the developers in the forum in order to know about the current state of the emulator because the [TODO list](https://docs.google.com/document/d/1SWIop0uBI9IW8VGg97TAtoT_CHNoP42FzYmvG1F4QDA) isn't maintained anymore. If you want to contribute to the user interface translation, please checkout [citra project on transifex](https://www.transifex.com/citra/citra). We centralize the translation work there, and periodically upstream translation. @@ -39,6 +39,5 @@ We happily accept monetary donations or donated games and hardware. Please see o * 3DS games for testing * Any equipment required for homebrew * Infrastructure setup -* Eventually 3D displays to get proper 3D output working -We also more than gladly accept used 3DS consoles, preferably ones with firmware 4.5 or lower! If you would like to give yours away, don't hesitate to join our IRC channel #citra on [Freenode](http://webchat.freenode.net/?channels=citra) and talk to neobrain or bunnei. Mind you, IRC is slow-paced, so it might be a while until people reply. If you're in a hurry you can just leave contact details in the channel or via private message and we'll get back to you. +We also more than gladly accept used 3DS consoles! If you would like to give yours away, don't hesitate to join our [Discord server](https://citra-emu.org/discord/) and talk to bunnei. From 6d3d9f7a8a81151b56e662ab88a9ec80f4d08375 Mon Sep 17 00:00:00 2001 From: Tobias <thm.frey@gmail.com> Date: Sat, 29 Feb 2020 19:48:27 +0100 Subject: [PATCH 40/41] core: Add support for N3DS memory mappings (#5103) * core: Add support for N3DS memory mappings * Address review comments --- src/core/core.cpp | 8 +++--- src/core/core.h | 2 +- src/core/file_sys/ncch_container.h | 3 ++- src/core/hle/kernel/kernel.cpp | 4 +-- src/core/hle/kernel/kernel.h | 4 +-- src/core/hle/kernel/memory.cpp | 36 ++++++++++++++++++++------ src/core/loader/loader.h | 11 +++++++- src/core/loader/ncch.cpp | 13 ++++++++++ src/core/loader/ncch.h | 2 ++ src/tests/core/arm/arm_test_common.cpp | 2 +- src/tests/core/hle/kernel/hle_ipc.cpp | 4 +-- src/tests/core/memory/memory.cpp | 2 +- 12 files changed, 69 insertions(+), 22 deletions(-) diff --git a/src/core/core.cpp b/src/core/core.cpp index 01ab8481c..03bd384f2 100644 --- a/src/core/core.cpp +++ b/src/core/core.cpp @@ -174,7 +174,9 @@ System::ResultStatus System::Load(Frontend::EmuWindow& emu_window, const std::st } ASSERT(system_mode.first); - ResultStatus init_result{Init(emu_window, *system_mode.first)}; + auto n3ds_mode = app_loader->LoadKernelN3dsMode(); + ASSERT(n3ds_mode.first); + ResultStatus init_result{Init(emu_window, *system_mode.first, *n3ds_mode.first)}; if (init_result != ResultStatus::Success) { LOG_CRITICAL(Core, "Failed to initialize system (Error {})!", static_cast<u32>(init_result)); @@ -246,7 +248,7 @@ void System::Reschedule() { } } -System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mode) { +System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mode, u8 n3ds_mode) { LOG_DEBUG(HW_Memory, "initialized OK"); std::size_t num_cores = 2; @@ -259,7 +261,7 @@ System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mo timing = std::make_unique<Timing>(num_cores); kernel = std::make_unique<Kernel::KernelSystem>( - *memory, *timing, [this] { PrepareReschedule(); }, system_mode, num_cores); + *memory, *timing, [this] { PrepareReschedule(); }, system_mode, num_cores, n3ds_mode); if (Settings::values.use_cpu_jit) { #ifdef ARCHITECTURE_x86_64 diff --git a/src/core/core.h b/src/core/core.h index 2727ea78c..4bc8e6a85 100644 --- a/src/core/core.h +++ b/src/core/core.h @@ -303,7 +303,7 @@ private: * @param system_mode The system mode. * @return ResultStatus code, indicating if the operation succeeded. */ - ResultStatus Init(Frontend::EmuWindow& emu_window, u32 system_mode); + ResultStatus Init(Frontend::EmuWindow& emu_window, u32 system_mode, u8 n3ds_mode); /// Reschedule the core emulation void Reschedule(); diff --git a/src/core/file_sys/ncch_container.h b/src/core/file_sys/ncch_container.h index 8deda1fff..7eadd9835 100644 --- a/src/core/file_sys/ncch_container.h +++ b/src/core/file_sys/ncch_container.h @@ -149,7 +149,8 @@ struct ExHeader_StorageInfo { struct ExHeader_ARM11_SystemLocalCaps { u64_le program_id; u32_le core_version; - u8 reserved_flags[2]; + u8 reserved_flag; + u8 n3ds_mode; union { u8 flags0; BitField<0, 2, u8> ideal_processor; diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp index c0b6f8308..995cfc658 100644 --- a/src/core/hle/kernel/kernel.cpp +++ b/src/core/hle/kernel/kernel.cpp @@ -19,10 +19,10 @@ namespace Kernel { /// Initialize the kernel KernelSystem::KernelSystem(Memory::MemorySystem& memory, Core::Timing& timing, std::function<void()> prepare_reschedule_callback, u32 system_mode, - u32 num_cores) + u32 num_cores, u8 n3ds_mode) : memory(memory), timing(timing), prepare_reschedule_callback(std::move(prepare_reschedule_callback)) { - MemoryInit(system_mode); + MemoryInit(system_mode, n3ds_mode); resource_limits = std::make_unique<ResourceLimitList>(*this); for (u32 core_id = 0; core_id < num_cores; ++core_id) { diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h index fd68cbf6d..c275d7ce4 100644 --- a/src/core/hle/kernel/kernel.h +++ b/src/core/hle/kernel/kernel.h @@ -86,7 +86,7 @@ class KernelSystem { public: explicit KernelSystem(Memory::MemorySystem& memory, Core::Timing& timing, std::function<void()> prepare_reschedule_callback, u32 system_mode, - u32 num_cores); + u32 num_cores, u8 n3ds_mode); ~KernelSystem(); using PortPair = std::pair<std::shared_ptr<ServerPort>, std::shared_ptr<ClientPort>>; @@ -263,7 +263,7 @@ public: Core::Timing& timing; private: - void MemoryInit(u32 mem_type); + void MemoryInit(u32 mem_type, u8 n3ds_mode); std::function<void()> prepare_reschedule_callback; diff --git a/src/core/hle/kernel/memory.cpp b/src/core/hle/kernel/memory.cpp index e4aae5d13..7b6425bbd 100644 --- a/src/core/hle/kernel/memory.cpp +++ b/src/core/hle/kernel/memory.cpp @@ -19,6 +19,7 @@ #include "core/hle/kernel/vm_manager.h" #include "core/hle/result.h" #include "core/memory.h" +#include "core/settings.h" //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -40,11 +41,32 @@ static const u32 memory_region_sizes[8][3] = { {0x0B200000, 0x02E00000, 0x02000000}, // 7 }; -void KernelSystem::MemoryInit(u32 mem_type) { - // TODO(yuriks): On the n3DS, all o3DS configurations (<=5) are forced to 6 instead. - ASSERT_MSG(mem_type <= 5, "New 3DS memory configuration aren't supported yet!"); +namespace MemoryMode { +enum N3DSMode : u8 { + Mode6 = 1, + Mode7 = 2, + Mode6_2 = 3, +}; +} + +void KernelSystem::MemoryInit(u32 mem_type, u8 n3ds_mode) { ASSERT(mem_type != 1); + const bool is_new_3ds = Settings::values.is_new_3ds; + u32 reported_mem_type = mem_type; + if (is_new_3ds) { + if (n3ds_mode == MemoryMode::Mode6 || n3ds_mode == MemoryMode::Mode6_2) { + mem_type = 6; + reported_mem_type = 6; + } else if (n3ds_mode == MemoryMode::Mode7) { + mem_type = 7; + reported_mem_type = 7; + } else { + // On the N3ds, all O3ds configurations (<=5) are forced to 6 instead. + mem_type = 6; + } + } + // The kernel allocation regions (APPLICATION, SYSTEM and BASE) are laid out in sequence, with // the sizes specified in the memory_region_sizes table. VAddr base = 0; @@ -55,14 +77,12 @@ void KernelSystem::MemoryInit(u32 mem_type) { } // We must've allocated the entire FCRAM by the end - ASSERT(base == Memory::FCRAM_SIZE); + ASSERT(base == (is_new_3ds ? Memory::FCRAM_N3DS_SIZE : Memory::FCRAM_SIZE)); config_mem_handler = std::make_unique<ConfigMem::Handler>(); auto& config_mem = config_mem_handler->GetConfigMem(); - config_mem.app_mem_type = mem_type; - // app_mem_malloc does not always match the configured size for memory_region[0]: in case the - // n3DS type override is in effect it reports the size the game expects, not the real one. - config_mem.app_mem_alloc = memory_region_sizes[mem_type][0]; + config_mem.app_mem_type = reported_mem_type; + config_mem.app_mem_alloc = memory_region_sizes[reported_mem_type][0]; config_mem.sys_mem_alloc = memory_regions[1].size; config_mem.base_mem_alloc = memory_regions[2].size; diff --git a/src/core/loader/loader.h b/src/core/loader/loader.h index 0414f181c..f9a8e9296 100644 --- a/src/core/loader/loader.h +++ b/src/core/loader/loader.h @@ -105,13 +105,22 @@ public: * Loads the system mode that this application needs. * This function defaults to 2 (96MB allocated to the application) if it can't read the * information. - * @returns A pair with the optional system mode, and and the status. + * @returns A pair with the optional system mode, and the status. */ virtual std::pair<std::optional<u32>, ResultStatus> LoadKernelSystemMode() { // 96MB allocated to the application. return std::make_pair(2, ResultStatus::Success); } + /** + * Loads the N3ds mode that this application uses. + * It defaults to 0 (O3DS default) if it can't read the information. + * @returns A pair with the optional N3ds mode, and the status. + */ + virtual std::pair<std::optional<u8>, ResultStatus> LoadKernelN3dsMode() { + return std::make_pair(0, ResultStatus::Success); + } + /** * Get whether this application is executable. * @param out_executable Reference to store the executable flag into. diff --git a/src/core/loader/ncch.cpp b/src/core/loader/ncch.cpp index 1a966da5e..81053524f 100644 --- a/src/core/loader/ncch.cpp +++ b/src/core/loader/ncch.cpp @@ -61,6 +61,19 @@ std::pair<std::optional<u32>, ResultStatus> AppLoader_NCCH::LoadKernelSystemMode ResultStatus::Success); } +std::pair<std::optional<u8>, ResultStatus> AppLoader_NCCH::LoadKernelN3dsMode() { + if (!is_loaded) { + ResultStatus res = base_ncch.Load(); + if (res != ResultStatus::Success) { + return std::make_pair(std::optional<u8>{}, res); + } + } + + // Set the system mode as the one from the exheader. + return std::make_pair(overlay_ncch->exheader_header.arm11_system_local_caps.n3ds_mode, + ResultStatus::Success); +} + ResultStatus AppLoader_NCCH::LoadExec(std::shared_ptr<Kernel::Process>& process) { using Kernel::CodeSet; diff --git a/src/core/loader/ncch.h b/src/core/loader/ncch.h index 041cfddbd..6f680b063 100644 --- a/src/core/loader/ncch.h +++ b/src/core/loader/ncch.h @@ -41,6 +41,8 @@ public: */ std::pair<std::optional<u32>, ResultStatus> LoadKernelSystemMode() override; + std::pair<std::optional<u8>, ResultStatus> LoadKernelN3dsMode() override; + ResultStatus IsExecutable(bool& out_executable) override; ResultStatus ReadCode(std::vector<u8>& buffer) override; diff --git a/src/tests/core/arm/arm_test_common.cpp b/src/tests/core/arm/arm_test_common.cpp index 0957c7c97..583459e7c 100644 --- a/src/tests/core/arm/arm_test_common.cpp +++ b/src/tests/core/arm/arm_test_common.cpp @@ -17,7 +17,7 @@ TestEnvironment::TestEnvironment(bool mutable_memory_) timing = std::make_unique<Core::Timing>(1); memory = std::make_unique<Memory::MemorySystem>(); - kernel = std::make_unique<Kernel::KernelSystem>(*memory, *timing, [] {}, 0, 1); + kernel = std::make_unique<Kernel::KernelSystem>(*memory, *timing, [] {}, 0, 1, 0); kernel->SetCurrentProcess(kernel->CreateProcess(kernel->CreateCodeSet("", 0))); page_table = &kernel->GetCurrentProcess()->vm_manager.page_table; diff --git a/src/tests/core/hle/kernel/hle_ipc.cpp b/src/tests/core/hle/kernel/hle_ipc.cpp index 59026afd6..d1aacd739 100644 --- a/src/tests/core/hle/kernel/hle_ipc.cpp +++ b/src/tests/core/hle/kernel/hle_ipc.cpp @@ -23,7 +23,7 @@ static std::shared_ptr<Object> MakeObject(Kernel::KernelSystem& kernel) { TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel]") { Core::Timing timing(1); Memory::MemorySystem memory; - Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1); + Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1, 0); auto [server, client] = kernel.CreateSessionPair(); HLERequestContext context(kernel, std::move(server), nullptr); @@ -235,7 +235,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel TEST_CASE("HLERequestContext::WriteToOutgoingCommandBuffer", "[core][kernel]") { Core::Timing timing(1); Memory::MemorySystem memory; - Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1); + Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1, 0); auto [server, client] = kernel.CreateSessionPair(); HLERequestContext context(kernel, std::move(server), nullptr); diff --git a/src/tests/core/memory/memory.cpp b/src/tests/core/memory/memory.cpp index 2e7c71434..8f08862a1 100644 --- a/src/tests/core/memory/memory.cpp +++ b/src/tests/core/memory/memory.cpp @@ -13,7 +13,7 @@ TEST_CASE("Memory::IsValidVirtualAddress", "[core][memory]") { Core::Timing timing(1); Memory::MemorySystem memory; - Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1); + Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1, 0); SECTION("these regions should not be mapped on an empty process") { auto process = kernel.CreateProcess(kernel.CreateCodeSet("", 0)); CHECK(Memory::IsValidVirtualAddress(*process, Memory::PROCESS_IMAGE_VADDR) == false); From 2c0bd0f2a120144d0b7845f4e8d56fb243ba36ef Mon Sep 17 00:00:00 2001 From: Weiyi Wang <wwylele@gmail.com> Date: Sat, 29 Feb 2020 13:52:34 -0500 Subject: [PATCH 41/41] travis/transifex: use HEREDOC for initializing config (#5109) It seems that in recent bash update in CI doesn't correctly interpret the quote string any more. --- .travis/transifex/docker.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.travis/transifex/docker.sh b/.travis/transifex/docker.sh index 003b298b6..8d0d7e2d4 100644 --- a/.travis/transifex/docker.sh +++ b/.travis/transifex/docker.sh @@ -1,7 +1,13 @@ #!/bin/bash -e # Setup RC file for tx -echo $'[https://www.transifex.com]\nhostname = https://www.transifex.com\nusername = api\npassword = '"$TRANSIFEX_API_TOKEN"$'\n' > ~/.transifexrc +cat << EOF > ~/.transifexrc +[https://www.transifex.com] +hostname = https://www.transifex.com +username = api +password = $TRANSIFEX_API_TOKEN +EOF + set -x