ssl: tolerate handshake without hostname set (#11328)

This commit is contained in:
liamwhite 2023-08-25 18:02:32 -04:00 committed by GitHub
parent b923f5aa7e
commit 234cc45192
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 24 deletions

View file

@ -139,7 +139,6 @@ private:
bool do_not_close_socket = false; bool do_not_close_socket = false;
bool get_server_cert_chain = false; bool get_server_cert_chain = false;
std::shared_ptr<Network::SocketBase> socket; std::shared_ptr<Network::SocketBase> socket;
bool did_set_host_name = false;
bool did_handshake = false; bool did_handshake = false;
Result SetSocketDescriptorImpl(s32* out_fd, s32 fd) { Result SetSocketDescriptorImpl(s32* out_fd, s32 fd) {
@ -174,11 +173,7 @@ private:
Result SetHostNameImpl(const std::string& hostname) { Result SetHostNameImpl(const std::string& hostname) {
LOG_DEBUG(Service_SSL, "called. hostname={}", hostname); LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
ASSERT(!did_handshake); ASSERT(!did_handshake);
Result res = backend->SetHostName(hostname); return backend->SetHostName(hostname);
if (res == ResultSuccess) {
did_set_host_name = true;
}
return res;
} }
Result SetVerifyOptionImpl(u32 option) { Result SetVerifyOptionImpl(u32 option) {
@ -208,9 +203,6 @@ private:
Result DoHandshakeImpl() { Result DoHandshakeImpl() {
ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; }); ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; });
ASSERT_OR_EXECUTE_MSG(
did_set_host_name, { return ResultInternalError; },
"Expected SetHostName before DoHandshake");
Result res = backend->DoHandshake(); Result res = backend->DoHandshake();
did_handshake = res.IsSuccess(); did_handshake = res.IsSuccess();
return res; return res;

View file

@ -167,9 +167,8 @@ public:
} }
~SSLConnectionBackendOpenSSL() { ~SSLConnectionBackendOpenSSL() {
// these are null-tolerant: // this is null-tolerant:
SSL_free(ssl); SSL_free(ssl);
BIO_free(bio);
} }
static void KeyLogCallback(const SSL* ssl, const char* line) { static void KeyLogCallback(const SSL* ssl, const char* line) {

View file

@ -31,9 +31,9 @@ CredHandle cred_handle;
static void OneTimeInit() { static void OneTimeInit() {
schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
schannel_cred.dwFlags = schannel_cred.dwFlags =
SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
SCH_CRED_AUTO_CRED_VALIDATION | // validate certs SCH_CRED_NO_SERVERNAME_CHECK | // don't validate server names
SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate
// ^ I'm assuming that nobody would want to connect Yuzu to a // ^ I'm assuming that nobody would want to connect Yuzu to a
// service that requires some OS-provided corporate client // service that requires some OS-provided corporate client
// certificate, and presenting one to some arbitrary server // certificate, and presenting one to some arbitrary server
@ -227,16 +227,15 @@ public:
ciphertext_read_buf.size()); ciphertext_read_buf.size());
} }
const SECURITY_STATUS ret = char* hostname_ptr = hostname ? const_cast<char*>(hostname->c_str()) : nullptr;
InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, const SECURITY_STATUS ret = InitializeSecurityContextA(
// Caller ensured we have set a hostname: &cred_handle, initial_call_done ? &ctxt : nullptr, hostname_ptr, req,
const_cast<char*>(hostname.value().c_str()), req, 0, // Reserved1
0, // Reserved1 0, // TargetDataRep not used with Schannel
0, // TargetDataRep not used with Schannel initial_call_done ? &input_desc : nullptr,
initial_call_done ? &input_desc : nullptr, 0, // Reserved2
0, // Reserved2 initial_call_done ? nullptr : &ctxt, &output_desc, &attr,
initial_call_done ? nullptr : &ctxt, &output_desc, &attr, nullptr); // ptsExpiry
nullptr); // ptsExpiry
if (output_buffers[0].pvBuffer) { if (output_buffers[0].pvBuffer) {
const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),