diff --git a/.travis/transifex/docker.sh b/.travis/transifex/docker.sh index 003b298b6..8d0d7e2d4 100644 --- a/.travis/transifex/docker.sh +++ b/.travis/transifex/docker.sh @@ -1,7 +1,13 @@ #!/bin/bash -e # Setup RC file for tx -echo $'[https://www.transifex.com]\nhostname = https://www.transifex.com\nusername = api\npassword = '"$TRANSIFEX_API_TOKEN"$'\n' > ~/.transifexrc +cat << EOF > ~/.transifexrc +[https://www.transifex.com] +hostname = https://www.transifex.com +username = api +password = $TRANSIFEX_API_TOKEN +EOF + set -x diff --git a/CMakeLists.txt b/CMakeLists.txt index 45deee61a..faa41ea54 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,8 @@ CMAKE_DEPENDENT_OPTION(COMPILE_WITH_DWARF "Add DWARF debugging information" ON " option(USE_SYSTEM_BOOST "Use the system Boost libs (instead of the bundled ones)" OFF) +CMAKE_DEPENDENT_OPTION(ENABLE_FDK "Use FDK AAC decoder" OFF "NOT ENABLE_FFMPEG_AUDIO_DECODER;NOT ENABLE_MF" OFF) + if(NOT EXISTS ${PROJECT_SOURCE_DIR}/.git/hooks/pre-commit) message(STATUS "Copying pre-commit hook") file(COPY hooks/pre-commit @@ -218,6 +220,12 @@ if (ENABLE_FFMPEG_VIDEO_DUMPER) add_definitions(-DENABLE_FFMPEG_VIDEO_DUMPER) endif() +if (ENABLE_FDK) + find_library(FDK_AAC fdk-aac DOC "The path to fdk_aac library") + if(FDK_AAC STREQUAL "FDK_AAC-NOTFOUND") + message(FATAL_ERROR "fdk_aac library not found.") + endif() +endif() # Platform-specific library requirements # ====================================== diff --git a/README.md b/README.md index dc5c6ce30..522c482f0 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Citra ============== -[![Travis CI Build Status](https://travis-ci.org/citra-emu/citra.svg?branch=master)](https://travis-ci.org/citra-emu/citra) +[![Travis CI Build Status](https://travis-ci.com/citra-emu/citra.svg?branch=master)](https://travis-ci.com/citra-emu/citra) [![AppVeyor CI Build Status](https://ci.appveyor.com/api/projects/status/sdf1o4kh3g1e68m9?svg=true)](https://ci.appveyor.com/project/bunnei/citra) [![Bitrise CI Build Status](https://app.bitrise.io/app/4ccd8e5720f0d13b/status.svg?token=H32TmbCwxb3OQ-M66KbAyw&branch=master)](https://app.bitrise.io/app/4ccd8e5720f0d13b) @@ -16,13 +16,13 @@ Check out our [website](https://citra-emu.org/)! Need help? Check out our [asking for help](https://citra-emu.org/help/reference/asking/) guide. -For development discussion, please join us at #citra-dev on freenode. +For development discussion, please join us on our [Discord server](https://citra-emu.org/discord/) or at #citra-dev on freenode. ### Development Most of the development happens on GitHub. It's also where [our central repository](https://github.com/citra-emu/citra) is hosted. -If you want to contribute please take a look at the [Contributor's Guide](https://github.com/citra-emu/citra/wiki/Contributing) and [Developer Information](https://github.com/citra-emu/citra/wiki/Developer-Information). You should as well contact any of the developers in the forum in order to know about the current state of the emulator because the [TODO list](https://docs.google.com/document/d/1SWIop0uBI9IW8VGg97TAtoT_CHNoP42FzYmvG1F4QDA) isn't maintained anymore. +If you want to contribute please take a look at the [Contributor's Guide](https://github.com/citra-emu/citra/wiki/Contributing) and [Developer Information](https://github.com/citra-emu/citra/wiki/Developer-Information). You should also contact any of the developers in the forum in order to know about the current state of the emulator because the [TODO list](https://docs.google.com/document/d/1SWIop0uBI9IW8VGg97TAtoT_CHNoP42FzYmvG1F4QDA) isn't maintained anymore. If you want to contribute to the user interface translation, please checkout [citra project on transifex](https://www.transifex.com/citra/citra). We centralize the translation work there, and periodically upstream translation. @@ -39,6 +39,5 @@ We happily accept monetary donations or donated games and hardware. Please see o * 3DS games for testing * Any equipment required for homebrew * Infrastructure setup -* Eventually 3D displays to get proper 3D output working -We also more than gladly accept used 3DS consoles, preferably ones with firmware 4.5 or lower! If you would like to give yours away, don't hesitate to join our IRC channel #citra on [Freenode](http://webchat.freenode.net/?channels=citra) and talk to neobrain or bunnei. Mind you, IRC is slow-paced, so it might be a while until people reply. If you're in a hurry you can just leave contact details in the channel or via private message and we'll get back to you. +We also more than gladly accept used 3DS consoles! If you would like to give yours away, don't hesitate to join our [Discord server](https://citra-emu.org/discord/) and talk to bunnei. diff --git a/bitrise.yml b/bitrise.yml index e6ab3d22d..d5a5cb1ab 100644 --- a/bitrise.yml +++ b/bitrise.yml @@ -52,7 +52,7 @@ workflows: sudo apt remove cmake -y sudo apt purge --auto-remove cmake -y sudo apt install ninja-build -y - version=3.8 + version=3.10 build=2 mkdir ~/temp cd ~/temp @@ -97,7 +97,7 @@ workflows: sudo apt remove cmake -y sudo apt purge --auto-remove cmake -y sudo apt install ninja-build -y - version=3.8 + version=3.10 build=2 mkdir ~/temp cd ~/temp diff --git a/externals/CMakeLists.txt b/externals/CMakeLists.txt index 47a92794c..49cca86a1 100644 --- a/externals/CMakeLists.txt +++ b/externals/CMakeLists.txt @@ -115,9 +115,14 @@ if (ENABLE_WEB_SERVICE) # lurlparser add_subdirectory(lurlparser EXCLUDE_FROM_ALL) + if(ANDROID) + add_subdirectory(android-ifaddrs) + endif() + # httplib add_library(httplib INTERFACE) target_include_directories(httplib INTERFACE ./httplib) + target_compile_options(httplib INTERFACE -DCPPHTTPLIB_OPENSSL_SUPPORT) # cpp-jwt add_library(cpp-jwt INTERFACE) diff --git a/externals/android-ifaddrs/CMakeLists.txt b/externals/android-ifaddrs/CMakeLists.txt new file mode 100644 index 000000000..25f243679 --- /dev/null +++ b/externals/android-ifaddrs/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(ifaddrs + ifaddrs.c + ifaddrs.h +) + +create_target_directory_groups(ifaddrs) + +target_include_directories(ifaddrs INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/externals/android-ifaddrs/ifaddrs.c b/externals/android-ifaddrs/ifaddrs.c new file mode 100644 index 000000000..3a3878a68 --- /dev/null +++ b/externals/android-ifaddrs/ifaddrs.c @@ -0,0 +1,600 @@ +/* +Copyright (c) 2013, Kenneth MacKay +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include "ifaddrs.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +typedef struct NetlinkList +{ + struct NetlinkList *m_next; + struct nlmsghdr *m_data; + unsigned int m_size; +} NetlinkList; + +static int netlink_socket(void) +{ + int l_socket = socket(PF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + if(l_socket < 0) + { + return -1; + } + + struct sockaddr_nl l_addr; + memset(&l_addr, 0, sizeof(l_addr)); + l_addr.nl_family = AF_NETLINK; + if(bind(l_socket, (struct sockaddr *)&l_addr, sizeof(l_addr)) < 0) + { + close(l_socket); + return -1; + } + + return l_socket; +} + +static int netlink_send(int p_socket, int p_request) +{ + char l_buffer[NLMSG_ALIGN(sizeof(struct nlmsghdr)) + NLMSG_ALIGN(sizeof(struct rtgenmsg))]; + memset(l_buffer, 0, sizeof(l_buffer)); + struct nlmsghdr *l_hdr = (struct nlmsghdr *)l_buffer; + struct rtgenmsg *l_msg = (struct rtgenmsg *)NLMSG_DATA(l_hdr); + + l_hdr->nlmsg_len = NLMSG_LENGTH(sizeof(*l_msg)); + l_hdr->nlmsg_type = p_request; + l_hdr->nlmsg_flags = NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST; + l_hdr->nlmsg_pid = 0; + l_hdr->nlmsg_seq = p_socket; + l_msg->rtgen_family = AF_UNSPEC; + + struct sockaddr_nl l_addr; + memset(&l_addr, 0, sizeof(l_addr)); + l_addr.nl_family = AF_NETLINK; + return (sendto(p_socket, l_hdr, l_hdr->nlmsg_len, 0, (struct sockaddr *)&l_addr, sizeof(l_addr))); +} + +static int netlink_recv(int p_socket, void *p_buffer, size_t p_len) +{ + struct msghdr l_msg; + struct iovec l_iov = { p_buffer, p_len }; + struct sockaddr_nl l_addr; + int l_result; + + for(;;) + { + l_msg.msg_name = (void *)&l_addr; + l_msg.msg_namelen = sizeof(l_addr); + l_msg.msg_iov = &l_iov; + l_msg.msg_iovlen = 1; + l_msg.msg_control = NULL; + l_msg.msg_controllen = 0; + l_msg.msg_flags = 0; + int l_result = recvmsg(p_socket, &l_msg, 0); + + if(l_result < 0) + { + if(errno == EINTR) + { + continue; + } + return -2; + } + + if(l_msg.msg_flags & MSG_TRUNC) + { // buffer was too small + return -1; + } + return l_result; + } +} + +static struct nlmsghdr *getNetlinkResponse(int p_socket, int *p_size, int *p_done) +{ + size_t l_size = 4096; + void *l_buffer = NULL; + + for(;;) + { + free(l_buffer); + l_buffer = malloc(l_size); + + int l_read = netlink_recv(p_socket, l_buffer, l_size); + *p_size = l_read; + if(l_read == -2) + { + free(l_buffer); + return NULL; + } + if(l_read >= 0) + { + pid_t l_pid = getpid(); + struct nlmsghdr *l_hdr; + for(l_hdr = (struct nlmsghdr *)l_buffer; NLMSG_OK(l_hdr, (unsigned int)l_read); l_hdr = (struct nlmsghdr *)NLMSG_NEXT(l_hdr, l_read)) + { + if((pid_t)l_hdr->nlmsg_pid != l_pid || (int)l_hdr->nlmsg_seq != p_socket) + { + continue; + } + + if(l_hdr->nlmsg_type == NLMSG_DONE) + { + *p_done = 1; + break; + } + + if(l_hdr->nlmsg_type == NLMSG_ERROR) + { + free(l_buffer); + return NULL; + } + } + return l_buffer; + } + + l_size *= 2; + } +} + +static NetlinkList *newListItem(struct nlmsghdr *p_data, unsigned int p_size) +{ + NetlinkList *l_item = malloc(sizeof(NetlinkList)); + l_item->m_next = NULL; + l_item->m_data = p_data; + l_item->m_size = p_size; + return l_item; +} + +static void freeResultList(NetlinkList *p_list) +{ + NetlinkList *l_cur; + while(p_list) + { + l_cur = p_list; + p_list = p_list->m_next; + free(l_cur->m_data); + free(l_cur); + } +} + +static NetlinkList *getResultList(int p_socket, int p_request) +{ + if(netlink_send(p_socket, p_request) < 0) + { + return NULL; + } + + NetlinkList *l_list = NULL; + NetlinkList *l_end = NULL; + int l_size; + int l_done = 0; + while(!l_done) + { + struct nlmsghdr *l_hdr = getNetlinkResponse(p_socket, &l_size, &l_done); + if(!l_hdr) + { // error + freeResultList(l_list); + return NULL; + } + + NetlinkList *l_item = newListItem(l_hdr, l_size); + if(!l_list) + { + l_list = l_item; + } + else + { + l_end->m_next = l_item; + } + l_end = l_item; + } + return l_list; +} + +static size_t maxSize(size_t a, size_t b) +{ + return (a > b ? a : b); +} + +static size_t calcAddrLen(sa_family_t p_family, int p_dataSize) +{ + switch(p_family) + { + case AF_INET: + return sizeof(struct sockaddr_in); + case AF_INET6: + return sizeof(struct sockaddr_in6); + case AF_PACKET: + return maxSize(sizeof(struct sockaddr_ll), offsetof(struct sockaddr_ll, sll_addr) + p_dataSize); + default: + return maxSize(sizeof(struct sockaddr), offsetof(struct sockaddr, sa_data) + p_dataSize); + } +} + +static void makeSockaddr(sa_family_t p_family, struct sockaddr *p_dest, void *p_data, size_t p_size) +{ + switch(p_family) + { + case AF_INET: + memcpy(&((struct sockaddr_in*)p_dest)->sin_addr, p_data, p_size); + break; + case AF_INET6: + memcpy(&((struct sockaddr_in6*)p_dest)->sin6_addr, p_data, p_size); + break; + case AF_PACKET: + memcpy(((struct sockaddr_ll*)p_dest)->sll_addr, p_data, p_size); + ((struct sockaddr_ll*)p_dest)->sll_halen = p_size; + break; + default: + memcpy(p_dest->sa_data, p_data, p_size); + break; + } + p_dest->sa_family = p_family; +} + +static void addToEnd(struct ifaddrs **p_resultList, struct ifaddrs *p_entry) +{ + if(!*p_resultList) + { + *p_resultList = p_entry; + } + else + { + struct ifaddrs *l_cur = *p_resultList; + while(l_cur->ifa_next) + { + l_cur = l_cur->ifa_next; + } + l_cur->ifa_next = p_entry; + } +} + +static void interpretLink(struct nlmsghdr *p_hdr, struct ifaddrs **p_links, struct ifaddrs **p_resultList) +{ + struct ifinfomsg *l_info = (struct ifinfomsg *)NLMSG_DATA(p_hdr); + + size_t l_nameSize = 0; + size_t l_addrSize = 0; + size_t l_dataSize = 0; + + size_t l_rtaSize = NLMSG_PAYLOAD(p_hdr, sizeof(struct ifinfomsg)); + struct rtattr *l_rta; + for(l_rta = (struct rtattr *)(((char *)l_info) + NLMSG_ALIGN(sizeof(struct ifinfomsg))); RTA_OK(l_rta, l_rtaSize); l_rta = RTA_NEXT(l_rta, l_rtaSize)) + { + void *l_rtaData = RTA_DATA(l_rta); + size_t l_rtaDataSize = RTA_PAYLOAD(l_rta); + switch(l_rta->rta_type) + { + case IFLA_ADDRESS: + case IFLA_BROADCAST: + l_addrSize += NLMSG_ALIGN(calcAddrLen(AF_PACKET, l_rtaDataSize)); + break; + case IFLA_IFNAME: + l_nameSize += NLMSG_ALIGN(l_rtaSize + 1); + break; + case IFLA_STATS: + l_dataSize += NLMSG_ALIGN(l_rtaSize); + break; + default: + break; + } + } + + struct ifaddrs *l_entry = malloc(sizeof(struct ifaddrs) + l_nameSize + l_addrSize + l_dataSize); + memset(l_entry, 0, sizeof(struct ifaddrs)); + l_entry->ifa_name = ""; + + char *l_name = ((char *)l_entry) + sizeof(struct ifaddrs); + char *l_addr = l_name + l_nameSize; + char *l_data = l_addr + l_addrSize; + + l_entry->ifa_flags = l_info->ifi_flags; + + l_rtaSize = NLMSG_PAYLOAD(p_hdr, sizeof(struct ifinfomsg)); + for(l_rta = (struct rtattr *)(((char *)l_info) + NLMSG_ALIGN(sizeof(struct ifinfomsg))); RTA_OK(l_rta, l_rtaSize); l_rta = RTA_NEXT(l_rta, l_rtaSize)) + { + void *l_rtaData = RTA_DATA(l_rta); + size_t l_rtaDataSize = RTA_PAYLOAD(l_rta); + switch(l_rta->rta_type) + { + case IFLA_ADDRESS: + case IFLA_BROADCAST: + { + size_t l_addrLen = calcAddrLen(AF_PACKET, l_rtaDataSize); + makeSockaddr(AF_PACKET, (struct sockaddr *)l_addr, l_rtaData, l_rtaDataSize); + ((struct sockaddr_ll *)l_addr)->sll_ifindex = l_info->ifi_index; + ((struct sockaddr_ll *)l_addr)->sll_hatype = l_info->ifi_type; + if(l_rta->rta_type == IFLA_ADDRESS) + { + l_entry->ifa_addr = (struct sockaddr *)l_addr; + } + else + { + l_entry->ifa_broadaddr = (struct sockaddr *)l_addr; + } + l_addr += NLMSG_ALIGN(l_addrLen); + break; + } + case IFLA_IFNAME: + strncpy(l_name, l_rtaData, l_rtaDataSize); + l_name[l_rtaDataSize] = '\0'; + l_entry->ifa_name = l_name; + break; + case IFLA_STATS: + memcpy(l_data, l_rtaData, l_rtaDataSize); + l_entry->ifa_data = l_data; + break; + default: + break; + } + } + + addToEnd(p_resultList, l_entry); + p_links[l_info->ifi_index - 1] = l_entry; +} + +static void interpretAddr(struct nlmsghdr *p_hdr, struct ifaddrs **p_links, struct ifaddrs **p_resultList) +{ + struct ifaddrmsg *l_info = (struct ifaddrmsg *)NLMSG_DATA(p_hdr); + + size_t l_nameSize = 0; + size_t l_addrSize = 0; + + int l_addedNetmask = 0; + + size_t l_rtaSize = NLMSG_PAYLOAD(p_hdr, sizeof(struct ifaddrmsg)); + struct rtattr *l_rta; + for(l_rta = (struct rtattr *)(((char *)l_info) + NLMSG_ALIGN(sizeof(struct ifaddrmsg))); RTA_OK(l_rta, l_rtaSize); l_rta = RTA_NEXT(l_rta, l_rtaSize)) + { + void *l_rtaData = RTA_DATA(l_rta); + size_t l_rtaDataSize = RTA_PAYLOAD(l_rta); + if(l_info->ifa_family == AF_PACKET) + { + continue; + } + + switch(l_rta->rta_type) + { + case IFA_ADDRESS: + case IFA_LOCAL: + if((l_info->ifa_family == AF_INET || l_info->ifa_family == AF_INET6) && !l_addedNetmask) + { // make room for netmask + l_addrSize += NLMSG_ALIGN(calcAddrLen(l_info->ifa_family, l_rtaDataSize)); + l_addedNetmask = 1; + } + case IFA_BROADCAST: + l_addrSize += NLMSG_ALIGN(calcAddrLen(l_info->ifa_family, l_rtaDataSize)); + break; + case IFA_LABEL: + l_nameSize += NLMSG_ALIGN(l_rtaSize + 1); + break; + default: + break; + } + } + + struct ifaddrs *l_entry = malloc(sizeof(struct ifaddrs) + l_nameSize + l_addrSize); + memset(l_entry, 0, sizeof(struct ifaddrs)); + l_entry->ifa_name = p_links[l_info->ifa_index - 1]->ifa_name; + + char *l_name = ((char *)l_entry) + sizeof(struct ifaddrs); + char *l_addr = l_name + l_nameSize; + + l_entry->ifa_flags = l_info->ifa_flags | p_links[l_info->ifa_index - 1]->ifa_flags; + + l_rtaSize = NLMSG_PAYLOAD(p_hdr, sizeof(struct ifaddrmsg)); + for(l_rta = (struct rtattr *)(((char *)l_info) + NLMSG_ALIGN(sizeof(struct ifaddrmsg))); RTA_OK(l_rta, l_rtaSize); l_rta = RTA_NEXT(l_rta, l_rtaSize)) + { + void *l_rtaData = RTA_DATA(l_rta); + size_t l_rtaDataSize = RTA_PAYLOAD(l_rta); + switch(l_rta->rta_type) + { + case IFA_ADDRESS: + case IFA_BROADCAST: + case IFA_LOCAL: + { + size_t l_addrLen = calcAddrLen(l_info->ifa_family, l_rtaDataSize); + makeSockaddr(l_info->ifa_family, (struct sockaddr *)l_addr, l_rtaData, l_rtaDataSize); + if(l_info->ifa_family == AF_INET6) + { + if(IN6_IS_ADDR_LINKLOCAL((struct in6_addr *)l_rtaData) || IN6_IS_ADDR_MC_LINKLOCAL((struct in6_addr *)l_rtaData)) + { + ((struct sockaddr_in6 *)l_addr)->sin6_scope_id = l_info->ifa_index; + } + } + + if(l_rta->rta_type == IFA_ADDRESS) + { // apparently in a point-to-point network IFA_ADDRESS contains the dest address and IFA_LOCAL contains the local address + if(l_entry->ifa_addr) + { + l_entry->ifa_dstaddr = (struct sockaddr *)l_addr; + } + else + { + l_entry->ifa_addr = (struct sockaddr *)l_addr; + } + } + else if(l_rta->rta_type == IFA_LOCAL) + { + if(l_entry->ifa_addr) + { + l_entry->ifa_dstaddr = l_entry->ifa_addr; + } + l_entry->ifa_addr = (struct sockaddr *)l_addr; + } + else + { + l_entry->ifa_broadaddr = (struct sockaddr *)l_addr; + } + l_addr += NLMSG_ALIGN(l_addrLen); + break; + } + case IFA_LABEL: + strncpy(l_name, l_rtaData, l_rtaDataSize); + l_name[l_rtaDataSize] = '\0'; + l_entry->ifa_name = l_name; + break; + default: + break; + } + } + + if(l_entry->ifa_addr && (l_entry->ifa_addr->sa_family == AF_INET || l_entry->ifa_addr->sa_family == AF_INET6)) + { + unsigned l_maxPrefix = (l_entry->ifa_addr->sa_family == AF_INET ? 32 : 128); + unsigned l_prefix = (l_info->ifa_prefixlen > l_maxPrefix ? l_maxPrefix : l_info->ifa_prefixlen); + char l_mask[16] = {0}; + unsigned i; + for(i=0; i<(l_prefix/8); ++i) + { + l_mask[i] = 0xff; + } + l_mask[i] = 0xff << (8 - (l_prefix % 8)); + + makeSockaddr(l_entry->ifa_addr->sa_family, (struct sockaddr *)l_addr, l_mask, l_maxPrefix / 8); + l_entry->ifa_netmask = (struct sockaddr *)l_addr; + } + + addToEnd(p_resultList, l_entry); +} + +static void interpret(int p_socket, NetlinkList *p_netlinkList, struct ifaddrs **p_links, struct ifaddrs **p_resultList) +{ + pid_t l_pid = getpid(); + for(; p_netlinkList; p_netlinkList = p_netlinkList->m_next) + { + unsigned int l_nlsize = p_netlinkList->m_size; + struct nlmsghdr *l_hdr; + for(l_hdr = p_netlinkList->m_data; NLMSG_OK(l_hdr, l_nlsize); l_hdr = NLMSG_NEXT(l_hdr, l_nlsize)) + { + if((pid_t)l_hdr->nlmsg_pid != l_pid || (int)l_hdr->nlmsg_seq != p_socket) + { + continue; + } + + if(l_hdr->nlmsg_type == NLMSG_DONE) + { + break; + } + + if(l_hdr->nlmsg_type == RTM_NEWLINK) + { + interpretLink(l_hdr, p_links, p_resultList); + } + else if(l_hdr->nlmsg_type == RTM_NEWADDR) + { + interpretAddr(l_hdr, p_links, p_resultList); + } + } + } +} + +static unsigned countLinks(int p_socket, NetlinkList *p_netlinkList) +{ + unsigned l_links = 0; + pid_t l_pid = getpid(); + for(; p_netlinkList; p_netlinkList = p_netlinkList->m_next) + { + unsigned int l_nlsize = p_netlinkList->m_size; + struct nlmsghdr *l_hdr; + for(l_hdr = p_netlinkList->m_data; NLMSG_OK(l_hdr, l_nlsize); l_hdr = NLMSG_NEXT(l_hdr, l_nlsize)) + { + if((pid_t)l_hdr->nlmsg_pid != l_pid || (int)l_hdr->nlmsg_seq != p_socket) + { + continue; + } + + if(l_hdr->nlmsg_type == NLMSG_DONE) + { + break; + } + + if(l_hdr->nlmsg_type == RTM_NEWLINK) + { + ++l_links; + } + } + } + + return l_links; +} + +int getifaddrs(struct ifaddrs **ifap) +{ + if(!ifap) + { + return -1; + } + *ifap = NULL; + + int l_socket = netlink_socket(); + if(l_socket < 0) + { + return -1; + } + + NetlinkList *l_linkResults = getResultList(l_socket, RTM_GETLINK); + if(!l_linkResults) + { + close(l_socket); + return -1; + } + + NetlinkList *l_addrResults = getResultList(l_socket, RTM_GETADDR); + if(!l_addrResults) + { + close(l_socket); + freeResultList(l_linkResults); + return -1; + } + + unsigned l_numLinks = countLinks(l_socket, l_linkResults) + countLinks(l_socket, l_addrResults); + struct ifaddrs *l_links[l_numLinks]; + memset(l_links, 0, l_numLinks * sizeof(struct ifaddrs *)); + + interpret(l_socket, l_linkResults, l_links, ifap); + interpret(l_socket, l_addrResults, l_links, ifap); + + freeResultList(l_linkResults); + freeResultList(l_addrResults); + close(l_socket); + return 0; +} + +void freeifaddrs(struct ifaddrs *ifa) +{ + struct ifaddrs *l_cur; + while(ifa) + { + l_cur = ifa; + ifa = ifa->ifa_next; + free(l_cur); + } +} \ No newline at end of file diff --git a/externals/android-ifaddrs/ifaddrs.h b/externals/android-ifaddrs/ifaddrs.h new file mode 100644 index 000000000..42d1b37e8 --- /dev/null +++ b/externals/android-ifaddrs/ifaddrs.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 1995, 1999 + * Berkeley Software Design, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * THIS SOFTWARE IS PROVIDED BY Berkeley Software Design, Inc. ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design, Inc. BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * + * BSDI ifaddrs.h,v 2.5 2000/02/23 14:51:59 dab Exp + */ + +#ifndef _IFADDRS_H_ +#define _IFADDRS_H_ + +struct ifaddrs { + struct ifaddrs *ifa_next; + char *ifa_name; + unsigned int ifa_flags; + struct sockaddr *ifa_addr; + struct sockaddr *ifa_netmask; + struct sockaddr *ifa_dstaddr; + void *ifa_data; +}; + +/* + * This may have been defined in . Note that if is + * to be included it must be included before this header file. + */ +#ifndef ifa_broadaddr +#define ifa_broadaddr ifa_dstaddr /* broadcast address interface */ +#endif + +#include + +__BEGIN_DECLS +extern int getifaddrs(struct ifaddrs **ifap); +extern void freeifaddrs(struct ifaddrs *ifa); +__END_DECLS + +#endif \ No newline at end of file diff --git a/externals/boost b/externals/boost index 6d7edc593..727f616b6 160000 --- a/externals/boost +++ b/externals/boost @@ -1 +1 @@ -Subproject commit 6d7edc593be8e47c8de7bc5f7d6b32971fad0c24 +Subproject commit 727f616b6e5cafaba072131c077a3b8fea87b8be diff --git a/externals/httplib/README.md b/externals/httplib/README.md index 0e26522b5..33cbaf183 100644 --- a/externals/httplib/README.md +++ b/externals/httplib/README.md @@ -1,4 +1,4 @@ -From https://github.com/yhirose/cpp-httplib/commit/d9479bc0b12e8a1e8bce2d34da4feeef488581f3 +From https://github.com/yhirose/cpp-httplib/commit/b251668522dd459d2c6a75c10390a11b640be708 MIT License @@ -13,3 +13,4 @@ It's extremely easy to setup. Just include httplib.h file in your code! Inspired by Sinatra and express. © 2017 Yuji Hirose + diff --git a/externals/httplib/httplib.h b/externals/httplib/httplib.h index dd9afe693..ab087b184 100644 --- a/externals/httplib/httplib.h +++ b/externals/httplib/httplib.h @@ -1,73 +1,177 @@ // // httplib.h // -// Copyright (c) 2017 Yuji Hirose. All rights reserved. +// Copyright (c) 2020 Yuji Hirose. All rights reserved. // MIT License // -#ifndef _CPPHTTPLIB_HTTPLIB_H_ -#define _CPPHTTPLIB_HTTPLIB_H_ +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (std::numeric_limits::max)() +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +// if hardware_concurrency() outputs 0 we still wants to use threads for this. +// -1 because we have one thread already in the main function. +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + (std::thread::hardware_concurrency() \ + ? std::thread::hardware_concurrency() - 1 \ + : 2) +#endif + +/* + * Headers + */ #ifdef _WIN32 #ifndef _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS -#endif +#endif //_CRT_SECURE_NO_WARNINGS + #ifndef _CRT_NONSTDC_NO_DEPRECATE #define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = int; #endif -#if defined(_MSC_VER) && _MSC_VER < 1900 +#if _MSC_VER < 1900 #define snprintf _snprintf_s #endif +#endif // _MSC_VER #ifndef S_ISREG -#define S_ISREG(m) (((m)&S_IFREG)==S_IFREG) -#endif +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#endif // S_ISREG + #ifndef S_ISDIR -#define S_ISDIR(m) (((m)&S_IFDIR)==S_IFDIR) -#endif +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX #include #include #include -#undef min -#undef max +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +#ifdef _MSC_VER +#pragma comment(lib, "ws2_32.lib") +#endif #ifndef strcasecmp #define strcasecmp _stricmp +#endif // strcasecmp + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) #endif -typedef SOCKET socket_t; -#else -#include -#include -#include -#include -#include +#else // not _WIN32 + #include -#include -#include -#include - -typedef int socket_t; -#define INVALID_SOCKET (-1) +#include +#include +#include +#include +#ifdef CPPHTTPLIB_USE_POLL +#include #endif +#include +#include +#include +#include +#include +using socket_t = int; +#define INVALID_SOCKET (-1) +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include #include #include +#include #include #include #include +#include #include #include -#include #include -#include -#include +#include #ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#include +#include #include +#include + +#include +#include + +// #if OPENSSL_VERSION_NUMBER < 0x1010100fL +// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported +// #endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#include +inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { + return M_ASN1_STRING_data(asn1); +} +#endif #endif #ifdef CPPHTTPLIB_ZLIB_SUPPORT @@ -75,1198 +179,2321 @@ typedef int socket_t; #endif /* - * Configuration + * Declaration */ -#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 -#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 - -namespace httplib -{ +namespace httplib { namespace detail { struct ci { - bool operator() (const std::string & s1, const std::string & s2) const { - return std::lexicographical_compare( - s1.begin(), s1.end(), - s2.begin(), s2.end(), - [](char c1, char c2) { - return ::tolower(c1) < ::tolower(c2); - }); - } + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } }; } // namespace detail -enum class HttpVersion { v1_0 = 0, v1_1 }; +using Headers = std::multimap; -typedef std::multimap Headers; +using Params = std::multimap; +using Match = std::smatch; -template -std::pair make_range_header(uint64_t value, Args... args); +using Progress = std::function; -typedef std::multimap Params; -typedef std::smatch Match; -typedef std::function Progress; +struct Response; +using ResponseHandler = std::function; -struct MultipartFile { - std::string filename; - std::string content_type; - size_t offset = 0; - size_t length = 0; +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; }; -typedef std::multimap MultipartFiles; +using MultipartFormDataItems = std::vector; +using MultipartFormDataMap = std::multimap; + +class DataSink { +public: + DataSink() = default; + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function done; + std::function is_writable; +}; + +using ContentProvider = + std::function; + +using ContentReceiver = + std::function; + +using MultipartContentHeader = + std::function; + +class ContentReader { +public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader muitlpart_reader) + : reader_(reader), muitlpart_reader_(muitlpart_reader) {} + + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return muitlpart_reader_(header, receiver); + } + + bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + + Reader reader_; + MultipartReader muitlpart_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; struct Request { - std::string version; - std::string method; - std::string target; - std::string path; - Headers headers; - std::string body; - Params params; - MultipartFiles files; - Match matches; + std::string method; + std::string path; + Headers headers; + std::string body; - Progress progress; + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; - bool has_header(const char* key) const; - std::string get_header_value(const char* key) const; - void set_header(const char* key, const char* val); + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + Progress progress; - bool has_param(const char* key) const; - std::string get_param_value(const char* key) const; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl; +#endif - bool has_file(const char* key) const; - MultipartFile get_file_value(const char* key) const; + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; + + bool is_multipart_form_data() const; + + bool has_file(const char *key) const; + MultipartFormData get_file_value(const char *key) const; + + // private members... + size_t content_length; + ContentProvider content_provider; }; struct Response { - std::string version; - int status; - Headers headers; - std::string body; + std::string version; + int status = -1; + Headers headers; + std::string body; - bool has_header(const char* key) const; - std::string get_header_value(const char* key) const; - void set_header(const char* key, const char* val); + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); - void set_redirect(const char* uri); - void set_content(const char* s, size_t n, const char* content_type); - void set_content(const std::string& s, const char* content_type); + void set_redirect(const char *url); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(const std::string &s, const char *content_type); - Response() : status(-1) {} + void set_content_provider( + size_t length, + std::function + provider, + std::function resource_releaser = [] {}); + + void set_chunked_content_provider( + std::function provider, + std::function resource_releaser = [] {}); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser) { + content_provider_resource_releaser(); + } + } + + // private members... + size_t content_length = 0; + ContentProvider content_provider; + std::function content_provider_resource_releaser; }; class Stream { public: - virtual ~Stream() {} - virtual int read(char* ptr, size_t size) = 0; - virtual int write(const char* ptr, size_t size1) = 0; - virtual int write(const char* ptr) = 0; - virtual std::string get_remote_addr() = 0; + virtual ~Stream() = default; - template - void write_format(const char* fmt, const Args& ...args); + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual int read(char *ptr, size_t size) = 0; + virtual int write(const char *ptr, size_t size) = 0; + virtual std::string get_remote_addr() const = 0; + + template + int write_format(const char *fmt, const Args &... args); + int write(const char *ptr); + int write(const std::string &s); }; -class SocketStream : public Stream { +class TaskQueue { public: - SocketStream(socket_t sock); - virtual ~SocketStream(); + TaskQueue() = default; + virtual ~TaskQueue() = default; + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; +}; - virtual int read(char* ptr, size_t size); - virtual int write(const char* ptr, size_t size); - virtual int write(const char* ptr); - virtual std::string get_remote_addr(); +class ThreadPool : public TaskQueue { +public: + explicit ThreadPool(size_t n) : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + void enqueue(std::function fn) override { + std::unique_lock lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } private: - socket_t sock_; + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; }; +using Logger = std::function; + class Server { public: - typedef std::function Handler; - typedef std::function Logger; + using Handler = std::function; + using HandlerWithContentReader = std::function; - Server(); + Server(); - virtual ~Server(); + virtual ~Server(); - virtual bool is_valid() const; + virtual bool is_valid() const; - Server& Get(const char* pattern, Handler handler); - Server& Post(const char* pattern, Handler handler); + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); + Server &Post(const char *pattern, HandlerWithContentReader handler); + Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); + Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); + Server &Delete(const char *pattern, Handler handler); + Server &Options(const char *pattern, Handler handler); - Server& Put(const char* pattern, Handler handler); - Server& Delete(const char* pattern, Handler handler); - Server& Options(const char* pattern, Handler handler); + [[deprecated]] bool set_base_dir(const char *dir, + const char *mount_point = nullptr); + bool set_mount_point(const char *mount_point, const char *dir); + bool remove_mount_point(const char *mount_point); + void set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime); + void set_file_request_handler(Handler handler); - bool set_base_dir(const char* path); + void set_error_handler(Handler handler); + void set_logger(Logger logger); - void set_error_handler(Handler handler); - void set_logger(Logger logger); + void set_keep_alive_max_count(size_t count); + void set_read_timeout(time_t sec, time_t usec); + void set_payload_max_length(size_t length); - void set_keep_alive_max_count(size_t count); + bool bind_to_port(const char *host, int port, int socket_flags = 0); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); - int bind_to_any_port(const char* host, int socket_flags = 0); - bool listen_after_bind(); + bool listen(const char *host, int port, int socket_flags = 0); - bool listen(const char* host, int port, int socket_flags = 0); + bool is_running() const; + void stop(); - bool is_running() const; - void stop(); + std::function new_task_queue; protected: - bool process_request(Stream& strm, bool last_connection, bool& connection_close); + bool process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function &setup_request); - size_t keep_alive_max_count_; + size_t keep_alive_max_count_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + size_t payload_max_length_; private: - typedef std::vector> Handlers; + using Handlers = std::vector>; + using HandlersForContentReader = + std::vector>; - socket_t create_server_socket(const char* host, int port, int socket_flags) const; - int bind_internal(const char* host, int port, int socket_flags); - bool listen_internal(); + socket_t create_server_socket(const char *host, int port, + int socket_flags) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); - bool routing(Request& req, Response& res); - bool handle_file_request(Request& req, Response& res); - bool dispatch_request(Request& req, Response& res, Handlers& handlers); + bool routing(Request &req, Response &res, Stream &strm, bool last_connection); + bool handle_file_request(Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, Handlers &handlers); + bool dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + HandlersForContentReader &handlers); - bool parse_request_line(const char* s, Request& req); - void write_response(Stream& strm, bool last_connection, const Request& req, Response& res); + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool last_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, bool last_connection, Request &req, + Response &res); + bool read_content_with_content_receiver( + Stream &strm, bool last_connection, Request &req, Response &res, + ContentReceiver receiver, MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, bool last_connection, Request &req, + Response &res, ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver); - virtual bool read_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); - bool is_running_; - socket_t svr_sock_; - std::string base_dir_; - Handlers get_handlers_; - Handlers post_handlers_; - Handlers put_handlers_; - Handlers delete_handlers_; - Handlers options_handlers_; - Handler error_handler_; - Logger logger_; - - // TODO: Use thread pool... - std::mutex running_threads_mutex_; - int running_threads_; + std::atomic is_running_; + std::atomic svr_sock_; + std::vector> base_dirs_; + std::map file_extension_and_mimetype_map_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; }; class Client { public: - Client( - const char* host, - int port = 80, - size_t timeout_sec = 300); + explicit Client(const std::string &host, int port = 80, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); - virtual ~Client(); + virtual ~Client(); - virtual bool is_valid() const; + virtual bool is_valid() const; - std::shared_ptr Get(const char* path, Progress progress = nullptr); - std::shared_ptr Get(const char* path, const Headers& headers, Progress progress = nullptr); + std::shared_ptr Get(const char *path); - std::shared_ptr Head(const char* path); - std::shared_ptr Head(const char* path, const Headers& headers); + std::shared_ptr Get(const char *path, const Headers &headers); - std::shared_ptr Post(const char* path, const std::string& body, const char* content_type); - std::shared_ptr Post(const char* path, const Headers& headers, const std::string& body, const char* content_type); + std::shared_ptr Get(const char *path, Progress progress); - std::shared_ptr Post(const char* path, const Params& params); - std::shared_ptr Post(const char* path, const Headers& headers, const Params& params); + std::shared_ptr Get(const char *path, const Headers &headers, + Progress progress); - std::shared_ptr Put(const char* path, const std::string& body, const char* content_type); - std::shared_ptr Put(const char* path, const Headers& headers, const std::string& body, const char* content_type); + std::shared_ptr Get(const char *path, + ContentReceiver content_receiver); - std::shared_ptr Delete(const char* path); - std::shared_ptr Delete(const char* path, const Headers& headers); + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); - std::shared_ptr Options(const char* path); - std::shared_ptr Options(const char* path, const Headers& headers); + std::shared_ptr + Get(const char *path, ContentReceiver content_receiver, Progress progress); - bool send(Request& req, Response& res); + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress); -protected: - bool process_request(Stream& strm, Request& req, Response& res, bool& connection_close); + std::shared_ptr Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); - const std::string host_; - const int port_; - size_t timeout_sec_; - const std::string host_and_port_; + std::shared_ptr Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); -private: - socket_t create_client_socket() const; - bool read_response_line(Stream& strm, Response& res); - void write_request(Stream& strm, Request& req); + std::shared_ptr Head(const char *path); - virtual bool read_and_close_socket(socket_t sock, Request& req, Response& res); -}; + std::shared_ptr Head(const char *path, const Headers &headers); + + std::shared_ptr Post(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Post(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Post(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Post(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Post(const char *path, const Params ¶ms); + + std::shared_ptr Post(const char *path, const Headers &headers, + const Params ¶ms); + + std::shared_ptr Post(const char *path, + const MultipartFormDataItems &items); + + std::shared_ptr Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + + std::shared_ptr Put(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Put(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Put(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Put(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Put(const char *path, const Params ¶ms); + + std::shared_ptr Put(const char *path, const Headers &headers, + const Params ¶ms); + + std::shared_ptr Patch(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Patch(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Patch(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Patch(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Delete(const char *path); + + std::shared_ptr Delete(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Delete(const char *path, const Headers &headers); + + std::shared_ptr Delete(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Options(const char *path); + + std::shared_ptr Options(const char *path, const Headers &headers); + + bool send(const Request &req, Response &res); + + bool send(const std::vector &requests, + std::vector &responses); + + void set_timeout_sec(time_t timeout_sec); + + void set_read_timeout(time_t sec, time_t usec); + + void set_keep_alive_max_count(size_t count); + + void set_basic_auth(const char *username, const char *password); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -class SSLSocketStream : public Stream { -public: - SSLSocketStream(socket_t sock, SSL* ssl); - virtual ~SSLSocketStream(); + void set_digest_auth(const char *username, const char *password); +#endif - virtual int read(char* ptr, size_t size); - virtual int write(const char* ptr, size_t size); - virtual int write(const char* ptr); - virtual std::string get_remote_addr(); + void set_follow_location(bool on); + + void set_compress(bool on); + + void set_interface(const char *intf); + + void set_proxy(const char *host, int port); + + void set_proxy_basic_auth(const char *username, const char *password); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const char *username, const char *password); +#endif + + void set_logger(Logger logger); + +protected: + bool process_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); + + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t timeout_sec_ = 300; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + + std::string basic_auth_username_; + std::string basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool follow_location_ = false; + + bool compress_ = false; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + + Logger logger_; + + void copy_settings(const Client &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + timeout_sec_ = rhs.timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + keep_alive_max_count_ = rhs.keep_alive_max_count_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + follow_location_ = rhs.follow_location_; + compress_ = rhs.compress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif + logger_ = rhs.logger_; + } private: - socket_t sock_; - SSL* ssl_; + socket_t create_client_socket() const; + bool read_response_line(Stream &strm, Response &res); + bool write_request(Stream &strm, const Request &req, bool last_connection); + bool redirect(const Request &req, Response &res); + bool handle_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool connect(socket_t sock, Response &res, bool &error); +#endif + + std::shared_ptr send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type); + + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback); + + virtual bool is_ssl() const; }; +inline void Get(std::vector &requests, const char *path, + const Headers &headers) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + requests.emplace_back(std::move(req)); +} + +inline void Get(std::vector &requests, const char *path) { + Get(requests, path, Headers()); +} + +inline void Post(std::vector &requests, const char *path, + const Headers &headers, const std::string &body, + const char *content_type) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + req.headers.emplace("Content-Type", content_type); + req.body = body; + requests.emplace_back(std::move(req)); +} + +inline void Post(std::vector &requests, const char *path, + const std::string &body, const char *content_type) { + Post(requests, path, Headers(), body, content_type); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLServer : public Server { public: - SSLServer( - const char* cert_path, const char* private_key_path); + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); - virtual ~SSLServer(); + virtual ~SSLServer(); - virtual bool is_valid() const; + virtual bool is_valid() const; private: - virtual bool read_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); - SSL_CTX* ctx_; - std::mutex ctx_mutex_; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; }; class SSLClient : public Client { public: - SSLClient( - const char* host, - int port = 80, - size_t timeout_sec = 300); + SSLClient(const std::string &host, int port = 443, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); - virtual ~SSLClient(); + virtual ~SSLClient(); - virtual bool is_valid() const; + virtual bool is_valid() const; + + void set_ca_cert_path(const char *ca_ceert_file_path, + const char *ca_cert_dir_path = nullptr); + + void enable_server_certificate_verification(bool enabled); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const noexcept; private: - virtual bool read_and_close_socket(socket_t sock, Request& req, Response& res); + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback); + virtual bool is_ssl() const; - SSL_CTX* ctx_; - std::mutex ctx_mutex_; + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::vector host_components_; + + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + bool server_certificate_verification_ = false; + long verify_result_ = 0; }; #endif +// ---------------------------------------------------------------------------- + /* * Implementation */ + namespace detail { -template -void split(const char* b, const char* e, char d, Fn fn) -{ - int i = 0; - int beg = 0; +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} - while (e ? (b + i != e) : (b[i] != '\0')) { - if (b[i] == d) { - fn(&b[beg], &b[i]); - beg = i + 1; - } - i++; +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + const char *charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = (0xC0 | ((code >> 6) & 0x1F)); + buff[1] = (0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = (0xF0 | ((code >> 18) & 0x7)); + buff[1] = (0x80 | ((code >> 12) & 0x3F)); + buff[2] = (0x80 | ((code >> 6) & 0x3F)); + buff[3] = (0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + int val = 0; + int valb = -6; + + for (uint8_t c : in) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_file(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +} + +inline bool is_dir(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; } - if (i) { - fn(&b[beg], &b[i]); + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], size); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); +} + +template void split(const char *b, const char *e, char d, Fn fn) { + int i = 0; + int beg = 0; + + while (e ? (b + i != e) : (b[i] != '\0')) { + if (b[i] == d) { + fn(&b[beg], &b[i]); + beg = i + 1; + } + i++; + } + + if (i) { fn(&b[beg], &b[i]); } } // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` // to store data. The call can set memory on stack for performance. class stream_line_reader { public: - stream_line_reader(Stream& strm, char* fixed_buffer, size_t fixed_buffer_size) - : strm_(strm) - , fixed_buffer_(fixed_buffer) - , fixed_buffer_size_(fixed_buffer_size) { - } + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} - const char* ptr() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_; - } else { - return glowable_buffer_.data(); - } - } - - bool getline() { - fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); - - for (size_t i = 0; ; i++) { - char byte; - auto n = strm_.read(&byte, 1); - - if (n < 0) { - return false; - } else if (n == 0) { - if (i == 0) { - return false; - } else { - break; - } - } - - append(byte); - - if (byte == '\n') { - break; - } - } - - return true; - } - -private: - void append(char c) { - if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { - fixed_buffer_[fixed_buffer_used_size_++] = c; - fixed_buffer_[fixed_buffer_used_size_] = '\0'; - } else { - if (glowable_buffer_.empty()) { - assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); - } - glowable_buffer_ += c; - } - } - - Stream& strm_; - char* fixed_buffer_; - const size_t fixed_buffer_size_; - size_t fixed_buffer_used_size_; - std::string glowable_buffer_; -}; - -inline int close_socket(socket_t sock) -{ -#ifdef _WIN32 - return closesocket(sock); -#else - return close(sock); -#endif -} - -inline int select_read(socket_t sock, size_t sec, size_t usec) -{ - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); - - timeval tv; - tv.tv_sec = sec; - tv.tv_usec = usec; - - return select(sock + 1, &fds, NULL, NULL, &tv); -} - -inline bool wait_until_socket_is_ready(socket_t sock, size_t sec, size_t usec) -{ - fd_set fdsr; - FD_ZERO(&fdsr); - FD_SET(sock, &fdsr); - - auto fdsw = fdsr; - auto fdse = fdsr; - - timeval tv; - tv.tv_sec = sec; - tv.tv_usec = usec; - - if (select(sock + 1, &fdsr, &fdsw, &fdse, &tv) < 0) { - return false; - } else if (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw)) { - int error = 0; - socklen_t len = sizeof(error); - if (getsockopt(sock, SOL_SOCKET, SO_ERROR, (char*)&error, &len) < 0 || error) { - return false; - } + const char *ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; } else { + return glowable_buffer_.data(); + } + } + + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } + } + + bool end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; + } + + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + + if (byte == '\n') { break; } } return true; + } + +private: + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; +}; + +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +inline int select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return poll(&pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); +#endif +} + +inline int select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return poll(&pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); +#endif +} + +inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + if (poll(&pfd_read, 1, timeout) > 0 && + pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; +#else + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + if (select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv) > 0 && + (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; +#endif +} + +class SocketStream : public Stream { +public: + SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + int read(char *ptr, size_t size) override; + int write(const char *ptr, size_t size) override; + std::string get_remote_addr() const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream : public Stream { +public: + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec); + virtual ~SSLSocketStream(); + + bool is_readable() const override; + bool is_writable() const override; + int read(char *ptr, size_t size) override; + int write(const char *ptr, size_t size) override; + std::string get_remote_addr() const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; +}; +#endif + +class BufferStream : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + int read(char *ptr, size_t size) override; + int write(const char *ptr, size_t size) override; + std::string get_remote_addr() const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + int position = 0; +}; + +template +inline bool process_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + assert(keep_alive_max_count > 0); + + auto ret = false; + + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { // keep_alive_max_count is 0 or 1 + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(strm, true, dummy_connection_close); + } + + return ret; } template -inline bool read_and_close_socket(socket_t sock, size_t keep_alive_max_count, T callback) -{ - bool ret = false; - - if (keep_alive_max_count > 0) { - auto count = keep_alive_max_count; - while (count > 0 && - detail::select_read(sock, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { - SocketStream strm(sock); - auto last_connection = count == 1; - auto connection_close = false; - - ret = callback(strm, last_connection, connection_close); - if (!ret || connection_close) { - break; - } - - count--; - } - } else { - SocketStream strm(sock); - auto dummy_connection_close = false; - ret = callback(strm, true, dummy_connection_close); - } - - close_socket(sock); - return ret; +inline bool process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, + time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + auto ret = process_socket(is_client_request, sock, keep_alive_max_count, + read_timeout_sec, read_timeout_usec, callback); + close_socket(sock); + return ret; } -inline int shutdown_socket(socket_t sock) -{ +inline int shutdown_socket(socket_t sock) { #ifdef _WIN32 - return shutdown(sock, SD_BOTH); + return shutdown(sock, SD_BOTH); #else - return shutdown(sock, SHUT_RDWR); + return shutdown(sock, SHUT_RDWR); #endif } template -socket_t create_socket(const char* host, int port, Fn fn, int socket_flags = 0) -{ +socket_t create_socket(const char *host, int port, Fn fn, + int socket_flags = 0) { #ifdef _WIN32 #define SO_SYNCHRONOUS_NONALERT 0x20 #define SO_OPENTYPE 0x7008 - int opt = SO_SYNCHRONOUS_NONALERT; - setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char*)&opt, sizeof(opt)); + int opt = SO_SYNCHRONOUS_NONALERT; + setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char *)&opt, + sizeof(opt)); #endif - // Get address info - struct addrinfo hints; - struct addrinfo *result; + // Get address info + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = socket_flags; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; - auto service = std::to_string(port); + auto service = std::to_string(port); - if (getaddrinfo(host, service.c_str(), &hints, &result)) { - return INVALID_SOCKET; - } - - for (auto rp = result; rp; rp = rp->ai_next) { - // Create a socket - auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - if (sock == INVALID_SOCKET) { - continue; - } - - // Make 'reuse address' option available - int yes = 1; - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)&yes, sizeof(yes)); - - // bind or connect - if (fn(sock, *rp)) { - freeaddrinfo(result); - return sock; - } - - close_socket(sock); - } - - freeaddrinfo(result); + if (getaddrinfo(host, service.c_str(), &hints, &result)) { return INVALID_SOCKET; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + if (sock == INVALID_SOCKET) { continue; } + +#ifndef _WIN32 + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } +#endif + + // Make 'reuse address' option available + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); +#endif + + // bind or connect + if (fn(sock, *rp)) { + freeaddrinfo(result); + return sock; + } + + close_socket(sock); + } + + freeaddrinfo(result); + return INVALID_SOCKET; } -inline void set_nonblocking(socket_t sock, bool nonblocking) -{ +inline void set_nonblocking(socket_t sock, bool nonblocking) { #ifdef _WIN32 - auto flags = nonblocking ? 1UL : 0UL; - ioctlsocket(sock, FIONBIO, &flags); + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); #else - auto flags = fcntl(sock, F_GETFL, 0); - fcntl(sock, F_SETFL, nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); #endif } -inline bool is_connection_error() -{ +inline bool is_connection_error() { #ifdef _WIN32 - return WSAGetLastError() != WSAEWOULDBLOCK; + return WSAGetLastError() != WSAEWOULDBLOCK; #else - return errno != EINPROGRESS; + return errno != EINPROGRESS; #endif } +inline bool bind_ip_address(socket_t sock, const char *host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host, "0", &hints, &result)) { return false; } + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + freeaddrinfo(result); + return ret; +} + +inline std::string if2ip(const std::string &ifn) { +#ifndef _WIN32 + struct ifaddrs *ifap; + getifaddrs(&ifap); + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } + } + } + freeifaddrs(ifap); +#endif + return std::string(); +} + +inline socket_t create_client_socket(const char *host, int port, + time_t timeout_sec, + const std::string &intf) { + return create_socket( + host, port, [&](socket_t sock, struct addrinfo &ai) -> bool { + if (!intf.empty()) { + auto ip = if2ip(intf); + if (ip.empty()) { ip = intf; } + if (!bind_ip_address(sock, ip.c_str())) { return false; } + } + + set_nonblocking(sock, true); + + auto ret = ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); + if (ret < 0) { + if (is_connection_error() || + !wait_until_socket_is_ready(sock, timeout_sec, 0)) { + close_socket(sock); + return false; + } + } + + set_nonblocking(sock, false); + return true; + }); +} + inline std::string get_remote_addr(socket_t sock) { - struct sockaddr_storage addr; - socklen_t len = sizeof(addr); + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); - if (!getpeername(sock, (struct sockaddr*)&addr, &len)) { - char ipstr[NI_MAXHOST]; + if (!getpeername(sock, reinterpret_cast(&addr), &len)) { + std::array ipstr{}; - if (!getnameinfo((struct sockaddr*)&addr, len, - ipstr, sizeof(ipstr), nullptr, 0, NI_NUMERICHOST)) { - return ipstr; - } + if (!getnameinfo(reinterpret_cast(&addr), len, + ipstr.data(), ipstr.size(), nullptr, 0, NI_NUMERICHOST)) { + return ipstr.data(); } + } - return std::string(); + return std::string(); } -inline bool is_file(const std::string& path) -{ - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +inline const char * +find_content_type(const std::string &path, + const std::map &user_data) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second.c_str(); } + + if (ext == "txt") { + return "text/plain"; + } else if (ext == "html" || ext == "htm") { + return "text/html"; + } else if (ext == "css") { + return "text/css"; + } else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } else if (ext == "png") { + return "image/png"; + } else if (ext == "gif") { + return "image/gif"; + } else if (ext == "svg") { + return "image/svg+xml"; + } else if (ext == "ico") { + return "image/x-icon"; + } else if (ext == "json") { + return "application/json"; + } else if (ext == "pdf") { + return "application/pdf"; + } else if (ext == "js") { + return "application/javascript"; + } else if (ext == "wasm") { + return "application/wasm"; + } else if (ext == "xml") { + return "application/xml"; + } else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; } -inline bool is_dir(const std::string& path) -{ - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); -} - -inline bool is_valid_path(const std::string& path) { - size_t level = 0; - size_t i = 0; - - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } - - while (i < path.size()) { - // Read component - auto beg = i; - while (i < path.size() && path[i] != '/') { - i++; - } - - auto len = i - beg; - assert(len > 0); - - if (!path.compare(beg, len, ".")) { - ; - } else if (!path.compare(beg, len, "..")) { - if (level == 0) { - return false; - } - level--; - } else { - level++; - } - - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } - } - - return true; -} - -inline void read_file(const std::string& path, std::string& out) -{ - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast(size)); - fs.read(&out[0], size); -} - -inline std::string file_extension(const std::string& path) -{ - std::smatch m; - auto pat = std::regex("\\.([a-zA-Z0-9]+)$"); - if (std::regex_search(path, m, pat)) { - return m[1].str(); - } - return std::string(); -} - -inline const char* find_content_type(const std::string& path) -{ - auto ext = file_extension(path); - if (ext == "txt") { - return "text/plain"; - } else if (ext == "html") { - return "text/html"; - } else if (ext == "css") { - return "text/css"; - } else if (ext == "jpeg" || ext == "jpg") { - return "image/jpg"; - } else if (ext == "png") { - return "image/png"; - } else if (ext == "gif") { - return "image/gif"; - } else if (ext == "svg") { - return "image/svg+xml"; - } else if (ext == "ico") { - return "image/x-icon"; - } else if (ext == "json") { - return "application/json"; - } else if (ext == "pdf") { - return "application/pdf"; - } else if (ext == "js") { - return "application/javascript"; - } else if (ext == "xml") { - return "application/xml"; - } else if (ext == "xhtml") { - return "application/xhtml+xml"; - } - return nullptr; -} - -inline const char* status_message(int status) -{ - switch (status) { - case 200: return "OK"; - case 301: return "Moved Permanently"; - case 302: return "Found"; - case 303: return "See Other"; - case 304: return "Not Modified"; - case 400: return "Bad Request"; - case 403: return "Forbidden"; - case 404: return "Not Found"; - case 415: return "Unsupported Media Type"; - default: - case 500: return "Internal Server Error"; - } -} - -inline const char* get_header_value(const Headers& headers, const char* key, const char* def) -{ - auto it = headers.find(key); - if (it != headers.end()) { - return it->second.c_str(); - } - return def; -} - -inline int get_header_value_int(const Headers& headers, const char* key, int def) -{ - auto it = headers.find(key); - if (it != headers.end()) { - return std::stoi(it->second); - } - return def; -} - -inline bool read_headers(Stream& strm, Headers& headers) -{ - static std::regex re(R"((.+?):\s*(.+?)\s*\r\n)"); - - const auto bufsiz = 2048; - char buf[bufsiz]; - - stream_line_reader reader(strm, buf, bufsiz); - - for (;;) { - if (!reader.getline()) { - return false; - } - if (!strcmp(reader.ptr(), "\r\n")) { - break; - } - std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { - auto key = std::string(m[1]); - auto val = std::string(m[2]); - headers.emplace(key, val); - } - } - - return true; -} - -inline bool read_content_with_length(Stream& strm, std::string& out, size_t len, Progress progress) -{ - out.assign(len, 0); - size_t r = 0; - while (r < len){ - auto n = strm.read(&out[r], len - r); - if (n <= 0) { - return false; - } - - r += n; - - if (progress) { - progress(r, len); - } - } - - return true; -} - -inline bool read_content_without_length(Stream& strm, std::string& out) -{ - for (;;) { - char byte; - auto n = strm.read(&byte, 1); - if (n < 0) { - return false; - } else if (n == 0) { - return true; - } - out += byte; - } - - return true; -} - -inline bool read_content_chunked(Stream& strm, std::string& out) -{ - const auto bufsiz = 16; - char buf[bufsiz]; - - stream_line_reader reader(strm, buf, bufsiz); - - if (!reader.getline()) { - return false; - } - - auto chunk_len = std::stoi(reader.ptr(), 0, 16); - - while (chunk_len > 0){ - std::string chunk; - if (!read_content_with_length(strm, chunk, chunk_len, nullptr)) { - return false; - } - - if (!reader.getline()) { - return false; - } - - if (strcmp(reader.ptr(), "\r\n")) { - break; - } - - out += chunk; - - if (!reader.getline()) { - return false; - } - - chunk_len = std::stoi(reader.ptr(), 0, 16); - } - - if (chunk_len == 0) { - // Reader terminator after chunks - if (!reader.getline() || strcmp(reader.ptr(), "\r\n")) - return false; - } - - return true; -} - -template -bool read_content(Stream& strm, T& x, Progress progress = Progress()) -{ - auto len = get_header_value_int(x.headers, "Content-Length", 0); - - if (len) { - return read_content_with_length(strm, x.body, len, progress); - } else { - const auto& encoding = get_header_value(x.headers, "Transfer-Encoding", ""); - - if (!strcasecmp(encoding, "chunked")) { - return read_content_chunked(strm, x.body); - } else { - return read_content_without_length(strm, x.body); - } - } - - return true; -} - -template -inline void write_headers(Stream& strm, const T& info) -{ - for (const auto& x: info.headers) { - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - } - strm.write("\r\n"); -} - -inline std::string encode_url(const std::string& s) -{ - std::string result; - - for (auto i = 0; s[i]; i++) { - switch (s[i]) { - case ' ': result += "+"; break; - case '\'': result += "%27"; break; - case ',': result += "%2C"; break; - case ':': result += "%3A"; break; - case ';': result += "%3B"; break; - default: - if (s[i] < 0) { - result += '%'; - char hex[4]; - size_t len = snprintf(hex, sizeof(hex) - 1, "%02X", (unsigned char)s[i]); - assert(len == 2); - result.append(hex, len); - } else { - result += s[i]; - } - break; - } - } - - return result; -} - -inline bool is_hex(char c, int& v) -{ - if (0x20 <= c && isdigit(c)) { - v = c - '0'; - return true; - } else if ('A' <= c && c <= 'F') { - v = c - 'A' + 10; - return true; - } else if ('a' <= c && c <= 'f') { - v = c - 'a' + 10; - return true; - } - return false; -} - -inline bool from_hex_to_i(const std::string& s, size_t i, size_t cnt, int& val) -{ - if (i >= s.size()) { - return false; - } - - val = 0; - for (; cnt; i++, cnt--) { - if (!s[i]) { - return false; - } - int v = 0; - if (is_hex(s[i], v)) { - val = val * 16 + v; - } else { - return false; - } - } - return true; -} - -inline size_t to_utf8(int code, char* buff) -{ - if (code < 0x0080) { - buff[0] = (code & 0x7F); - return 1; - } else if (code < 0x0800) { - buff[0] = (0xC0 | ((code >> 6) & 0x1F)); - buff[1] = (0x80 | (code & 0x3F)); - return 2; - } else if (code < 0xD800) { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } else if (code < 0xE000) { // D800 - DFFF is invalid... - return 0; - } else if (code < 0x10000) { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } else if (code < 0x110000) { - buff[0] = (0xF0 | ((code >> 18) & 0x7)); - buff[1] = (0x80 | ((code >> 12) & 0x3F)); - buff[2] = (0x80 | ((code >> 6) & 0x3F)); - buff[3] = (0x80 | (code & 0x3F)); - return 4; - } - - // NOTREACHED - return 0; -} - -inline std::string decode_url(const std::string& s) -{ - std::string result; - - for (size_t i = 0; i < s.size(); i++) { - if (s[i] == '%' && i + 1 < s.size()) { - if (s[i + 1] == 'u') { - int val = 0; - if (from_hex_to_i(s, i + 2, 4, val)) { - // 4 digits Unicode codes - char buff[4]; - size_t len = to_utf8(val, buff); - if (len > 0) { - result.append(buff, len); - } - i += 5; // 'u0000' - } else { - result += s[i]; - } - } else { - int val = 0; - if (from_hex_to_i(s, i + 1, 2, val)) { - // 2 digits hex codes - result += val; - i += 2; // '00' - } else { - result += s[i]; - } - } - } else if (s[i] == '+') { - result += ' '; - } else { - result += s[i]; - } - } - - return result; -} - -inline void parse_query_text(const std::string& s, Params& params) -{ - split(&s[0], &s[s.size()], '&', [&](const char* b, const char* e) { - std::string key; - std::string val; - split(b, e, '=', [&](const char* b, const char* e) { - if (key.empty()) { - key.assign(b, e); - } else { - val.assign(b, e); - } - }); - params.emplace(key, decode_url(val)); - }); -} - -inline bool parse_multipart_boundary(const std::string& content_type, std::string& boundary) -{ - auto pos = content_type.find("boundary="); - if (pos == std::string::npos) { - return false; - } - - boundary = content_type.substr(pos + 9); - return true; -} - -inline bool parse_multipart_formdata( - const std::string& boundary, const std::string& body, MultipartFiles& files) -{ - static std::string dash = "--"; - static std::string crlf = "\r\n"; - - static std::regex re_content_type( - "Content-Type: (.*?)", std::regex_constants::icase); - - static std::regex re_content_disposition( - "Content-Disposition: form-data; name=\"(.*?)\"(?:; filename=\"(.*?)\")?", - std::regex_constants::icase); - - auto dash_boundary = dash + boundary; - - auto pos = body.find(dash_boundary); - if (pos != 0) { - return false; - } - - pos += dash_boundary.size(); - - auto next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } - - pos = next_pos + crlf.size(); - - while (pos < body.size()) { - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } - - std::string name; - MultipartFile file; - - auto header = body.substr(pos, (next_pos - pos)); - - while (pos != next_pos) { - std::smatch m; - if (std::regex_match(header, m, re_content_type)) { - file.content_type = m[1]; - } else if (std::regex_match(header, m, re_content_disposition)) { - name = m[1]; - file.filename = m[2]; - } - - pos = next_pos + crlf.size(); - - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } - - header = body.substr(pos, (next_pos - pos)); - } - - pos = next_pos + crlf.size(); - - next_pos = body.find(crlf + dash_boundary, pos); - - if (next_pos == std::string::npos) { - return false; - } - - file.offset = pos; - file.length = next_pos - pos; - - pos = next_pos + crlf.size() + dash_boundary.size(); - - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { - return false; - } - - files.emplace(name, file); - - pos = next_pos + crlf.size(); - } - - return true; -} - -inline std::string to_lower(const char* beg, const char* end) -{ - std::string out; - auto it = beg; - while (it != end) { - out += ::tolower(*it); - it++; - } - return out; -} - -inline void make_range_header_core(std::string&) {} - -template -inline void make_range_header_core(std::string& field, uint64_t value) -{ - if (!field.empty()) { - field += ", "; - } - field += std::to_string(value) + "-"; -} - -template -inline void make_range_header_core(std::string& field, uint64_t value1, uint64_t value2, Args... args) -{ - if (!field.empty()) { - field += ", "; - } - field += std::to_string(value1) + "-" + std::to_string(value2); - make_range_header_core(field, args...); +inline const char *status_message(int status) { + switch (status) { + case 200: return "OK"; + case 202: return "Accepted"; + case 204: return "No Content"; + case 206: return "Partial Content"; + case 301: return "Moved Permanently"; + case 302: return "Found"; + case 303: return "See Other"; + case 304: return "Not Modified"; + case 400: return "Bad Request"; + case 401: return "Unauthorized"; + case 403: return "Forbidden"; + case 404: return "Not Found"; + case 413: return "Payload Too Large"; + case 414: return "Request-URI Too Long"; + case 415: return "Unsupported Media Type"; + case 416: return "Range Not Satisfiable"; + case 503: return "Service Unavailable"; + + default: + case 500: return "Internal Server Error"; + } } #ifdef CPPHTTPLIB_ZLIB_SUPPORT -inline bool can_compress(const std::string& content_type) { - return !content_type.find("text/") || - content_type == "image/svg+xml" || - content_type == "application/javascript" || - content_type == "application/json" || - content_type == "application/xml" || - content_type == "application/xhtml+xml"; +inline bool can_compress(const std::string &content_type) { + return !content_type.find("text/") || content_type == "image/svg+xml" || + content_type == "application/javascript" || + content_type == "application/json" || + content_type == "application/xml" || + content_type == "application/xhtml+xml"; } -inline void compress(std::string& content) -{ - z_stream strm; +inline bool compress(std::string &content) { + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY); + if (ret != Z_OK) { return false; } + + strm.avail_in = content.size(); + strm.next_in = + const_cast(reinterpret_cast(content.data())); + + std::string compressed; + + std::array buff{}; + do { + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); + ret = deflate(&strm, Z_FINISH); + assert(ret != Z_STREAM_ERROR); + compressed.append(buff.data(), buff.size() - strm.avail_out); + } while (strm.avail_out == 0); + + assert(ret == Z_STREAM_END); + assert(strm.avail_in == 0); + + content.swap(compressed); + + deflateEnd(&strm); + return true; +} + +class decompressor { +public: + decompressor() { strm.zalloc = Z_NULL; strm.zfree = Z_NULL; strm.opaque = Z_NULL; - auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, Z_DEFAULT_STRATEGY); - if (ret != Z_OK) { - return; - } + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 16 specifies + // that the stream to decompress will be formatted with a gzip wrapper. + is_valid_ = inflateInit2(&strm, 16 + 15) == Z_OK; + } - strm.avail_in = content.size(); - strm.next_in = (Bytef *)content.data(); + ~decompressor() { inflateEnd(&strm); } - std::string compressed; + bool is_valid() const { return is_valid_; } - const auto bufsiz = 16384; - char buff[bufsiz]; + template + bool decompress(const char *data, size_t data_length, T callback) { + int ret = Z_OK; + + strm.avail_in = data_length; + strm.next_in = const_cast(reinterpret_cast(data)); + + std::array buff{}; do { - strm.avail_out = bufsiz; - strm.next_out = (Bytef *)buff; - deflate(&strm, Z_FINISH); - compressed.append(buff, bufsiz - strm.avail_out); + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm); return false; + } + + if (!callback(buff.data(), buff.size() - strm.avail_out)) { + return false; + } } while (strm.avail_out == 0); - content.swap(compressed); + return ret == Z_OK || ret == Z_STREAM_END; + } - deflateEnd(&strm); +private: + bool is_valid_; + z_stream strm; +}; +#endif + +inline bool has_header(const Headers &headers, const char *key) { + return headers.find(key) != headers.end(); } -inline void decompress(std::string& content) -{ - z_stream strm; - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; +inline const char *get_header_value(const Headers &headers, const char *key, + size_t id = 0, const char *def = nullptr) { + auto it = headers.find(key); + std::advance(it, id); + if (it != headers.end()) { return it->second.c_str(); } + return def; +} - // 15 is the value of wbits, which should be at the maximum possible value to ensure - // that any gzip stream can be decoded. The offset of 16 specifies that the stream - // to decompress will be formatted with a gzip wrapper. - auto ret = inflateInit2(&strm, 16 + 15); - if (ret != Z_OK) { - return; +inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, + int def = 0) { + auto it = headers.find(key); + if (it != headers.end()) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } else { + continue; // Skip invalid line. } - strm.avail_in = content.size(); - strm.next_in = (Bytef *)content.data(); + // Skip trailing spaces and tabs. + auto end = line_reader.ptr() + line_reader.size() - 2; + while (line_reader.ptr() < end && (end[-1] == ' ' || end[-1] == '\t')) { + end--; + } - std::string decompressed; + // Horizontal tab and ' ' are considered whitespace and are ignored when on + // the left or right side of the header value: + // - https://stackoverflow.com/questions/50179659/ + // - https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html + static const std::regex re(R"((.+?):[\t ]*(.+))"); - const auto bufsiz = 16384; - char buff[bufsiz]; - do { - strm.avail_out = bufsiz; - strm.next_out = (Bytef *)buff; - inflate(&strm, Z_NO_FLUSH); - decompressed.append(buff, bufsiz - strm.avail_out); - } while (strm.avail_out == 0); + std::cmatch m; + if (std::regex_match(line_reader.ptr(), end, m, re)) { + auto key = std::string(m[1]); + auto val = std::string(m[2]); + headers.emplace(key, val); + } + } - content.swap(decompressed); + return true; +} - inflateEnd(&strm); +inline bool read_content_with_length(Stream &strm, uint64_t len, + Progress progress, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, n)) { return false; } + + r += n; + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += n; + } +} + +inline bool read_content_without_length(Stream &strm, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; + } + if (!out(buf, n)) { return false; } + } + + return true; +} + +inline bool read_content_chunked(Stream &strm, ContentReceiver out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return false; } + + auto chunk_len = std::stoi(line_reader.ptr(), 0, 16); + + while (chunk_len > 0) { + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { return false; } + + if (strcmp(line_reader.ptr(), "\r\n")) { break; } + + if (!line_reader.getline()) { return false; } + + chunk_len = std::stoi(line_reader.ptr(), 0, 16); + } + + if (chunk_len == 0) { + // Reader terminator after chunks + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + Progress progress, ContentReceiver receiver) { + + ContentReceiver out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor decompressor; + + if (!decompressor.is_valid()) { + status = 500; + return false; + } + + if (x.get_header_value("Content-Encoding") == "gzip") { + out = [&](const char *buf, size_t n) { + return decompressor.decompress( + buf, n, [&](const char *buf, size_t n) { return receiver(buf, n); }); + }; + } +#else + if (x.get_header_value("Content-Encoding") == "gzip") { + status = 415; + return false; + } +#endif + + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value_uint64(x.headers, "Content-Length", 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); + } + } + + if (!ret) { status = exceed_payload_max_length ? 413 : 400; } + + return ret; +} + +template +inline int write_headers(Stream &strm, const T &info, const Headers &headers) { + auto write_len = 0; + for (const auto &x : info.headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline ssize_t write_content(Stream &strm, ContentProvider content_provider, + size_t offset, size_t length) { + size_t begin_offset = offset; + size_t end_offset = offset + length; + while (offset < end_offset) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + offset += l; + written_length = strm.write(d, l); + }; + data_sink.done = [&](void) { written_length = -1; }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, end_offset - offset, data_sink); + if (written_length < 0) { return written_length; } + } + return static_cast(offset - begin_offset); +} + +template +inline ssize_t write_content_chunked(Stream &strm, + ContentProvider content_provider, + T is_shutting_down) { + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + while (data_available && !is_shutting_down()) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + data_available = l > 0; + offset += l; + + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n"; + written_length = strm.write(chunk); + }; + data_sink.done = [&](void) { + data_available = false; + written_length = strm.write("0\r\n\r\n"); + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, 0, data_sink); + + if (written_length < 0) { return written_length; } + total_written_length += written_length; + } + return total_written_length; +} + +template +inline bool redirect(T &cli, const Request &req, Response &res, + const std::string &path) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count -= 1; + + Response new_res; + + auto ret = cli.send(new_req, new_res); + if (ret) { res = new_res; } + return ret; +} + +inline std::string encode_url(const std::string &s) { + std::string result; + + for (auto i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + size_t len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, len); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + split(&s[0], &s[s.size()], '&', [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b, const char *e) { + if (key.empty()) { + key.assign(b, e); + } else { + val.assign(b, e); + } + }); + params.emplace(key, decode_url(val)); + }); +} + +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { return false; } + + boundary = content_type.substr(pos + 9); + return true; +} + +inline bool parse_range_header(const std::string &s, Ranges &ranges) { + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = m.position(1); + auto len = m.length(1); + bool all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) return; + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch m; + if (std::regex_match(b, e, m, re_another_range)) { + ssize_t first = -1; + if (!m.str(1).empty()) { + first = static_cast(std::stoll(m.str(1))); + } + + ssize_t last = -1; + if (!m.str(2).empty()) { + last = static_cast(std::stoll(m.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; +} + +class MultipartFormDataParser { +public: + MultipartFormDataParser() {} + + void set_boundary(const std::string &boundary) { boundary_ = boundary; } + + bool is_valid() const { return is_valid_; } + + template + bool parse(const char *buf, size_t n, T content_callback, U header_callback) { + static const std::regex re_content_type(R"(^Content-Type:\s*(.*?)\s*$)", + std::regex_constants::icase); + + static const std::regex re_content_disposition( + "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" + "\"(.*?)\")?\\s*$", + std::regex_constants::icase); + + buf_.append(buf, n); // TODO: performance improvement + + while (!buf_.empty()) { + switch (state_) { + case 0: { // Initial boundary + auto pattern = dash_ + boundary_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + auto pos = buf_.find(pattern); + if (pos != 0) { + is_done_ = true; + return false; + } + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_.find(crlf_); + while (pos != std::string::npos) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + is_done_ = false; + return false; + } + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 3; + break; + } + + auto header = buf_.substr(0, pos); + { + std::smatch m; + if (std::regex_match(header, m, re_content_type)) { + file_.content_type = m[1]; + } else if (std::regex_match(header, m, re_content_disposition)) { + file_.name = m[1]; + file_.filename = m[2]; + } + } + + buf_.erase(0, pos + crlf_.size()); + off_ += pos + crlf_.size(); + pos = buf_.find(crlf_); + } + break; + } + case 3: { // Body + { + auto pattern = crlf_ + dash_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos == std::string::npos) { pos = buf_.size(); } + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos; + buf_.erase(0, pos); + } + + { + auto pattern = crlf_ + dash_ + boundary_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos != std::string::npos) { + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos + pattern.size(); + buf_.erase(0, pos + pattern.size()); + state_ = 4; + } else { + if (!content_callback(buf_.data(), pattern.size())) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pattern.size(); + buf_.erase(0, pattern.size()); + } + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_.size()) { return true; } + if (buf_.find(crlf_) == 0) { + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 1; + } else { + auto pattern = dash_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + if (buf_.find(pattern) == 0) { + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + is_valid_ = true; + state_ = 5; + } else { + is_done_ = true; + return true; + } + } + break; + } + case 5: { // Done + is_valid_ = false; + return false; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + + std::string buf_; + size_t state_ = 0; + size_t is_valid_ = false; + size_t is_done_ = false; + size_t off_ = 0; + MultipartFormData file_; +}; + +inline std::string to_lower(const char *beg, const char *end) { + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; +} + +inline std::string make_multipart_data_boundary() { + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); + + std::string result = "--cpp-httplib-multipart-data-"; + + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + + return result; +} + +inline std::pair +get_range_offset_and_length(const Request &req, size_t content_length, + size_t index) { + auto r = req.ranges[index]; + + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } + + if (r.first == -1) { + r.first = content_length - r.second; + r.second = content_length - 1; + } + + if (r.second == -1) { r.second = content_length - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline std::string make_content_range_header_field(size_t offset, size_t length, + size_t content_length) { + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + SToken stoken, CToken ctoken, + Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offsets = get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { return false; } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); + + return true; +} + +inline std::string make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); + + return data; +} + +inline size_t +get_multipart_ranges_data_length(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +inline bool write_multipart_ranges_data(Stream &strm, const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type) { + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider, offset, length) >= 0; + }); +} + +inline std::pair +get_range_offset_and_length(const Request &req, const Response &res, + size_t index) { + auto r = req.ranges[index]; + + if (r.second == -1) { r.second = res.content_length - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI") { + return true; + } + // TODO: check if Content-Length is set + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +template +inline std::string message_digest(const std::string &s, Init init, + Update update, Final final, + size_t digest_length) { + using namespace std; + + std::vector md(digest_length, 0); + CTX ctx; + init(&ctx); + update(&ctx, s.data(), s.size()); + final(md.data(), &ctx); + + stringstream ss; + for (auto c : md) { + ss << setfill('0') << setw(2) << hex << (unsigned int)c; + } + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest(s, MD5_Init, MD5_Update, MD5_Final, + MD5_DIGEST_LENGTH); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, + SHA256_DIGEST_LENGTH); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, + SHA512_DIGEST_LENGTH); } #endif #ifdef _WIN32 class WSInit { public: - WSInit() { - WSADATA wsaData; - WSAStartup(0x0002, &wsaData); - } + WSInit() { + WSADATA wsaData; + WSAStartup(0x0002, &wsaData); + } - ~WSInit() { - WSACleanup(); - } + ~WSInit() { WSACleanup(); } }; static WSInit wsinit_; @@ -1275,876 +2502,1804 @@ static WSInit wsinit_; } // namespace detail // Header utilities -template -inline std::pair make_range_header(uint64_t value, Args... args) -{ - std::string field; - detail::make_range_header_core(field, value, args...); - field.insert(0, "bytes="); - return std::make_pair("Range", field); +inline std::pair make_range_header(Ranges ranges) { + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", field); +} + +inline std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + using namespace std; + + string nc; + { + stringstream ss; + ss << setfill('0') << setw(8) << hex << cnonce_count; + nc = ss.str(); + } + + auto qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else { + qop = "auth"; + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + string response; + { + auto H = algo == "SHA-256" + ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + + auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") + + "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path + + "\", algorithm=" + algo + ", qop=" + qop + ", nc=\"" + nc + + "\", cnonce=\"" + cnonce + "\", response=\"" + response + "\""; + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} +#endif + +inline bool parse_www_authenticate(const httplib::Response &res, + std::map &auth, + bool is_proxy) { + auto key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(m.position(1), m.length(1)); + auto val = m.length(2) > 0 ? s.substr(m.position(2), m.length(2)) + : s.substr(m.position(3), m.length(3)); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 +inline std::string random_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; } // Request implementation -inline bool Request::has_header(const char* key) const -{ - return headers.find(key) != headers.end(); +inline bool Request::has_header(const char *key) const { + return detail::has_header(headers, key); } -inline std::string Request::get_header_value(const char* key) const -{ - return detail::get_header_value(headers, key, ""); +inline std::string Request::get_header_value(const char *key, size_t id) const { + return detail::get_header_value(headers, key, id, ""); } -inline void Request::set_header(const char* key, const char* val) -{ - headers.emplace(key, val); +inline size_t Request::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return std::distance(r.first, r.second); } -inline bool Request::has_param(const char* key) const -{ - return params.find(key) != params.end(); +inline void Request::set_header(const char *key, const char *val) { + headers.emplace(key, val); } -inline std::string Request::get_param_value(const char* key) const -{ - auto it = params.find(key); - if (it != params.end()) { - return it->second; - } - return std::string(); +inline void Request::set_header(const char *key, const std::string &val) { + headers.emplace(key, val); } -inline bool Request::has_file(const char* key) const -{ - return files.find(key) != files.end(); +inline bool Request::has_param(const char *key) const { + return params.find(key) != params.end(); } -inline MultipartFile Request::get_file_value(const char* key) const -{ - auto it = files.find(key); - if (it != files.end()) { - return it->second; - } - return MultipartFile(); +inline std::string Request::get_param_value(const char *key, size_t id) const { + auto it = params.find(key); + std::advance(it, id); + if (it != params.end()) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const char *key) const { + auto r = params.equal_range(key); + return std::distance(r.first, r.second); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.find("multipart/form-data"); +} + +inline bool Request::has_file(const char *key) const { + return files.find(key) != files.end(); +} + +inline MultipartFormData Request::get_file_value(const char *key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFormData(); } // Response implementation -inline bool Response::has_header(const char* key) const -{ - return headers.find(key) != headers.end(); +inline bool Response::has_header(const char *key) const { + return headers.find(key) != headers.end(); } -inline std::string Response::get_header_value(const char* key) const -{ - return detail::get_header_value(headers, key, ""); +inline std::string Response::get_header_value(const char *key, + size_t id) const { + return detail::get_header_value(headers, key, id, ""); } -inline void Response::set_header(const char* key, const char* val) -{ - headers.emplace(key, val); +inline size_t Response::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return std::distance(r.first, r.second); } -inline void Response::set_redirect(const char* url) -{ - set_header("Location", url); - status = 302; +inline void Response::set_header(const char *key, const char *val) { + headers.emplace(key, val); } -inline void Response::set_content(const char* s, size_t n, const char* content_type) -{ - body.assign(s, n); - set_header("Content-Type", content_type); +inline void Response::set_header(const char *key, const std::string &val) { + headers.emplace(key, val); } -inline void Response::set_content(const std::string& s, const char* content_type) -{ - body = s; - set_header("Content-Type", content_type); +inline void Response::set_redirect(const char *url) { + set_header("Location", url); + status = 302; +} + +inline void Response::set_content(const char *s, size_t n, + const char *content_type) { + body.assign(s, n); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, + const char *content_type) { + body = s; + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t length, + std::function provider, + std::function resource_releaser) { + assert(length > 0); + content_length = length; + content_provider = [provider](size_t offset, size_t length, DataSink &sink) { + provider(offset, length, sink); + }; + content_provider_resource_releaser = resource_releaser; +} + +inline void Response::set_chunked_content_provider( + std::function provider, + std::function resource_releaser) { + content_length = 0; + content_provider = [provider](size_t offset, size_t, DataSink &sink) { + provider(offset, sink); + }; + content_provider_resource_releaser = resource_releaser; } // Rstream implementation -template -inline void Stream::write_format(const char* fmt, const Args& ...args) -{ - const auto bufsiz = 2048; - char buf[bufsiz]; +inline int Stream::write(const char *ptr) { return write(ptr, strlen(ptr)); } -#if defined(_MSC_VER) && _MSC_VER < 1900 - auto n = _snprintf_s(buf, bufsiz, bufsiz - 1, fmt, args...); -#else - auto n = snprintf(buf, bufsiz - 1, fmt, args...); -#endif - if (n > 0) { - if (n >= bufsiz - 1) { - std::vector glowable_buf(bufsiz); - - while (n >= static_cast(glowable_buf.size() - 1)) { - glowable_buf.resize(glowable_buf.size() * 2); -#if defined(_MSC_VER) && _MSC_VER < 1900 - n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), glowable_buf.size() - 1, fmt, args...); -#else - n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); -#endif - } - write(&glowable_buf[0], n); - } else { - write(buf, n); - } - } +inline int Stream::write(const std::string &s) { + return write(s.data(), s.size()); } +template +inline int Stream::write_format(const char *fmt, const Args &... args) { + std::array buf; + +#if defined(_MSC_VER) && _MSC_VER < 1900 + auto n = _snprintf_s(buf, bufsiz, buf.size() - 1, fmt, args...); +#else + auto n = snprintf(buf.data(), buf.size() - 1, fmt, args...); +#endif + if (n <= 0) { return n; } + + if (n >= static_cast(buf.size()) - 1) { + std::vector glowable_buf(buf.size()); + + while (n >= static_cast(glowable_buf.size() - 1)) { + glowable_buf.resize(glowable_buf.size() * 2); +#if defined(_MSC_VER) && _MSC_VER < 1900 + n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, args...); +#else + n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); +#endif + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); + } +} + +namespace detail { + // Socket stream implementation -inline SocketStream::SocketStream(socket_t sock): sock_(sock) -{ +inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + +inline SocketStream::~SocketStream() {} + +inline bool SocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } -inline SocketStream::~SocketStream() -{ +inline bool SocketStream::is_writable() const { + return detail::select_write(sock_, 0, 0) > 0; } -inline int SocketStream::read(char* ptr, size_t size) -{ - return recv(sock_, ptr, size, 0); +inline int SocketStream::read(char *ptr, size_t size) { + if (is_readable()) { return recv(sock_, ptr, static_cast(size), 0); } + return -1; } -inline int SocketStream::write(const char* ptr, size_t size) -{ - return send(sock_, ptr, size, 0); +inline int SocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { return send(sock_, ptr, static_cast(size), 0); } + return -1; } -inline int SocketStream::write(const char* ptr) -{ - return write(ptr, strlen(ptr)); +inline std::string SocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); } -inline std::string SocketStream::get_remote_addr() { - return detail::get_remote_addr(sock_); +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::is_writable() const { return true; } + +inline int BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1900 + int len_read = static_cast(buffer._Copy_s(ptr, size, size, position)); +#else + int len_read = static_cast(buffer.copy(ptr, size, position)); +#endif + position += len_read; + return len_read; } +inline int BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline std::string BufferStream::get_remote_addr() const { return ""; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +} // namespace detail + // HTTP server implementation inline Server::Server() - : keep_alive_max_count_(5) - , is_running_(false) - , svr_sock_(INVALID_SOCKET) - , running_threads_(0) -{ + : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + read_timeout_sec_(CPPHTTPLIB_READ_TIMEOUT_SECOND), + read_timeout_usec_(CPPHTTPLIB_READ_TIMEOUT_USECOND), + payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false), + svr_sock_(INVALID_SOCKET) { #ifndef _WIN32 - signal(SIGPIPE, SIG_IGN); + signal(SIGPIPE, SIG_IGN); #endif + new_task_queue = [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }; } -inline Server::~Server() -{ +inline Server::~Server() {} + +inline Server &Server::Get(const char *pattern, Handler handler) { + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Get(const char* pattern, Handler handler) -{ - get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Post(const char *pattern, Handler handler) { + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Post(const char* pattern, Handler handler) -{ - post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Post(const char *pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Put(const char* pattern, Handler handler) -{ - put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Put(const char *pattern, Handler handler) { + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Delete(const char* pattern, Handler handler) -{ - delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Put(const char *pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; } -inline Server& Server::Options(const char* pattern, Handler handler) -{ - options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; +inline Server &Server::Patch(const char *pattern, Handler handler) { + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; } -inline bool Server::set_base_dir(const char* path) -{ - if (detail::is_dir(path)) { - base_dir_ = path; - return true; +inline Server &Server::Patch(const char *pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Delete(const char *pattern, Handler handler) { + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Options(const char *pattern, Handler handler) { + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline bool Server::set_base_dir(const char *dir, const char *mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const char *mount_point, const char *dir) { + if (detail::is_dir(dir)) { + std::string mnt = mount_point ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.emplace_back(mnt, dir); + return true; } + } + return false; +} + +inline bool Server::remove_mount_point(const char *mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->first == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline void Server::set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime) { + file_extension_and_mimetype_map_[ext] = mime; +} + +inline void Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); +} + +inline void Server::set_error_handler(Handler handler) { + error_handler_ = std::move(handler); +} + +inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } + +inline void Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + +inline void Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; +} + +inline bool Server::bind_to_port(const char *host, int port, int socket_flags) { + if (bind_internal(host, port, socket_flags) < 0) return false; + return true; +} +inline int Server::bind_to_any_port(const char *host, int socket_flags) { + return bind_internal(host, 0, socket_flags); +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const char *host, int port, int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } +} + +inline bool Server::parse_request_line(const char *s, Request &req) { + const static std::regex re( + "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); + + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3]); + + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { detail::parse_query_text(m[4], req.params); } + + return true; + } + + return false; +} + +inline bool Server::write_response(Stream &strm, bool last_connection, + const Request &req, Response &res) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_) { error_handler_(req, res); } + + // Response line + if (!strm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { return false; -} + } -inline void Server::set_error_handler(Handler handler) -{ - error_handler_ = handler; -} + // Headers + if (last_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } -inline void Server::set_logger(Logger logger) -{ - logger_ = logger; -} + if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") { + res.set_header("Connection", "Keep-Alive"); + } -inline void Server::set_keep_alive_max_count(size_t count) -{ - keep_alive_max_count_ = count; -} + if (!res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } -inline int Server::bind_to_any_port(const char* host, int socket_flags) -{ - return bind_internal(host, 0, socket_flags); -} + if (!res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } -inline bool Server::listen_after_bind() { - return listen_internal(); -} + std::string content_type; + std::string boundary; -inline bool Server::listen(const char* host, int port, int socket_flags) -{ - if (bind_internal(host, port, socket_flags) < 0) - return false; - return listen_internal(); -} + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); -inline bool Server::is_running() const -{ - return is_running_; -} - -inline void Server::stop() -{ - if (is_running_) { - assert(svr_sock_ != INVALID_SOCKET); - detail::shutdown_socket(svr_sock_); - detail::close_socket(svr_sock_); - svr_sock_ = INVALID_SOCKET; - } -} - -inline bool Server::parse_request_line(const char* s, Request& req) -{ - static std::regex re("(GET|HEAD|POST|PUT|DELETE|OPTIONS) (([^?]+)(?:\\?(.+?))?) (HTTP/1\\.[01])\r\n"); - - std::cmatch m; - if (std::regex_match(s, m, re)) { - req.version = std::string(m[4]); - req.method = std::string(m[1]); - req.target = std::string(m[2]); - req.path = detail::decode_url(m[3]); - - // Parse query text - auto len = std::distance(m[4].first, m[4].second); - if (len > 0) { - detail::parse_query_text(m[4], req.params); - } - - return true; + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); } - return false; -} + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } -inline void Server::write_response(Stream& strm, bool last_connection, const Request& req, Response& res) -{ - assert(res.status != -1); - - if (400 <= res.status && error_handler_) { - error_handler_(req, res); + if (res.body.empty()) { + if (res.content_length > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_length); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + } else { + res.set_header("Content-Length", "0"); + } + } + } else { + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } else { + res.body = + detail::make_multipart_ranges_data(req, res, boundary, content_type); } - // Response line - strm.write_format("HTTP/1.1 %d %s\r\n", - res.status, - detail::status_message(res.status)); - - // Headers - if (last_connection || - req.version == "HTTP/1.0" || - req.get_header_value("Connection") == "close") { - res.set_header("Connection", "close"); - } - - if (!res.body.empty()) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - // TODO: 'Accpet-Encoding' has gzip, not gzip;q=0 - const auto& encodings = req.get_header_value("Accept-Encoding"); - if (encodings.find("gzip") != std::string::npos && - detail::can_compress(res.get_header_value("Content-Type"))) { - detail::compress(res.body); - res.set_header("Content-Encoding", "gzip"); - } + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + const auto &encodings = req.get_header_value("Accept-Encoding"); + if (encodings.find("gzip") != std::string::npos && + detail::can_compress(res.get_header_value("Content-Type"))) { + if (detail::compress(res.body)) { + res.set_header("Content-Encoding", "gzip"); + } + } #endif - if (!res.has_header("Content-Type")) { - res.set_header("Content-Type", "text/plain"); - } + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } - auto length = std::to_string(res.body.size()); - res.set_header("Content-Length", length.c_str()); + if (!detail::write_headers(strm, res, Headers())) { return false; } + + // Body + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { return false; } + } else if (res.content_provider) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + return false; + } } + } - detail::write_headers(strm, res); + // Log + if (logger_) { logger_(req, res); } - // Body - if (!res.body.empty() && req.method != "HEAD") { - strm.write(res.body.c_str(), res.body.size()); - } - - // Log - if (logger_) { - logger_(req, res); - } + return true; } -inline bool Server::handle_file_request(Request& req, Response& res) -{ - if (!base_dir_.empty() && detail::is_valid_path(req.path)) { - std::string path = base_dir_ + req.path; +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + if (res.content_length) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider, 0, + res.content_length) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider, offset, length) < + 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data(strm, req, res, boundary, + content_type)) { + return false; + } + } + } else { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + if (detail::write_content_chunked(strm, res.content_provider, + is_shutting_down) < 0) { + return false; + } + } + return true; +} - if (!path.empty() && path.back() == '/') { - path += "index.html"; - } +inline bool Server::read_content(Stream &strm, bool last_connection, + Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + auto ret = read_content_core( + strm, last_connection, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + return true; + }); + + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } + + return ret; +} + +inline bool Server::read_content_with_content_receiver( + Stream &strm, bool last_connection, Request &req, Response &res, + ContentReceiver receiver, MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, last_connection, req, res, receiver, + multipart_header, multipart_receiver); +} + +inline bool Server::read_content_core(Stream &strm, bool last_connection, + Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver) { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiver out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + multipart_form_data_parser.set_boundary(boundary); + out = [&](const char *buf, size_t n) { + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + mulitpart_header); + }; + } else { + out = receiver; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), out)) { + return write_response(strm, last_connection, req, res); + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + } + + return true; +} + +inline bool Server::handle_file_request(Request &req, Response &res, + bool head) { + for (const auto &kv : base_dirs_) { + const auto &mount_point = kv.first; + const auto &base_dir = kv.second; + + // Prefix match + if (!req.path.find(mount_point)) { + std::string sub_path = "/" + req.path.substr(mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } if (detail::is_file(path)) { - detail::read_file(path, res.body); - auto type = detail::find_content_type(path); - if (type) { - res.set_header("Content-Type", type); - } - res.status = 200; - return true; + detail::read_file(path, res.body); + auto type = + detail::find_content_type(path, file_extension_and_mimetype_map_); + if (type) { res.set_header("Content-Type", type); } + res.status = 200; + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + return true; } + } } - - return false; + } + return false; } -inline socket_t Server::create_server_socket(const char* host, int port, int socket_flags) const -{ - return detail::create_socket(host, port, - [](socket_t sock, struct addrinfo& ai) -> bool { - if (::bind(sock, ai.ai_addr, ai.ai_addrlen)) { - return false; - } - if (::listen(sock, 5)) { // Listen through 5 channels - return false; - } - return true; - }, socket_flags); +inline socket_t Server::create_server_socket(const char *host, int port, + int socket_flags) const { + return detail::create_socket( + host, port, + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, 5)) { // Listen through 5 channels + return false; + } + return true; + }, + socket_flags); } -inline int Server::bind_internal(const char* host, int port, int socket_flags) -{ - if (!is_valid()) { - return -1; - } +inline int Server::bind_internal(const char *host, int port, int socket_flags) { + if (!is_valid()) { return -1; } - svr_sock_ = create_server_socket(host, port, socket_flags); - if (svr_sock_ == INVALID_SOCKET) { - return -1; - } + svr_sock_ = create_server_socket(host, port, socket_flags); + if (svr_sock_ == INVALID_SOCKET) { return -1; } - if (port == 0) { - struct sockaddr_storage address; - socklen_t len = sizeof(address); - if (getsockname(svr_sock_, reinterpret_cast(&address), &len) == -1) { - return -1; + if (port == 0) { + struct sockaddr_storage address; + socklen_t len = sizeof(address); + if (getsockname(svr_sock_, reinterpret_cast(&address), + &len) == -1) { + return -1; + } + if (address.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&address)->sin_port); + } else if (address.ss_family == AF_INET6) { + return ntohs( + reinterpret_cast(&address)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + auto ret = true; + is_running_ = true; + + { + std::unique_ptr task_queue(new_task_queue()); + + for (;;) { + if (svr_sock_ == INVALID_SOCKET) { + // The server socket was closed by 'stop' method. + break; + } + + auto val = detail::select_read(svr_sock_, 0, 100000); + + if (val == 0) { // Timeout + continue; + } + + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; } - if (address.ss_family == AF_INET) { - return ntohs(reinterpret_cast(&address)->sin_port); - } else if (address.ss_family == AF_INET6) { - return ntohs(reinterpret_cast(&address)->sin6_port); + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; } else { - return -1; + ; // The server socket was closed by user. } - } else { - return port; + break; + } + + task_queue->enqueue([=]() { process_and_close_socket(sock); }); } + + task_queue->shutdown(); + } + + is_running_ = false; + return ret; } -inline bool Server::listen_internal() -{ - auto ret = true; - - is_running_ = true; - - for (;;) { - auto val = detail::select_read(svr_sock_, 0, 100000); - - if (val == 0) { // Timeout - if (svr_sock_ == INVALID_SOCKET) { - // The server socket was closed by 'stop' method. - break; - } - continue; - } - - socket_t sock = accept(svr_sock_, NULL, NULL); - - if (sock == INVALID_SOCKET) { - if (svr_sock_ != INVALID_SOCKET) { - detail::close_socket(svr_sock_); - ret = false; - } else { - ; // The server socket was closed by user. - } - break; - } - - // TODO: Use thread pool... - std::thread([=]() { - { - std::lock_guard guard(running_threads_mutex_); - running_threads_++; - } - - read_and_close_socket(sock); - - { - std::lock_guard guard(running_threads_mutex_); - running_threads_--; - } - }).detach(); - } - - // TODO: Use thread pool... - for (;;) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - std::lock_guard guard(running_threads_mutex_); - if (!running_threads_) { - break; - } - } - - is_running_ = false; - - return ret; -} - -inline bool Server::routing(Request& req, Response& res) -{ - if (req.method == "GET" && handle_file_request(req, res)) { - return true; - } - - if (req.method == "GET" || req.method == "HEAD") { - return dispatch_request(req, res, get_handlers_); - } else if (req.method == "POST") { - return dispatch_request(req, res, post_handlers_); - } else if (req.method == "PUT") { - return dispatch_request(req, res, put_handlers_); - } else if (req.method == "DELETE") { - return dispatch_request(req, res, delete_handlers_); - } else if (req.method == "OPTIONS") { - return dispatch_request(req, res, options_handlers_); - } - return false; -} - -inline bool Server::dispatch_request(Request& req, Response& res, Handlers& handlers) -{ - for (const auto& x: handlers) { - const auto& pattern = x.first; - const auto& handler = x.second; - - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res); - return true; - } - } - return false; -} - -inline bool Server::process_request(Stream& strm, bool last_connection, bool& connection_close) -{ - const auto bufsiz = 2048; - char buf[bufsiz]; - - detail::stream_line_reader reader(strm, buf, bufsiz); - - // Connection has been closed on client - if (!reader.getline()) { - return false; - } - - Request req; - Response res; - - res.version = "HTTP/1.1"; - - // Request line and headers - if (!parse_request_line(reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { - res.status = 400; - write_response(strm, last_connection, req, res); - return true; - } - - auto ret = true; - if (req.get_header_value("Connection") == "close") { - // ret = false; - connection_close = true; - } - - req.set_header("REMOTE_ADDR", strm.get_remote_addr().c_str()); - - // Body - if (req.method == "POST" || req.method == "PUT") { - if (!detail::read_content(strm, req)) { - res.status = 400; - write_response(strm, last_connection, req, res); - return ret; - } - - const auto& content_type = req.get_header_value("Content-Type"); - - if (req.get_header_value("Content-Encoding") == "gzip") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - detail::decompress(req.body); -#else - res.status = 415; - write_response(strm, last_connection, req, res); - return ret; -#endif - } - - if (!content_type.find("application/x-www-form-urlencoded")) { - detail::parse_query_text(req.body, req.params); - } else if(!content_type.find("multipart/form-data")) { - std::string boundary; - if (!detail::parse_multipart_boundary(content_type, boundary) || - !detail::parse_multipart_formdata(boundary, req.body, req.files)) { - res.status = 400; - write_response(strm, last_connection, req, res); - return ret; - } - } - } - - if (routing(req, res)) { - if (res.status == -1) { - res.status = 200; - } - } else { - res.status = 404; - } - - write_response(strm, last_connection, req, res); - return ret; -} - -inline bool Server::is_valid() const -{ +inline bool Server::routing(Request &req, Response &res, Stream &strm, + bool last_connection) { + // File handler + bool is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, last_connection, req, res, receiver, nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, last_connection, req, res, nullptr, header, receiver); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, reader, post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, reader, put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, reader, patch_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, last_connection, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; } -inline bool Server::read_and_close_socket(socket_t sock) -{ - return detail::read_and_close_socket( - sock, - keep_alive_max_count_, - [this](Stream& strm, bool last_connection, bool& connection_close) { - return process_request(strm, last_connection, connection_close); - }); +inline bool Server::dispatch_request(Request &req, Response &res, + Handlers &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + return false; +} + +inline bool Server::dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + HandlersForContentReader &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + Response res; + + res.version = "HTTP/1.1"; + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, last_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_close = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_close = true; + } + + req.set_header("REMOTE_ADDR", strm.get_remote_addr()); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error + } + } + + if (setup_request) { setup_request(req); } + + // Rounting + if (routing(req, res, strm, last_connection)) { + if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } + } else { + if (res.status == -1) { res.status = 404; } + } + + return write_response(strm, last_connection, req, res); +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + [this](Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, last_connection, connection_close, + nullptr); + }); } // HTTP client implementation -inline Client::Client( - const char* host, int port, size_t timeout_sec) - : host_(host) - , port_(port) - , timeout_sec_(timeout_sec) - , host_and_port_(host_ + ":" + std::to_string(port_)) -{ +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : host_(host), port_(port), + host_and_port_(host_ + ":" + std::to_string(port_)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + +inline Client::~Client() {} + +inline bool Client::is_valid() const { return true; } + +inline socket_t Client::create_client_socket() const { + if (!proxy_host_.empty()) { + return detail::create_client_socket(proxy_host_.c_str(), proxy_port_, + timeout_sec_, interface_); + } + return detail::create_client_socket(host_.c_str(), port_, timeout_sec_, + interface_); } -inline Client::~Client() -{ +inline bool Client::read_response_line(Stream &strm, Response &res) { + std::array buf; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + + const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), m, re)) { + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + } + + return true; } -inline bool Client::is_valid() const -{ - return true; -} +inline bool Client::send(const Request &req, Response &res) { + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } -inline socket_t Client::create_client_socket() const -{ - return detail::create_socket(host_.c_str(), port_, - [=](socket_t sock, struct addrinfo& ai) -> bool { - detail::set_nonblocking(sock, true); - - auto ret = connect(sock, ai.ai_addr, ai.ai_addrlen); - if (ret < 0) { - if (detail::is_connection_error() || - !detail::wait_until_socket_is_ready(sock, timeout_sec_, 0)) { - detail::close_socket(sock); - return false; - } - } - - detail::set_nonblocking(sock, false); - return true; - }); -} - -inline bool Client::read_response_line(Stream& strm, Response& res) -{ - const auto bufsiz = 2048; - char buf[bufsiz]; - - detail::stream_line_reader reader(strm, buf, bufsiz); - - if (!reader.getline()) { - return false; - } - - const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .+\r\n"); - - std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { - res.version = std::string(m[1]); - res.status = std::stoi(std::string(m[2])); - } - - return true; -} - -inline bool Client::send(Request& req, Response& res) -{ - if (req.path.empty()) { - return false; - } - - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { - return false; - } - - return read_and_close_socket(sock, req, res); -} - -inline void Client::write_request(Stream& strm, Request& req) -{ - auto path = detail::encode_url(req.path); - - // Request line - strm.write_format("%s %s HTTP/1.1\r\n", - req.method.c_str(), - path.c_str()); - - // Headers - req.set_header("Host", host_and_port_.c_str()); - - if (!req.has_header("Accept")) { - req.set_header("Accept", "*/*"); - } - - if (!req.has_header("User-Agent")) { - req.set_header("User-Agent", "cpp-httplib/0.2"); - } - - // TODO: Support KeepAlive connection - // if (!req.has_header("Connection")) { - req.set_header("Connection", "close"); - // } - - if (!req.body.empty()) { - if (!req.has_header("Content-Type")) { - req.set_header("Content-Type", "text/plain"); - } - - auto length = std::to_string(req.body.size()); - req.set_header("Content-Length", length.c_str()); - } - - detail::write_headers(strm, req); - - // Body - if (!req.body.empty()) { - if (req.get_header_value("Content-Type") == "application/x-www-form-urlencoded") { - auto str = detail::encode_url(req.body); - strm.write(str.c_str(), str.size()); - } else { - strm.write(req.body.c_str(), req.body.size()); - } - } -} - -inline bool Client::process_request(Stream& strm, Request& req, Response& res, bool& connection_close) -{ - // Send request - write_request(strm, req); - - // Receive response and headers - if (!read_response_line(strm, res) || !detail::read_headers(strm, res.headers)) { - return false; - } - - if (res.get_header_value("Connection") == "close" || res.version == "HTTP/1.0") { - connection_close = true; - } - - // Body - if (req.method != "HEAD") { - if (!detail::read_content(strm, res, req.progress)) { - return false; - } - - if (res.get_header_value("Content-Encoding") == "gzip") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - detail::decompress(res.body); -#else - return false; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + bool error; + if (!connect(sock, res, error)) { return error; } + } #endif + + return process_and_close_socket( + sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { + return handle_request(strm, req, res, last_connection, + connection_close); + }); +} + +inline bool Client::send(const std::vector &requests, + std::vector &responses) { + size_t i = 0; + while (i < requests.size()) { + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + Response res; + bool error; + if (!connect(sock, res, error)) { return false; } + } +#endif + + if (!process_and_close_socket(sock, requests.size() - i, + [&](Stream &strm, bool last_connection, + bool &connection_close) -> bool { + auto &req = requests[i++]; + auto res = Response(); + auto ret = handle_request(strm, req, res, + last_connection, + connection_close); + if (ret) { + responses.emplace_back(std::move(res)); + } + return ret; + })) { + return false; + } + } + + return true; +} + +inline bool Client::handle_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + if (req.path.empty()) { return false; } + + bool ret; + + if (!is_ssl() && !proxy_host_.empty()) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, last_connection, connection_close); + } else { + ret = process_request(strm, req, res, last_connection, connection_close); + } + + if (!ret) { return false; } + + if (300 < res.status && res.status < 400 && follow_location_) { + ret = redirect(req, res); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (res.status == 401 || res.status == 407) { + auto is_proxy = res.status == 407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + auto key = is_proxy ? "Proxy-Authorization" : "WWW-Authorization"; + new_req.headers.erase(key); + new_req.headers.insert(make_digest_authentication_header( + req, auth, 1, random_string(10), username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline bool Client::connect(socket_t sock, Response &res, bool &error) { + error = true; + Response res2; + + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false, connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; + } + + if (res2.status == 407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (parse_www_authenticate(res2, auth, true)) { + Response res3; + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, + bool &connection_close) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(make_digest_authentication_header( + req3, auth, 1, random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, false, + connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; } + } + } else { + res = res2; + return false; + } + } + + return true; +} +#endif + +inline bool Client::redirect(const Request &req, Response &res) { + if (req.redirect_count == 0) { return false; } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + const static std::regex re( + R"(^(?:([^:/?#]+):)?(?://([^/?#]*))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + + std::smatch m; + if (!regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto next_path = m[3].str(); + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + if (next_scheme == scheme && next_host == host_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str()); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); +#else + return false; +#endif + } else { + Client cli(next_host.c_str()); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); + } + } +} + +inline bool Client::write_request(Stream &strm, const Request &req, + bool last_connection) { + detail::BufferStream bstrm; + + // Request line + const auto &path = detail::encode_url(req.path); + + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + + // Additonal headers + Headers headers; + if (last_connection) { headers.emplace("Connection", "close"); } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } else { + if (port_ == 80) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } + + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.5"); + } + + if (req.body.empty()) { + if (req.content_provider) { + auto length = std::to_string(req.content_length); + headers.emplace("Content-Length", length); + } else { + headers.emplace("Content-Length", "0"); + } + } else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); } - return true; + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); + } + } + + if (!basic_auth_username_.empty() && !basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + + detail::write_headers(bstrm, req, headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); + + // Body + if (req.body.empty()) { + if (req.content_provider) { + size_t offset = 0; + size_t end_offset = req.content_length; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + auto written_length = strm.write(d, l); + offset += written_length; + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + while (offset < end_offset) { + req.content_provider(offset, end_offset - offset, data_sink); + } + } + } else { + strm.write(req.body); + } + + return true; } -inline bool Client::read_and_close_socket(socket_t sock, Request& req, Response& res) -{ - return detail::read_and_close_socket( - sock, - 0, - [&](Stream& strm, bool /*last_connection*/, bool& connection_close) { - return process_request(strm, req, res, connection_close); - }); -} +inline std::shared_ptr Client::send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; -inline std::shared_ptr Client::Get(const char* path, Progress progress) -{ - return Get(path, Headers(), progress); -} + req.headers.emplace("Content-Type", content_type); -inline std::shared_ptr Client::Get(const char* path, const Headers& headers, Progress progress) -{ - Request req; - req.method = "GET"; - req.path = path; - req.headers = headers; - req.progress = progress; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + if (content_provider) { + size_t offset = 0; - auto res = std::make_shared(); + DataSink data_sink; + data_sink.write = [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + offset += data_len; + }; + data_sink.is_writable = [&](void) { return true; }; - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr Client::Head(const char* path) -{ - return Head(path, Headers()); -} - -inline std::shared_ptr Client::Head(const char* path, const Headers& headers) -{ - Request req; - req.method = "HEAD"; - req.headers = headers; - req.path = path; - - auto res = std::make_shared(); - - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr Client::Post( - const char* path, const std::string& body, const char* content_type) -{ - return Post(path, Headers(), body, content_type); -} - -inline std::shared_ptr Client::Post( - const char* path, const Headers& headers, const std::string& body, const char* content_type) -{ - Request req; - req.method = "POST"; - req.headers = headers; - req.path = path; - - req.headers.emplace("Content-Type", content_type); - req.body = body; - - auto res = std::make_shared(); - - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr Client::Post(const char* path, const Params& params) -{ - return Post(path, Headers(), params); -} - -inline std::shared_ptr Client::Post(const char* path, const Headers& headers, const Params& params) -{ - std::string query; - for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { - query += "&"; - } - query += it->first; - query += "="; - query += it->second; + while (offset < content_length) { + content_provider(offset, content_length - offset, data_sink); + } + } else { + req.body = body; } - return Post(path, headers, query, "application/x-www-form-urlencoded"); + if (!detail::compress(req.body)) { return nullptr; } + req.headers.emplace("Content-Encoding", "gzip"); + } else +#endif + { + if (content_provider) { + req.content_length = content_length; + req.content_provider = content_provider; + } else { + req.body = body; + } + } + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; } -inline std::shared_ptr Client::Put( - const char* path, const std::string& body, const char* content_type) -{ - return Put(path, Headers(), body, content_type); +inline bool Client::process_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + // Send request + if (!write_request(strm, req, last_connection)) { return false; } + + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + return false; + } + + if (res.get_header_value("Connection") == "close" || + res.version == "HTTP/1.0") { + connection_close = true; + } + + if (req.response_handler) { + if (!req.response_handler(res)) { return false; } + } + + // Body + if (req.method != "HEAD" && req.method != "CONNECT") { + ContentReceiver out = [&](const char *buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { return false; } + res.body.append(buf, n); + return true; + }; + + if (req.content_receiver) { + out = [&](const char *buf, size_t n) { + return req.content_receiver(buf, n); + }; + } + + int dummy_status; + if (!detail::read_content(strm, res, std::numeric_limits::max(), + dummy_status, req.progress, out)) { + return false; + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; } -inline std::shared_ptr Client::Put( - const char* path, const Headers& headers, const std::string& body, const char* content_type) -{ - Request req; - req.method = "PUT"; - req.headers = headers; - req.path = path; - - req.headers.emplace("Content-Type", content_type); - req.body = body; - - auto res = std::make_shared(); - - return send(req, *res) ? res : nullptr; +inline bool Client::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + request_count = std::min(request_count, keep_alive_max_count_); + return detail::process_and_close_socket(true, sock, request_count, + read_timeout_sec_, read_timeout_usec_, + callback); } -inline std::shared_ptr Client::Delete(const char* path) -{ - return Delete(path, Headers()); +inline bool Client::is_ssl() const { return false; } + +inline std::shared_ptr Client::Get(const char *path) { + return Get(path, Headers(), Progress()); } -inline std::shared_ptr Client::Delete(const char* path, const Headers& headers) -{ - Request req; - req.method = "DELETE"; - req.path = path; - req.headers = headers; - - auto res = std::make_shared(); - - return send(req, *res) ? res : nullptr; +inline std::shared_ptr Client::Get(const char *path, + Progress progress) { + return Get(path, Headers(), std::move(progress)); } -inline std::shared_ptr Client::Options(const char* path) -{ - return Options(path, Headers()); +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers) { + return Get(path, headers, Progress()); } -inline std::shared_ptr Client::Options(const char* path, const Headers& headers) -{ - Request req; - req.method = "OPTIONS"; - req.path = path; - req.headers = headers; +inline std::shared_ptr +Client::Get(const char *path, const Headers &headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); - auto res = std::make_shared(); - - return send(req, *res) ? res : nullptr; + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; } +inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), content_receiver, + Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = std::move(content_receiver); + req.progress = std::move(progress); + + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Head(const char *path) { + return Head(path, Headers()); +} + +inline std::shared_ptr Client::Head(const char *path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Post(const char *path, + const std::string &body, + const char *content_type) { + return Post(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Post(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("POST", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr Client::Post(const char *path, + const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline std::shared_ptr Client::Post(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Post(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr +Client::Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("POST", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr +Client::Post(const char *path, const Headers &headers, const Params ¶ms) { + std::string query; + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += detail::encode_url(it->second); + } + + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline std::shared_ptr +Client::Post(const char *path, const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline std::shared_ptr +Client::Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items) { + auto boundary = detail::make_multipart_data_boundary(); + + std::string body; + + for (const auto &item : items) { + body += "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + body += item.content + "\r\n"; + } + + body += "--" + boundary + "--\r\n"; + + std::string content_type = "multipart/form-data; boundary=" + boundary; + return Post(path, headers, body, content_type.c_str()); +} + +inline std::shared_ptr Client::Put(const char *path, + const std::string &body, + const char *content_type) { + return Put(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Put(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("PUT", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr Client::Put(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Put(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr +Client::Put(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("PUT", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr Client::Put(const char *path, + const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline std::shared_ptr +Client::Put(const char *path, const Headers &headers, const Params ¶ms) { + std::string query; + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += detail::encode_url(it->second); + } + + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline std::shared_ptr Client::Patch(const char *path, + const std::string &body, + const char *content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Patch(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("PATCH", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr Client::Patch(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Patch(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr +Client::Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("PATCH", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr Client::Delete(const char *path) { + return Delete(path, Headers(), std::string(), nullptr); +} + +inline std::shared_ptr Client::Delete(const char *path, + const std::string &body, + const char *content_type) { + return Delete(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Delete(const char *path, + const Headers &headers) { + return Delete(path, headers, std::string(), nullptr); +} + +inline std::shared_ptr Client::Delete(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + + if (content_type) { req.headers.emplace("Content-Type", content_type); } + req.body = body; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Options(const char *path) { + return Options(path, Headers()); +} + +inline std::shared_ptr Client::Options(const char *path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.path = path; + req.headers = headers; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline void Client::set_timeout_sec(time_t timeout_sec) { + timeout_sec_ = timeout_sec; +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void Client::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + +inline void Client::set_basic_auth(const char *username, const char *password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const char *username, + const char *password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void Client::set_follow_location(bool on) { follow_location_ = on; } + +inline void Client::set_compress(bool on) { compress_ = on; } + +inline void Client::set_interface(const char *intf) { interface_ = intf; } + +inline void Client::set_proxy(const char *host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void Client::set_proxy_basic_auth(const char *username, + const char *password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const char *username, + const char *password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} +#endif + +inline void Client::set_logger(Logger logger) { logger_ = std::move(logger); } + /* * SSL Implementation */ @@ -2152,193 +4307,446 @@ inline std::shared_ptr Client::Options(const char* path, const Headers namespace detail { template -inline bool read_and_close_socket_ssl( - socket_t sock, size_t keep_alive_max_count, - // TODO: OpenSSL 1.0.2 occasionally crashes... - // The upcoming 1.1.0 is going to be thread safe. - SSL_CTX* ctx, std::mutex& ctx_mutex, - U SSL_connect_or_accept, V setup, - T callback) -{ - SSL* ssl = nullptr; - { - std::lock_guard guard(ctx_mutex); +inline bool process_and_close_socket_ssl( + bool is_client_request, socket_t sock, size_t keep_alive_max_count, + time_t read_timeout_sec, time_t read_timeout_usec, SSL_CTX *ctx, + std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup, T callback) { + assert(keep_alive_max_count > 0); - ssl = SSL_new(ctx); - if (!ssl) { - return false; - } - } + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } - auto bio = BIO_new_socket(sock, BIO_NOCLOSE); - SSL_set_bio(ssl, bio, bio); + if (!ssl) { + close_socket(sock); + return false; + } - setup(ssl); - - SSL_connect_or_accept(ssl); - - bool ret = false; - - if (keep_alive_max_count > 0) { - auto count = keep_alive_max_count; - while (count > 0 && - detail::select_read(sock, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { - SSLSocketStream strm(sock, ssl); - auto last_connection = count == 1; - auto connection_close = false; - - ret = callback(strm, last_connection, connection_close); - if (!ret || connection_close) { - break; - } - - count--; - } - } else { - SSLSocketStream strm(sock, ssl); - auto dummy_connection_close = false; - ret = callback(strm, true, dummy_connection_close); - } + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); + if (!setup(ssl)) { SSL_shutdown(ssl); - { - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); } close_socket(sock); + return false; + } - return ret; + auto ret = false; + + if (SSL_connect_or_accept(ssl) == 1) { + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(ssl, strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(ssl, strm, true, dummy_connection_close); + } + } + + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + + return ret; } +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static std::shared_ptr> openSSL_locks_; + +class SSLThreadLocks { +public: + SSLThreadLocks() { + openSSL_locks_ = + std::make_shared>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } + + ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); } + +private: + static void locking_callback(int mode, int type, const char * /*file*/, + int /*line*/) { + auto &locks = *openSSL_locks_; + if (mode & CRYPTO_LOCK) { + locks[type].lock(); + } else { + locks[type].unlock(); + } + } +}; + +#endif + class SSLInit { public: - SSLInit() { - SSL_load_error_strings(); - SSL_library_init(); - } + SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + SSL_load_error_strings(); + SSL_library_init(); +#else + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); +#endif + } + + ~SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + ERR_free_strings(); +#endif + } + +private: +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSLThreadLocks thread_init_; +#endif }; +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, + time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + +inline SSLSocketStream::~SSLSocketStream() {} + +inline bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SSLSocketStream::is_writable() const { + return detail::select_write(sock_, 0, 0) > 0; +} + +inline int SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0 || + select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } + return -1; +} + +inline int SSLSocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { return SSL_write(ssl_, ptr, static_cast(size)); } + return -1; +} + +inline std::string SSLSocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); +} + static SSLInit sslinit_; } // namespace detail -// SSL socket stream implementation -inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL* ssl) - : sock_(sock), ssl_(ssl) -{ -} - -inline SSLSocketStream::~SSLSocketStream() -{ -} - -inline int SSLSocketStream::read(char* ptr, size_t size) -{ - return SSL_read(ssl_, ptr, size); -} - -inline int SSLSocketStream::write(const char* ptr, size_t size) -{ - return SSL_write(ssl_, ptr, size); -} - -inline int SSLSocketStream::write(const char* ptr) -{ - return write(ptr, strlen(ptr)); -} - -inline std::string SSLSocketStream::get_remote_addr() { - return detail::get_remote_addr(sock_); -} - // SSL HTTP server implementation -inline SSLServer::SSLServer(const char* cert_path, const char* private_key_path) -{ - ctx_ = SSL_CTX_new(SSLv23_server_method()); +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path) { + ctx_ = SSL_CTX_new(SSLv23_server_method()); - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); - // EC_KEY_free(ecdh); + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); - if (SSL_CTX_use_certificate_file(ctx_, cert_path, SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); } + } } -inline SSLServer::~SSLServer() -{ - if (ctx_) { - SSL_CTX_free(ctx_); - } +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } } -inline bool SSLServer::is_valid() const -{ - return ctx_; -} +inline bool SSLServer::is_valid() const { return ctx_; } -inline bool SSLServer::read_and_close_socket(socket_t sock) -{ - return detail::read_and_close_socket_ssl( - sock, - keep_alive_max_count_, - ctx_, ctx_mutex_, - SSL_accept, - [](SSL* /*ssl*/) {}, - [this](Stream& strm, bool last_connection, bool& connection_close) { - return process_request(strm, last_connection, connection_close); - }); +inline bool SSLServer::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket_ssl( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; }, + [this](SSL *ssl, Stream &strm, bool last_connection, + bool &connection_close) { + return process_request(strm, last_connection, connection_close, + [&](Request &req) { req.ssl = ssl; }); + }); } // SSL HTTP client implementation -inline SSLClient::SSLClient(const char* host, int port, size_t timeout_sec) - : Client(host, port, timeout_sec) -{ - ctx_ = SSL_CTX_new(SSLv23_client_method()); -} +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : Client(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(SSLv23_client_method()); -inline SSLClient::~SSLClient() -{ - if (ctx_) { - SSL_CTX_free(ctx_); + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; } + } } -inline bool SSLClient::is_valid() const -{ - return ctx_; +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } } -inline bool SSLClient::read_and_close_socket(socket_t sock, Request& req, Response& res) -{ - return is_valid() && detail::read_and_close_socket_ssl( - sock, 0, - ctx_, ctx_mutex_, - SSL_connect, - [&](SSL* ssl) { - SSL_set_tlsext_host_name(ssl, host_.c_str()); - }, - [&](Stream& strm, bool /*last_connection*/, bool& connection_close) { - return process_request(strm, req, res, connection_close); - }); +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path) { + if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } + if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } +} + +inline void SSLClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const noexcept { return ctx_; } + +inline bool SSLClient::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + + request_count = std::min(request_count, keep_alive_max_count_); + + return is_valid() && + detail::process_and_close_socket_ssl( + true, sock, request_count, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, + [&](SSL *ssl) { + if (ca_cert_file_path_.empty()) { + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); + } else { + if (!SSL_CTX_load_verify_locations( + ctx_, ca_cert_file_path_.c_str(), nullptr)) { + return false; + } + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + } + + if (SSL_connect(ssl) != 1) { return false; } + + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl); + + if (verify_result_ != X509_V_OK) { return false; } + + auto server_cert = SSL_get_peer_certificate(ssl); + + if (server_cert == nullptr) { return false; } + + if (!verify_host(server_cert)) { + X509_free(server_cert); + return false; + } + X509_free(server_cert); + } + + return true; + }, + [&](SSL *ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }, + [&](SSL * /*ssl*/, Stream &strm, bool last_connection, + bool &connection_close) { + return callback(strm, last_connection, connection_close); + }); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (auto i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; + } + } + } + } + + if (dsn_matched || ip_mached) { ret = true; } + } + + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { return check_host_name(name, name_len); } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; } #endif +// ---------------------------------------------------------------------------- + } // namespace httplib -#endif - -// vim: et ts=4 sw=4 cin cino={1s ff=unix +#endif // CPPHTTPLIB_HTTPLIB_H \ No newline at end of file diff --git a/src/audio_core/CMakeLists.txt b/src/audio_core/CMakeLists.txt index f2b3e1f3b..2caed1233 100644 --- a/src/audio_core/CMakeLists.txt +++ b/src/audio_core/CMakeLists.txt @@ -62,6 +62,13 @@ elseif(ENABLE_FFMPEG_AUDIO_DECODER) target_include_directories(audio_core PRIVATE ${FFMPEG_DIR}/include) endif() target_compile_definitions(audio_core PUBLIC HAVE_FFMPEG) +elseif(ENABLE_FDK) + target_sources(audio_core PRIVATE + hle/fdk_decoder.cpp + hle/fdk_decoder.h + ) + target_link_libraries(audio_core PRIVATE ${FDK_AAC}) + target_compile_definitions(audio_core PUBLIC HAVE_FDK) endif() if(SDL2_FOUND) diff --git a/src/audio_core/hle/fdk_decoder.cpp b/src/audio_core/hle/fdk_decoder.cpp new file mode 100644 index 000000000..c99e3d43c --- /dev/null +++ b/src/audio_core/hle/fdk_decoder.cpp @@ -0,0 +1,233 @@ +// Copyright 2019 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include +#include "audio_core/hle/fdk_decoder.h" + +namespace AudioCore::HLE { + +class FDKDecoder::Impl { +public: + explicit Impl(Memory::MemorySystem& memory); + ~Impl(); + std::optional ProcessRequest(const BinaryRequest& request); + bool IsValid() const { + return decoder != nullptr; + } + +private: + std::optional Initalize(const BinaryRequest& request); + + std::optional Decode(const BinaryRequest& request); + + void Clear(); + + Memory::MemorySystem& memory; + + HANDLE_AACDECODER decoder = nullptr; +}; + +FDKDecoder::Impl::Impl(Memory::MemorySystem& memory) : memory(memory) { + // allocate an array of LIB_INFO structures + // if we don't pre-fill the whole segment with zeros, when we call `aacDecoder_GetLibInfo` + // it will segfault, upon investigation, there is some code in fdk_aac depends on your initial + // values in this array + LIB_INFO decoder_info[FDK_MODULE_LAST] = {}; + // get library information and fill the struct + if (aacDecoder_GetLibInfo(decoder_info) != 0) { + LOG_ERROR(Audio_DSP, "Failed to retrieve fdk_aac library information!"); + return; + } + // This segment: identify the broken fdk_aac implementation + // and refuse to initialize if identified as broken (check for module IDs) + // although our AAC samples do not contain SBC feature, this is a way to detect + // watered down version of fdk_aac implementations + if (FDKlibInfo_getCapabilities(decoder_info, FDK_SBRDEC) == 0) { + LOG_ERROR(Audio_DSP, "Bad fdk_aac library found! Initialization aborted!"); + return; + } + + LOG_INFO(Audio_DSP, "Using fdk_aac version {} (build date: {})", decoder_info[0].versionStr, + decoder_info[0].build_date); + + // choose the input format when initializing: 1 layer of ADTS + decoder = aacDecoder_Open(TRANSPORT_TYPE::TT_MP4_ADTS, 1); + // set maximum output channel to two (stereo) + // if the input samples have more channels, fdk_aac will perform a downmix + AAC_DECODER_ERROR ret = aacDecoder_SetParam(decoder, AAC_PCM_MAX_OUTPUT_CHANNELS, 2); + if (ret != AAC_DEC_OK) { + // unable to set this parameter reflects the decoder implementation might be broken + // we'd better shuts down everything + aacDecoder_Close(decoder); + decoder = nullptr; + LOG_ERROR(Audio_DSP, "Unable to set downmix parameter: {}", ret); + return; + } +} + +std::optional FDKDecoder::Impl::Initalize(const BinaryRequest& request) { + BinaryResponse response; + std::memcpy(&response, &request, sizeof(response)); + response.unknown1 = 0x0; + + if (decoder) { + LOG_INFO(Audio_DSP, "FDK Decoder initialized"); + Clear(); + } else { + LOG_ERROR(Audio_DSP, "Decoder not initialized"); + } + + return response; +} + +FDKDecoder::Impl::~Impl() { + if (decoder) + aacDecoder_Close(decoder); +} + +void FDKDecoder::Impl::Clear() { + s16 decoder_output[8192]; + // flush and re-sync the decoder, discarding the internal buffer + // we actually don't care if this succeeds or not + // FLUSH - flush internal buffer + // INTR - treat the current internal buffer as discontinuous + // CONCEAL - try to interpolate and smooth out the samples + if (decoder) + aacDecoder_DecodeFrame(decoder, decoder_output, 8192, + AACDEC_FLUSH & AACDEC_INTR & AACDEC_CONCEAL); +} + +std::optional FDKDecoder::Impl::ProcessRequest(const BinaryRequest& request) { + if (request.codec != DecoderCodec::AAC) { + LOG_ERROR(Audio_DSP, "FDK AAC Decoder cannot handle such codec: {}", + static_cast(request.codec)); + return {}; + } + + switch (request.cmd) { + case DecoderCommand::Init: { + return Initalize(request); + } + case DecoderCommand::Decode: { + return Decode(request); + } + case DecoderCommand::Unknown: { + BinaryResponse response; + std::memcpy(&response, &request, sizeof(response)); + response.unknown1 = 0x0; + return response; + } + default: + LOG_ERROR(Audio_DSP, "Got unknown binary request: {}", static_cast(request.cmd)); + return {}; + } +} + +std::optional FDKDecoder::Impl::Decode(const BinaryRequest& request) { + BinaryResponse response; + response.codec = request.codec; + response.cmd = request.cmd; + response.size = request.size; + + if (!decoder) { + LOG_DEBUG(Audio_DSP, "Decoder not initalized"); + // This is a hack to continue games that are not compiled with the aac codec + response.num_channels = 2; + response.num_samples = 1024; + return response; + } + + if (request.src_addr < Memory::FCRAM_PADDR || + request.src_addr + request.size > Memory::FCRAM_PADDR + Memory::FCRAM_SIZE) { + LOG_ERROR(Audio_DSP, "Got out of bounds src_addr {:08x}", request.src_addr); + return {}; + } + u8* data = memory.GetFCRAMPointer(request.src_addr - Memory::FCRAM_PADDR); + + std::array, 2> out_streams; + + std::size_t data_size = request.size; + + // decoding loops + AAC_DECODER_ERROR result = AAC_DEC_OK; + // 8192 units of s16 are enough to hold one frame of AAC-LC or AAC-HE/v2 data + s16 decoder_output[8192]; + // note that we don't free this pointer as it is automatically freed by fdk_aac + CStreamInfo* stream_info; + // how many bytes to be queued into the decoder, decrementing from the buffer size + u32 buffer_remaining = data_size; + // alias the data_size as an u32 + u32 input_size = data_size; + + while (buffer_remaining) { + // queue the input buffer, fdk_aac will automatically slice out the buffer it needs + // from the input buffer + result = aacDecoder_Fill(decoder, &data, &input_size, &buffer_remaining); + if (result != AAC_DEC_OK) { + // there are some issues when queuing the input buffer + LOG_ERROR(Audio_DSP, "Failed to enqueue the input samples"); + return std::nullopt; + } + // get output from decoder + result = aacDecoder_DecodeFrame(decoder, decoder_output, 8192, 0); + if (result == AAC_DEC_OK) { + // get the stream information + stream_info = aacDecoder_GetStreamInfo(decoder); + // fill the stream information for binary response + response.num_channels = stream_info->aacNumChannels; + response.num_samples = stream_info->frameSize; + // fill the output + // the sample size = frame_size * channel_counts + for (int sample = 0; sample < (stream_info->frameSize * 2); sample++) { + for (int ch = 0; ch < stream_info->aacNumChannels; ch++) { + out_streams[ch].push_back(decoder_output[(sample * 2) + 1]); + } + } + } else if (result == AAC_DEC_TRANSPORT_SYNC_ERROR) { + // decoder has some synchronization problems, try again with new samples, + // using old samples might trigger this error again + continue; + } else { + LOG_ERROR(Audio_DSP, "Error decoding the sample: {}", result); + return std::nullopt; + } + } + // transfer the decoded buffer from vector to the FCRAM + if (out_streams[0].size() != 0) { + if (request.dst_addr_ch0 < Memory::FCRAM_PADDR || + request.dst_addr_ch0 + out_streams[0].size() > + Memory::FCRAM_PADDR + Memory::FCRAM_SIZE) { + LOG_ERROR(Audio_DSP, "Got out of bounds dst_addr_ch0 {:08x}", request.dst_addr_ch0); + return {}; + } + std::memcpy(memory.GetFCRAMPointer(request.dst_addr_ch0 - Memory::FCRAM_PADDR), + out_streams[0].data(), out_streams[0].size()); + } + + if (out_streams[1].size() != 0) { + if (request.dst_addr_ch1 < Memory::FCRAM_PADDR || + request.dst_addr_ch1 + out_streams[1].size() > + Memory::FCRAM_PADDR + Memory::FCRAM_SIZE) { + LOG_ERROR(Audio_DSP, "Got out of bounds dst_addr_ch1 {:08x}", request.dst_addr_ch1); + return {}; + } + std::memcpy(memory.GetFCRAMPointer(request.dst_addr_ch1 - Memory::FCRAM_PADDR), + out_streams[1].data(), out_streams[1].size()); + } + return response; +} + +FDKDecoder::FDKDecoder(Memory::MemorySystem& memory) : impl(std::make_unique(memory)) {} + +FDKDecoder::~FDKDecoder() = default; + +std::optional FDKDecoder::ProcessRequest(const BinaryRequest& request) { + return impl->ProcessRequest(request); +} + +bool FDKDecoder::IsValid() const { + return impl->IsValid(); +} + +} // namespace AudioCore::HLE diff --git a/src/audio_core/hle/fdk_decoder.h b/src/audio_core/hle/fdk_decoder.h new file mode 100644 index 000000000..337c6054a --- /dev/null +++ b/src/audio_core/hle/fdk_decoder.h @@ -0,0 +1,23 @@ +// Copyright 2019 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include "audio_core/hle/decoder.h" + +namespace AudioCore::HLE { + +class FDKDecoder final : public DecoderBase { +public: + explicit FDKDecoder(Memory::MemorySystem& memory); + ~FDKDecoder() override; + std::optional ProcessRequest(const BinaryRequest& request) override; + bool IsValid() const override; + +private: + class Impl; + std::unique_ptr impl; +}; + +} // namespace AudioCore::HLE diff --git a/src/audio_core/hle/hle.cpp b/src/audio_core/hle/hle.cpp index f96e2b642..87af17ba7 100644 --- a/src/audio_core/hle/hle.cpp +++ b/src/audio_core/hle/hle.cpp @@ -13,6 +13,8 @@ #include "audio_core/hle/wmf_decoder.h" #elif HAVE_FFMPEG #include "audio_core/hle/ffmpeg_decoder.h" +#elif HAVE_FDK +#include "audio_core/hle/fdk_decoder.h" #endif #include "audio_core/hle/common.h" #include "audio_core/hle/decoder.h" @@ -124,6 +126,8 @@ DspHle::Impl::Impl(DspHle& parent_, Memory::MemorySystem& memory) : parent(paren decoder = std::make_unique(memory); #elif defined(HAVE_FFMPEG) decoder = std::make_unique(memory); +#elif defined(HAVE_FDK) + decoder = std::make_unique(memory); #else LOG_WARNING(Audio_DSP, "No decoder found, this could lead to missing audio"); decoder = std::make_unique(); diff --git a/src/citra_qt/configuration/configure_general.cpp b/src/citra_qt/configuration/configure_general.cpp index e4b6be073..47d559e3d 100644 --- a/src/citra_qt/configuration/configure_general.cpp +++ b/src/citra_qt/configuration/configure_general.cpp @@ -25,10 +25,6 @@ ConfigureGeneral::ConfigureGeneral(QWidget* parent) ConfigureGeneral::~ConfigureGeneral() = default; void ConfigureGeneral::SetConfiguration() { - ui->toggle_frame_limit->setChecked(Settings::values.use_frame_limit); - ui->frame_limit->setEnabled(ui->toggle_frame_limit->isChecked()); - ui->frame_limit->setValue(Settings::values.frame_limit); - ui->toggle_check_exit->setChecked(UISettings::values.confirm_before_closing); ui->toggle_background_pause->setChecked(UISettings::values.pause_when_in_background); @@ -57,9 +53,6 @@ void ConfigureGeneral::ResetDefaults() { } void ConfigureGeneral::ApplyConfiguration() { - Settings::values.use_frame_limit = ui->toggle_frame_limit->isChecked(); - Settings::values.frame_limit = ui->frame_limit->value(); - UISettings::values.confirm_before_closing = ui->toggle_check_exit->isChecked(); UISettings::values.pause_when_in_background = ui->toggle_background_pause->isChecked(); diff --git a/src/citra_qt/debugger/ipc/recorder.cpp b/src/citra_qt/debugger/ipc/recorder.cpp index 24128e38a..ef634d750 100644 --- a/src/citra_qt/debugger/ipc/recorder.cpp +++ b/src/citra_qt/debugger/ipc/recorder.cpp @@ -114,7 +114,7 @@ void IPCRecorderWidget::SetEnabled(bool enabled) { } void IPCRecorderWidget::Clear() { - id_offset = records.size() + 1; + id_offset += records.size(); records.clear(); ui->main->invisibleRootItem()->takeChildren(); diff --git a/src/citra_qt/debugger/registers.cpp b/src/citra_qt/debugger/registers.cpp index f689708ad..eb4d3d6dc 100644 --- a/src/citra_qt/debugger/registers.cpp +++ b/src/citra_qt/debugger/registers.cpp @@ -61,13 +61,14 @@ void RegistersWidget::OnDebugModeEntered() { if (!Core::System::GetInstance().IsPoweredOn()) return; + // Todo: Handle all cores for (int i = 0; i < core_registers->childCount(); ++i) core_registers->child(i)->setText( - 1, QStringLiteral("0x%1").arg(Core::CPU().GetReg(i), 8, 16, QLatin1Char('0'))); + 1, QStringLiteral("0x%1").arg(Core::GetCore(0).GetReg(i), 8, 16, QLatin1Char('0'))); for (int i = 0; i < vfp_registers->childCount(); ++i) vfp_registers->child(i)->setText( - 1, QStringLiteral("0x%1").arg(Core::CPU().GetVFPReg(i), 8, 16, QLatin1Char('0'))); + 1, QStringLiteral("0x%1").arg(Core::GetCore(0).GetVFPReg(i), 8, 16, QLatin1Char('0'))); UpdateCPSRValues(); UpdateVFPSystemRegisterValues(); @@ -127,7 +128,8 @@ void RegistersWidget::CreateCPSRChildren() { } void RegistersWidget::UpdateCPSRValues() { - const u32 cpsr_val = Core::CPU().GetCPSR(); + // Todo: Handle all cores + const u32 cpsr_val = Core::GetCore(0).GetCPSR(); cpsr->setText(1, QStringLiteral("0x%1").arg(cpsr_val, 8, 16, QLatin1Char('0'))); cpsr->child(0)->setText( @@ -191,10 +193,11 @@ void RegistersWidget::CreateVFPSystemRegisterChildren() { } void RegistersWidget::UpdateVFPSystemRegisterValues() { - const u32 fpscr_val = Core::CPU().GetVFPSystemReg(VFP_FPSCR); - const u32 fpexc_val = Core::CPU().GetVFPSystemReg(VFP_FPEXC); - const u32 fpinst_val = Core::CPU().GetVFPSystemReg(VFP_FPINST); - const u32 fpinst2_val = Core::CPU().GetVFPSystemReg(VFP_FPINST2); + // Todo: handle all cores + const u32 fpscr_val = Core::GetCore(0).GetVFPSystemReg(VFP_FPSCR); + const u32 fpexc_val = Core::GetCore(0).GetVFPSystemReg(VFP_FPEXC); + const u32 fpinst_val = Core::GetCore(0).GetVFPSystemReg(VFP_FPINST); + const u32 fpinst2_val = Core::GetCore(0).GetVFPSystemReg(VFP_FPINST2); QTreeWidgetItem* const fpscr = vfp_system_registers->child(0); fpscr->setText(1, QStringLiteral("0x%1").arg(fpscr_val, 8, 16, QLatin1Char('0'))); diff --git a/src/citra_qt/debugger/wait_tree.cpp b/src/citra_qt/debugger/wait_tree.cpp index 019153576..15a7ada03 100644 --- a/src/citra_qt/debugger/wait_tree.cpp +++ b/src/citra_qt/debugger/wait_tree.cpp @@ -12,6 +12,7 @@ #include "core/hle/kernel/thread.h" #include "core/hle/kernel/timer.h" #include "core/hle/kernel/wait_object.h" +#include "core/settings.h" WaitTreeItem::~WaitTreeItem() = default; @@ -51,12 +52,16 @@ std::size_t WaitTreeItem::Row() const { } std::vector> WaitTreeItem::MakeThreadItemList() { - const auto& threads = Core::System::GetInstance().Kernel().GetThreadManager().GetThreadList(); + u32 num_cores = Core::GetNumCores(); std::vector> item_list; - item_list.reserve(threads.size()); - for (std::size_t i = 0; i < threads.size(); ++i) { - item_list.push_back(std::make_unique(*threads[i])); - item_list.back()->row = i; + for (u32 i = 0; i < num_cores; ++i) { + const auto& threads = + Core::System::GetInstance().Kernel().GetThreadManager(i).GetThreadList(); + item_list.reserve(item_list.size() + threads.size()); + for (std::size_t i = 0; i < threads.size(); ++i) { + item_list.push_back(std::make_unique(*threads[i])); + item_list.back()->row = i; + } } return item_list; } diff --git a/src/citra_qt/game_list.cpp b/src/citra_qt/game_list.cpp index 2d5af1b0c..8a5df5172 100644 --- a/src/citra_qt/game_list.cpp +++ b/src/citra_qt/game_list.cpp @@ -468,6 +468,8 @@ void GameList::AddGamePopup(QMenu& context_menu, const QString& path, u64 progra QAction* open_texture_dump_location = context_menu.addAction(tr("Open Texture Dump Location")); QAction* open_texture_load_location = context_menu.addAction(tr("Open Custom Texture Location")); + QAction* open_mods_location = context_menu.addAction(tr("Open Mods Location")); + QAction* dump_romfs = context_menu.addAction(tr("Dump RomFS")); QAction* navigate_to_gamedb_entry = context_menu.addAction(tr("Navigate to GameDB entry")); const bool is_application = @@ -497,6 +499,8 @@ void GameList::AddGamePopup(QMenu& context_menu, const QString& path, u64 progra open_texture_dump_location->setVisible(is_application); open_texture_load_location->setVisible(is_application); + open_mods_location->setVisible(is_application); + dump_romfs->setVisible(is_application); navigate_to_gamedb_entry->setVisible(it != compatibility_list.end()); @@ -526,6 +530,15 @@ void GameList::AddGamePopup(QMenu& context_menu, const QString& path, u64 progra emit OpenFolderRequested(program_id, GameListOpenTarget::TEXTURE_LOAD); } }); + connect(open_mods_location, &QAction::triggered, [this, program_id] { + if (FileUtil::CreateFullPath(fmt::format("{}mods/{:016X}/", + FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + program_id))) { + emit OpenFolderRequested(program_id, GameListOpenTarget::MODS); + } + }); + connect(dump_romfs, &QAction::triggered, + [this, path, program_id] { emit DumpRomFSRequested(path, program_id); }); connect(navigate_to_gamedb_entry, &QAction::triggered, [this, program_id]() { emit NavigateToGamedbEntryRequested(program_id, compatibility_list); }); diff --git a/src/citra_qt/game_list.h b/src/citra_qt/game_list.h index ef280ef04..334089037 100644 --- a/src/citra_qt/game_list.h +++ b/src/citra_qt/game_list.h @@ -35,7 +35,8 @@ enum class GameListOpenTarget { APPLICATION = 2, UPDATE_DATA = 3, TEXTURE_DUMP = 4, - TEXTURE_LOAD = 5 + TEXTURE_LOAD = 5, + MODS = 6, }; class GameList : public QWidget { @@ -81,6 +82,7 @@ signals: void OpenFolderRequested(u64 program_id, GameListOpenTarget target); void NavigateToGamedbEntryRequested(u64 program_id, const CompatibilityList& compatibility_list); + void DumpRomFSRequested(QString game_path, u64 program_id); void OpenDirectory(const QString& directory); void AddDirectory(); void ShowList(bool show); diff --git a/src/citra_qt/main.cpp b/src/citra_qt/main.cpp index de66e14ac..b73877440 100644 --- a/src/citra_qt/main.cpp +++ b/src/citra_qt/main.cpp @@ -597,6 +597,7 @@ void GMainWindow::ConnectWidgetEvents() { connect(game_list, &GameList::OpenFolderRequested, this, &GMainWindow::OnGameListOpenFolder); connect(game_list, &GameList::NavigateToGamedbEntryRequested, this, &GMainWindow::OnGameListNavigateToGamedbEntry); + connect(game_list, &GameList::DumpRomFSRequested, this, &GMainWindow::OnGameListDumpRomFS); connect(game_list, &GameList::AddDirectory, this, &GMainWindow::OnGameListAddDirectory); connect(game_list_placeholder, &GameListPlaceholder::AddDirectory, this, &GMainWindow::OnGameListAddDirectory); @@ -1231,6 +1232,11 @@ void GMainWindow::OnGameListOpenFolder(u64 data_id, GameListOpenTarget target) { path = fmt::format("{}textures/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), data_id); break; + case GameListOpenTarget::MODS: + open_target = "Mods"; + path = fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + data_id); + break; default: LOG_ERROR(Frontend, "Unexpected target {}", static_cast(target)); return; @@ -1262,6 +1268,46 @@ void GMainWindow::OnGameListNavigateToGamedbEntry(u64 program_id, QDesktopServices::openUrl(QUrl(QStringLiteral("https://citra-emu.org/game/") + directory)); } +void GMainWindow::OnGameListDumpRomFS(QString game_path, u64 program_id) { + auto* dialog = new QProgressDialog(tr("Dumping..."), tr("Cancel"), 0, 0, this); + dialog->setWindowModality(Qt::WindowModal); + dialog->setWindowFlags(dialog->windowFlags() & + ~(Qt::WindowCloseButtonHint | Qt::WindowContextHelpButtonHint)); + dialog->setCancelButton(nullptr); + dialog->setMinimumDuration(0); + dialog->setValue(0); + + const auto base_path = fmt::format( + "{}romfs/{:016X}", FileUtil::GetUserPath(FileUtil::UserPath::DumpDir), program_id); + const auto update_path = + fmt::format("{}romfs/{:016X}", FileUtil::GetUserPath(FileUtil::UserPath::DumpDir), + program_id | 0x0004000e00000000); + using FutureWatcher = QFutureWatcher>; + auto* future_watcher = new FutureWatcher(this); + connect(future_watcher, &FutureWatcher::finished, + [this, program_id, dialog, base_path, update_path, future_watcher] { + dialog->hide(); + const auto& [base, update] = future_watcher->result(); + if (base != Loader::ResultStatus::Success) { + QMessageBox::critical( + this, tr("Citra"), + tr("Could not dump base RomFS.\nRefer to the log for details.")); + return; + } + QDesktopServices::openUrl(QUrl::fromLocalFile(QString::fromStdString(base_path))); + if (update == Loader::ResultStatus::Success) { + QDesktopServices::openUrl( + QUrl::fromLocalFile(QString::fromStdString(update_path))); + } + }); + + auto future = QtConcurrent::run([game_path, base_path, update_path] { + std::unique_ptr loader = Loader::GetLoader(game_path.toStdString()); + return std::make_pair(loader->DumpRomFS(base_path), loader->DumpUpdateRomFS(update_path)); + }); + future_watcher->setFuture(future); +} + void GMainWindow::OnGameListOpenDirectory(const QString& directory) { QString path; if (directory == QStringLiteral("INSTALLED")) { diff --git a/src/citra_qt/main.h b/src/citra_qt/main.h index ebe1a013a..63979c6b5 100644 --- a/src/citra_qt/main.h +++ b/src/citra_qt/main.h @@ -176,6 +176,7 @@ private slots: void OnGameListOpenFolder(u64 program_id, GameListOpenTarget target); void OnGameListNavigateToGamedbEntry(u64 program_id, const CompatibilityList& compatibility_list); + void OnGameListDumpRomFS(QString game_path, u64 program_id); void OnGameListOpenDirectory(const QString& directory); void OnGameListAddDirectory(); void OnGameListShowList(bool show); diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 2b02ffc41..9a1e3f783 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -72,6 +72,8 @@ add_library(core STATIC file_sys/delay_generator.h file_sys/ivfc_archive.cpp file_sys/ivfc_archive.h + file_sys/layered_fs.cpp + file_sys/layered_fs.h file_sys/ncch_container.cpp file_sys/ncch_container.h file_sys/patch.cpp @@ -469,9 +471,17 @@ create_target_directory_groups(core) target_link_libraries(core PUBLIC common PRIVATE audio_core network video_core) target_link_libraries(core PUBLIC Boost::boost PRIVATE cryptopp fmt open_source_archives Boost::serialization) + if (ENABLE_WEB_SERVICE) - target_compile_definitions(core PRIVATE -DENABLE_WEB_SERVICE) - target_link_libraries(core PRIVATE web_service) + get_directory_property(OPENSSL_LIBS + DIRECTORY ${PROJECT_SOURCE_DIR}/externals/libressl + DEFINITION OPENSSL_LIBS) + + target_compile_definitions(core PRIVATE -DENABLE_WEB_SERVICE -DCPPHTTPLIB_OPENSSL_SUPPORT) + target_link_libraries(core PRIVATE web_service ${OPENSSL_LIBS} httplib lurlparser) + if (ANDROID) + target_link_libraries(core PRIVATE ifaddrs) + endif() endif() if (ARCHITECTURE_x86_64) diff --git a/src/core/arm/arm_interface.h b/src/core/arm/arm_interface.h index 8ef51519b..b02e3941f 100644 --- a/src/core/arm/arm_interface.h +++ b/src/core/arm/arm_interface.h @@ -10,6 +10,7 @@ #include "common/common_types.h" #include "core/arm/skyeye_common/arm_regformat.h" #include "core/arm/skyeye_common/vfp/asm_vfp.h" +#include "core/core_timing.h" namespace Memory { struct PageTable; @@ -18,6 +19,8 @@ struct PageTable; /// Generic ARM11 CPU interface class ARM_Interface : NonCopyable { public: + explicit ARM_Interface(u32 id, std::shared_ptr timer) + : timer(timer), id(id){}; virtual ~ARM_Interface() {} class ThreadContext { @@ -222,11 +225,26 @@ public: virtual void PurgeState() = 0; + std::shared_ptr GetTimer() { + return timer; + } + + u32 GetID() const { + return id; + } + +protected: + std::shared_ptr timer; + private: + u32 id; + friend class boost::serialization::access; template void save(Archive& ar, const unsigned int file_version) const { + ar << timer; + ar << id; auto page_table = GetPageTable(); ar << page_table; for (auto i = 0; i < 15; i++) { @@ -254,6 +272,8 @@ private: template void load(Archive& ar, const unsigned int file_version) { PurgeState(); + ar >> timer; + ar >> id; std::shared_ptr page_table = nullptr; ar >> page_table; SetPageTable(page_table); diff --git a/src/core/arm/dynarmic/arm_dynarmic.cpp b/src/core/arm/dynarmic/arm_dynarmic.cpp index 33bc03a4f..b85dfdcf6 100644 --- a/src/core/arm/dynarmic/arm_dynarmic.cpp +++ b/src/core/arm/dynarmic/arm_dynarmic.cpp @@ -72,8 +72,7 @@ private: class DynarmicUserCallbacks final : public Dynarmic::A32::UserCallbacks { public: explicit DynarmicUserCallbacks(ARM_Dynarmic& parent) - : parent(parent), timing(parent.system.CoreTiming()), svc_context(parent.system), - memory(parent.memory) {} + : parent(parent), svc_context(parent.system), memory(parent.memory) {} ~DynarmicUserCallbacks() = default; std::uint8_t MemoryRead8(VAddr vaddr) override { @@ -137,7 +136,7 @@ public: parent.jit->HaltExecution(); parent.SetPC(pc); Kernel::Thread* thread = - parent.system.Kernel().GetThreadManager().GetCurrentThread(); + parent.system.Kernel().GetCurrentThreadManager().GetCurrentThread(); parent.SaveContext(thread->context); GDBStub::Break(); GDBStub::SendTrap(thread, 5); @@ -150,22 +149,23 @@ public: } void AddTicks(std::uint64_t ticks) override { - timing.AddTicks(ticks); + parent.GetTimer()->AddTicks(ticks); } std::uint64_t GetTicksRemaining() override { - s64 ticks = timing.GetDowncount(); + s64 ticks = parent.GetTimer()->GetDowncount(); return static_cast(ticks <= 0 ? 0 : ticks); } ARM_Dynarmic& parent; - Core::Timing& timing; Kernel::SVCContext svc_context; Memory::MemorySystem& memory; }; ARM_Dynarmic::ARM_Dynarmic(Core::System* system, Memory::MemorySystem& memory, - PrivilegeMode initial_mode) - : system(*system), memory(memory), cb(std::make_unique(*this)) { + PrivilegeMode initial_mode, u32 id, + std::shared_ptr timer) + : ARM_Interface(id, timer), system(*system), memory(memory), + cb(std::make_unique(*this)) { interpreter_state = std::make_shared(system, memory, initial_mode); SetPageTable(memory.GetCurrentPageTable()); } diff --git a/src/core/arm/dynarmic/arm_dynarmic.h b/src/core/arm/dynarmic/arm_dynarmic.h index c4d01835d..a8f224083 100644 --- a/src/core/arm/dynarmic/arm_dynarmic.h +++ b/src/core/arm/dynarmic/arm_dynarmic.h @@ -24,7 +24,8 @@ class DynarmicUserCallbacks; class ARM_Dynarmic final : public ARM_Interface { public: - ARM_Dynarmic(Core::System* system, Memory::MemorySystem& memory, PrivilegeMode initial_mode); + ARM_Dynarmic(Core::System* system, Memory::MemorySystem& memory, PrivilegeMode initial_mode, + u32 id, std::shared_ptr timer); ~ARM_Dynarmic() override; void Run() override; diff --git a/src/core/arm/dyncom/arm_dyncom.cpp b/src/core/arm/dyncom/arm_dyncom.cpp index c069b428e..099a5c06a 100644 --- a/src/core/arm/dyncom/arm_dyncom.cpp +++ b/src/core/arm/dyncom/arm_dyncom.cpp @@ -69,8 +69,9 @@ private: }; ARM_DynCom::ARM_DynCom(Core::System* system, Memory::MemorySystem& memory, - PrivilegeMode initial_mode) - : system(system) { + PrivilegeMode initial_mode, u32 id, + std::shared_ptr timer) + : ARM_Interface(id, timer), system(system) { state = std::make_unique(system, memory, initial_mode); } @@ -78,7 +79,7 @@ ARM_DynCom::~ARM_DynCom() {} void ARM_DynCom::Run() { DEBUG_ASSERT(system != nullptr); - ExecuteInstructions(std::max(system->CoreTiming().GetDowncount(), 0)); + ExecuteInstructions(std::max(timer->GetDowncount(), 0)); } void ARM_DynCom::Step() { @@ -156,7 +157,7 @@ void ARM_DynCom::ExecuteInstructions(u64 num_instructions) { state->NumInstrsToExecute = num_instructions; unsigned ticks_executed = InterpreterMainLoop(state.get()); if (system != nullptr) { - system->CoreTiming().AddTicks(ticks_executed); + timer->AddTicks(ticks_executed); } state->ServeBreak(); } diff --git a/src/core/arm/dyncom/arm_dyncom.h b/src/core/arm/dyncom/arm_dyncom.h index 39d55a62a..91cbcded0 100644 --- a/src/core/arm/dyncom/arm_dyncom.h +++ b/src/core/arm/dyncom/arm_dyncom.h @@ -21,7 +21,8 @@ class MemorySystem; class ARM_DynCom final : public ARM_Interface { public: explicit ARM_DynCom(Core::System* system, Memory::MemorySystem& memory, - PrivilegeMode initial_mode); + PrivilegeMode initial_mode, u32 id, + std::shared_ptr timer); ~ARM_DynCom() override; void Run() override; diff --git a/src/core/arm/dyncom/arm_dyncom_interpreter.cpp b/src/core/arm/dyncom/arm_dyncom_interpreter.cpp index ba4a15b0c..706e0092b 100644 --- a/src/core/arm/dyncom/arm_dyncom_interpreter.cpp +++ b/src/core/arm/dyncom/arm_dyncom_interpreter.cpp @@ -3865,7 +3865,7 @@ SWI_INST : { if (inst_base->cond == ConditionCode::AL || CondPassed(cpu, inst_base->cond)) { DEBUG_ASSERT(cpu->system != nullptr); swi_inst* const inst_cream = (swi_inst*)inst_base->component; - cpu->system->CoreTiming().AddTicks(num_instrs); + cpu->system->GetRunningCore().GetTimer()->AddTicks(num_instrs); cpu->NumInstrsToExecute = num_instrs >= cpu->NumInstrsToExecute ? 0 : cpu->NumInstrsToExecute - num_instrs; num_instrs = 0; diff --git a/src/core/arm/skyeye_common/armstate.cpp b/src/core/arm/skyeye_common/armstate.cpp index 26520da76..775618a8b 100644 --- a/src/core/arm/skyeye_common/armstate.cpp +++ b/src/core/arm/skyeye_common/armstate.cpp @@ -607,8 +607,8 @@ void ARMul_State::ServeBreak() { } DEBUG_ASSERT(system != nullptr); - Kernel::Thread* thread = system->Kernel().GetThreadManager().GetCurrentThread(); - system->CPU().SaveContext(thread->context); + Kernel::Thread* thread = system->Kernel().GetCurrentThreadManager().GetCurrentThread(); + system->GetRunningCore().SaveContext(thread->context); if (last_bkpt_hit || GDBStub::IsMemoryBreak() || GDBStub::GetCpuStepFlag()) { last_bkpt_hit = false; diff --git a/src/core/cheats/gateway_cheat.cpp b/src/core/cheats/gateway_cheat.cpp index 15a0e8ca9..cba74bb7c 100644 --- a/src/core/cheats/gateway_cheat.cpp +++ b/src/core/cheats/gateway_cheat.cpp @@ -35,7 +35,7 @@ static inline std::enable_if_t> WriteOp(const GatewayCheat Core::System& system) { u32 addr = line.address + state.offset; write_func(addr, static_cast(line.value)); - system.CPU().InvalidateCacheRange(addr, sizeof(T)); + system.InvalidateCacheRange(addr, sizeof(T)); } template @@ -105,7 +105,7 @@ static inline std::enable_if_t> IncrementiveWriteOp( Core::System& system) { u32 addr = line.value + state.offset; write_func(addr, static_cast(state.reg)); - system.CPU().InvalidateCacheRange(addr, sizeof(T)); + system.InvalidateCacheRange(addr, sizeof(T)); state.offset += sizeof(T); } @@ -143,7 +143,8 @@ static inline void PatchOp(const GatewayCheat::CheatLine& line, State& state, Co } u32 num_bytes = line.value; u32 addr = line.address + state.offset; - system.CPU().InvalidateCacheRange(addr, num_bytes); + system.InvalidateCacheRange(addr, num_bytes); + bool first = true; u32 bit_offset = 0; if (num_bytes > 0) diff --git a/src/core/core.cpp b/src/core/core.cpp index f17474a85..fea1705ed 100644 --- a/src/core/core.cpp +++ b/src/core/core.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include "audio_core/dsp_interface.h" @@ -65,7 +66,8 @@ System::~System() = default; System::ResultStatus System::RunLoop(bool tight_loop) { status = ResultStatus::Success; - if (!cpu_core) { + if (std::any_of(cpu_cores.begin(), cpu_cores.end(), + [](std::shared_ptr ptr) { return ptr == nullptr; })) { return ResultStatus::ErrorNotInitialized; } @@ -83,22 +85,73 @@ System::ResultStatus System::RunLoop(bool tight_loop) { } } - // If we don't have a currently active thread then don't execute instructions, - // instead advance to the next event and try to yield to the next thread - if (kernel->GetThreadManager().GetCurrentThread() == nullptr) { - LOG_TRACE(Core_ARM11, "Idling"); - timing->Idle(); - timing->Advance(); - PrepareReschedule(); - } else { - timing->Advance(); - if (tight_loop) { - cpu_core->Run(); - } else { - cpu_core->Step(); + // All cores should have executed the same amount of ticks. If this is not the case an event was + // scheduled with a cycles_into_future smaller then the current downcount. + // So we have to get those cores to the same global time first + u64 global_ticks = timing->GetGlobalTicks(); + s64 max_delay = 0; + std::shared_ptr current_core_to_execute = nullptr; + for (auto& cpu_core : cpu_cores) { + if (cpu_core->GetTimer()->GetTicks() < global_ticks) { + s64 delay = global_ticks - cpu_core->GetTimer()->GetTicks(); + cpu_core->GetTimer()->Advance(delay); + if (max_delay < delay) { + max_delay = delay; + current_core_to_execute = cpu_core; + } } } + if (max_delay > 0) { + LOG_TRACE(Core_ARM11, "Core {} running (delayed) for {} ticks", + current_core_to_execute->GetID(), + current_core_to_execute->GetTimer()->GetDowncount()); + running_core = current_core_to_execute.get(); + kernel->SetRunningCPU(current_core_to_execute); + if (kernel->GetCurrentThreadManager().GetCurrentThread() == nullptr) { + LOG_TRACE(Core_ARM11, "Core {} idling", current_core_to_execute->GetID()); + current_core_to_execute->GetTimer()->Idle(); + PrepareReschedule(); + } else { + if (tight_loop) { + current_core_to_execute->Run(); + } else { + current_core_to_execute->Step(); + } + } + } else { + // Now all cores are at the same global time. So we will run them one after the other + // with a max slice that is the minimum of all max slices of all cores + // TODO: Make special check for idle since we can easily revert the time of idle cores + s64 max_slice = Timing::MAX_SLICE_LENGTH; + for (const auto& cpu_core : cpu_cores) { + max_slice = std::min(max_slice, cpu_core->GetTimer()->GetMaxSliceLength()); + } + for (auto& cpu_core : cpu_cores) { + cpu_core->GetTimer()->Advance(max_slice); + } + for (auto& cpu_core : cpu_cores) { + LOG_TRACE(Core_ARM11, "Core {} running for {} ticks", cpu_core->GetID(), + cpu_core->GetTimer()->GetDowncount()); + running_core = cpu_core.get(); + kernel->SetRunningCPU(cpu_core); + // If we don't have a currently active thread then don't execute instructions, + // instead advance to the next event and try to yield to the next thread + if (kernel->GetCurrentThreadManager().GetCurrentThread() == nullptr) { + LOG_TRACE(Core_ARM11, "Core {} idling", cpu_core->GetID()); + cpu_core->GetTimer()->Idle(); + PrepareReschedule(); + } else { + if (tight_loop) { + cpu_core->Run(); + } else { + cpu_core->Step(); + } + } + } + timing->AddToGlobalTicks(max_slice); + } + if (GDBStub::IsServerEnabled()) { GDBStub::SetCpuStepFlag(false); } @@ -183,7 +236,9 @@ System::ResultStatus System::Load(Frontend::EmuWindow& emu_window, const std::st } ASSERT(system_mode.first); - ResultStatus init_result{Init(emu_window, *system_mode.first)}; + auto n3ds_mode = app_loader->LoadKernelN3dsMode(); + ASSERT(n3ds_mode.first); + ResultStatus init_result{Init(emu_window, *system_mode.first, *n3ds_mode.first)}; if (init_result != ResultStatus::Success) { LOG_CRITICAL(Core, "Failed to initialize system (Error {})!", static_cast(init_result)); @@ -235,7 +290,7 @@ System::ResultStatus System::Load(Frontend::EmuWindow& emu_window, const std::st } void System::PrepareReschedule() { - cpu_core->PrepareReschedule(); + running_core->PrepareReschedule(); reschedule_pending = true; } @@ -249,31 +304,50 @@ void System::Reschedule() { } reschedule_pending = false; - kernel->GetThreadManager().Reschedule(); + for (const auto& core : cpu_cores) { + LOG_TRACE(Core_ARM11, "Reschedule core {}", core->GetID()); + kernel->GetThreadManager(core->GetID()).Reschedule(); + } } -System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mode) { +System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mode, u8 n3ds_mode) { LOG_DEBUG(HW_Memory, "initialized OK"); + std::size_t num_cores = 2; + if (Settings::values.is_new_3ds) { + num_cores = 4; + } + memory = std::make_unique(); - timing = std::make_unique(); + timing = std::make_unique(num_cores); - kernel = std::make_unique(*memory, *timing, - [this] { PrepareReschedule(); }, system_mode); + kernel = std::make_unique( + *memory, *timing, [this] { PrepareReschedule(); }, system_mode, num_cores, n3ds_mode); if (Settings::values.use_cpu_jit) { #ifdef ARCHITECTURE_x86_64 - cpu_core = std::make_shared(this, *memory, USER32MODE); + for (std::size_t i = 0; i < num_cores; ++i) { + cpu_cores.push_back( + std::make_shared(this, *memory, USER32MODE, i, timing->GetTimer(i))); + } #else - cpu_core = std::make_shared(this, *memory, USER32MODE); + for (std::size_t i = 0; i < num_cores; ++i) { + cpu_cores.push_back( + std::make_shared(this, *memory, USER32MODE, i, timing->GetTimer(i))); + } LOG_WARNING(Core, "CPU JIT requested, but Dynarmic not available"); #endif } else { - cpu_core = std::make_shared(this, *memory, USER32MODE); + for (std::size_t i = 0; i < num_cores; ++i) { + cpu_cores.push_back( + std::make_shared(this, *memory, USER32MODE, i, timing->GetTimer(i))); + } } + running_core = cpu_cores[0].get(); - kernel->SetCPU(cpu_core); + kernel->SetCPUs(cpu_cores); + kernel->SetRunningCPU(cpu_cores[0]); if (Settings::values.enable_dsp_lle) { dsp_core = std::make_unique(*memory, @@ -296,7 +370,7 @@ System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mo HW::Init(*memory); Service::Init(*this); - GDBStub::Init(); + GDBStub::DeferStart(); VideoCore::ResultStatus result = VideoCore::Init(emu_window, *memory); if (result != VideoCore::ResultStatus::Success) { @@ -318,6 +392,8 @@ System::ResultStatus System::Init(Frontend::EmuWindow& emu_window, u32 system_mo LOG_DEBUG(Core, "Initialized OK"); + initalized = true; + return ResultStatus::Success; } @@ -421,9 +497,10 @@ void System::Shutdown() { perf_stats.reset(); rpc_server.reset(); cheat_engine.reset(); + archive_manager.reset(); service_manager.reset(); dsp_core.reset(); - cpu_core.reset(); + cpu_cores.clear(); kernel.reset(); timing.reset(); app_loader.reset(); @@ -452,11 +529,18 @@ void System::Reset() { template void System::serialize(Archive& ar, const unsigned int file_version) { + u32 num_cores; + ar& num_cores; + if (num_cores != this->GetNumCores()) { + throw std::runtime_error("Wrong N3DS mode"); + } // flush on save, don't flush on load bool should_flush = !Archive::is_loading::value; Memory::RasterizerClearAll(should_flush); ar&* timing.get(); - ar&* cpu_core.get(); + for (int i = 0; i < num_cores; i++) { + ar&* cpu_cores[i].get(); + } ar&* service_manager.get(); ar& GPU::g_regs; ar& LCD::g_regs; diff --git a/src/core/core.h b/src/core/core.h index 0ce6924cd..66e880fd1 100644 --- a/src/core/core.h +++ b/src/core/core.h @@ -148,7 +148,10 @@ public: * @returns True if the emulated system is powered on, otherwise false. */ bool IsPoweredOn() const { - return cpu_core != nullptr; + return cpu_cores.size() > 0 && + std::all_of(cpu_cores.begin(), cpu_cores.end(), + [](std::shared_ptr ptr) { return ptr != nullptr; }); + ; } /** @@ -168,8 +171,29 @@ public: * Gets a reference to the emulated CPU. * @returns A reference to the emulated CPU. */ - ARM_Interface& CPU() { - return *cpu_core; + + ARM_Interface& GetRunningCore() { + return *running_core; + }; + + /** + * Gets a reference to the emulated CPU. + * @param core_id The id of the core requested. + * @returns A reference to the emulated CPU. + */ + + ARM_Interface& GetCore(u32 core_id) { + return *cpu_cores[core_id]; + }; + + u32 GetNumCores() const { + return static_cast(cpu_cores.size()); + } + + void InvalidateCacheRange(u32 start_address, std::size_t length) { + for (const auto& cpu : cpu_cores) { + cpu->InvalidateCacheRange(start_address, length); + } } /** @@ -291,7 +315,7 @@ private: * @param system_mode The system mode. * @return ResultStatus code, indicating if the operation succeeded. */ - ResultStatus Init(Frontend::EmuWindow& emu_window, u32 system_mode); + ResultStatus Init(Frontend::EmuWindow& emu_window, u32 system_mode, u8 n3ds_mode); /// Reschedule the core emulation void Reschedule(); @@ -300,7 +324,8 @@ private: std::unique_ptr app_loader; /// ARM11 CPU core - std::shared_ptr cpu_core; + std::vector> cpu_cores; + ARM_Interface* running_core = nullptr; /// DSP core std::unique_ptr dsp_core; @@ -342,6 +367,8 @@ private: private: static System s_instance; + bool initalized = false; + ResultStatus status = ResultStatus::Success; std::string status_details = ""; /// Saved variables for reset @@ -358,8 +385,16 @@ private: void serialize(Archive& ar, const unsigned int file_version); }; -inline ARM_Interface& CPU() { - return System::GetInstance().CPU(); +inline ARM_Interface& GetRunningCore() { + return System::GetInstance().GetRunningCore(); +} + +inline ARM_Interface& GetCore(u32 core_id) { + return System::GetInstance().GetCore(core_id); +} + +inline u32 GetNumCores() { + return System::GetInstance().GetNumCores(); } inline AudioCore::DspInterface& DSP() { diff --git a/src/core/core_timing.cpp b/src/core/core_timing.cpp index 116ba3a40..fd7d5d7a9 100644 --- a/src/core/core_timing.cpp +++ b/src/core/core_timing.cpp @@ -14,14 +14,22 @@ namespace Core { Timing* Timing::deserializing = nullptr; // Sort by time, unless the times are the same, in which case sort by the order added to the queue -bool Timing::Event::operator>(const Event& right) const { +bool Timing::Event::operator>(const Timing::Event& right) const { return std::tie(time, fifo_order) > std::tie(right.time, right.fifo_order); } -bool Timing::Event::operator<(const Event& right) const { +bool Timing::Event::operator<(const Timing::Event& right) const { return std::tie(time, fifo_order) < std::tie(right.time, right.fifo_order); } +Timing::Timing(std::size_t num_cores) { + timers.resize(num_cores); + for (std::size_t i = 0; i < num_cores; ++i) { + timers[i] = std::make_shared(); + } + current_timer = timers[0]; +} + TimingEventType* Timing::RegisterEvent(const std::string& name, TimedCallback callback) { // check for existing type with same name. // we want event type names to remain unique so that we can use them for serialization. @@ -34,73 +42,102 @@ TimingEventType* Timing::RegisterEvent(const std::string& name, TimedCallback ca return event_type; } -Timing::~Timing() { +void Timing::ScheduleEvent(s64 cycles_into_future, const TimingEventType* event_type, u64 userdata, + std::size_t core_id) { + ASSERT(event_type != nullptr); + std::shared_ptr timer; + if (core_id == std::numeric_limits::max()) { + timer = current_timer; + } else { + ASSERT(core_id < timers.size()); + timer = timers.at(core_id); + } + + s64 timeout = timer->GetTicks() + cycles_into_future; + if (current_timer == timer) { + // If this event needs to be scheduled before the next advance(), force one early + if (!timer->is_timer_sane) + timer->ForceExceptionCheck(cycles_into_future); + + timer->event_queue.emplace_back( + Event{timeout, timer->event_fifo_id++, userdata, event_type}); + std::push_heap(timer->event_queue.begin(), timer->event_queue.end(), std::greater<>()); + } else { + timer->ts_queue.Push(Event{static_cast(timer->GetTicks() + cycles_into_future), 0, + userdata, event_type}); + } +} + +void Timing::UnscheduleEvent(const TimingEventType* event_type, u64 userdata) { + for (auto timer : timers) { + auto itr = std::remove_if( + timer->event_queue.begin(), timer->event_queue.end(), + [&](const Event& e) { return e.type == event_type && e.userdata == userdata; }); + + // Removing random items breaks the invariant so we have to re-establish it. + if (itr != timer->event_queue.end()) { + timer->event_queue.erase(itr, timer->event_queue.end()); + std::make_heap(timer->event_queue.begin(), timer->event_queue.end(), std::greater<>()); + } + } + // TODO:remove events from ts_queue +} + +void Timing::RemoveEvent(const TimingEventType* event_type) { + for (auto timer : timers) { + auto itr = std::remove_if(timer->event_queue.begin(), timer->event_queue.end(), + [&](const Event& e) { return e.type == event_type; }); + + // Removing random items breaks the invariant so we have to re-establish it. + if (itr != timer->event_queue.end()) { + timer->event_queue.erase(itr, timer->event_queue.end()); + std::make_heap(timer->event_queue.begin(), timer->event_queue.end(), std::greater<>()); + } + } + // TODO:remove events from ts_queue +} + +void Timing::SetCurrentTimer(std::size_t core_id) { + current_timer = timers[core_id]; +} + +s64 Timing::GetTicks() const { + return current_timer->GetTicks(); +} + +s64 Timing::GetGlobalTicks() const { + return global_timer; +} + +std::chrono::microseconds Timing::GetGlobalTimeUs() const { + return std::chrono::microseconds{GetTicks() * 1000000 / BASE_CLOCK_RATE_ARM11}; +} + +std::shared_ptr Timing::GetTimer(std::size_t cpu_id) { + return timers[cpu_id]; +} + +Timing::Timer::~Timer() { MoveEvents(); } -u64 Timing::GetTicks() const { - u64 ticks = static_cast(global_timer); - if (!is_global_timer_sane) { +u64 Timing::Timer::GetTicks() const { + u64 ticks = static_cast(executed_ticks); + if (!is_timer_sane) { ticks += slice_length - downcount; } return ticks; } -void Timing::AddTicks(u64 ticks) { +void Timing::Timer::AddTicks(u64 ticks) { downcount -= ticks; } -u64 Timing::GetIdleTicks() const { +u64 Timing::Timer::GetIdleTicks() const { return static_cast(idled_cycles); } -void Timing::ScheduleEvent(s64 cycles_into_future, const TimingEventType* event_type, - u64 userdata) { - ASSERT(event_type != nullptr); - s64 timeout = GetTicks() + cycles_into_future; - - // If this event needs to be scheduled before the next advance(), force one early - if (!is_global_timer_sane) - ForceExceptionCheck(cycles_into_future); - - event_queue.emplace_back(Event{timeout, event_fifo_id++, userdata, event_type}); - std::push_heap(event_queue.begin(), event_queue.end(), std::greater<>()); -} - -void Timing::ScheduleEventThreadsafe(s64 cycles_into_future, const TimingEventType* event_type, - u64 userdata) { - ts_queue.Push(Event{global_timer + cycles_into_future, 0, userdata, event_type}); -} - -void Timing::UnscheduleEvent(const TimingEventType* event_type, u64 userdata) { - auto itr = std::remove_if(event_queue.begin(), event_queue.end(), [&](const Event& e) { - return e.type == event_type && e.userdata == userdata; - }); - - // Removing random items breaks the invariant so we have to re-establish it. - if (itr != event_queue.end()) { - event_queue.erase(itr, event_queue.end()); - std::make_heap(event_queue.begin(), event_queue.end(), std::greater<>()); - } -} - -void Timing::RemoveEvent(const TimingEventType* event_type) { - auto itr = std::remove_if(event_queue.begin(), event_queue.end(), - [&](const Event& e) { return e.type == event_type; }); - - // Removing random items breaks the invariant so we have to re-establish it. - if (itr != event_queue.end()) { - event_queue.erase(itr, event_queue.end()); - std::make_heap(event_queue.begin(), event_queue.end(), std::greater<>()); - } -} - -void Timing::RemoveNormalAndThreadsafeEvent(const TimingEventType* event_type) { - MoveEvents(); - RemoveEvent(event_type); -} - -void Timing::ForceExceptionCheck(s64 cycles) { +void Timing::Timer::ForceExceptionCheck(s64 cycles) { cycles = std::max(0, cycles); if (downcount > cycles) { slice_length -= downcount - cycles; @@ -108,7 +145,7 @@ void Timing::ForceExceptionCheck(s64 cycles) { } } -void Timing::MoveEvents() { +void Timing::Timer::MoveEvents() { for (Event ev; ts_queue.Pop(ev);) { ev.fifo_order = event_fifo_id++; event_queue.emplace_back(std::move(ev)); @@ -116,50 +153,54 @@ void Timing::MoveEvents() { } } -void Timing::Advance() { +s64 Timing::Timer::GetMaxSliceLength() const { + auto next_event = std::find_if(event_queue.begin(), event_queue.end(), + [&](const Event& e) { return e.time - executed_ticks > 0; }); + if (next_event != event_queue.end()) { + return next_event->time - executed_ticks; + } + return MAX_SLICE_LENGTH; +} + +void Timing::Timer::Advance(s64 max_slice_length) { MoveEvents(); s64 cycles_executed = slice_length - downcount; - global_timer += cycles_executed; - slice_length = MAX_SLICE_LENGTH; + idled_cycles = 0; + executed_ticks += cycles_executed; + slice_length = max_slice_length; - is_global_timer_sane = true; + is_timer_sane = true; - while (!event_queue.empty() && event_queue.front().time <= global_timer) { + while (!event_queue.empty() && event_queue.front().time <= executed_ticks) { Event evt = std::move(event_queue.front()); std::pop_heap(event_queue.begin(), event_queue.end(), std::greater<>()); event_queue.pop_back(); - if (event_types.find(*evt.type->name) == event_types.end()) { - LOG_ERROR(Core, "Unknown queued event {}", *evt.type->name); - } else if (evt.type->callback == nullptr) { + if (evt.type->callback == nullptr) { LOG_ERROR(Core, "Event '{}' has no callback", *evt.type->name); } if (evt.type->callback != nullptr) { - evt.type->callback(evt.userdata, global_timer - evt.time); + evt.type->callback(evt.userdata, executed_ticks - evt.time); } } - is_global_timer_sane = false; + is_timer_sane = false; // Still events left (scheduled in the future) if (!event_queue.empty()) { slice_length = static_cast( - std::min(event_queue.front().time - global_timer, MAX_SLICE_LENGTH)); + std::min(event_queue.front().time - executed_ticks, max_slice_length)); } downcount = slice_length; } -void Timing::Idle() { +void Timing::Timer::Idle() { idled_cycles += downcount; downcount = 0; } -std::chrono::microseconds Timing::GetGlobalTimeUs() const { - return std::chrono::microseconds{GetTicks() * 1000000 / BASE_CLOCK_RATE_ARM11}; -} - -s64 Timing::GetDowncount() const { +s64 Timing::Timer::GetDowncount() const { return downcount; } diff --git a/src/core/core_timing.h b/src/core/core_timing.h index 1e6a3b021..a56a9097d 100644 --- a/src/core/core_timing.h +++ b/src/core/core_timing.h @@ -135,65 +135,10 @@ struct TimingEventType { }; class Timing { -public: - ~Timing(); - - /** - * This should only be called from the emu thread, if you are calling it any other thread, you - * are doing something evil - */ - u64 GetTicks() const; - u64 GetIdleTicks() const; - void AddTicks(u64 ticks); - - /** - * Returns the event_type identifier. if name is not unique, it will assert. - */ - TimingEventType* RegisterEvent(const std::string& name, TimedCallback callback); - - /** - * After the first Advance, the slice lengths and the downcount will be reduced whenever an - * event is scheduled earlier than the current values. Scheduling from a callback will not - * update the downcount until the Advance() completes. - */ - void ScheduleEvent(s64 cycles_into_future, const TimingEventType* event_type, u64 userdata = 0); - - /** - * This is to be called when outside of hle threads, such as the graphics thread, wants to - * schedule things to be executed on the main thread. - * Not that this doesn't change slice_length and thus events scheduled by this might be called - * with a delay of up to MAX_SLICE_LENGTH - */ - void ScheduleEventThreadsafe(s64 cycles_into_future, const TimingEventType* event_type, - u64 userdata); - - void UnscheduleEvent(const TimingEventType* event_type, u64 userdata); - - /// We only permit one event of each type in the queue at a time. - void RemoveEvent(const TimingEventType* event_type); - void RemoveNormalAndThreadsafeEvent(const TimingEventType* event_type); - - /** Advance must be called at the beginning of dispatcher loops, not the end. Advance() ends - * the previous timing slice and begins the next one, you must Advance from the previous - * slice to the current one before executing any cycles. CoreTiming starts in slice -1 so an - * Advance() is required to initialize the slice length before the first cycle of emulated - * instructions is executed. - */ - void Advance(); - void MoveEvents(); - - /// Pretend that the main CPU has executed enough cycles to reach the next event. - void Idle(); - - void ForceExceptionCheck(s64 cycles); - - std::chrono::microseconds GetGlobalTimeUs() const; - - s64 GetDowncount() const; - private: static Timing* deserializing; +public: struct Event { s64 time; u64 fifo_order; @@ -229,48 +174,116 @@ private: static constexpr int MAX_SLICE_LENGTH = 20000; + class Timer { + public: + ~Timer(); + + s64 GetMaxSliceLength() const; + + void Advance(s64 max_slice_length = MAX_SLICE_LENGTH); + + void Idle(); + + u64 GetTicks() const; + u64 GetIdleTicks() const; + + void AddTicks(u64 ticks); + + s64 GetDowncount() const; + + void ForceExceptionCheck(s64 cycles); + + void MoveEvents(); + + private: + friend class Timing; + // The queue is a min-heap using std::make_heap/push_heap/pop_heap. + // We don't use std::priority_queue because we need to be able to serialize, unserialize and + // erase arbitrary events (RemoveEvent()) regardless of the queue order. These aren't + // accomodated by the standard adaptor class. + std::vector event_queue; + u64 event_fifo_id = 0; + // the queue for storing the events from other threads threadsafe until they will be added + // to the event_queue by the emu thread + Common::MPSCQueue ts_queue; + // Are we in a function that has been called from Advance() + // If events are sheduled from a function that gets called from Advance(), + // don't change slice_length and downcount. + // The time between CoreTiming being intialized and the first call to Advance() is + // considered the slice boundary between slice -1 and slice 0. Dispatcher loops must call + // Advance() before executing the first cycle of each slice to prepare the slice length and + // downcount for that slice. + bool is_timer_sane = true; + + s64 slice_length = MAX_SLICE_LENGTH; + s64 downcount = MAX_SLICE_LENGTH; + s64 executed_ticks = 0; + u64 idled_cycles; + + template + void serialize(Archive& ar, const unsigned int) { + MoveEvents(); + ar& slice_length; + ar& downcount; + ar& event_queue; + ar& event_fifo_id; + ar& idled_cycles; + } + friend class boost::serialization::access; + }; + + explicit Timing(std::size_t num_cores); + + ~Timing(){}; + + /** + * Returns the event_type identifier. if name is not unique, it will assert. + */ + TimingEventType* RegisterEvent(const std::string& name, TimedCallback callback); + + void ScheduleEvent(s64 cycles_into_future, const TimingEventType* event_type, u64 userdata = 0, + std::size_t core_id = std::numeric_limits::max()); + + void UnscheduleEvent(const TimingEventType* event_type, u64 userdata); + + /// We only permit one event of each type in the queue at a time. + void RemoveEvent(const TimingEventType* event_type); + + void SetCurrentTimer(std::size_t core_id); + + s64 GetTicks() const; + + s64 GetGlobalTicks() const; + + void AddToGlobalTicks(s64 ticks) { + global_timer += ticks; + } + + std::chrono::microseconds GetGlobalTimeUs() const; + + std::shared_ptr GetTimer(std::size_t cpu_id); + +private: s64 global_timer = 0; - s64 slice_length = MAX_SLICE_LENGTH; - s64 downcount = MAX_SLICE_LENGTH; // unordered_map stores each element separately as a linked list node so pointers to // elements remain stable regardless of rehashes/resizing. std::unordered_map event_types; - // The queue is a min-heap using std::make_heap/push_heap/pop_heap. - // We don't use std::priority_queue because we need to be able to serialize, unserialize and - // erase arbitrary events (RemoveEvent()) regardless of the queue order. These aren't - // accomodated by the standard adaptor class. - std::vector event_queue; - u64 event_fifo_id = 0; - // the queue for storing the events from other threads threadsafe until they will be added - // to the event_queue by the emu thread - Common::MPSCQueue ts_queue; - s64 idled_cycles = 0; - - // Are we in a function that has been called from Advance() - // If events are sheduled from a function that gets called from Advance(), - // don't change slice_length and downcount. - // The time between CoreTiming being intialized and the first call to Advance() is considered - // the slice boundary between slice -1 and slice 0. Dispatcher loops must call Advance() before - // executing the first cycle of each slice to prepare the slice length and downcount for - // that slice. - bool is_global_timer_sane = true; + std::vector> timers; + std::shared_ptr current_timer; template void serialize(Archive& ar, const unsigned int) { // event_types set during initialization of other things deserializing = this; - MoveEvents(); ar& global_timer; - ar& slice_length; - ar& downcount; - ar& event_queue; - ar& event_fifo_id; - ar& idled_cycles; + ar& timers; + ar& current_timer; deserializing = nullptr; } friend class boost::serialization::access; + }; } // namespace Core diff --git a/src/core/file_sys/layered_fs.cpp b/src/core/file_sys/layered_fs.cpp new file mode 100644 index 000000000..9d5fbf7c2 --- /dev/null +++ b/src/core/file_sys/layered_fs.cpp @@ -0,0 +1,604 @@ +// Copyright 2020 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include +#include +#include "common/alignment.h" +#include "common/assert.h" +#include "common/common_paths.h" +#include "common/file_util.h" +#include "common/string_util.h" +#include "common/swap.h" +#include "core/file_sys/layered_fs.h" +#include "core/file_sys/patch.h" + +namespace FileSys { + +struct FileRelocationInfo { + int type; // 0 - none, 1 - replaced / created, 2 - patched, 3 - removed + u64 original_offset; // Type 0. Offset is absolute + std::string replace_file_path; // Type 1 + std::vector patched_file; // Type 2 + u64 size; // Relocated file size +}; +struct LayeredFS::File { + std::string name; + std::string path; + FileRelocationInfo relocation{}; + Directory* parent; +}; + +struct DirectoryMetadata { + u32_le parent_directory_offset; + u32_le next_sibling_offset; + u32_le first_child_directory_offset; + u32_le first_file_offset; + u32_le hash_bucket_next; + u32_le name_length; + // Followed by a name of name length (aligned up to 4) +}; +static_assert(sizeof(DirectoryMetadata) == 0x18, "Size of DirectoryMetadata is not correct"); + +struct FileMetadata { + u32_le parent_directory_offset; + u32_le next_sibling_offset; + u64_le file_data_offset; + u64_le file_data_length; + u32_le hash_bucket_next; + u32_le name_length; + // Followed by a name of name length (aligned up to 4) +}; +static_assert(sizeof(FileMetadata) == 0x20, "Size of FileMetadata is not correct"); + +LayeredFS::LayeredFS(std::shared_ptr romfs_, std::string patch_path_, + std::string patch_ext_path_, bool load_relocations) + : romfs(std::move(romfs_)), patch_path(std::move(patch_path_)), + patch_ext_path(std::move(patch_ext_path_)) { + + romfs->ReadFile(0, sizeof(header), reinterpret_cast(&header)); + + ASSERT_MSG(header.header_length == sizeof(header), "Header size is incorrect"); + + // TODO: is root always the first directory in table? + root.parent = &root; + LoadDirectory(root, 0); + + if (load_relocations) { + LoadRelocations(); + LoadExtRelocations(); + } + + RebuildMetadata(); +} + +LayeredFS::~LayeredFS() = default; + +void LayeredFS::LoadDirectory(Directory& current, u32 offset) { + DirectoryMetadata metadata; + romfs->ReadFile(header.directory_metadata_table.offset + offset, sizeof(metadata), + reinterpret_cast(&metadata)); + + current.name = ReadName(header.directory_metadata_table.offset + offset + sizeof(metadata), + metadata.name_length); + current.path = current.parent->path + current.name + DIR_SEP; + directory_path_map.emplace(current.path, ¤t); + + if (metadata.first_file_offset != 0xFFFFFFFF) { + LoadFile(current, metadata.first_file_offset); + } + + if (metadata.first_child_directory_offset != 0xFFFFFFFF) { + auto child = std::make_unique(); + auto& directory = *child; + directory.parent = ¤t; + current.directories.emplace_back(std::move(child)); + LoadDirectory(directory, metadata.first_child_directory_offset); + } + + if (metadata.next_sibling_offset != 0xFFFFFFFF) { + auto sibling = std::make_unique(); + auto& directory = *sibling; + directory.parent = current.parent; + current.parent->directories.emplace_back(std::move(sibling)); + LoadDirectory(directory, metadata.next_sibling_offset); + } +} + +void LayeredFS::LoadFile(Directory& parent, u32 offset) { + FileMetadata metadata; + romfs->ReadFile(header.file_metadata_table.offset + offset, sizeof(metadata), + reinterpret_cast(&metadata)); + + auto file = std::make_unique(); + file->name = ReadName(header.file_metadata_table.offset + offset + sizeof(metadata), + metadata.name_length); + file->path = parent.path + file->name; + file->relocation.original_offset = header.file_data_offset + metadata.file_data_offset; + file->relocation.size = metadata.file_data_length; + file->parent = &parent; + + file_path_map.emplace(file->path, file.get()); + parent.files.emplace_back(std::move(file)); + + if (metadata.next_sibling_offset != 0xFFFFFFFF) { + LoadFile(parent, metadata.next_sibling_offset); + } +} + +std::string LayeredFS::ReadName(u32 offset, u32 name_length) { + std::vector buffer(name_length / sizeof(u16_le)); + romfs->ReadFile(offset, name_length, reinterpret_cast(buffer.data())); + + std::u16string name(buffer.size(), 0); + std::transform(buffer.begin(), buffer.end(), name.begin(), [](u16_le character) { + return static_cast(static_cast(character)); + }); + return Common::UTF16ToUTF8(name); +} + +void LayeredFS::LoadRelocations() { + if (!FileUtil::Exists(patch_path)) { + return; + } + + const FileUtil::DirectoryEntryCallable callback = [this, + &callback](u64* /*num_entries_out*/, + const std::string& directory, + const std::string& virtual_name) { + auto* parent = directory_path_map.at(directory.substr(patch_path.size() - 1)); + + if (FileUtil::IsDirectory(directory + virtual_name + DIR_SEP)) { + const auto path = (directory + virtual_name + DIR_SEP).substr(patch_path.size() - 1); + if (!directory_path_map.count(path)) { // Add this directory + auto directory = std::make_unique(); + directory->name = virtual_name; + directory->path = path; + directory->parent = parent; + directory_path_map.emplace(path, directory.get()); + parent->directories.emplace_back(std::move(directory)); + LOG_INFO(Service_FS, "LayeredFS created directory {}", path); + } + return FileUtil::ForeachDirectoryEntry(nullptr, directory + virtual_name + DIR_SEP, + callback); + } + + const auto path = (directory + virtual_name).substr(patch_path.size() - 1); + if (!file_path_map.count(path)) { // Newly created file + auto file = std::make_unique(); + file->name = virtual_name; + file->path = path; + file->parent = parent; + file_path_map.emplace(path, file.get()); + parent->files.emplace_back(std::move(file)); + LOG_INFO(Service_FS, "LayeredFS created file {}", path); + } + + auto* file = file_path_map.at(path); + file->relocation.type = 1; + file->relocation.replace_file_path = directory + virtual_name; + file->relocation.size = FileUtil::GetSize(directory + virtual_name); + LOG_INFO(Service_FS, "LayeredFS replacement file in use for {}", path); + return true; + }; + + FileUtil::ForeachDirectoryEntry(nullptr, patch_path, callback); +} + +void LayeredFS::LoadExtRelocations() { + if (!FileUtil::Exists(patch_ext_path)) { + return; + } + + if (patch_ext_path.back() == '/' || patch_ext_path.back() == '\\') { + // ScanDirectoryTree expects a path without trailing '/' + patch_ext_path.erase(patch_ext_path.size() - 1, 1); + } + + FileUtil::FSTEntry result; + FileUtil::ScanDirectoryTree(patch_ext_path, result, 256); + + for (const auto& entry : result.children) { + if (FileUtil::IsDirectory(entry.physicalName)) { + continue; + } + + const auto path = entry.physicalName.substr(patch_ext_path.size()); + if (path.size() >= 5 && path.substr(path.size() - 5) == ".stub") { + // Remove the corresponding file if exists + const auto file_path = path.substr(0, path.size() - 5); + if (file_path_map.count(file_path)) { + auto& file = *file_path_map[file_path]; + file.relocation.type = 3; + file.relocation.size = 0; + file_path_map.erase(file_path); + LOG_INFO(Service_FS, "LayeredFS removed file {}", file_path); + } else { + LOG_WARNING(Service_FS, "LayeredFS file for stub {} not found", path); + } + } else if (path.size() >= 4) { + const auto extension = path.substr(path.size() - 4); + if (extension != ".ips" && extension != ".bps") { + LOG_WARNING(Service_FS, "LayeredFS unknown ext file {}", path); + } + + const auto file_path = path.substr(0, path.size() - 4); + if (!file_path_map.count(file_path)) { + LOG_WARNING(Service_FS, "LayeredFS original file for patch {} not found", path); + continue; + } + + FileUtil::IOFile patch_file(entry.physicalName, "rb"); + if (!patch_file) { + LOG_ERROR(Service_FS, "LayeredFS Could not open file {}", entry.physicalName); + continue; + } + + const auto size = patch_file.GetSize(); + std::vector patch(size); + if (patch_file.ReadBytes(patch.data(), size) != size) { + LOG_ERROR(Service_FS, "LayeredFS Could not read file {}", entry.physicalName); + continue; + } + + auto& file = *file_path_map[file_path]; + std::vector buffer(file.relocation.size); // Original size + romfs->ReadFile(file.relocation.original_offset, buffer.size(), buffer.data()); + + bool ret = false; + if (extension == ".ips") { + ret = Patch::ApplyIpsPatch(patch, buffer); + } else { + ret = Patch::ApplyBpsPatch(patch, buffer); + } + + if (ret) { + LOG_INFO(Service_FS, "LayeredFS patched file {}", file_path); + + file.relocation.type = 2; + file.relocation.size = buffer.size(); + file.relocation.patched_file = std::move(buffer); + } else { + LOG_ERROR(Service_FS, "LayeredFS failed to patch file {}", file_path); + } + } else { + LOG_WARNING(Service_FS, "LayeredFS unknown ext file {}", path); + } + } +} + +std::size_t GetNameSize(const std::string& name) { + std::u16string u16name = Common::UTF8ToUTF16(name); + return Common::AlignUp(u16name.size() * 2, 4); +} + +void LayeredFS::PrepareBuildDirectory(Directory& current) { + directory_metadata_offset_map.emplace(¤t, current_directory_offset); + directory_list.emplace_back(¤t); + current_directory_offset += sizeof(DirectoryMetadata) + GetNameSize(current.name); +} + +void LayeredFS::PrepareBuildFile(File& current) { + if (current.relocation.type == 3) { // Deleted files are not counted + return; + } + file_metadata_offset_map.emplace(¤t, current_file_offset); + file_list.emplace_back(¤t); + current_file_offset += sizeof(FileMetadata) + GetNameSize(current.name); +} + +void LayeredFS::PrepareBuild(Directory& current) { + for (const auto& child : current.files) { + PrepareBuildFile(*child); + } + + for (const auto& child : current.directories) { + PrepareBuildDirectory(*child); + } + + for (const auto& child : current.directories) { + PrepareBuild(*child); + } +} + +// Implementation from 3dbrew +u32 CalcHash(const std::string& name, u32 parent_offset) { + u32 hash = parent_offset ^ 123456789; + std::u16string u16name = Common::UTF8ToUTF16(name); + for (char16_t c : u16name) { + hash = (hash >> 5) | (hash << 27); + hash ^= static_cast(c); + } + return hash; +} + +std::size_t WriteName(u8* dest, std::u16string name) { + const auto buffer_size = Common::AlignUp(name.size() * 2, 4); + std::vector buffer(buffer_size / 2); + std::transform(name.begin(), name.end(), buffer.begin(), [](char16_t character) { + return static_cast(static_cast(character)); + }); + std::memcpy(dest, buffer.data(), buffer_size); + + return buffer_size; +} + +void LayeredFS::BuildDirectories() { + directory_metadata_table.resize(current_directory_offset, 0xFF); + + std::size_t written = 0; + for (const auto& directory : directory_list) { + DirectoryMetadata metadata; + std::memset(&metadata, 0xFF, sizeof(metadata)); + metadata.parent_directory_offset = directory_metadata_offset_map.at(directory->parent); + + if (directory->parent != directory) { + bool flag = false; + for (const auto& sibling : directory->parent->directories) { + if (flag) { + metadata.next_sibling_offset = directory_metadata_offset_map.at(sibling.get()); + break; + } else if (sibling.get() == directory) { + flag = true; + } + } + } + + if (!directory->directories.empty()) { + metadata.first_child_directory_offset = + directory_metadata_offset_map.at(directory->directories.front().get()); + } + + if (!directory->files.empty()) { + metadata.first_file_offset = + file_metadata_offset_map.at(directory->files.front().get()); + } + + const auto bucket = CalcHash(directory->name, metadata.parent_directory_offset) % + directory_hash_table.size(); + metadata.hash_bucket_next = directory_hash_table[bucket]; + directory_hash_table[bucket] = directory_metadata_offset_map.at(directory); + + // Write metadata and name + std::u16string u16name = Common::UTF8ToUTF16(directory->name); + metadata.name_length = u16name.size() * 2; + + std::memcpy(directory_metadata_table.data() + written, &metadata, sizeof(metadata)); + written += sizeof(metadata); + + written += WriteName(directory_metadata_table.data() + written, u16name); + } + + ASSERT_MSG(written == directory_metadata_table.size(), + "Calculated size for directory metadata table is wrong"); +} + +void LayeredFS::BuildFiles() { + file_metadata_table.resize(current_file_offset, 0xFF); + + std::size_t written = 0; + for (const auto& file : file_list) { + FileMetadata metadata; + std::memset(&metadata, 0xFF, sizeof(metadata)); + + metadata.parent_directory_offset = directory_metadata_offset_map.at(file->parent); + + bool flag = false; + for (const auto& sibling : file->parent->files) { + if (sibling->relocation.type == 3) { // removed file + continue; + } + if (flag) { + metadata.next_sibling_offset = file_metadata_offset_map.at(sibling.get()); + break; + } else if (sibling.get() == file) { + flag = true; + } + } + + metadata.file_data_offset = current_data_offset; + metadata.file_data_length = file->relocation.size; + current_data_offset += Common::AlignUp(metadata.file_data_length, 16); + if (metadata.file_data_length != 0) { + data_offset_map.emplace(metadata.file_data_offset, file); + } + + const auto bucket = + CalcHash(file->name, metadata.parent_directory_offset) % file_hash_table.size(); + metadata.hash_bucket_next = file_hash_table[bucket]; + file_hash_table[bucket] = file_metadata_offset_map.at(file); + + // Write metadata and name + std::u16string u16name = Common::UTF8ToUTF16(file->name); + metadata.name_length = u16name.size() * 2; + + std::memcpy(file_metadata_table.data() + written, &metadata, sizeof(metadata)); + written += sizeof(metadata); + + written += WriteName(file_metadata_table.data() + written, u16name); + } + + ASSERT_MSG(written == file_metadata_table.size(), + "Calculated size for file metadata table is wrong"); +} + +// Implementation from 3dbrew +std::size_t GetHashTableSize(std::size_t entry_count) { + if (entry_count < 3) { + return 3; + } else if (entry_count < 19) { + return entry_count | 1; + } else { + std::size_t count = entry_count; + while (count % 2 == 0 || count % 3 == 0 || count % 5 == 0 || count % 7 == 0 || + count % 11 == 0 || count % 13 == 0 || count % 17 == 0) { + count++; + } + return count; + } +} + +void LayeredFS::RebuildMetadata() { + PrepareBuildDirectory(root); + PrepareBuild(root); + + directory_hash_table.resize(GetHashTableSize(directory_list.size()), 0xFFFFFFFF); + file_hash_table.resize(GetHashTableSize(file_list.size()), 0xFFFFFFFF); + + BuildDirectories(); + BuildFiles(); + + // Create header + RomFSHeader header; + header.header_length = sizeof(header); + header.directory_hash_table = { + /*offset*/ sizeof(header), + /*length*/ static_cast(directory_hash_table.size() * sizeof(u32_le))}; + header.directory_metadata_table = { + /*offset*/ + header.directory_hash_table.offset + header.directory_hash_table.length, + /*length*/ static_cast(directory_metadata_table.size())}; + header.file_hash_table = { + /*offset*/ + header.directory_metadata_table.offset + header.directory_metadata_table.length, + /*length*/ static_cast(file_hash_table.size() * sizeof(u32_le))}; + header.file_metadata_table = {/*offset*/ header.file_hash_table.offset + + header.file_hash_table.length, + /*length*/ static_cast(file_metadata_table.size())}; + header.file_data_offset = + Common::AlignUp(header.file_metadata_table.offset + header.file_metadata_table.length, 16); + + // Write hash table and metadata table + metadata.resize(header.file_data_offset); + std::memcpy(metadata.data(), &header, header.header_length); + std::memcpy(metadata.data() + header.directory_hash_table.offset, directory_hash_table.data(), + header.directory_hash_table.length); + std::memcpy(metadata.data() + header.directory_metadata_table.offset, + directory_metadata_table.data(), header.directory_metadata_table.length); + std::memcpy(metadata.data() + header.file_hash_table.offset, file_hash_table.data(), + header.file_hash_table.length); + std::memcpy(metadata.data() + header.file_metadata_table.offset, file_metadata_table.data(), + header.file_metadata_table.length); +} + +std::size_t LayeredFS::GetSize() const { + return metadata.size() + current_data_offset; +} + +std::size_t LayeredFS::ReadFile(std::size_t offset, std::size_t length, u8* buffer) { + ASSERT_MSG(offset + length <= GetSize(), "Out of bound"); + + std::size_t read_size = 0; + if (offset < metadata.size()) { + // First read the metadata + const auto to_read = std::min(metadata.size() - offset, length); + std::memcpy(buffer, metadata.data() + offset, to_read); + read_size += to_read; + offset = 0; + } else { + offset -= metadata.size(); + } + + // Read files + auto current = (--data_offset_map.upper_bound(offset)); + while (read_size < length) { + const auto relative_offset = offset - current->first; + std::size_t to_read{}; + if (current->second->relocation.size > relative_offset) { + to_read = std::min(current->second->relocation.size - relative_offset, + length - read_size); + } + const auto alignment = + std::min(Common::AlignUp(current->second->relocation.size, 16) - + relative_offset, + length - read_size) - + to_read; + + // Read the file in different ways depending on relocation type + auto& relocation = current->second->relocation; + if (relocation.type == 0) { // none + romfs->ReadFile(relocation.original_offset + relative_offset, to_read, + buffer + read_size); + } else if (relocation.type == 1) { // replace + FileUtil::IOFile replace_file(relocation.replace_file_path, "rb"); + if (replace_file) { + replace_file.Seek(relative_offset, SEEK_SET); + replace_file.ReadBytes(buffer + read_size, to_read); + } else { + LOG_ERROR(Service_FS, "Could not open replacement file for {}", + current->second->path); + } + } else if (relocation.type == 2) { // patch + std::memcpy(buffer + read_size, relocation.patched_file.data() + relative_offset, + to_read); + } else { + UNREACHABLE(); + } + + std::memset(buffer + read_size + to_read, 0, alignment); + + read_size += to_read + alignment; + offset += to_read + alignment; + current++; + } + + return read_size; +} + +bool LayeredFS::ExtractDirectory(Directory& current, const std::string& target_path) { + if (!FileUtil::CreateFullPath(target_path + current.path)) { + LOG_ERROR(Service_FS, "Could not create path {}", target_path + current.path); + return false; + } + + constexpr std::size_t BufferSize = 0x10000; + std::array buffer; + for (const auto& file : current.files) { + // Extract file + const auto path = target_path + file->path; + LOG_INFO(Service_FS, "Extracting {} to {}", file->path, path); + + FileUtil::IOFile target_file(path, "wb"); + if (!target_file) { + LOG_ERROR(Service_FS, "Could not open file {}", path); + return false; + } + + std::size_t written = 0; + while (written < file->relocation.size) { + const auto to_read = + std::min(buffer.size(), file->relocation.size - written); + if (romfs->ReadFile(file->relocation.original_offset + written, to_read, + buffer.data()) != to_read) { + LOG_ERROR(Service_FS, "Could not read from RomFS"); + return false; + } + + if (target_file.WriteBytes(buffer.data(), to_read) != to_read) { + LOG_ERROR(Service_FS, "Could not write to file {}", path); + return false; + } + + written += to_read; + } + } + + for (const auto& directory : current.directories) { + if (!ExtractDirectory(*directory, target_path)) { + return false; + } + } + + return true; +} + +bool LayeredFS::DumpRomFS(const std::string& target_path) { + std::string path = target_path; + if (path.back() == '/' || path.back() == '\\') { + path.erase(path.size() - 1, 1); + } + + return ExtractDirectory(root, path); +} + +} // namespace FileSys diff --git a/src/core/file_sys/layered_fs.h b/src/core/file_sys/layered_fs.h new file mode 100644 index 000000000..956eedcfa --- /dev/null +++ b/src/core/file_sys/layered_fs.h @@ -0,0 +1,123 @@ +// Copyright 2020 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include +#include +#include +#include +#include "common/common_types.h" +#include "common/swap.h" +#include "core/file_sys/romfs_reader.h" + +namespace FileSys { + +struct RomFSHeader { + struct Descriptor { + u32_le offset; + u32_le length; + }; + u32_le header_length; + Descriptor directory_hash_table; + Descriptor directory_metadata_table; + Descriptor file_hash_table; + Descriptor file_metadata_table; + u32_le file_data_offset; +}; +static_assert(sizeof(RomFSHeader) == 0x28, "Size of RomFSHeader is not correct"); + +/** + * LayeredFS implementation. This basically adds a layer to another RomFSReader. + * + * patch_path: Path for RomFS replacements. Files present in this path replace or create + * corresponding files in RomFS. + * patch_ext_path: Path for RomFS extensions. Files present in this path: + * - When with an extension of ".stub", remove the corresponding file in the RomFS. + * - When with an extension of ".ips" or ".bps", patch the file in the RomFS. + */ +class LayeredFS : public RomFSReader { +public: + explicit LayeredFS(std::shared_ptr romfs, std::string patch_path, + std::string patch_ext_path, bool load_relocations = true); + ~LayeredFS() override; + + std::size_t GetSize() const override; + std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) override; + + bool DumpRomFS(const std::string& target_path); + +private: + struct File; + struct Directory { + std::string name; + std::string path; // with trailing '/' + std::vector> files; + std::vector> directories; + Directory* parent; + }; + + std::string ReadName(u32 offset, u32 name_length); + + // Loads the current directory, then its siblings, and then its children. + void LoadDirectory(Directory& current, u32 offset); + + // Load the file at offset, and then its siblings. + void LoadFile(Directory& parent, u32 offset); + + // Load replace/create relocations + void LoadRelocations(); + + // Load patch/remove relocations + void LoadExtRelocations(); + + // Calculate the offset of a single directory add it to the map and list of directories + void PrepareBuildDirectory(Directory& current); + + // Calculate the offset of a single file add it to the map and list of files + void PrepareBuildFile(File& current); + + // Recursively generate a sequence of files and directories and their offsets for all + // children of current. (The current directory itself is not handled.) + void PrepareBuild(Directory& current); + + void BuildDirectories(); + void BuildFiles(); + + // Recursively extract a directory and all its contents to target_path + // target_path should be without trailing '/'. + bool ExtractDirectory(Directory& current, const std::string& target_path); + + void RebuildMetadata(); + + std::shared_ptr romfs; + std::string patch_path; + std::string patch_ext_path; + + RomFSHeader header; + Directory root; + std::unordered_map file_path_map; + std::unordered_map directory_path_map; + std::map data_offset_map; // assigned data offset -> file + std::vector metadata; // Includes header, hash table and metadata + + // Used for rebuilding header + std::vector directory_hash_table; + std::vector file_hash_table; + + std::unordered_map + directory_metadata_offset_map; // directory -> metadata offset + std::vector directory_list; // sequence of directories to be written to metadata + u64 current_directory_offset{}; // current directory metadata offset + std::vector directory_metadata_table; // rebuilt directory metadata table + + std::unordered_map file_metadata_offset_map; // file -> metadata offset + std::vector file_list; // sequence of files to be written to metadata + u64 current_file_offset{}; // current file metadata offset + std::vector file_metadata_table; // rebuilt file metadata table + u64 current_data_offset{}; // current assigned data offset +}; + +} // namespace FileSys diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp index f0687fa9e..056f7a901 100644 --- a/src/core/file_sys/ncch_container.cpp +++ b/src/core/file_sys/ncch_container.cpp @@ -11,6 +11,7 @@ #include "common/common_types.h" #include "common/logging/log.h" #include "core/core.h" +#include "core/file_sys/layered_fs.h" #include "core/file_sys/ncch_container.h" #include "core/file_sys/patch.h" #include "core/file_sys/seed_db.h" @@ -25,6 +26,14 @@ namespace FileSys { static const int kMaxSections = 8; ///< Maximum number of sections (files) in an ExeFs static const int kBlockSize = 0x200; ///< Size of ExeFS blocks (in bytes) +u64 GetModId(u64 program_id) { + constexpr u64 UPDATE_MASK = 0x0000000e'00000000; + if ((program_id & 0x000000ff'00000000) == UPDATE_MASK) { // Apply the mods to updates + return program_id & ~UPDATE_MASK; + } + return program_id; +} + /** * Get the decompressed size of an LZSS compressed ExeFS file * @param buffer Buffer of compressed file @@ -303,8 +312,22 @@ Loader::ResultStatus NCCHContainer::Load() { } } - FileUtil::IOFile exheader_override_file{filepath + ".exheader", "rb"}; - const bool has_exheader_override = read_exheader(exheader_override_file); + const auto mods_path = + fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + GetModId(ncch_header.program_id)); + const std::array exheader_override_paths{{ + mods_path + "exheader.bin", + filepath + ".exheader", + }}; + + bool has_exheader_override = false; + for (const auto& path : exheader_override_paths) { + FileUtil::IOFile exheader_override_file{path, "rb"}; + if (read_exheader(exheader_override_file)) { + has_exheader_override = true; + break; + } + } if (has_exheader_override) { if (exheader_header.system_info.jump_id != exheader_header.arm11_system_local_caps.program_id) { @@ -512,7 +535,15 @@ Loader::ResultStatus NCCHContainer::ApplyCodePatch(std::vector& code) const std::string path; bool (*patch_fn)(const std::vector& patch, std::vector& code); }; - const std::array patch_paths{{ + + const auto mods_path = + fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + GetModId(ncch_header.program_id)); + const std::array patch_paths{{ + {mods_path + "exefs/code.ips", Patch::ApplyIpsPatch}, + {mods_path + "exefs/code.bps", Patch::ApplyBpsPatch}, + {mods_path + "code.ips", Patch::ApplyIpsPatch}, + {mods_path + "code.bps", Patch::ApplyBpsPatch}, {filepath + ".exefsdir/code.ips", Patch::ApplyIpsPatch}, {filepath + ".exefsdir/code.bps", Patch::ApplyBpsPatch}, }}; @@ -551,23 +582,34 @@ Loader::ResultStatus NCCHContainer::LoadOverrideExeFSSection(const char* name, else return Loader::ResultStatus::Error; - std::string section_override = filepath + ".exefsdir/" + override_name; - FileUtil::IOFile section_file(section_override, "rb"); + const auto mods_path = + fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + GetModId(ncch_header.program_id)); + const std::array override_paths{{ + mods_path + "exefs/" + override_name, + mods_path + override_name, + filepath + ".exefsdir/" + override_name, + }}; - if (section_file.IsOpen()) { - auto section_size = section_file.GetSize(); - buffer.resize(section_size); + for (const auto& path : override_paths) { + FileUtil::IOFile section_file(path, "rb"); - section_file.Seek(0, SEEK_SET); - if (section_file.ReadBytes(&buffer[0], section_size) == section_size) { - LOG_WARNING(Service_FS, "File {} overriding built-in ExeFS file", section_override); - return Loader::ResultStatus::Success; + if (section_file.IsOpen()) { + auto section_size = section_file.GetSize(); + buffer.resize(section_size); + + section_file.Seek(0, SEEK_SET); + if (section_file.ReadBytes(&buffer[0], section_size) == section_size) { + LOG_WARNING(Service_FS, "File {} overriding built-in ExeFS file", path); + return Loader::ResultStatus::Success; + } } } return Loader::ResultStatus::ErrorNotUsed; } -Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr& romfs_file) { +Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr& romfs_file, + bool use_layered_fs) { Loader::ResultStatus result = Load(); if (result != Loader::ResultStatus::Success) return result; @@ -597,14 +639,43 @@ Loader::ResultStatus NCCHContainer::ReadRomFS(std::shared_ptr& romf if (!romfs_file_inner.IsOpen()) return Loader::ResultStatus::Error; + std::shared_ptr direct_romfs; if (is_encrypted) { - romfs_file = std::make_shared(std::move(romfs_file_inner), romfs_offset, - romfs_size, secondary_key, romfs_ctr, 0x1000); + direct_romfs = + std::make_shared(std::move(romfs_file_inner), romfs_offset, + romfs_size, secondary_key, romfs_ctr, 0x1000); } else { - romfs_file = - std::make_shared(std::move(romfs_file_inner), romfs_offset, romfs_size); + direct_romfs = std::make_shared(std::move(romfs_file_inner), + romfs_offset, romfs_size); } + const auto path = + fmt::format("{}mods/{:016X}/", FileUtil::GetUserPath(FileUtil::UserPath::LoadDir), + GetModId(ncch_header.program_id)); + if (use_layered_fs && + (FileUtil::Exists(path + "romfs/") || FileUtil::Exists(path + "romfs_ext/"))) { + + romfs_file = std::make_shared(std::move(direct_romfs), path + "romfs/", + path + "romfs_ext/"); + } else { + romfs_file = std::move(direct_romfs); + } + + return Loader::ResultStatus::Success; +} + +Loader::ResultStatus NCCHContainer::DumpRomFS(const std::string& target_path) { + std::shared_ptr direct_romfs; + Loader::ResultStatus result = ReadRomFS(direct_romfs, false); + if (result != Loader::ResultStatus::Success) + return result; + + std::shared_ptr layered_fs = + std::make_shared(std::move(direct_romfs), "", "", false); + + if (!layered_fs->DumpRomFS(target_path)) { + return Loader::ResultStatus::Error; + } return Loader::ResultStatus::Success; } @@ -614,9 +685,10 @@ Loader::ResultStatus NCCHContainer::ReadOverrideRomFS(std::shared_ptr(std::move(romfs_file_inner), 0, - romfs_file_inner.GetSize()); + LOG_WARNING(Service_FS, "File {} overriding built-in RomFS; LayeredFS not enabled", + split_filepath); + romfs_file = std::make_shared(std::move(romfs_file_inner), 0, + romfs_file_inner.GetSize()); return Loader::ResultStatus::Success; } } diff --git a/src/core/file_sys/ncch_container.h b/src/core/file_sys/ncch_container.h index f06ee8ef6..7eadd9835 100644 --- a/src/core/file_sys/ncch_container.h +++ b/src/core/file_sys/ncch_container.h @@ -149,7 +149,8 @@ struct ExHeader_StorageInfo { struct ExHeader_ARM11_SystemLocalCaps { u64_le program_id; u32_le core_version; - u8 reserved_flags[2]; + u8 reserved_flag; + u8 n3ds_mode; union { u8 flags0; BitField<0, 2, u8> ideal_processor; @@ -247,7 +248,15 @@ public: * @param size The size of the romfs * @return ResultStatus result of function */ - Loader::ResultStatus ReadRomFS(std::shared_ptr& romfs_file); + Loader::ResultStatus ReadRomFS(std::shared_ptr& romfs_file, + bool use_layered_fs = true); + + /** + * Dump the RomFS of the NCCH container to the user folder. + * @param target_path target path to dump to + * @return ResultStatus result of function. + */ + Loader::ResultStatus DumpRomFS(const std::string& target_path); /** * Get the override RomFS of the NCCH container diff --git a/src/core/file_sys/romfs_reader.cpp b/src/core/file_sys/romfs_reader.cpp index 8e624cbfc..64374684a 100644 --- a/src/core/file_sys/romfs_reader.cpp +++ b/src/core/file_sys/romfs_reader.cpp @@ -5,7 +5,7 @@ namespace FileSys { -std::size_t RomFSReader::ReadFile(std::size_t offset, std::size_t length, u8* buffer) { +std::size_t DirectRomFSReader::ReadFile(std::size_t offset, std::size_t length, u8* buffer) { if (length == 0) return 0; // Crypto++ does not like zero size buffer file.Seek(file_offset + offset, SEEK_SET); diff --git a/src/core/file_sys/romfs_reader.h b/src/core/file_sys/romfs_reader.h index ab1986fe6..1cfaa3b4f 100644 --- a/src/core/file_sys/romfs_reader.h +++ b/src/core/file_sys/romfs_reader.h @@ -7,23 +7,39 @@ namespace FileSys { +/** + * Interface for reading RomFS data. + */ class RomFSReader { public: - RomFSReader(FileUtil::IOFile&& file, std::size_t file_offset, std::size_t data_size) + virtual ~RomFSReader() = default; + + virtual std::size_t GetSize() const = 0; + virtual std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) = 0; +}; + +/** + * A RomFS reader that directly reads the RomFS file. + */ +class DirectRomFSReader : public RomFSReader { +public: + DirectRomFSReader(FileUtil::IOFile&& file, std::size_t file_offset, std::size_t data_size) : is_encrypted(false), file(std::move(file)), file_offset(file_offset), data_size(data_size) {} - RomFSReader(FileUtil::IOFile&& file, std::size_t file_offset, std::size_t data_size, - const std::array& key, const std::array& ctr, - std::size_t crypto_offset) + DirectRomFSReader(FileUtil::IOFile&& file, std::size_t file_offset, std::size_t data_size, + const std::array& key, const std::array& ctr, + std::size_t crypto_offset) : is_encrypted(true), file(std::move(file)), key(key), ctr(ctr), file_offset(file_offset), crypto_offset(crypto_offset), data_size(data_size) {} - std::size_t GetSize() const { + ~DirectRomFSReader() override = default; + + std::size_t GetSize() const override { return data_size; } - std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer); + std::size_t ReadFile(std::size_t offset, std::size_t length, u8* buffer) override; private: bool is_encrypted; @@ -34,7 +50,7 @@ private: u64 crypto_offset; u64 data_size; - RomFSReader() = default; + DirectRomFSReader() = default; template void serialize(Archive& ar, const unsigned int) { diff --git a/src/core/gdbstub/gdbstub.cpp b/src/core/gdbstub/gdbstub.cpp index 7f722ab0f..1e42ff5e4 100644 --- a/src/core/gdbstub/gdbstub.cpp +++ b/src/core/gdbstub/gdbstub.cpp @@ -121,6 +121,7 @@ constexpr char target_xml[] = )"; int gdbserver_socket = -1; +bool defer_start = false; u8 command_buffer[GDB_BUFFER_SIZE]; u32 command_length; @@ -160,10 +161,14 @@ BreakpointMap breakpoints_write; } // Anonymous namespace static Kernel::Thread* FindThreadById(int id) { - const auto& threads = Core::System::GetInstance().Kernel().GetThreadManager().GetThreadList(); - for (auto& thread : threads) { - if (thread->GetThreadId() == static_cast(id)) { - return thread.get(); + u32 num_cores = Core::GetNumCores(); + for (u32 i = 0; i < num_cores; ++i) { + const auto& threads = + Core::System::GetInstance().Kernel().GetThreadManager(i).GetThreadList(); + for (auto& thread : threads) { + if (thread->GetThreadId() == static_cast(id)) { + return thread.get(); + } } } return nullptr; @@ -414,7 +419,10 @@ static void RemoveBreakpoint(BreakpointType type, VAddr addr) { Core::System::GetInstance().Memory().WriteBlock( *Core::System::GetInstance().Kernel().GetCurrentProcess(), bp->second.addr, bp->second.inst.data(), bp->second.inst.size()); - Core::CPU().ClearInstructionCache(); + u32 num_cores = Core::GetNumCores(); + for (u32 i = 0; i < num_cores; ++i) { + Core::GetCore(i).ClearInstructionCache(); + } } p.erase(addr); } @@ -540,10 +548,13 @@ static void HandleQuery() { SendReply(target_xml); } else if (strncmp(query, "fThreadInfo", strlen("fThreadInfo")) == 0) { std::string val = "m"; - const auto& threads = - Core::System::GetInstance().Kernel().GetThreadManager().GetThreadList(); - for (const auto& thread : threads) { - val += fmt::format("{:x},", thread->GetThreadId()); + u32 num_cores = Core::GetNumCores(); + for (u32 i = 0; i < num_cores; ++i) { + const auto& threads = + Core::System::GetInstance().Kernel().GetThreadManager(i).GetThreadList(); + for (const auto& thread : threads) { + val += fmt::format("{:x},", thread->GetThreadId()); + } } val.pop_back(); SendReply(val.c_str()); @@ -553,11 +564,14 @@ static void HandleQuery() { std::string buffer; buffer += "l"; buffer += ""; - const auto& threads = - Core::System::GetInstance().Kernel().GetThreadManager().GetThreadList(); - for (const auto& thread : threads) { - buffer += fmt::format(R"*()*", - thread->GetThreadId(), thread->GetThreadId()); + u32 num_cores = Core::GetNumCores(); + for (u32 i = 0; i < num_cores; ++i) { + const auto& threads = + Core::System::GetInstance().Kernel().GetThreadManager(i).GetThreadList(); + for (const auto& thread : threads) { + buffer += fmt::format(R"*()*", + thread->GetThreadId(), thread->GetThreadId()); + } } buffer += ""; SendReply(buffer.c_str()); @@ -619,9 +633,9 @@ static void SendSignal(Kernel::Thread* thread, u32 signal, bool full = true) { if (full) { buffer = fmt::format("T{:02x}{:02x}:{:08x};{:02x}:{:08x};{:02x}:{:08x}", latest_signal, - PC_REGISTER, htonl(Core::CPU().GetPC()), SP_REGISTER, - htonl(Core::CPU().GetReg(SP_REGISTER)), LR_REGISTER, - htonl(Core::CPU().GetReg(LR_REGISTER))); + PC_REGISTER, htonl(Core::GetRunningCore().GetPC()), SP_REGISTER, + htonl(Core::GetRunningCore().GetReg(SP_REGISTER)), LR_REGISTER, + htonl(Core::GetRunningCore().GetReg(LR_REGISTER))); } else { buffer = fmt::format("T{:02x}", latest_signal); } @@ -782,7 +796,7 @@ static void WriteRegister() { return SendReply("E01"); } - Core::CPU().LoadContext(current_thread->context); + Core::GetRunningCore().LoadContext(current_thread->context); SendReply("OK"); } @@ -812,7 +826,7 @@ static void WriteRegisters() { } } - Core::CPU().LoadContext(current_thread->context); + Core::GetRunningCore().LoadContext(current_thread->context); SendReply("OK"); } @@ -869,7 +883,7 @@ static void WriteMemory() { GdbHexToMem(data.data(), len_pos + 1, len); Core::System::GetInstance().Memory().WriteBlock( *Core::System::GetInstance().Kernel().GetCurrentProcess(), addr, data.data(), len); - Core::CPU().ClearInstructionCache(); + Core::GetRunningCore().ClearInstructionCache(); SendReply("OK"); } @@ -883,12 +897,12 @@ void Break(bool is_memory_break) { static void Step() { if (command_length > 1) { RegWrite(PC_REGISTER, GdbHexToInt(command_buffer + 1), current_thread); - Core::CPU().LoadContext(current_thread->context); + Core::GetRunningCore().LoadContext(current_thread->context); } step_loop = true; halt_loop = true; send_trap = true; - Core::CPU().ClearInstructionCache(); + Core::GetRunningCore().ClearInstructionCache(); } bool IsMemoryBreak() { @@ -904,7 +918,7 @@ static void Continue() { memory_break = false; step_loop = false; halt_loop = false; - Core::CPU().ClearInstructionCache(); + Core::GetRunningCore().ClearInstructionCache(); } /** @@ -930,7 +944,7 @@ static bool CommitBreakpoint(BreakpointType type, VAddr addr, u32 len) { Core::System::GetInstance().Memory().WriteBlock( *Core::System::GetInstance().Kernel().GetCurrentProcess(), addr, btrap.data(), btrap.size()); - Core::CPU().ClearInstructionCache(); + Core::GetRunningCore().ClearInstructionCache(); } p.insert({addr, breakpoint}); @@ -1030,6 +1044,9 @@ static void RemoveBreakpoint() { void HandlePacket() { if (!IsConnected()) { + if (defer_start) { + ToggleServer(true); + } return; } @@ -1120,6 +1137,10 @@ void ToggleServer(bool status) { } } +void DeferStart() { + defer_start = true; +} + static void Init(u16 port) { if (!server_enabled) { // Set the halt loop to false in case the user enabled the gdbstub mid-execution. @@ -1203,6 +1224,7 @@ void Shutdown() { if (!server_enabled) { return; } + defer_start = false; LOG_INFO(Debug_GDBStub, "Stopping GDB ..."); if (gdbserver_socket != -1) { diff --git a/src/core/gdbstub/gdbstub.h b/src/core/gdbstub/gdbstub.h index 131b3f823..b878b7957 100644 --- a/src/core/gdbstub/gdbstub.h +++ b/src/core/gdbstub/gdbstub.h @@ -42,6 +42,13 @@ void ToggleServer(bool status); /// Start the gdbstub server. void Init(); +/** + * Defer initialization of the gdbstub to the first packet processing functions. + * This avoids a case where the gdbstub thread is frozen after initialization + * and fails to respond in time to packets. + */ +void DeferStart(); + /// Stop gdbstub server. void Shutdown(); diff --git a/src/core/hle/kernel/handle_table.cpp b/src/core/hle/kernel/handle_table.cpp index 71e18eb7c..d717c8399 100644 --- a/src/core/hle/kernel/handle_table.cpp +++ b/src/core/hle/kernel/handle_table.cpp @@ -83,7 +83,7 @@ bool HandleTable::IsValid(Handle handle) const { std::shared_ptr HandleTable::GetGeneric(Handle handle) const { if (handle == CurrentThread) { - return SharedFrom(kernel.GetThreadManager().GetCurrentThread()); + return SharedFrom(kernel.GetCurrentThreadManager().GetCurrentThread()); } else if (handle == CurrentProcess) { return kernel.GetCurrentProcess(); } diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp index ea3da508e..8afddb499 100644 --- a/src/core/hle/kernel/kernel.cpp +++ b/src/core/hle/kernel/kernel.cpp @@ -20,22 +20,30 @@ namespace Kernel { /// Initialize the kernel KernelSystem::KernelSystem(Memory::MemorySystem& memory, Core::Timing& timing, - std::function prepare_reschedule_callback, u32 system_mode) + std::function prepare_reschedule_callback, u32 system_mode, + u32 num_cores, u8 n3ds_mode) : memory(memory), timing(timing), prepare_reschedule_callback(std::move(prepare_reschedule_callback)) { for (auto i = 0; i < memory_regions.size(); i++) { memory_regions[i] = std::make_shared(); } - MemoryInit(system_mode); + MemoryInit(system_mode, n3ds_mode); resource_limits = std::make_unique(*this); - thread_manager = std::make_unique(*this); + for (u32 core_id = 0; core_id < num_cores; ++core_id) { + thread_managers.push_back(std::make_unique(*this, core_id)); + } timer_manager = std::make_unique(timing); ipc_recorder = std::make_unique(); + stored_processes.assign(num_cores, nullptr); + + next_thread_id = 1; } /// Shutdown the kernel -KernelSystem::~KernelSystem() = default; +KernelSystem::~KernelSystem() { + ResetThreadIDs(); +}; ResourceLimitList& KernelSystem::ResourceLimit() { return *resource_limits; @@ -58,6 +66,15 @@ void KernelSystem::SetCurrentProcess(std::shared_ptr process) { SetCurrentMemoryPageTable(process->vm_manager.page_table); } +void KernelSystem::SetCurrentProcessForCPU(std::shared_ptr process, u32 core_id) { + if (current_cpu->GetID() == core_id) { + current_process = process; + SetCurrentMemoryPageTable(process->vm_manager.page_table); + } else { + stored_processes[core_id] = process; + } +} + void KernelSystem::SetCurrentMemoryPageTable(std::shared_ptr page_table) { memory.SetCurrentPageTable(page_table); if (current_cpu != nullptr) { @@ -65,17 +82,39 @@ void KernelSystem::SetCurrentMemoryPageTable(std::shared_ptr } } -void KernelSystem::SetCPU(std::shared_ptr cpu) { +void KernelSystem::SetCPUs(std::vector> cpus) { + ASSERT(cpus.size() == thread_managers.size()); + u32 i = 0; + for (const auto& cpu : cpus) { + thread_managers[i++]->SetCPU(*cpu); + } +} + +void KernelSystem::SetRunningCPU(std::shared_ptr cpu) { + if (current_process) { + stored_processes[current_cpu->GetID()] = current_process; + } current_cpu = cpu; - thread_manager->SetCPU(*cpu); + timing.SetCurrentTimer(cpu->GetID()); + if (stored_processes[current_cpu->GetID()]) { + SetCurrentProcess(stored_processes[current_cpu->GetID()]); + } } -ThreadManager& KernelSystem::GetThreadManager() { - return *thread_manager; +ThreadManager& KernelSystem::GetThreadManager(u32 core_id) { + return *thread_managers[core_id]; } -const ThreadManager& KernelSystem::GetThreadManager() const { - return *thread_manager; +const ThreadManager& KernelSystem::GetThreadManager(u32 core_id) const { + return *thread_managers[core_id]; +} + +ThreadManager& KernelSystem::GetCurrentThreadManager() { + return *thread_managers[current_cpu->GetID()]; +} + +const ThreadManager& KernelSystem::GetCurrentThreadManager() const { + return *thread_managers[current_cpu->GetID()]; } TimerManager& KernelSystem::GetTimerManager() { @@ -106,6 +145,14 @@ void KernelSystem::AddNamedPort(std::string name, std::shared_ptr po named_ports.emplace(std::move(name), std::move(port)); } +u32 KernelSystem::NewThreadId() { + return next_thread_id++; +} + +void KernelSystem::ResetThreadIDs() { + next_thread_id = 0; +} + template void KernelSystem::serialize(Archive& ar, const unsigned int file_version) { ar& memory_regions; @@ -118,9 +165,14 @@ void KernelSystem::serialize(Archive& ar, const unsigned int file_version) { ar& next_process_id; ar& process_list; ar& current_process; - ar&* thread_manager.get(); + // NB: core count checked in 'core' + for (auto& thread_manager : thread_managers) { + ar&* thread_manager.get(); + } ar& config_mem_handler; ar& shared_page_handler; + ar& stored_processes; + ar& next_thread_id; // Deliberately don't include debugger info to allow debugging through loads } diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h index 41ef5272b..828843afc 100644 --- a/src/core/hle/kernel/kernel.h +++ b/src/core/hle/kernel/kernel.h @@ -88,7 +88,8 @@ enum class MemoryRegion : u16 { class KernelSystem { public: explicit KernelSystem(Memory::MemorySystem& memory, Core::Timing& timing, - std::function prepare_reschedule_callback, u32 system_mode); + std::function prepare_reschedule_callback, u32 system_mode, + u32 num_cores, u8 n3ds_mode); ~KernelSystem(); using PortPair = std::pair, std::shared_ptr>; @@ -214,13 +215,19 @@ public: std::shared_ptr GetCurrentProcess() const; void SetCurrentProcess(std::shared_ptr process); + void SetCurrentProcessForCPU(std::shared_ptr process, u32 core_id); void SetCurrentMemoryPageTable(std::shared_ptr page_table); - void SetCPU(std::shared_ptr cpu); + void SetCPUs(std::vector> cpu); - ThreadManager& GetThreadManager(); - const ThreadManager& GetThreadManager() const; + void SetRunningCPU(std::shared_ptr cpu); + + ThreadManager& GetThreadManager(u32 core_id); + const ThreadManager& GetThreadManager(u32 core_id) const; + + ThreadManager& GetCurrentThreadManager(); + const ThreadManager& GetCurrentThreadManager() const; TimerManager& GetTimerManager(); const TimerManager& GetTimerManager() const; @@ -246,6 +253,10 @@ public: prepare_reschedule_callback(); } + u32 NewThreadId(); + + void ResetThreadIDs(); + /// Map of named ports managed by the kernel, which can be retrieved using the ConnectToPort std::unordered_map> named_ports; @@ -256,7 +267,7 @@ public: Core::Timing& timing; private: - void MemoryInit(u32 mem_type); + void MemoryInit(u32 mem_type, u8 n3ds_mode); std::function prepare_reschedule_callback; @@ -280,14 +291,17 @@ private: std::vector> process_list; std::shared_ptr current_process; + std::vector> stored_processes; - std::unique_ptr thread_manager; + std::vector> thread_managers; std::shared_ptr config_mem_handler; std::shared_ptr shared_page_handler; std::unique_ptr ipc_recorder; + u32 next_thread_id; + friend class boost::serialization::access; template void serialize(Archive& ar, const unsigned int file_version); diff --git a/src/core/hle/kernel/memory.cpp b/src/core/hle/kernel/memory.cpp index f8db3c31e..a4e77e0b2 100644 --- a/src/core/hle/kernel/memory.cpp +++ b/src/core/hle/kernel/memory.cpp @@ -19,6 +19,7 @@ #include "core/hle/kernel/vm_manager.h" #include "core/hle/result.h" #include "core/memory.h" +#include "core/settings.h" //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -40,11 +41,32 @@ static const u32 memory_region_sizes[8][3] = { {0x0B200000, 0x02E00000, 0x02000000}, // 7 }; -void KernelSystem::MemoryInit(u32 mem_type) { - // TODO(yuriks): On the n3DS, all o3DS configurations (<=5) are forced to 6 instead. - ASSERT_MSG(mem_type <= 5, "New 3DS memory configuration aren't supported yet!"); +namespace MemoryMode { +enum N3DSMode : u8 { + Mode6 = 1, + Mode7 = 2, + Mode6_2 = 3, +}; +} + +void KernelSystem::MemoryInit(u32 mem_type, u8 n3ds_mode) { ASSERT(mem_type != 1); + const bool is_new_3ds = Settings::values.is_new_3ds; + u32 reported_mem_type = mem_type; + if (is_new_3ds) { + if (n3ds_mode == MemoryMode::Mode6 || n3ds_mode == MemoryMode::Mode6_2) { + mem_type = 6; + reported_mem_type = 6; + } else if (n3ds_mode == MemoryMode::Mode7) { + mem_type = 7; + reported_mem_type = 7; + } else { + // On the N3ds, all O3ds configurations (<=5) are forced to 6 instead. + mem_type = 6; + } + } + // The kernel allocation regions (APPLICATION, SYSTEM and BASE) are laid out in sequence, with // the sizes specified in the memory_region_sizes table. VAddr base = 0; @@ -55,14 +77,12 @@ void KernelSystem::MemoryInit(u32 mem_type) { } // We must've allocated the entire FCRAM by the end - ASSERT(base == Memory::FCRAM_SIZE); + ASSERT(base == (is_new_3ds ? Memory::FCRAM_N3DS_SIZE : Memory::FCRAM_SIZE)); config_mem_handler = std::make_shared(); auto& config_mem = config_mem_handler->GetConfigMem(); - config_mem.app_mem_type = mem_type; - // app_mem_malloc does not always match the configured size for memory_region[0]: in case the - // n3DS type override is in effect it reports the size the game expects, not the real one. - config_mem.app_mem_alloc = memory_region_sizes[mem_type][0]; + config_mem.app_mem_type = reported_mem_type; + config_mem.app_mem_alloc = memory_region_sizes[reported_mem_type][0]; config_mem.sys_mem_alloc = memory_regions[1]->size; config_mem.base_mem_alloc = memory_regions[2]->size; diff --git a/src/core/hle/kernel/mutex.cpp b/src/core/hle/kernel/mutex.cpp index 6aff80224..77a5fe903 100644 --- a/src/core/hle/kernel/mutex.cpp +++ b/src/core/hle/kernel/mutex.cpp @@ -39,7 +39,7 @@ std::shared_ptr KernelSystem::CreateMutex(bool initial_locked, std::strin // Acquire mutex with current thread if initialized as locked if (initial_locked) - mutex->Acquire(thread_manager->GetCurrentThread()); + mutex->Acquire(thread_managers[current_cpu->GetID()]->GetCurrentThread()); return mutex; } diff --git a/src/core/hle/kernel/shared_page.cpp b/src/core/hle/kernel/shared_page.cpp index 5456a48e0..896055142 100644 --- a/src/core/hle/kernel/shared_page.cpp +++ b/src/core/hle/kernel/shared_page.cpp @@ -70,7 +70,7 @@ Handler::Handler(Core::Timing& timing) : timing(timing) { using namespace std::placeholders; update_time_event = timing.RegisterEvent("SharedPage::UpdateTimeCallback", std::bind(&Handler::UpdateTimeCallback, this, _1, _2)); - timing.ScheduleEvent(0, update_time_event); + timing.ScheduleEvent(0, update_time_event, 0, 0); float slidestate = Settings::values.factor_3d / 100.0f; shared_page.sliderstate_3d = static_cast(slidestate); diff --git a/src/core/hle/kernel/svc.cpp b/src/core/hle/kernel/svc.cpp index 371bbdfe3..9aeeea236 100644 --- a/src/core/hle/kernel/svc.cpp +++ b/src/core/hle/kernel/svc.cpp @@ -280,12 +280,12 @@ void SVC::ExitProcess() { current_process->status = ProcessStatus::Exited; // Stop all the process threads that are currently waiting for objects. - auto& thread_list = kernel.GetThreadManager().GetThreadList(); + auto& thread_list = kernel.GetCurrentThreadManager().GetThreadList(); for (auto& thread : thread_list) { if (thread->owner_process != current_process) continue; - if (thread.get() == kernel.GetThreadManager().GetCurrentThread()) + if (thread.get() == kernel.GetCurrentThreadManager().GetCurrentThread()) continue; // TODO(Subv): When are the other running/ready threads terminated? @@ -297,7 +297,7 @@ void SVC::ExitProcess() { } // Kill the current thread - kernel.GetThreadManager().GetCurrentThread()->Stop(); + kernel.GetCurrentThreadManager().GetCurrentThread()->Stop(); system.PrepareReschedule(); } @@ -388,7 +388,7 @@ ResultCode SVC::SendSyncRequest(Handle handle) { system.PrepareReschedule(); - auto thread = SharedFrom(kernel.GetThreadManager().GetCurrentThread()); + auto thread = SharedFrom(kernel.GetCurrentThreadManager().GetCurrentThread()); if (kernel.GetIPCRecorder().IsEnabled()) { kernel.GetIPCRecorder().RegisterRequest(session, thread); @@ -476,7 +476,7 @@ private: /// Wait for a handle to synchronize, timeout after the specified nanoseconds ResultCode SVC::WaitSynchronization1(Handle handle, s64 nano_seconds) { auto object = kernel.GetCurrentProcess()->handle_table.Get(handle); - Thread* thread = kernel.GetThreadManager().GetCurrentThread(); + Thread* thread = kernel.GetCurrentThreadManager().GetCurrentThread(); if (object == nullptr) return ERR_INVALID_HANDLE; @@ -514,7 +514,7 @@ ResultCode SVC::WaitSynchronization1(Handle handle, s64 nano_seconds) { /// Wait for the given handles to synchronize, timeout after the specified nanoseconds ResultCode SVC::WaitSynchronizationN(s32* out, VAddr handles_address, s32 handle_count, bool wait_all, s64 nano_seconds) { - Thread* thread = kernel.GetThreadManager().GetCurrentThread(); + Thread* thread = kernel.GetCurrentThreadManager().GetCurrentThread(); if (!Memory::IsValidVirtualAddress(*kernel.GetCurrentProcess(), handles_address)) return ERR_INVALID_POINTER; @@ -684,7 +684,7 @@ ResultCode SVC::ReplyAndReceive(s32* index, VAddr handles_address, s32 handle_co // We are also sending a command reply. // Do not send a reply if the command id in the command buffer is 0xFFFF. - Thread* thread = kernel.GetThreadManager().GetCurrentThread(); + Thread* thread = kernel.GetCurrentThreadManager().GetCurrentThread(); u32 cmd_buff_header = memory.Read32(thread->GetCommandBufferAddress()); IPC::Header header{cmd_buff_header}; if (reply_target != 0 && header.command_id != 0xFFFF) { @@ -791,7 +791,7 @@ ResultCode SVC::ArbitrateAddress(Handle handle, u32 address, u32 type, u32 value return ERR_INVALID_HANDLE; auto res = - arbiter->ArbitrateAddress(SharedFrom(kernel.GetThreadManager().GetCurrentThread()), + arbiter->ArbitrateAddress(SharedFrom(kernel.GetCurrentThreadManager().GetCurrentThread()), static_cast(type), address, value, nanoseconds); // TODO(Subv): Identify in which specific cases this call should cause a reschedule. @@ -912,14 +912,19 @@ ResultCode SVC::CreateThread(Handle* out_handle, u32 entry_point, u32 arg, VAddr break; case ThreadProcessorIdAll: LOG_INFO(Kernel_SVC, - "Newly created thread is allowed to be run in any Core, unimplemented."); + "Newly created thread is allowed to be run in any Core, for now run in core 0."); + processor_id = ThreadProcessorId0; break; case ThreadProcessorId1: - LOG_ERROR(Kernel_SVC, - "Newly created thread must run in the SysCore (Core1), unimplemented."); + case ThreadProcessorId2: + case ThreadProcessorId3: + // TODO: Check and log for: When processorid==0x2 and the process is not a BASE mem-region + // process, exheader kernel-flags bitmask 0x2000 must be set (otherwise error 0xD9001BEA is + // returned). When processorid==0x3 and the process is not a BASE mem-region process, error + // 0xD9001BEA is returned. These are the only restriction checks done by the kernel for + // processorid. break; default: - // TODO(bunnei): Implement support for other processor IDs ASSERT_MSG(false, "Unsupported thread processor ID: {}", processor_id); break; } @@ -945,9 +950,9 @@ ResultCode SVC::CreateThread(Handle* out_handle, u32 entry_point, u32 arg, VAddr /// Called when a thread exits void SVC::ExitThread() { - LOG_TRACE(Kernel_SVC, "called, pc=0x{:08X}", system.CPU().GetPC()); + LOG_TRACE(Kernel_SVC, "called, pc=0x{:08X}", system.GetRunningCore().GetPC()); - kernel.GetThreadManager().ExitCurrentThread(); + kernel.GetCurrentThreadManager().ExitCurrentThread(); system.PrepareReschedule(); } @@ -993,7 +998,7 @@ ResultCode SVC::SetThreadPriority(Handle handle, u32 priority) { /// Create a mutex ResultCode SVC::CreateMutex(Handle* out_handle, u32 initial_locked) { std::shared_ptr mutex = kernel.CreateMutex(initial_locked != 0); - mutex->name = fmt::format("mutex-{:08x}", system.CPU().GetReg(14)); + mutex->name = fmt::format("mutex-{:08x}", system.GetRunningCore().GetReg(14)); CASCADE_RESULT(*out_handle, kernel.GetCurrentProcess()->handle_table.Create(std::move(mutex))); LOG_TRACE(Kernel_SVC, "called initial_locked={} : created handle=0x{:08X}", @@ -1010,7 +1015,7 @@ ResultCode SVC::ReleaseMutex(Handle handle) { if (mutex == nullptr) return ERR_INVALID_HANDLE; - return mutex->Release(kernel.GetThreadManager().GetCurrentThread()); + return mutex->Release(kernel.GetCurrentThreadManager().GetCurrentThread()); } /// Get the ID of the specified process @@ -1060,7 +1065,7 @@ ResultCode SVC::GetThreadId(u32* thread_id, Handle handle) { ResultCode SVC::CreateSemaphore(Handle* out_handle, s32 initial_count, s32 max_count) { CASCADE_RESULT(std::shared_ptr semaphore, kernel.CreateSemaphore(initial_count, max_count)); - semaphore->name = fmt::format("semaphore-{:08x}", system.CPU().GetReg(14)); + semaphore->name = fmt::format("semaphore-{:08x}", system.GetRunningCore().GetReg(14)); CASCADE_RESULT(*out_handle, kernel.GetCurrentProcess()->handle_table.Create(std::move(semaphore))); @@ -1130,8 +1135,9 @@ ResultCode SVC::QueryMemory(MemoryInfo* memory_info, PageInfo* page_info, u32 ad /// Create an event ResultCode SVC::CreateEvent(Handle* out_handle, u32 reset_type) { - std::shared_ptr evt = kernel.CreateEvent( - static_cast(reset_type), fmt::format("event-{:08x}", system.CPU().GetReg(14))); + std::shared_ptr evt = + kernel.CreateEvent(static_cast(reset_type), + fmt::format("event-{:08x}", system.GetRunningCore().GetReg(14))); CASCADE_RESULT(*out_handle, kernel.GetCurrentProcess()->handle_table.Create(std::move(evt))); LOG_TRACE(Kernel_SVC, "called reset_type=0x{:08X} : created handle=0x{:08X}", reset_type, @@ -1173,8 +1179,9 @@ ResultCode SVC::ClearEvent(Handle handle) { /// Creates a timer ResultCode SVC::CreateTimer(Handle* out_handle, u32 reset_type) { - std::shared_ptr timer = kernel.CreateTimer( - static_cast(reset_type), fmt ::format("timer-{:08x}", system.CPU().GetReg(14))); + std::shared_ptr timer = + kernel.CreateTimer(static_cast(reset_type), + fmt ::format("timer-{:08x}", system.GetRunningCore().GetReg(14))); CASCADE_RESULT(*out_handle, kernel.GetCurrentProcess()->handle_table.Create(std::move(timer))); LOG_TRACE(Kernel_SVC, "called reset_type=0x{:08X} : created handle=0x{:08X}", reset_type, @@ -1228,7 +1235,7 @@ ResultCode SVC::CancelTimer(Handle handle) { void SVC::SleepThread(s64 nanoseconds) { LOG_TRACE(Kernel_SVC, "called nanoseconds={}", nanoseconds); - ThreadManager& thread_manager = kernel.GetThreadManager(); + ThreadManager& thread_manager = kernel.GetCurrentThreadManager(); // Don't attempt to yield execution if there are no available threads to run, // this way we avoid a useless reschedule to the idle thread. @@ -1246,10 +1253,11 @@ void SVC::SleepThread(s64 nanoseconds) { /// This returns the total CPU ticks elapsed since the CPU was powered-on s64 SVC::GetSystemTick() { - s64 result = system.CoreTiming().GetTicks(); + // TODO: Use globalTicks here? + s64 result = system.GetRunningCore().GetTimer()->GetTicks(); // Advance time to defeat dumb games (like Cubic Ninja) that busy-wait for the frame to end. // Measured time between two calls on a 9.2 o3DS with Ninjhax 1.1b - system.CoreTiming().AddTicks(150); + system.GetRunningCore().GetTimer()->AddTicks(150); return result; } @@ -1611,11 +1619,11 @@ void SVC::CallSVC(u32 immediate) { SVC::SVC(Core::System& system) : system(system), kernel(system.Kernel()), memory(system.Memory()) {} u32 SVC::GetReg(std::size_t n) { - return system.CPU().GetReg(static_cast(n)); + return system.GetRunningCore().GetReg(static_cast(n)); } void SVC::SetReg(std::size_t n, u32 value) { - system.CPU().SetReg(static_cast(n), value); + system.GetRunningCore().SetReg(static_cast(n), value); } SVCContext::SVCContext(Core::System& system) : impl(std::make_unique(system)) {} diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp index 5382f9ec8..465ba46f5 100644 --- a/src/core/hle/kernel/thread.cpp +++ b/src/core/hle/kernel/thread.cpp @@ -62,13 +62,10 @@ void Thread::Acquire(Thread* thread) { ASSERT_MSG(!ShouldWait(thread), "object unavailable!"); } -u32 ThreadManager::NewThreadId() { - return next_thread_id++; -} - -Thread::Thread(KernelSystem& kernel) - : WaitObject(kernel), context(kernel.GetThreadManager().NewContext()), - thread_manager(kernel.GetThreadManager()) {} +Thread::Thread(KernelSystem& kernel, u32 core_id) + : WaitObject(kernel), context(kernel.GetThreadManager(core_id).NewContext()), + core_id(core_id), + thread_manager(kernel.GetThreadManager(core_id)) {} Thread::~Thread() {} Thread* ThreadManager::GetCurrentThread() const { @@ -113,7 +110,7 @@ void ThreadManager::SwitchContext(Thread* new_thread) { // Save context for previous thread if (previous_thread) { - previous_thread->last_running_ticks = timing.GetTicks(); + previous_thread->last_running_ticks = timing.GetGlobalTicks(); cpu->SaveContext(previous_thread->context); if (previous_thread->status == ThreadStatus::Running) { @@ -140,7 +137,7 @@ void ThreadManager::SwitchContext(Thread* new_thread) { new_thread->status = ThreadStatus::Running; if (previous_process != current_thread->owner_process) { - kernel.SetCurrentProcess(current_thread->owner_process); + kernel.SetCurrentProcessForCPU(current_thread->owner_process, cpu->GetID()); } cpu->LoadContext(new_thread->context); @@ -153,7 +150,7 @@ void ThreadManager::SwitchContext(Thread* new_thread) { } Thread* ThreadManager::PopNextReadyThread() { - Thread* next; + Thread* next = nullptr; Thread* thread = GetCurrentThread(); if (thread && thread->status == ThreadStatus::Running) { @@ -337,22 +334,22 @@ ResultVal> KernelSystem::CreateThread( ErrorSummary::InvalidArgument, ErrorLevel::Permanent); } - auto thread{std::make_shared(*this)}; + auto thread{std::make_shared(*this, processor_id)}; - thread_manager->thread_list.push_back(thread); - thread_manager->ready_queue.prepare(priority); + thread_managers[processor_id]->thread_list.push_back(thread); + thread_managers[processor_id]->ready_queue.prepare(priority); - thread->thread_id = thread_manager->NewThreadId(); + thread->thread_id = NewThreadId(); thread->status = ThreadStatus::Dormant; thread->entry_point = entry_point; thread->stack_top = stack_top; thread->nominal_priority = thread->current_priority = priority; - thread->last_running_ticks = timing.GetTicks(); + thread->last_running_ticks = timing.GetGlobalTicks(); thread->processor_id = processor_id; thread->wait_objects.clear(); thread->wait_address = 0; thread->name = std::move(name); - thread_manager->wakeup_callback_table[thread->thread_id] = thread.get(); + thread_managers[processor_id]->wakeup_callback_table[thread->thread_id] = thread.get(); thread->owner_process = owner_process; // Find the next available TLS index, and mark it as used @@ -397,7 +394,7 @@ ResultVal> KernelSystem::CreateThread( // to initialize the context ResetThreadContext(thread->context, stack_top, entry_point, arg); - thread_manager->ready_queue.push_back(thread->current_priority, thread.get()); + thread_managers[processor_id]->ready_queue.push_back(thread->current_priority, thread.get()); thread->status = ThreadStatus::Ready; return MakeResult>(std::move(thread)); @@ -463,6 +460,9 @@ void ThreadManager::Reschedule() { LOG_TRACE(Kernel, "context switch {} -> idle", cur->GetObjectId()); } else if (next) { LOG_TRACE(Kernel, "context switch idle -> {}", next->GetObjectId()); + } else { + LOG_TRACE(Kernel, "context switch idle -> idle, do nothing"); + return; } SwitchContext(next); @@ -489,11 +489,10 @@ VAddr Thread::GetCommandBufferAddress() const { return GetTLSAddress() + command_header_offset; } -ThreadManager::ThreadManager(Kernel::KernelSystem& kernel) : kernel(kernel) { - ThreadWakeupEventType = - kernel.timing.RegisterEvent("ThreadWakeupCallback", [this](u64 thread_id, s64 cycle_late) { - ThreadWakeupCallback(thread_id, cycle_late); - }); +ThreadManager::ThreadManager(Kernel::KernelSystem& kernel, u32 core_id) : kernel(kernel) { + ThreadWakeupEventType = kernel.timing.RegisterEvent( + "ThreadWakeupCallback_" + std::to_string(core_id), + [this](u64 thread_id, s64 cycle_late) { ThreadWakeupCallback(thread_id, cycle_late); }); } ThreadManager::~ThreadManager() { diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h index c64e24071..09179c0e9 100644 --- a/src/core/hle/kernel/thread.h +++ b/src/core/hle/kernel/thread.h @@ -38,7 +38,9 @@ enum ThreadProcessorId : s32 { ThreadProcessorIdAll = -1, ///< Run thread on either core ThreadProcessorId0 = 0, ///< Run thread on core 0 (AppCore) ThreadProcessorId1 = 1, ///< Run thread on core 1 (SysCore) - ThreadProcessorIdMax = 2, ///< Processor ID must be less than this + ThreadProcessorId2 = 2, ///< Run thread on core 2 (additional n3ds core) + ThreadProcessorId3 = 3, ///< Run thread on core 3 (additional n3ds core) + ThreadProcessorIdMax = 4, ///< Processor ID must be less than this }; enum class ThreadStatus { @@ -75,15 +77,9 @@ private: class ThreadManager { public: - explicit ThreadManager(Kernel::KernelSystem& kernel); + explicit ThreadManager(Kernel::KernelSystem& kernel, u32 core_id); ~ThreadManager(); - /** - * Creates a new thread ID - * @return The new thread ID - */ - u32 NewThreadId(); - /** * Gets the current thread */ @@ -150,7 +146,6 @@ private: Kernel::KernelSystem& kernel; ARM_Interface* cpu; - u32 next_thread_id = 1; std::shared_ptr current_thread; Common::ThreadQueueList ready_queue; std::unordered_map wakeup_callback_table; @@ -167,7 +162,6 @@ private: friend class boost::serialization::access; template void serialize(Archive& ar, const unsigned int file_version) { - ar& next_thread_id; ar& current_thread; ar& ready_queue; ar& wakeup_callback_table; @@ -177,7 +171,7 @@ private: class Thread final : public WaitObject { public: - explicit Thread(KernelSystem&); + explicit Thread(KernelSystem&, u32 core_id); ~Thread() override; std::string GetName() const override { @@ -329,6 +323,8 @@ public: // available. In case of a timeout, the object will be nullptr. std::shared_ptr wakeup_callback; + const u32 core_id; + private: ThreadManager& thread_manager; @@ -351,4 +347,20 @@ std::shared_ptr SetupMainThread(KernelSystem& kernel, u32 entry_point, u } // namespace Kernel BOOST_CLASS_EXPORT_KEY(Kernel::Thread) -CONSTRUCT_KERNEL_OBJECT(Kernel::Thread) + +namespace boost::serialization { + +template +inline void save_construct_data(Archive& ar, const Kernel::Thread* t, + const unsigned int file_version) { + ar << t->core_id; +} + +template +inline void load_construct_data(Archive& ar, Kernel::Thread* t, const unsigned int file_version) { + u32 core_id; + ar >> core_id; + ::new (t) Kernel::Thread(Core::Global(), core_id); +} + +} // namespace boost::serialization diff --git a/src/core/hle/service/hid/hid.h b/src/core/hle/service/hid/hid.h index e34c1186a..ad526ac6f 100644 --- a/src/core/hle/service/hid/hid.h +++ b/src/core/hle/service/hid/hid.h @@ -6,9 +6,7 @@ #include #include -#ifndef _MSC_VER #include -#endif #include #include "common/bit_field.h" #include "common/common_funcs.h" @@ -177,10 +175,6 @@ struct GyroscopeCalibrateParam { } x, y, z; }; -// TODO: MSVC does not support using offsetof() on non-static data members even though this -// is technically allowed since C++11. This macro should be enabled once MSVC adds -// support for that. -#ifndef _MSC_VER #define ASSERT_REG_POSITION(field_name, position) \ static_assert(offsetof(SharedMem, field_name) == position * 4, \ "Field " #field_name " has invalid position") @@ -189,7 +183,6 @@ ASSERT_REG_POSITION(pad.index_reset_ticks, 0x0); ASSERT_REG_POSITION(touch.index_reset_ticks, 0x2A); #undef ASSERT_REG_POSITION -#endif // !defined(_MSC_VER) struct DirectionState { bool up; diff --git a/src/core/hle/service/http_c.cpp b/src/core/hle/service/http_c.cpp index 768dd3bd8..af639e493 100644 --- a/src/core/hle/service/http_c.cpp +++ b/src/core/hle/service/http_c.cpp @@ -2,9 +2,14 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#include +#ifdef ENABLE_WEB_SERVICE +#include +#endif #include #include #include "common/archives.h" +#include "common/assert.h" #include "core/core.h" #include "core/file_sys/archive_ncch.h" #include "core/file_sys/file_backend.h" @@ -52,6 +57,82 @@ const ResultCode ERROR_WRONG_CERT_HANDLE = // 0xD8A0A0C9 const ResultCode ERROR_CERT_ALREADY_SET = // 0xD8A0A03D ResultCode(61, ErrorModule::HTTP, ErrorSummary::InvalidState, ErrorLevel::Permanent); +void Context::MakeRequest() { + ASSERT(state == RequestState::NotStarted); + +#ifdef ENABLE_WEB_SERVICE + LUrlParser::clParseURL parsedUrl = LUrlParser::clParseURL::ParseURL(url); + int port; + std::unique_ptr client; + if (parsedUrl.m_Scheme == "http") { + if (!parsedUrl.GetPort(&port)) { + port = 80; + } + // TODO(B3N30): Support for setting timeout + // Figure out what the default timeout on 3DS is + client = std::make_unique(parsedUrl.m_Host.c_str(), port); + } else { + if (!parsedUrl.GetPort(&port)) { + port = 443; + } + // TODO(B3N30): Support for setting timeout + // Figure out what the default timeout on 3DS is + + auto ssl_client = std::make_unique(parsedUrl.m_Host, port); + SSL_CTX* ctx = ssl_client->ssl_context(); + client = std::move(ssl_client); + + if (auto client_cert = ssl_config.client_cert_ctx.lock()) { + SSL_CTX_use_certificate_ASN1(ctx, client_cert->certificate.size(), + client_cert->certificate.data()); + SSL_CTX_use_PrivateKey_ASN1(EVP_PKEY_RSA, ctx, client_cert->private_key.data(), + client_cert->private_key.size()); + } + + // TODO(B3N30): Check for SSLOptions-Bits and set the verify method accordingly + // https://www.3dbrew.org/wiki/SSL_Services#SSLOpt + // Hack: Since for now RootCerts are not implemented we set the VerifyMode to None. + SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL); + } + + state = RequestState::InProgress; + + static const std::unordered_map request_method_strings{ + {RequestMethod::Get, "GET"}, {RequestMethod::Post, "POST"}, + {RequestMethod::Head, "HEAD"}, {RequestMethod::Put, "PUT"}, + {RequestMethod::Delete, "DELETE"}, {RequestMethod::PostEmpty, "POST"}, + {RequestMethod::PutEmpty, "PUT"}, + }; + + httplib::Request request; + request.method = request_method_strings.at(method); + request.path = url; + // TODO(B3N30): Add post data body + request.progress = [this](u64 current, u64 total) -> bool { + // TODO(B3N30): Is there a state that shows response header are available + current_download_size_bytes = current; + total_download_size_bytes = total; + return true; + }; + + for (const auto& header : headers) { + request.headers.emplace(header.name, header.value); + } + + if (!client->send(request, response)) { + LOG_ERROR(Service_HTTP, "Request failed"); + state = RequestState::TimedOut; + } else { + LOG_DEBUG(Service_HTTP, "Request successful"); + // TODO(B3N30): Verify this state on HW + state = RequestState::ReadyToDownloadContent; + } +#else + LOG_ERROR(Service_HTTP, "Tried to make request but WebServices is not enabled in this build"); + state = RequestState::TimedOut; +#endif +} + void HTTP_C::Initialize(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x1, 1, 4); const u32 shmem_size = rp.Pop(); @@ -156,7 +237,15 @@ void HTTP_C::BeginRequest(Kernel::HLERequestContext& ctx) { auto itr = contexts.find(context_handle); ASSERT(itr != contexts.end()); - // TODO(B3N30): Make the request + // On a 3DS BeginRequest and BeginRequestAsync will push the Request to a worker queue. + // You can only enqueue 8 requests at the same time. + // trying to enqueue any more will either fail (BeginRequestAsync), or block (BeginRequest) + // Note that you only can have 8 Contexts at a time. So this difference shouldn't matter + // Then there are 3? worker threads that pop the requests from the queue and send them + // For now make every request async in it's own thread. + + itr->second.request_future = + std::async(std::launch::async, &Context::MakeRequest, std::ref(itr->second)); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); rb.Push(RESULT_SUCCESS); @@ -201,7 +290,15 @@ void HTTP_C::BeginRequestAsync(Kernel::HLERequestContext& ctx) { auto itr = contexts.find(context_handle); ASSERT(itr != contexts.end()); - // TODO(B3N30): Make the request + // On a 3DS BeginRequest and BeginRequestAsync will push the Request to a worker queue. + // You can only enqueue 8 requests at the same time. + // trying to enqueue any more will either fail (BeginRequestAsync), or block (BeginRequest) + // Note that you only can have 8 Contexts at a time. So this difference shouldn't matter + // Then there are 3? worker threads that pop the requests from the queue and send them + // For now make every request async in it's own thread. + + itr->second.request_future = + std::async(std::launch::async, &Context::MakeRequest, std::ref(itr->second)); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); rb.Push(RESULT_SUCCESS); @@ -264,7 +361,7 @@ void HTTP_C::CreateContext(Kernel::HLERequestContext& ctx) { return; } - contexts.emplace(++context_counter, Context()); + contexts.try_emplace(++context_counter); contexts[context_counter].url = std::move(url); contexts[context_counter].method = method; contexts[context_counter].state = RequestState::NotStarted; @@ -311,10 +408,9 @@ void HTTP_C::CloseContext(Kernel::HLERequestContext& ctx) { } // TODO(Subv): What happens if you try to close a context that's currently being used? - ASSERT(itr->second.state == RequestState::NotStarted); - // TODO(Subv): Make sure that only the session that created the context can close it. + // Note that this will block if a request is still in progress contexts.erase(itr); session_data->num_http_contexts--; diff --git a/src/core/hle/service/http_c.h b/src/core/hle/service/http_c.h index 641b6b4d1..d94b47a45 100644 --- a/src/core/hle/service/http_c.h +++ b/src/core/hle/service/http_c.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include #include @@ -15,6 +16,12 @@ #include #include #include +#ifdef ENABLE_WEB_SERVICE +#if defined(__ANDROID__) +#include +#endif +#include +#endif #include "core/hle/kernel/shared_memory.h" #include "core/hle/service/service.h" @@ -113,8 +120,7 @@ public: Context(const Context&) = delete; Context& operator=(const Context&) = delete; - Context(Context&& other) = default; - Context& operator=(Context&&) = default; + void MakeRequest(); struct Proxy { std::string url; @@ -195,14 +201,21 @@ public: u32 session_id; std::string url; RequestMethod method; - RequestState state = RequestState::NotStarted; - boost::optional proxy; - boost::optional basic_auth; + std::atomic state = RequestState::NotStarted; + std::optional proxy; + std::optional basic_auth; SSLConfig ssl_config{}; u32 socket_buffer_size; std::vector headers; std::vector post_data; + std::future request_future; + std::atomic current_download_size_bytes; + std::atomic total_download_size_bytes; +#ifdef ENABLE_WEB_SERVICE + httplib::Response response; +#endif + private: template void serialize(Archive& ar, const unsigned int) { @@ -219,6 +232,7 @@ private: ar& post_data; } friend class boost::serialization::access; + }; struct SessionData : public Kernel::SessionRequestHandler::SessionDataBase { diff --git a/src/core/hle/service/ldr_ro/cro_helper.cpp b/src/core/hle/service/ldr_ro/cro_helper.cpp index 86600e7a9..89e64f9d8 100644 --- a/src/core/hle/service/ldr_ro/cro_helper.cpp +++ b/src/core/hle/service/ldr_ro/cro_helper.cpp @@ -55,7 +55,7 @@ VAddr CROHelper::SegmentTagToAddress(SegmentTag segment_tag) const { return 0; SegmentEntry entry; - GetEntry(memory, segment_tag.segment_index, entry); + GetEntry(system.Memory(), segment_tag.segment_index, entry); if (segment_tag.offset_into_segment >= entry.size) return 0; @@ -71,12 +71,12 @@ ResultCode CROHelper::ApplyRelocation(VAddr target_address, RelocationType reloc break; case RelocationType::AbsoluteAddress: case RelocationType::AbsoluteAddress2: - memory.Write32(target_address, symbol_address + addend); - cpu.InvalidateCacheRange(target_address, sizeof(u32)); + system.Memory().Write32(target_address, symbol_address + addend); + system.InvalidateCacheRange(target_address, sizeof(u32)); break; case RelocationType::RelativeAddress: - memory.Write32(target_address, symbol_address + addend - target_future_address); - cpu.InvalidateCacheRange(target_address, sizeof(u32)); + system.Memory().Write32(target_address, symbol_address + addend - target_future_address); + system.InvalidateCacheRange(target_address, sizeof(u32)); break; case RelocationType::ThumbBranch: case RelocationType::ArmBranch: @@ -98,8 +98,8 @@ ResultCode CROHelper::ClearRelocation(VAddr target_address, RelocationType reloc case RelocationType::AbsoluteAddress: case RelocationType::AbsoluteAddress2: case RelocationType::RelativeAddress: - memory.Write32(target_address, 0); - cpu.InvalidateCacheRange(target_address, sizeof(u32)); + system.Memory().Write32(target_address, 0); + system.InvalidateCacheRange(target_address, sizeof(u32)); break; case RelocationType::ThumbBranch: case RelocationType::ArmBranch: @@ -121,7 +121,8 @@ ResultCode CROHelper::ApplyRelocationBatch(VAddr batch, u32 symbol_address, bool VAddr relocation_address = batch; while (true) { RelocationEntry relocation; - memory.ReadBlock(process, relocation_address, &relocation, sizeof(RelocationEntry)); + system.Memory().ReadBlock(process, relocation_address, &relocation, + sizeof(RelocationEntry)); VAddr relocation_target = SegmentTagToAddress(relocation.target_position); if (relocation_target == 0) { @@ -142,9 +143,9 @@ ResultCode CROHelper::ApplyRelocationBatch(VAddr batch, u32 symbol_address, bool } RelocationEntry relocation; - memory.ReadBlock(process, batch, &relocation, sizeof(RelocationEntry)); + system.Memory().ReadBlock(process, batch, &relocation, sizeof(RelocationEntry)); relocation.is_batch_resolved = reset ? 0 : 1; - memory.WriteBlock(process, batch, &relocation, sizeof(RelocationEntry)); + system.Memory().WriteBlock(process, batch, &relocation, sizeof(RelocationEntry)); return RESULT_SUCCESS; } @@ -154,13 +155,13 @@ VAddr CROHelper::FindExportNamedSymbol(const std::string& name) const { std::size_t len = name.size(); ExportTreeEntry entry; - GetEntry(memory, 0, entry); + GetEntry(system.Memory(), 0, entry); ExportTreeEntry::Child next; next.raw = entry.left.raw; u32 found_id; while (true) { - GetEntry(memory, next.next_index, entry); + GetEntry(system.Memory(), next.next_index, entry); if (next.is_end) { found_id = entry.export_table_index; @@ -186,9 +187,9 @@ VAddr CROHelper::FindExportNamedSymbol(const std::string& name) const { u32 export_strings_size = GetField(ExportStringsSize); ExportNamedSymbolEntry symbol_entry; - GetEntry(memory, found_id, symbol_entry); + GetEntry(system.Memory(), found_id, symbol_entry); - if (memory.ReadCString(symbol_entry.name_offset, export_strings_size) != name) + if (system.Memory().ReadCString(symbol_entry.name_offset, export_strings_size) != name) return 0; return SegmentTagToAddress(symbol_entry.symbol_position); @@ -279,7 +280,7 @@ ResultVal CROHelper::RebaseSegmentTable(u32 cro_size, VAddr data_segment_ u32 segment_num = GetField(SegmentNum); for (u32 i = 0; i < segment_num; ++i) { SegmentEntry segment; - GetEntry(memory, i, segment); + GetEntry(system.Memory(), i, segment); if (segment.type == SegmentType::Data) { if (segment.size != 0) { if (segment.size > data_segment_size) @@ -298,7 +299,7 @@ ResultVal CROHelper::RebaseSegmentTable(u32 cro_size, VAddr data_segment_ if (segment.offset > module_address + cro_size) return CROFormatError(0x19); } - SetEntry(memory, i, segment); + SetEntry(system.Memory(), i, segment); } return MakeResult(prev_data_segment + module_address); } @@ -310,7 +311,7 @@ ResultCode CROHelper::RebaseExportNamedSymbolTable() { u32 export_named_symbol_num = GetField(ExportNamedSymbolNum); for (u32 i = 0; i < export_named_symbol_num; ++i) { ExportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset += module_address; @@ -320,7 +321,7 @@ ResultCode CROHelper::RebaseExportNamedSymbolTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -329,7 +330,7 @@ ResultCode CROHelper::VerifyExportTreeTable() const { u32 tree_num = GetField(ExportTreeNum); for (u32 i = 0; i < tree_num; ++i) { ExportTreeEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.left.next_index >= tree_num || entry.right.next_index >= tree_num) { return CROFormatError(0x11); @@ -353,7 +354,7 @@ ResultCode CROHelper::RebaseImportModuleTable() { u32 module_num = GetField(ImportModuleNum); for (u32 i = 0; i < module_num; ++i) { ImportModuleEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset += module_address; @@ -379,7 +380,7 @@ ResultCode CROHelper::RebaseImportModuleTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -395,7 +396,7 @@ ResultCode CROHelper::RebaseImportNamedSymbolTable() { u32 num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset += module_address; @@ -413,7 +414,7 @@ ResultCode CROHelper::RebaseImportNamedSymbolTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -427,7 +428,7 @@ ResultCode CROHelper::RebaseImportIndexedSymbolTable() { u32 num = GetField(ImportIndexedSymbolNum); for (u32 i = 0; i < num; ++i) { ImportIndexedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.relocation_batch_offset != 0) { entry.relocation_batch_offset += module_address; @@ -437,7 +438,7 @@ ResultCode CROHelper::RebaseImportIndexedSymbolTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -451,7 +452,7 @@ ResultCode CROHelper::RebaseImportAnonymousSymbolTable() { u32 num = GetField(ImportAnonymousSymbolNum); for (u32 i = 0; i < num; ++i) { ImportAnonymousSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.relocation_batch_offset != 0) { entry.relocation_batch_offset += module_address; @@ -461,7 +462,7 @@ ResultCode CROHelper::RebaseImportAnonymousSymbolTable() { } } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } return RESULT_SUCCESS; } @@ -476,14 +477,14 @@ ResultCode CROHelper::ResetExternalRelocations() { ExternalRelocationEntry relocation; // Verifies that the last relocation is the end of a batch - GetEntry(memory, external_relocation_num - 1, relocation); + GetEntry(system.Memory(), external_relocation_num - 1, relocation); if (!relocation.is_batch_end) { return CROFormatError(0x12); } bool batch_begin = true; for (u32 i = 0; i < external_relocation_num; ++i) { - GetEntry(memory, i, relocation); + GetEntry(system.Memory(), i, relocation); VAddr relocation_target = SegmentTagToAddress(relocation.target_position); if (relocation_target == 0) { @@ -500,7 +501,7 @@ ResultCode CROHelper::ResetExternalRelocations() { if (batch_begin) { // resets to unresolved state relocation.is_batch_resolved = 0; - SetEntry(memory, i, relocation); + SetEntry(system.Memory(), i, relocation); } // if current is an end, then the next is a beginning @@ -516,7 +517,7 @@ ResultCode CROHelper::ClearExternalRelocations() { bool batch_begin = true; for (u32 i = 0; i < external_relocation_num; ++i) { - GetEntry(memory, i, relocation); + GetEntry(system.Memory(), i, relocation); VAddr relocation_target = SegmentTagToAddress(relocation.target_position); if (relocation_target == 0) { @@ -532,7 +533,7 @@ ResultCode CROHelper::ClearExternalRelocations() { if (batch_begin) { // resets to unresolved state relocation.is_batch_resolved = 0; - SetEntry(memory, i, relocation); + SetEntry(system.Memory(), i, relocation); } // if current is an end, then the next is a beginning @@ -548,13 +549,13 @@ ResultCode CROHelper::ApplyStaticAnonymousSymbolToCRS(VAddr crs_address) { static_relocation_table_offset + GetField(StaticRelocationNum) * sizeof(StaticRelocationEntry); - CROHelper crs(crs_address, process, memory, cpu); + CROHelper crs(crs_address, process, system); u32 offset_export_num = GetField(StaticAnonymousSymbolNum); LOG_INFO(Service_LDR, "CRO \"{}\" exports {} static anonymous symbols", ModuleName(), offset_export_num); for (u32 i = 0; i < offset_export_num; ++i) { StaticAnonymousSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); u32 batch_address = entry.relocation_batch_offset + module_address; if (batch_address < static_relocation_table_offset || @@ -579,7 +580,7 @@ ResultCode CROHelper::ApplyInternalRelocations(u32 old_data_segment_address) { u32 internal_relocation_num = GetField(InternalRelocationNum); for (u32 i = 0; i < internal_relocation_num; ++i) { InternalRelocationEntry relocation; - GetEntry(memory, i, relocation); + GetEntry(system.Memory(), i, relocation); VAddr target_addressB = SegmentTagToAddress(relocation.target_position); if (target_addressB == 0) { return CROFormatError(0x15); @@ -587,7 +588,7 @@ ResultCode CROHelper::ApplyInternalRelocations(u32 old_data_segment_address) { VAddr target_address; SegmentEntry target_segment; - GetEntry(memory, relocation.target_position.segment_index, target_segment); + GetEntry(system.Memory(), relocation.target_position.segment_index, target_segment); if (target_segment.type == SegmentType::Data) { // If the relocation is to the .data segment, we need to relocate it in the old buffer @@ -602,7 +603,7 @@ ResultCode CROHelper::ApplyInternalRelocations(u32 old_data_segment_address) { } SegmentEntry symbol_segment; - GetEntry(memory, relocation.symbol_segment, symbol_segment); + GetEntry(system.Memory(), relocation.symbol_segment, symbol_segment); LOG_TRACE(Service_LDR, "Internally relocates 0x{:08X} with 0x{:08X}", target_address, symbol_segment.offset); ResultCode result = ApplyRelocation(target_address, relocation.type, relocation.addend, @@ -619,7 +620,7 @@ ResultCode CROHelper::ClearInternalRelocations() { u32 internal_relocation_num = GetField(InternalRelocationNum); for (u32 i = 0; i < internal_relocation_num; ++i) { InternalRelocationEntry relocation; - GetEntry(memory, i, relocation); + GetEntry(system.Memory(), i, relocation); VAddr target_address = SegmentTagToAddress(relocation.target_position); if (target_address == 0) { @@ -639,13 +640,13 @@ void CROHelper::UnrebaseImportAnonymousSymbolTable() { u32 num = GetField(ImportAnonymousSymbolNum); for (u32 i = 0; i < num; ++i) { ImportAnonymousSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.relocation_batch_offset != 0) { entry.relocation_batch_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -653,13 +654,13 @@ void CROHelper::UnrebaseImportIndexedSymbolTable() { u32 num = GetField(ImportIndexedSymbolNum); for (u32 i = 0; i < num; ++i) { ImportIndexedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.relocation_batch_offset != 0) { entry.relocation_batch_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -667,7 +668,7 @@ void CROHelper::UnrebaseImportNamedSymbolTable() { u32 num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset -= module_address; @@ -677,7 +678,7 @@ void CROHelper::UnrebaseImportNamedSymbolTable() { entry.relocation_batch_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -685,7 +686,7 @@ void CROHelper::UnrebaseImportModuleTable() { u32 module_num = GetField(ImportModuleNum); for (u32 i = 0; i < module_num; ++i) { ImportModuleEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset -= module_address; @@ -699,7 +700,7 @@ void CROHelper::UnrebaseImportModuleTable() { entry.import_anonymous_symbol_table_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -707,13 +708,13 @@ void CROHelper::UnrebaseExportNamedSymbolTable() { u32 export_named_symbol_num = GetField(ExportNamedSymbolNum); for (u32 i = 0; i < export_named_symbol_num; ++i) { ExportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.name_offset != 0) { entry.name_offset -= module_address; } - SetEntry(memory, i, entry); + SetEntry(system.Memory(), i, entry); } } @@ -721,7 +722,7 @@ void CROHelper::UnrebaseSegmentTable() { u32 segment_num = GetField(SegmentNum); for (u32 i = 0; i < segment_num; ++i) { SegmentEntry segment; - GetEntry(memory, i, segment); + GetEntry(system.Memory(), i, segment); if (segment.type == SegmentType::BSS) { segment.offset = 0; @@ -729,7 +730,7 @@ void CROHelper::UnrebaseSegmentTable() { segment.offset -= module_address; } - SetEntry(memory, i, segment); + SetEntry(system.Memory(), i, segment); } } @@ -751,17 +752,17 @@ ResultCode CROHelper::ApplyImportNamedSymbol(VAddr crs_address) { u32 symbol_import_num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); if (!relocation_entry.is_batch_resolved) { ResultCode result = ForEachAutoLinkCRO( - process, memory, cpu, crs_address, [&](CROHelper source) -> ResultVal { + process, system, crs_address, [&](CROHelper source) -> ResultVal { std::string symbol_name = - memory.ReadCString(entry.name_offset, import_strings_size); + system.Memory().ReadCString(entry.name_offset, import_strings_size); u32 symbol_address = source.FindExportNamedSymbol(symbol_name); if (symbol_address != 0) { @@ -794,11 +795,11 @@ ResultCode CROHelper::ResetImportNamedSymbol() { u32 symbol_import_num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); ResultCode result = ApplyRelocationBatch(relocation_addr, unresolved_symbol, true); if (result.IsError()) { @@ -815,11 +816,11 @@ ResultCode CROHelper::ResetImportIndexedSymbol() { u32 import_num = GetField(ImportIndexedSymbolNum); for (u32 i = 0; i < import_num; ++i) { ImportIndexedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); ResultCode result = ApplyRelocationBatch(relocation_addr, unresolved_symbol, true); if (result.IsError()) { @@ -836,11 +837,11 @@ ResultCode CROHelper::ResetImportAnonymousSymbol() { u32 import_num = GetField(ImportAnonymousSymbolNum); for (u32 i = 0; i < import_num; ++i) { ImportAnonymousSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); ResultCode result = ApplyRelocationBatch(relocation_addr, unresolved_symbol, true); if (result.IsError()) { @@ -857,19 +858,20 @@ ResultCode CROHelper::ApplyModuleImport(VAddr crs_address) { u32 import_module_num = GetField(ImportModuleNum); for (u32 i = 0; i < import_module_num; ++i) { ImportModuleEntry entry; - GetEntry(memory, i, entry); - std::string want_cro_name = memory.ReadCString(entry.name_offset, import_strings_size); + GetEntry(system.Memory(), i, entry); + std::string want_cro_name = + system.Memory().ReadCString(entry.name_offset, import_strings_size); ResultCode result = ForEachAutoLinkCRO( - process, memory, cpu, crs_address, [&](CROHelper source) -> ResultVal { + process, system, crs_address, [&](CROHelper source) -> ResultVal { if (want_cro_name == source.ModuleName()) { LOG_INFO(Service_LDR, "CRO \"{}\" imports {} indexed symbols from \"{}\"", ModuleName(), entry.import_indexed_symbol_num, source.ModuleName()); for (u32 j = 0; j < entry.import_indexed_symbol_num; ++j) { ImportIndexedSymbolEntry im; - entry.GetImportIndexedSymbolEntry(process, memory, j, im); + entry.GetImportIndexedSymbolEntry(process, system.Memory(), j, im); ExportIndexedSymbolEntry ex; - source.GetEntry(memory, im.index, ex); + source.GetEntry(system.Memory(), im.index, ex); u32 symbol_address = source.SegmentTagToAddress(ex.symbol_position); LOG_TRACE(Service_LDR, " Imports 0x{:08X}", symbol_address); ResultCode result = @@ -884,7 +886,7 @@ ResultCode CROHelper::ApplyModuleImport(VAddr crs_address) { ModuleName(), entry.import_anonymous_symbol_num, source.ModuleName()); for (u32 j = 0; j < entry.import_anonymous_symbol_num; ++j) { ImportAnonymousSymbolEntry im; - entry.GetImportAnonymousSymbolEntry(process, memory, j, im); + entry.GetImportAnonymousSymbolEntry(process, system.Memory(), j, im); u32 symbol_address = source.SegmentTagToAddress(im.symbol_position); LOG_TRACE(Service_LDR, " Imports 0x{:08X}", symbol_address); ResultCode result = @@ -913,15 +915,15 @@ ResultCode CROHelper::ApplyExportNamedSymbol(CROHelper target) { u32 target_symbol_import_num = target.GetField(ImportNamedSymbolNum); for (u32 i = 0; i < target_symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - target.GetEntry(memory, i, entry); + target.GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); if (!relocation_entry.is_batch_resolved) { std::string symbol_name = - memory.ReadCString(entry.name_offset, target_import_strings_size); + system.Memory().ReadCString(entry.name_offset, target_import_strings_size); u32 symbol_address = FindExportNamedSymbol(symbol_name); if (symbol_address != 0) { LOG_TRACE(Service_LDR, " exports symbol \"{}\"", symbol_name); @@ -944,15 +946,15 @@ ResultCode CROHelper::ResetExportNamedSymbol(CROHelper target) { u32 target_symbol_import_num = target.GetField(ImportNamedSymbolNum); for (u32 i = 0; i < target_symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - target.GetEntry(memory, i, entry); + target.GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); if (relocation_entry.is_batch_resolved) { std::string symbol_name = - memory.ReadCString(entry.name_offset, target_import_strings_size); + system.Memory().ReadCString(entry.name_offset, target_import_strings_size); u32 symbol_address = FindExportNamedSymbol(symbol_name); if (symbol_address != 0) { LOG_TRACE(Service_LDR, " unexports symbol \"{}\"", symbol_name); @@ -974,18 +976,19 @@ ResultCode CROHelper::ApplyModuleExport(CROHelper target) { u32 target_import_module_num = target.GetField(ImportModuleNum); for (u32 i = 0; i < target_import_module_num; ++i) { ImportModuleEntry entry; - target.GetEntry(memory, i, entry); + target.GetEntry(system.Memory(), i, entry); - if (memory.ReadCString(entry.name_offset, target_import_string_size) != module_name) + if (system.Memory().ReadCString(entry.name_offset, target_import_string_size) != + module_name) continue; LOG_INFO(Service_LDR, "CRO \"{}\" exports {} indexed symbols to \"{}\"", module_name, entry.import_indexed_symbol_num, target.ModuleName()); for (u32 j = 0; j < entry.import_indexed_symbol_num; ++j) { ImportIndexedSymbolEntry im; - entry.GetImportIndexedSymbolEntry(process, memory, j, im); + entry.GetImportIndexedSymbolEntry(process, system.Memory(), j, im); ExportIndexedSymbolEntry ex; - GetEntry(memory, im.index, ex); + GetEntry(system.Memory(), im.index, ex); u32 symbol_address = SegmentTagToAddress(ex.symbol_position); LOG_TRACE(Service_LDR, " exports symbol 0x{:08X}", symbol_address); ResultCode result = @@ -1000,7 +1003,7 @@ ResultCode CROHelper::ApplyModuleExport(CROHelper target) { entry.import_anonymous_symbol_num, target.ModuleName()); for (u32 j = 0; j < entry.import_anonymous_symbol_num; ++j) { ImportAnonymousSymbolEntry im; - entry.GetImportAnonymousSymbolEntry(process, memory, j, im); + entry.GetImportAnonymousSymbolEntry(process, system.Memory(), j, im); u32 symbol_address = SegmentTagToAddress(im.symbol_position); LOG_TRACE(Service_LDR, " exports symbol 0x{:08X}", symbol_address); ResultCode result = @@ -1023,16 +1026,17 @@ ResultCode CROHelper::ResetModuleExport(CROHelper target) { u32 target_import_module_num = target.GetField(ImportModuleNum); for (u32 i = 0; i < target_import_module_num; ++i) { ImportModuleEntry entry; - target.GetEntry(memory, i, entry); + target.GetEntry(system.Memory(), i, entry); - if (memory.ReadCString(entry.name_offset, target_import_string_size) != module_name) + if (system.Memory().ReadCString(entry.name_offset, target_import_string_size) != + module_name) continue; LOG_DEBUG(Service_LDR, "CRO \"{}\" unexports indexed symbols to \"{}\"", module_name, target.ModuleName()); for (u32 j = 0; j < entry.import_indexed_symbol_num; ++j) { ImportIndexedSymbolEntry im; - entry.GetImportIndexedSymbolEntry(process, memory, j, im); + entry.GetImportIndexedSymbolEntry(process, system.Memory(), j, im); ResultCode result = target.ApplyRelocationBatch(im.relocation_batch_offset, unresolved_symbol, true); if (result.IsError()) { @@ -1045,7 +1049,7 @@ ResultCode CROHelper::ResetModuleExport(CROHelper target) { target.ModuleName()); for (u32 j = 0; j < entry.import_anonymous_symbol_num; ++j) { ImportAnonymousSymbolEntry im; - entry.GetImportAnonymousSymbolEntry(process, memory, j, im); + entry.GetImportAnonymousSymbolEntry(process, system.Memory(), j, im); ResultCode result = target.ApplyRelocationBatch(im.relocation_batch_offset, unresolved_symbol, true); if (result.IsError()) { @@ -1063,15 +1067,16 @@ ResultCode CROHelper::ApplyExitRelocations(VAddr crs_address) { u32 symbol_import_num = GetField(ImportNamedSymbolNum); for (u32 i = 0; i < symbol_import_num; ++i) { ImportNamedSymbolEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); VAddr relocation_addr = entry.relocation_batch_offset; ExternalRelocationEntry relocation_entry; - memory.ReadBlock(process, relocation_addr, &relocation_entry, - sizeof(ExternalRelocationEntry)); + system.Memory().ReadBlock(process, relocation_addr, &relocation_entry, + sizeof(ExternalRelocationEntry)); - if (memory.ReadCString(entry.name_offset, import_strings_size) == "__aeabi_atexit") { + if (system.Memory().ReadCString(entry.name_offset, import_strings_size) == + "__aeabi_atexit") { ResultCode result = ForEachAutoLinkCRO( - process, memory, cpu, crs_address, [&](CROHelper source) -> ResultVal { + process, system, crs_address, [&](CROHelper source) -> ResultVal { u32 symbol_address = source.FindExportNamedSymbol("nnroAeabiAtexit_"); if (symbol_address != 0) { @@ -1126,7 +1131,8 @@ ResultCode CROHelper::Rebase(VAddr crs_address, u32 cro_size, VAddr data_segment return result; } - result = VerifyStringTableLength(memory, GetField(ModuleNameOffset), GetField(ModuleNameSize)); + result = VerifyStringTableLength(system.Memory(), GetField(ModuleNameOffset), + GetField(ModuleNameSize)); if (result.IsError()) { LOG_ERROR(Service_LDR, "Error verifying module name {:08X}", result.raw); return result; @@ -1155,8 +1161,8 @@ ResultCode CROHelper::Rebase(VAddr crs_address, u32 cro_size, VAddr data_segment return result; } - result = - VerifyStringTableLength(memory, GetField(ExportStringsOffset), GetField(ExportStringsSize)); + result = VerifyStringTableLength(system.Memory(), GetField(ExportStringsOffset), + GetField(ExportStringsSize)); if (result.IsError()) { LOG_ERROR(Service_LDR, "Error verifying export strings {:08X}", result.raw); return result; @@ -1192,8 +1198,8 @@ ResultCode CROHelper::Rebase(VAddr crs_address, u32 cro_size, VAddr data_segment return result; } - result = - VerifyStringTableLength(memory, GetField(ImportStringsOffset), GetField(ImportStringsSize)); + result = VerifyStringTableLength(system.Memory(), GetField(ImportStringsOffset), + GetField(ImportStringsSize)); if (result.IsError()) { LOG_ERROR(Service_LDR, "Error verifying import strings {:08X}", result.raw); return result; @@ -1266,11 +1272,11 @@ ResultCode CROHelper::Link(VAddr crs_address, bool link_on_load_bug_fix) { // so we do the same if (GetField(SegmentNum) >= 2) { // means we have .data segment SegmentEntry entry; - GetEntry(memory, 2, entry); + GetEntry(system.Memory(), 2, entry); ASSERT(entry.type == SegmentType::Data); data_segment_address = entry.offset; entry.offset = GetField(DataOffset); - SetEntry(memory, 2, entry); + SetEntry(system.Memory(), 2, entry); } } SCOPE_EXIT({ @@ -1278,9 +1284,9 @@ ResultCode CROHelper::Link(VAddr crs_address, bool link_on_load_bug_fix) { if (link_on_load_bug_fix) { if (GetField(SegmentNum) >= 2) { SegmentEntry entry; - GetEntry(memory, 2, entry); + GetEntry(system.Memory(), 2, entry); entry.offset = data_segment_address; - SetEntry(memory, 2, entry); + SetEntry(system.Memory(), 2, entry); } } }); @@ -1301,7 +1307,7 @@ ResultCode CROHelper::Link(VAddr crs_address, bool link_on_load_bug_fix) { } // Exports symbols to other modules - result = ForEachAutoLinkCRO(process, memory, cpu, crs_address, + result = ForEachAutoLinkCRO(process, system, crs_address, [this](CROHelper target) -> ResultVal { ResultCode result = ApplyExportNamedSymbol(target); if (result.IsError()) @@ -1346,7 +1352,7 @@ ResultCode CROHelper::Unlink(VAddr crs_address) { // Resets all symbols in other modules imported from this module // Note: the RO service seems only searching in auto-link modules - result = ForEachAutoLinkCRO(process, memory, cpu, crs_address, + result = ForEachAutoLinkCRO(process, system, crs_address, [this](CROHelper target) -> ResultVal { ResultCode result = ResetExportNamedSymbol(target); if (result.IsError()) @@ -1387,13 +1393,13 @@ void CROHelper::InitCRS() { } void CROHelper::Register(VAddr crs_address, bool auto_link) { - CROHelper crs(crs_address, process, memory, cpu); - CROHelper head(auto_link ? crs.NextModule() : crs.PreviousModule(), process, memory, cpu); + CROHelper crs(crs_address, process, system); + CROHelper head(auto_link ? crs.NextModule() : crs.PreviousModule(), process, system); if (head.module_address) { // there are already CROs registered // register as the new tail - CROHelper tail(head.PreviousModule(), process, memory, cpu); + CROHelper tail(head.PreviousModule(), process, system); // link with the old tail ASSERT(tail.NextModule() == 0); @@ -1419,11 +1425,11 @@ void CROHelper::Register(VAddr crs_address, bool auto_link) { } void CROHelper::Unregister(VAddr crs_address) { - CROHelper crs(crs_address, process, memory, cpu); - CROHelper next_head(crs.NextModule(), process, memory, cpu); - CROHelper previous_head(crs.PreviousModule(), process, memory, cpu); - CROHelper next(NextModule(), process, memory, cpu); - CROHelper previous(PreviousModule(), process, memory, cpu); + CROHelper crs(crs_address, process, system); + CROHelper next_head(crs.NextModule(), process, system); + CROHelper previous_head(crs.PreviousModule(), process, system); + CROHelper next(NextModule(), process, system); + CROHelper previous(PreviousModule(), process, system); if (module_address == next_head.module_address || module_address == previous_head.module_address) { @@ -1517,7 +1523,7 @@ std::tuple CROHelper::GetExecutablePages() const { u32 segment_num = GetField(SegmentNum); for (u32 i = 0; i < segment_num; ++i) { SegmentEntry entry; - GetEntry(memory, i, entry); + GetEntry(system.Memory(), i, entry); if (entry.type == SegmentType::Code && entry.size != 0) { VAddr begin = Common::AlignDown(entry.offset, Memory::PAGE_SIZE); VAddr end = Common::AlignUp(entry.offset + entry.size, Memory::PAGE_SIZE); diff --git a/src/core/hle/service/ldr_ro/cro_helper.h b/src/core/hle/service/ldr_ro/cro_helper.h index 46fbe05a6..265b6971e 100644 --- a/src/core/hle/service/ldr_ro/cro_helper.h +++ b/src/core/hle/service/ldr_ro/cro_helper.h @@ -33,12 +33,11 @@ static constexpr u32 CRO_HASH_SIZE = 0x80; class CROHelper final { public: // TODO (wwylele): pass in the process handle for memory access - explicit CROHelper(VAddr cro_address, Kernel::Process& process, Memory::MemorySystem& memory, - ARM_Interface& cpu) - : module_address(cro_address), process(process), memory(memory), cpu(cpu) {} + explicit CROHelper(VAddr cro_address, Kernel::Process& process, Core::System& system) + : module_address(cro_address), process(process), system(system) {} std::string ModuleName() const { - return memory.ReadCString(GetField(ModuleNameOffset), GetField(ModuleNameSize)); + return system.Memory().ReadCString(GetField(ModuleNameOffset), GetField(ModuleNameSize)); } u32 GetFileSize() const { @@ -144,8 +143,7 @@ public: private: const VAddr module_address; ///< the virtual address of this module Kernel::Process& process; ///< the owner process of this module - Memory::MemorySystem& memory; - ARM_Interface& cpu; + Core::System& system; /** * Each item in this enum represents a u32 field in the header begin from address+0x80, @@ -403,11 +401,11 @@ private: } u32 GetField(HeaderField field) const { - return memory.Read32(Field(field)); + return system.Memory().Read32(Field(field)); } void SetField(HeaderField field, u32 value) { - memory.Write32(Field(field), value); + system.Memory().Write32(Field(field), value); } /** @@ -474,12 +472,11 @@ private: * otherwise error code of the last iteration. */ template - static ResultCode ForEachAutoLinkCRO(Kernel::Process& process, Memory::MemorySystem& memory, - ARM_Interface& cpu, VAddr crs_address, - FunctionObject func) { + static ResultCode ForEachAutoLinkCRO(Kernel::Process& process, Core::System& system, + VAddr crs_address, FunctionObject func) { VAddr current = crs_address; while (current != 0) { - CROHelper cro(current, process, memory, cpu); + CROHelper cro(current, process, system); CASCADE_RESULT(bool next, func(cro)); if (!next) break; diff --git a/src/core/hle/service/ldr_ro/ldr_ro.cpp b/src/core/hle/service/ldr_ro/ldr_ro.cpp index 8ac6f1068..0ffcfda24 100644 --- a/src/core/hle/service/ldr_ro/ldr_ro.cpp +++ b/src/core/hle/service/ldr_ro/ldr_ro.cpp @@ -120,7 +120,7 @@ void RO::Initialize(Kernel::HLERequestContext& ctx) { return; } - CROHelper crs(crs_address, *process, system.Memory(), system.CPU()); + CROHelper crs(crs_address, *process, system); crs.InitCRS(); result = crs.Rebase(0, crs_size, 0, 0, 0, 0, true); @@ -254,7 +254,7 @@ void RO::LoadCRO(Kernel::HLERequestContext& ctx, bool link_on_load_bug_fix) { return; } - CROHelper cro(cro_address, *process, system.Memory(), system.CPU()); + CROHelper cro(cro_address, *process, system); result = cro.VerifyHash(cro_size, crr_address); if (result.IsError()) { @@ -318,7 +318,7 @@ void RO::LoadCRO(Kernel::HLERequestContext& ctx, bool link_on_load_bug_fix) { } } - system.CPU().InvalidateCacheRange(cro_address, cro_size); + system.InvalidateCacheRange(cro_address, cro_size); LOG_INFO(Service_LDR, "CRO \"{}\" loaded at 0x{:08X}, fixed_end=0x{:08X}", cro.ModuleName(), cro_address, cro_address + fix_size); @@ -336,7 +336,7 @@ void RO::UnloadCRO(Kernel::HLERequestContext& ctx) { LOG_DEBUG(Service_LDR, "called, cro_address=0x{:08X}, zero={}, cro_buffer_ptr=0x{:08X}", cro_address, zero, cro_buffer_ptr); - CROHelper cro(cro_address, *process, system.Memory(), system.CPU()); + CROHelper cro(cro_address, *process, system); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); @@ -391,7 +391,7 @@ void RO::UnloadCRO(Kernel::HLERequestContext& ctx) { LOG_ERROR(Service_LDR, "Error unmapping CRO {:08X}", result.raw); } - system.CPU().InvalidateCacheRange(cro_address, fixed_size); + system.InvalidateCacheRange(cro_address, fixed_size); rb.Push(result); } @@ -403,7 +403,7 @@ void RO::LinkCRO(Kernel::HLERequestContext& ctx) { LOG_DEBUG(Service_LDR, "called, cro_address=0x{:08X}", cro_address); - CROHelper cro(cro_address, *process, system.Memory(), system.CPU()); + CROHelper cro(cro_address, *process, system); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); @@ -443,7 +443,7 @@ void RO::UnlinkCRO(Kernel::HLERequestContext& ctx) { LOG_DEBUG(Service_LDR, "called, cro_address=0x{:08X}", cro_address); - CROHelper cro(cro_address, *process, system.Memory(), system.CPU()); + CROHelper cro(cro_address, *process, system); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); @@ -492,7 +492,7 @@ void RO::Shutdown(Kernel::HLERequestContext& ctx) { return; } - CROHelper crs(slot->loaded_crs, *process, system.Memory(), system.CPU()); + CROHelper crs(slot->loaded_crs, *process, system); crs.Unrebase(true); ResultCode result = RESULT_SUCCESS; diff --git a/src/core/hw/gpu.cpp b/src/core/hw/gpu.cpp index b283b7d91..4f6a31441 100644 --- a/src/core/hw/gpu.cpp +++ b/src/core/hw/gpu.cpp @@ -402,8 +402,8 @@ inline void Write(u32 addr, const T data) { switch (index) { // Memory fills are triggered once the fill value is written. - case GPU_REG_INDEX_WORKAROUND(memory_fill_config[0].trigger, 0x00004 + 0x3): - case GPU_REG_INDEX_WORKAROUND(memory_fill_config[1].trigger, 0x00008 + 0x3): { + case GPU_REG_INDEX(memory_fill_config[0].trigger): + case GPU_REG_INDEX(memory_fill_config[1].trigger): { const bool is_second_filler = (index != GPU_REG_INDEX(memory_fill_config[0].trigger)); auto& config = g_regs.memory_fill_config[is_second_filler]; diff --git a/src/core/hw/gpu.h b/src/core/hw/gpu.h index 9169bf9ae..3252364e2 100644 --- a/src/core/hw/gpu.h +++ b/src/core/hw/gpu.h @@ -22,41 +22,15 @@ namespace GPU { constexpr float SCREEN_REFRESH_RATE = 60; // Returns index corresponding to the Regs member labeled by field_name -// TODO: Due to Visual studio bug 209229, offsetof does not return constant expressions -// when used with array elements (e.g. GPU_REG_INDEX(memory_fill_config[0])). -// For details cf. -// https://connect.microsoft.com/VisualStudio/feedback/details/209229/offsetof-does-not-produce-a-constant-expression-for-array-members -// Hopefully, this will be fixed sometime in the future. -// For lack of better alternatives, we currently hardcode the offsets when constant -// expressions are needed via GPU_REG_INDEX_WORKAROUND (on sane compilers, static_asserts -// will then make sure the offsets indeed match the automatically calculated ones). #define GPU_REG_INDEX(field_name) (offsetof(GPU::Regs, field_name) / sizeof(u32)) -#if defined(_MSC_VER) -#define GPU_REG_INDEX_WORKAROUND(field_name, backup_workaround_index) (backup_workaround_index) -#else -// NOTE: Yeah, hacking in a static_assert here just to workaround the lacking MSVC compiler -// really is this annoying. This macro just forwards its first argument to GPU_REG_INDEX -// and then performs a (no-op) cast to std::size_t iff the second argument matches the -// expected field offset. Otherwise, the compiler will fail to compile this code. -#define GPU_REG_INDEX_WORKAROUND(field_name, backup_workaround_index) \ - ((typename std::enable_if::type) GPU_REG_INDEX(field_name)) -#endif // MMIO region 0x1EFxxxxx struct Regs { // helper macro to make sure the defined structures are of the expected size. -#if defined(_MSC_VER) -// TODO: MSVC does not support using sizeof() on non-static data members even though this -// is technically allowed since C++11. This macro should be enabled once MSVC adds -// support for that. -#define ASSERT_MEMBER_SIZE(name, size_in_bytes) -#else #define ASSERT_MEMBER_SIZE(name, size_in_bytes) \ static_assert(sizeof(name) == size_in_bytes, \ "Structure size and register block length don't match") -#endif // Components are laid out in reverse byte order, most significant bits first. enum class PixelFormat : u32 { @@ -307,10 +281,6 @@ private: }; static_assert(std::is_standard_layout::value, "Structure does not use standard layout"); -// TODO: MSVC does not support using offsetof() on non-static data members even though this -// is technically allowed since C++11. This macro should be enabled once MSVC adds -// support for that. -#ifndef _MSC_VER #define ASSERT_REG_POSITION(field_name, position) \ static_assert(offsetof(Regs, field_name) == position * 4, \ "Field " #field_name " has invalid position") @@ -323,7 +293,6 @@ ASSERT_REG_POSITION(display_transfer_config, 0x00300); ASSERT_REG_POSITION(command_processor_config, 0x00638); #undef ASSERT_REG_POSITION -#endif // !defined(_MSC_VER) // The total number of registers is chosen arbitrarily, but let's make sure it's not some odd value // anyway. diff --git a/src/core/loader/3dsx.cpp b/src/core/loader/3dsx.cpp index 7629ad376..84321011b 100644 --- a/src/core/loader/3dsx.cpp +++ b/src/core/loader/3dsx.cpp @@ -309,8 +309,8 @@ ResultStatus AppLoader_THREEDSX::ReadRomFS(std::shared_ptr if (!romfs_file_inner.IsOpen()) return ResultStatus::Error; - romfs_file = std::make_shared(std::move(romfs_file_inner), - romfs_offset, romfs_size); + romfs_file = std::make_shared(std::move(romfs_file_inner), + romfs_offset, romfs_size); return ResultStatus::Success; } diff --git a/src/core/loader/loader.h b/src/core/loader/loader.h index 20e84c6a9..f9a8e9296 100644 --- a/src/core/loader/loader.h +++ b/src/core/loader/loader.h @@ -105,13 +105,22 @@ public: * Loads the system mode that this application needs. * This function defaults to 2 (96MB allocated to the application) if it can't read the * information. - * @returns A pair with the optional system mode, and and the status. + * @returns A pair with the optional system mode, and the status. */ virtual std::pair, ResultStatus> LoadKernelSystemMode() { // 96MB allocated to the application. return std::make_pair(2, ResultStatus::Success); } + /** + * Loads the N3ds mode that this application uses. + * It defaults to 0 (O3DS default) if it can't read the information. + * @returns A pair with the optional N3ds mode, and the status. + */ + virtual std::pair, ResultStatus> LoadKernelN3dsMode() { + return std::make_pair(0, ResultStatus::Success); + } + /** * Get whether this application is executable. * @param out_executable Reference to store the executable flag into. @@ -186,6 +195,15 @@ public: return ResultStatus::ErrorNotImplemented; } + /** + * Dump the RomFS of the applciation + * @param target_path The target path to dump to + * @return ResultStatus result of function + */ + virtual ResultStatus DumpRomFS(const std::string& target_path) { + return ResultStatus::ErrorNotImplemented; + } + /** * Get the update RomFS of the application * Since the RomFS can be huge, we return a file reference instead of copying to a buffer @@ -196,6 +214,15 @@ public: return ResultStatus::ErrorNotImplemented; } + /** + * Dump the update RomFS of the applciation + * @param target_path The target path to dump to + * @return ResultStatus result of function + */ + virtual ResultStatus DumpUpdateRomFS(const std::string& target_path) { + return ResultStatus::ErrorNotImplemented; + } + /** * Get the title of the application * @param title Reference to store the application title into diff --git a/src/core/loader/ncch.cpp b/src/core/loader/ncch.cpp index 2e688e011..81053524f 100644 --- a/src/core/loader/ncch.cpp +++ b/src/core/loader/ncch.cpp @@ -61,6 +61,19 @@ std::pair, ResultStatus> AppLoader_NCCH::LoadKernelSystemMode ResultStatus::Success); } +std::pair, ResultStatus> AppLoader_NCCH::LoadKernelN3dsMode() { + if (!is_loaded) { + ResultStatus res = base_ncch.Load(); + if (res != ResultStatus::Success) { + return std::make_pair(std::optional{}, res); + } + } + + // Set the system mode as the one from the exheader. + return std::make_pair(overlay_ncch->exheader_header.arm11_system_local_caps.n3ds_mode, + ResultStatus::Success); +} + ResultStatus AppLoader_NCCH::LoadExec(std::shared_ptr& process) { using Kernel::CodeSet; @@ -254,6 +267,18 @@ ResultStatus AppLoader_NCCH::ReadUpdateRomFS(std::shared_ptr data; Loader::SMDH smdh; diff --git a/src/core/loader/ncch.h b/src/core/loader/ncch.h index 7c86f85d8..6f680b063 100644 --- a/src/core/loader/ncch.h +++ b/src/core/loader/ncch.h @@ -41,6 +41,8 @@ public: */ std::pair, ResultStatus> LoadKernelSystemMode() override; + std::pair, ResultStatus> LoadKernelN3dsMode() override; + ResultStatus IsExecutable(bool& out_executable) override; ResultStatus ReadCode(std::vector& buffer) override; @@ -59,6 +61,10 @@ public: ResultStatus ReadUpdateRomFS(std::shared_ptr& romfs_file) override; + ResultStatus DumpRomFS(const std::string& target_path) override; + + ResultStatus DumpUpdateRomFS(const std::string& target_path) override; + ResultStatus ReadTitle(std::string& title) override; private: diff --git a/src/core/rpc/rpc_server.cpp b/src/core/rpc/rpc_server.cpp index aec99f273..a95c53fd7 100644 --- a/src/core/rpc/rpc_server.cpp +++ b/src/core/rpc/rpc_server.cpp @@ -46,7 +46,9 @@ void RPCServer::HandleWriteMemory(Packet& packet, u32 address, const u8* data, u Core::System::GetInstance().Memory().WriteBlock( *Core::System::GetInstance().Kernel().GetCurrentProcess(), address, data, data_size); // If the memory happens to be executable code, make sure the changes become visible - Core::CPU().InvalidateCacheRange(address, data_size); + + // Is current core correct here? + Core::System::GetInstance().InvalidateCacheRange(address, data_size); } packet.SetPacketDataSize(0); packet.SendReply(); diff --git a/src/input_common/udp/client.cpp b/src/input_common/udp/client.cpp index 887436550..c9ffe899c 100644 --- a/src/input_common/udp/client.cpp +++ b/src/input_common/udp/client.cpp @@ -14,7 +14,6 @@ #include "input_common/udp/client.h" #include "input_common/udp/protocol.h" -using boost::asio::ip::address_v4; using boost::asio::ip::udp; namespace InputCommon::CemuhookUDP { @@ -31,10 +30,10 @@ public: explicit Socket(const std::string& host, u16 port, u8 pad_index, u32 client_id, SocketCallback callback) - : client_id(client_id), timer(io_service), - send_endpoint(udp::endpoint(address_v4::from_string(host), port)), - socket(io_service, udp::endpoint(udp::v4(), 0)), pad_index(pad_index), - callback(std::move(callback)) {} + : callback(std::move(callback)), timer(io_service), + socket(io_service, udp::endpoint(udp::v4(), 0)), client_id(client_id), + pad_index(pad_index), + send_endpoint(udp::endpoint(boost::asio::ip::make_address_v4(host), port)) {} void Stop() { io_service.stop(); @@ -126,7 +125,7 @@ static void SocketLoop(Socket* socket) { Client::Client(std::shared_ptr status, const std::string& host, u16 port, u8 pad_index, u32 client_id) - : status(status) { + : status(std::move(status)) { StartCommunication(host, port, pad_index, client_id); } @@ -208,7 +207,7 @@ void TestCommunication(const std::string& host, u16 port, u8 pad_index, u32 clie Common::Event success_event; SocketCallback callback{[](Response::Version version) {}, [](Response::PortInfo info) {}, [&](Response::PadData data) { success_event.Set(); }}; - Socket socket{host, port, pad_index, client_id, callback}; + Socket socket{host, port, pad_index, client_id, std::move(callback)}; std::thread worker_thread{SocketLoop, &socket}; bool result = success_event.WaitFor(std::chrono::seconds(8)); socket.Stop(); @@ -264,7 +263,7 @@ CalibrationConfigurationJob::CalibrationConfigurationJob( complete_event.Set(); } }}; - Socket socket{host, port, pad_index, client_id, callback}; + Socket socket{host, port, pad_index, client_id, std::move(callback)}; std::thread worker_thread{SocketLoop, &socket}; complete_event.Wait(); socket.Stop(); diff --git a/src/input_common/udp/client.h b/src/input_common/udp/client.h index 5177f46be..82bdb68d9 100644 --- a/src/input_common/udp/client.h +++ b/src/input_common/udp/client.h @@ -11,7 +11,6 @@ #include #include #include -#include #include "common/common_types.h" #include "common/thread.h" #include "common/vector_math.h" diff --git a/src/input_common/udp/protocol.h b/src/input_common/udp/protocol.h index d31bbeb89..5b1852d55 100644 --- a/src/input_common/udp/protocol.h +++ b/src/input_common/udp/protocol.h @@ -7,7 +7,6 @@ #include #include #include -#include #include #include "common/bit_field.h" #include "common/swap.h" diff --git a/src/input_common/udp/udp.cpp b/src/input_common/udp/udp.cpp index 43691ae2c..c4d5121b9 100644 --- a/src/input_common/udp/udp.cpp +++ b/src/input_common/udp/udp.cpp @@ -2,7 +2,8 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. -#include "common/logging/log.h" +#include +#include #include "common/param_package.h" #include "core/frontend/input.h" #include "core/settings.h" @@ -14,7 +15,7 @@ namespace InputCommon::CemuhookUDP { class UDPTouchDevice final : public Input::TouchDevice { public: explicit UDPTouchDevice(std::shared_ptr status_) : status(std::move(status_)) {} - std::tuple GetStatus() const { + std::tuple GetStatus() const override { std::lock_guard guard(status->update_mutex); return status->touch_status; } @@ -26,7 +27,7 @@ private: class UDPMotionDevice final : public Input::MotionDevice { public: explicit UDPMotionDevice(std::shared_ptr status_) : status(std::move(status_)) {} - std::tuple, Common::Vec3> GetStatus() const { + std::tuple, Common::Vec3> GetStatus() const override { std::lock_guard guard(status->update_mutex); return status->motion_status; } diff --git a/src/input_common/udp/udp.h b/src/input_common/udp/udp.h index ea3de60bb..3eac8c7ea 100644 --- a/src/input_common/udp/udp.h +++ b/src/input_common/udp/udp.h @@ -2,16 +2,13 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#pragma once + #include -#include -#include "input_common/main.h" #include "input_common/udp/client.h" namespace InputCommon::CemuhookUDP { -class UDPTouchDevice; -class UDPMotionDevice; - class State { public: State(); diff --git a/src/tests/core/arm/arm_test_common.cpp b/src/tests/core/arm/arm_test_common.cpp index cf2ecf0d4..381f36bee 100644 --- a/src/tests/core/arm/arm_test_common.cpp +++ b/src/tests/core/arm/arm_test_common.cpp @@ -15,9 +15,9 @@ static std::shared_ptr page_table = nullptr; TestEnvironment::TestEnvironment(bool mutable_memory_) : mutable_memory(mutable_memory_), test_memory(std::make_shared(this)) { - timing = std::make_unique(); + timing = std::make_unique(1); memory = std::make_unique(); - kernel = std::make_unique(*memory, *timing, [] {}, 0); + kernel = std::make_unique(*memory, *timing, [] {}, 0, 1, 0); kernel->SetCurrentProcess(kernel->CreateProcess(kernel->CreateCodeSet("", 0))); page_table = kernel->GetCurrentProcess()->vm_manager.page_table; diff --git a/src/tests/core/arm/dyncom/arm_dyncom_vfp_tests.cpp b/src/tests/core/arm/dyncom/arm_dyncom_vfp_tests.cpp index 5fadaf85e..f0152eb7d 100644 --- a/src/tests/core/arm/dyncom/arm_dyncom_vfp_tests.cpp +++ b/src/tests/core/arm/dyncom/arm_dyncom_vfp_tests.cpp @@ -23,7 +23,7 @@ TEST_CASE("ARM_DynCom (vfp): vadd", "[arm_dyncom]") { test_env.SetMemory32(0, 0xEE321A03); // vadd.f32 s2, s4, s6 test_env.SetMemory32(4, 0xEAFFFFFE); // b +#0 - ARM_DynCom dyncom(nullptr, test_env.GetMemory(), USER32MODE); + ARM_DynCom dyncom(nullptr, test_env.GetMemory(), USER32MODE, 0, nullptr); std::vector test_cases{{ #include "vfp_vadd_f32.inc" diff --git a/src/tests/core/core_timing.cpp b/src/tests/core/core_timing.cpp index 1cfd9e971..850f13bc5 100644 --- a/src/tests/core/core_timing.cpp +++ b/src/tests/core/core_timing.cpp @@ -34,16 +34,16 @@ static void AdvanceAndCheck(Core::Timing& timing, u32 idx, int downcount, int ex expected_callback = CB_IDS[idx]; lateness = expected_lateness; - timing.AddTicks(timing.GetDowncount() - - cpu_downcount); // Pretend we executed X cycles of instructions. - timing.Advance(); + timing.GetTimer(0)->AddTicks(timing.GetTimer(0)->GetDowncount() - + cpu_downcount); // Pretend we executed X cycles of instructions. + timing.GetTimer(0)->Advance(); REQUIRE(decltype(callbacks_ran_flags)().set(idx) == callbacks_ran_flags); - REQUIRE(downcount == timing.GetDowncount()); + REQUIRE(downcount == timing.GetTimer(0)->GetDowncount()); } TEST_CASE("CoreTiming[BasicOrder]", "[core]") { - Core::Timing timing; + Core::Timing timing(1); Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", CallbackTemplate<0>); Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", CallbackTemplate<1>); @@ -52,60 +52,19 @@ TEST_CASE("CoreTiming[BasicOrder]", "[core]") { Core::TimingEventType* cb_e = timing.RegisterEvent("callbackE", CallbackTemplate<4>); // Enter slice 0 - timing.Advance(); + timing.GetTimer(0)->Advance(); // D -> B -> C -> A -> E - timing.ScheduleEvent(1000, cb_a, CB_IDS[0]); - REQUIRE(1000 == timing.GetDowncount()); - timing.ScheduleEvent(500, cb_b, CB_IDS[1]); - REQUIRE(500 == timing.GetDowncount()); - timing.ScheduleEvent(800, cb_c, CB_IDS[2]); - REQUIRE(500 == timing.GetDowncount()); - timing.ScheduleEvent(100, cb_d, CB_IDS[3]); - REQUIRE(100 == timing.GetDowncount()); - timing.ScheduleEvent(1200, cb_e, CB_IDS[4]); - REQUIRE(100 == timing.GetDowncount()); - - AdvanceAndCheck(timing, 3, 400); - AdvanceAndCheck(timing, 1, 300); - AdvanceAndCheck(timing, 2, 200); - AdvanceAndCheck(timing, 0, 200); - AdvanceAndCheck(timing, 4, MAX_SLICE_LENGTH); -} - -TEST_CASE("CoreTiming[Threadsave]", "[core]") { - Core::Timing timing; - - Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", CallbackTemplate<0>); - Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", CallbackTemplate<1>); - Core::TimingEventType* cb_c = timing.RegisterEvent("callbackC", CallbackTemplate<2>); - Core::TimingEventType* cb_d = timing.RegisterEvent("callbackD", CallbackTemplate<3>); - Core::TimingEventType* cb_e = timing.RegisterEvent("callbackE", CallbackTemplate<4>); - - // Enter slice 0 - timing.Advance(); - - // D -> B -> C -> A -> E - timing.ScheduleEventThreadsafe(1000, cb_a, CB_IDS[0]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(1000); - REQUIRE(1000 == timing.GetDowncount()); - timing.ScheduleEventThreadsafe(500, cb_b, CB_IDS[1]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(500); - REQUIRE(500 == timing.GetDowncount()); - timing.ScheduleEventThreadsafe(800, cb_c, CB_IDS[2]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(800); - REQUIRE(500 == timing.GetDowncount()); - timing.ScheduleEventThreadsafe(100, cb_d, CB_IDS[3]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(100); - REQUIRE(100 == timing.GetDowncount()); - timing.ScheduleEventThreadsafe(1200, cb_e, CB_IDS[4]); - // Manually force since ScheduleEventThreadsafe doesn't call it - timing.ForceExceptionCheck(1200); - REQUIRE(100 == timing.GetDowncount()); + timing.ScheduleEvent(1000, cb_a, CB_IDS[0], 0); + REQUIRE(1000 == timing.GetTimer(0)->GetDowncount()); + timing.ScheduleEvent(500, cb_b, CB_IDS[1], 0); + REQUIRE(500 == timing.GetTimer(0)->GetDowncount()); + timing.ScheduleEvent(800, cb_c, CB_IDS[2], 0); + REQUIRE(500 == timing.GetTimer(0)->GetDowncount()); + timing.ScheduleEvent(100, cb_d, CB_IDS[3], 0); + REQUIRE(100 == timing.GetTimer(0)->GetDowncount()); + timing.ScheduleEvent(1200, cb_e, CB_IDS[4], 0); + REQUIRE(100 == timing.GetTimer(0)->GetDowncount()); AdvanceAndCheck(timing, 3, 400); AdvanceAndCheck(timing, 1, 300); @@ -131,7 +90,7 @@ void FifoCallback(u64 userdata, s64 cycles_late) { TEST_CASE("CoreTiming[SharedSlot]", "[core]") { using namespace SharedSlotTest; - Core::Timing timing; + Core::Timing timing(1); Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", FifoCallback<0>); Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", FifoCallback<1>); @@ -139,36 +98,36 @@ TEST_CASE("CoreTiming[SharedSlot]", "[core]") { Core::TimingEventType* cb_d = timing.RegisterEvent("callbackD", FifoCallback<3>); Core::TimingEventType* cb_e = timing.RegisterEvent("callbackE", FifoCallback<4>); - timing.ScheduleEvent(1000, cb_a, CB_IDS[0]); - timing.ScheduleEvent(1000, cb_b, CB_IDS[1]); - timing.ScheduleEvent(1000, cb_c, CB_IDS[2]); - timing.ScheduleEvent(1000, cb_d, CB_IDS[3]); - timing.ScheduleEvent(1000, cb_e, CB_IDS[4]); + timing.ScheduleEvent(1000, cb_a, CB_IDS[0], 0); + timing.ScheduleEvent(1000, cb_b, CB_IDS[1], 0); + timing.ScheduleEvent(1000, cb_c, CB_IDS[2], 0); + timing.ScheduleEvent(1000, cb_d, CB_IDS[3], 0); + timing.ScheduleEvent(1000, cb_e, CB_IDS[4], 0); // Enter slice 0 - timing.Advance(); - REQUIRE(1000 == timing.GetDowncount()); + timing.GetTimer(0)->Advance(); + REQUIRE(1000 == timing.GetTimer(0)->GetDowncount()); callbacks_ran_flags = 0; counter = 0; lateness = 0; - timing.AddTicks(timing.GetDowncount()); - timing.Advance(); - REQUIRE(MAX_SLICE_LENGTH == timing.GetDowncount()); + timing.GetTimer(0)->AddTicks(timing.GetTimer(0)->GetDowncount()); + timing.GetTimer(0)->Advance(); + REQUIRE(MAX_SLICE_LENGTH == timing.GetTimer(0)->GetDowncount()); REQUIRE(0x1FULL == callbacks_ran_flags.to_ullong()); } TEST_CASE("CoreTiming[PredictableLateness]", "[core]") { - Core::Timing timing; + Core::Timing timing(1); Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", CallbackTemplate<0>); Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", CallbackTemplate<1>); // Enter slice 0 - timing.Advance(); + timing.GetTimer(0)->Advance(); - timing.ScheduleEvent(100, cb_a, CB_IDS[0]); - timing.ScheduleEvent(200, cb_b, CB_IDS[1]); + timing.ScheduleEvent(100, cb_a, CB_IDS[0], 0); + timing.ScheduleEvent(200, cb_b, CB_IDS[1], 0); AdvanceAndCheck(timing, 0, 90, 10, -10); // (100 - 10) AdvanceAndCheck(timing, 1, MAX_SLICE_LENGTH, 50, -50); @@ -190,7 +149,7 @@ static void RescheduleCallback(Core::Timing& timing, u64 userdata, s64 cycles_la TEST_CASE("CoreTiming[ChainScheduling]", "[core]") { using namespace ChainSchedulingTest; - Core::Timing timing; + Core::Timing timing(1); Core::TimingEventType* cb_a = timing.RegisterEvent("callbackA", CallbackTemplate<0>); Core::TimingEventType* cb_b = timing.RegisterEvent("callbackB", CallbackTemplate<1>); @@ -201,28 +160,30 @@ TEST_CASE("CoreTiming[ChainScheduling]", "[core]") { }); // Enter slice 0 - timing.Advance(); + timing.GetTimer(0)->Advance(); - timing.ScheduleEvent(800, cb_a, CB_IDS[0]); - timing.ScheduleEvent(1000, cb_b, CB_IDS[1]); - timing.ScheduleEvent(2200, cb_c, CB_IDS[2]); - timing.ScheduleEvent(1000, cb_rs, reinterpret_cast(cb_rs)); - REQUIRE(800 == timing.GetDowncount()); + timing.ScheduleEvent(800, cb_a, CB_IDS[0], 0); + timing.ScheduleEvent(1000, cb_b, CB_IDS[1], 0); + timing.ScheduleEvent(2200, cb_c, CB_IDS[2], 0); + timing.ScheduleEvent(1000, cb_rs, reinterpret_cast(cb_rs), 0); + REQUIRE(800 == timing.GetTimer(0)->GetDowncount()); reschedules = 3; AdvanceAndCheck(timing, 0, 200); // cb_a AdvanceAndCheck(timing, 1, 1000); // cb_b, cb_rs REQUIRE(2 == reschedules); - timing.AddTicks(timing.GetDowncount()); - timing.Advance(); // cb_rs + timing.GetTimer(0)->AddTicks(timing.GetTimer(0)->GetDowncount()); + timing.GetTimer(0)->Advance(); // cb_rs REQUIRE(1 == reschedules); - REQUIRE(200 == timing.GetDowncount()); + REQUIRE(200 == timing.GetTimer(0)->GetDowncount()); AdvanceAndCheck(timing, 2, 800); // cb_c - timing.AddTicks(timing.GetDowncount()); - timing.Advance(); // cb_rs + timing.GetTimer(0)->AddTicks(timing.GetTimer(0)->GetDowncount()); + timing.GetTimer(0)->Advance(); // cb_rs REQUIRE(0 == reschedules); - REQUIRE(MAX_SLICE_LENGTH == timing.GetDowncount()); + REQUIRE(MAX_SLICE_LENGTH == timing.GetTimer(0)->GetDowncount()); } + +// TODO: Add tests for multiple timers diff --git a/src/tests/core/hle/kernel/hle_ipc.cpp b/src/tests/core/hle/kernel/hle_ipc.cpp index a4f7c8062..880bb9995 100644 --- a/src/tests/core/hle/kernel/hle_ipc.cpp +++ b/src/tests/core/hle/kernel/hle_ipc.cpp @@ -24,9 +24,9 @@ static std::shared_ptr MakeObject(Kernel::KernelSystem& kernel) { } TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel]") { - Core::Timing timing; + Core::Timing timing(1); Memory::MemorySystem memory; - Kernel::KernelSystem kernel(memory, timing, [] {}, 0); + Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1, 0); auto [server, client] = kernel.CreateSessionPair(); HLERequestContext context(kernel, std::move(server), nullptr); @@ -239,9 +239,9 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel } TEST_CASE("HLERequestContext::WriteToOutgoingCommandBuffer", "[core][kernel]") { - Core::Timing timing; + Core::Timing timing(1); Memory::MemorySystem memory; - Kernel::KernelSystem kernel(memory, timing, [] {}, 0); + Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1, 0); auto [server, client] = kernel.CreateSessionPair(); HLERequestContext context(kernel, std::move(server), nullptr); diff --git a/src/tests/core/memory/memory.cpp b/src/tests/core/memory/memory.cpp index 4a6d54bf7..8f08862a1 100644 --- a/src/tests/core/memory/memory.cpp +++ b/src/tests/core/memory/memory.cpp @@ -11,9 +11,9 @@ #include "core/memory.h" TEST_CASE("Memory::IsValidVirtualAddress", "[core][memory]") { - Core::Timing timing; + Core::Timing timing(1); Memory::MemorySystem memory; - Kernel::KernelSystem kernel(memory, timing, [] {}, 0); + Kernel::KernelSystem kernel(memory, timing, [] {}, 0, 1, 0); SECTION("these regions should not be mapped on an empty process") { auto process = kernel.CreateProcess(kernel.CreateCodeSet("", 0)); CHECK(Memory::IsValidVirtualAddress(*process, Memory::PROCESS_IMAGE_VADDR) == false); diff --git a/src/video_core/renderer_opengl/gl_rasterizer_cache.cpp b/src/video_core/renderer_opengl/gl_rasterizer_cache.cpp index 4dcc64832..3365c6afa 100644 --- a/src/video_core/renderer_opengl/gl_rasterizer_cache.cpp +++ b/src/video_core/renderer_opengl/gl_rasterizer_cache.cpp @@ -1926,7 +1926,7 @@ void RasterizerCacheOpenGL::ValidateSurface(const Surface& surface, PAddr addr, } void RasterizerCacheOpenGL::ClearAll(bool flush) { - const SurfaceInterval flush_interval(0x0, 0xFFFFFFFF); + const auto flush_interval = PageMap::interval_type::right_open(0x0, 0xFFFFFFFF); // Force flush all surfaces from the cache if (flush) { FlushRegion(0x0, 0xFFFFFFFF); @@ -1945,8 +1945,8 @@ void RasterizerCacheOpenGL::ClearAll(bool flush) { // Remove the whole cache without really looking at it. cached_pages -= flush_interval; - dirty_regions -= flush_interval; - surface_cache -= flush_interval; + dirty_regions -= SurfaceInterval(0x0, 0xFFFFFFFF); + surface_cache -= SurfaceInterval(0x0, 0xFFFFFFFF); remove_surfaces.clear(); } diff --git a/src/video_core/renderer_opengl/gl_rasterizer_cache.h b/src/video_core/renderer_opengl/gl_rasterizer_cache.h index 69322f713..46645daec 100644 --- a/src/video_core/renderer_opengl/gl_rasterizer_cache.h +++ b/src/video_core/renderer_opengl/gl_rasterizer_cache.h @@ -80,11 +80,15 @@ struct CachedSurface; using Surface = std::shared_ptr; using SurfaceSet = std::set; -using SurfaceRegions = boost::icl::interval_set; -using SurfaceMap = boost::icl::interval_map; -using SurfaceCache = boost::icl::interval_map; +using SurfaceInterval = boost::icl::right_open_interval; +using SurfaceRegions = boost::icl::interval_set; +using SurfaceMap = + boost::icl::interval_map; +using SurfaceCache = + boost::icl::interval_map; -using SurfaceInterval = SurfaceCache::interval_type; static_assert(std::is_same() && std::is_same(), "incorrect interval types"); @@ -101,6 +105,29 @@ enum class ScaleMatch { }; struct SurfaceParams { +private: + static constexpr std::array BPP_TABLE = { + 32, // RGBA8 + 24, // RGB8 + 16, // RGB5A1 + 16, // RGB565 + 16, // RGBA4 + 16, // IA8 + 16, // RG8 + 8, // I8 + 8, // A8 + 8, // IA4 + 4, // I4 + 4, // A4 + 4, // ETC1 + 8, // ETC1A4 + 16, // D16 + 0, + 24, // D24 + 32, // D24S8 + }; + +public: enum class PixelFormat { // First 5 formats are shared between textures and color buffers RGBA8 = 0, @@ -139,30 +166,11 @@ struct SurfaceParams { }; static constexpr unsigned int GetFormatBpp(PixelFormat format) { - constexpr std::array bpp_table = { - 32, // RGBA8 - 24, // RGB8 - 16, // RGB5A1 - 16, // RGB565 - 16, // RGBA4 - 16, // IA8 - 16, // RG8 - 8, // I8 - 8, // A8 - 8, // IA4 - 4, // I4 - 4, // A4 - 4, // ETC1 - 8, // ETC1A4 - 16, // D16 - 0, - 24, // D24 - 32, // D24S8 - }; - - assert(static_cast(format) < bpp_table.size()); - return bpp_table[static_cast(format)]; + const auto format_idx = static_cast(format); + DEBUG_ASSERT_MSG(format_idx < BPP_TABLE.size(), "Invalid pixel format {}", format_idx); + return BPP_TABLE[format_idx]; } + unsigned int GetFormatBpp() const { return GetFormatBpp(pixel_format); } @@ -245,7 +253,7 @@ struct SurfaceParams { } SurfaceInterval GetInterval() const { - return SurfaceInterval::right_open(addr, end); + return SurfaceInterval(addr, end); } // Returns the outer rectangle containing "interval" diff --git a/src/web_service/CMakeLists.txt b/src/web_service/CMakeLists.txt index c3d42fe8b..5695e25f0 100644 --- a/src/web_service/CMakeLists.txt +++ b/src/web_service/CMakeLists.txt @@ -18,3 +18,6 @@ get_directory_property(OPENSSL_LIBS DEFINITION OPENSSL_LIBS) target_compile_definitions(web_service PRIVATE -DCPPHTTPLIB_OPENSSL_SUPPORT) target_link_libraries(web_service PRIVATE common network json-headers ${OPENSSL_LIBS} httplib lurlparser cpp-jwt) +if (ANDROID) + target_link_libraries(web_service PRIVATE ifaddrs) +endif() diff --git a/src/web_service/web_backend.cpp b/src/web_service/web_backend.cpp index 6683f459f..c047677f9 100644 --- a/src/web_service/web_backend.cpp +++ b/src/web_service/web_backend.cpp @@ -8,6 +8,9 @@ #include #include #include +#if defined(__ANDROID__) +#include +#endif #include #include "common/common_types.h" #include "common/logging/log.h" @@ -73,14 +76,14 @@ struct Client::Impl { if (!parsedUrl.GetPort(&port)) { port = HTTP_PORT; } - cli = std::make_unique(parsedUrl.m_Host.c_str(), port, - TIMEOUT_SECONDS); + cli = std::make_unique(parsedUrl.m_Host.c_str(), port); + cli->set_timeout_sec(TIMEOUT_SECONDS); } else if (parsedUrl.m_Scheme == "https") { if (!parsedUrl.GetPort(&port)) { port = HTTPS_PORT; } - cli = std::make_unique(parsedUrl.m_Host.c_str(), port, - TIMEOUT_SECONDS); + cli = std::make_unique(parsedUrl.m_Host.c_str(), port); + cli->set_timeout_sec(TIMEOUT_SECONDS); } else { LOG_ERROR(WebService, "Bad URL scheme {}", parsedUrl.m_Scheme); return Common::WebResult{Common::WebResult::Code::InvalidURL, "Bad URL scheme"};