diff --git a/src/audio_core/CMakeLists.txt b/src/audio_core/CMakeLists.txt index 4e0c9f4de..f2b3e1f3b 100644 --- a/src/audio_core/CMakeLists.txt +++ b/src/audio_core/CMakeLists.txt @@ -45,7 +45,9 @@ if(ENABLE_MF) hle/wmf_decoder_utils.cpp hle/wmf_decoder_utils.h ) - target_link_libraries(audio_core PRIVATE mf.lib mfplat.lib mfuuid.lib) + # We dynamically load the required symbols from mf.dll and mfplat.dll but mfuuid is not a dll + # just a static library of GUIDS so include that one directly. + target_link_libraries(audio_core PRIVATE mfuuid.lib) target_compile_definitions(audio_core PUBLIC HAVE_MF) elseif(ENABLE_FFMPEG_AUDIO_DECODER) target_sources(audio_core PRIVATE diff --git a/src/audio_core/hle/decoder.h b/src/audio_core/hle/decoder.h index a551df4eb..7bb5a25f4 100644 --- a/src/audio_core/hle/decoder.h +++ b/src/audio_core/hle/decoder.h @@ -56,6 +56,9 @@ class DecoderBase { public: virtual ~DecoderBase(); virtual std::optional ProcessRequest(const BinaryRequest& request) = 0; + /// Return true if this Decoder can be loaded. Return false if the system cannot create the + /// decoder + virtual bool IsValid() const = 0; }; class NullDecoder final : public DecoderBase { @@ -63,6 +66,9 @@ public: NullDecoder(); ~NullDecoder() override; std::optional ProcessRequest(const BinaryRequest& request) override; + bool IsValid() const override { + return true; + } }; } // namespace AudioCore::HLE diff --git a/src/audio_core/hle/ffmpeg_decoder.cpp b/src/audio_core/hle/ffmpeg_decoder.cpp index 6354f2c3e..061decf13 100644 --- a/src/audio_core/hle/ffmpeg_decoder.cpp +++ b/src/audio_core/hle/ffmpeg_decoder.cpp @@ -12,6 +12,9 @@ public: explicit Impl(Memory::MemorySystem& memory); ~Impl(); std::optional ProcessRequest(const BinaryRequest& request); + bool IsValid() const { + return initalized; + } private: std::optional Initalize(const BinaryRequest& request); @@ -261,4 +264,8 @@ std::optional FFMPEGDecoder::ProcessRequest(const BinaryRequest& return impl->ProcessRequest(request); } +bool FFMPEGDecoder::IsValid() const { + return impl->IsValid(); +} + } // namespace AudioCore::HLE diff --git a/src/audio_core/hle/ffmpeg_decoder.h b/src/audio_core/hle/ffmpeg_decoder.h index 190251543..ee5e8cda7 100644 --- a/src/audio_core/hle/ffmpeg_decoder.h +++ b/src/audio_core/hle/ffmpeg_decoder.h @@ -13,6 +13,7 @@ public: explicit FFMPEGDecoder(Memory::MemorySystem& memory); ~FFMPEGDecoder() override; std::optional ProcessRequest(const BinaryRequest& request) override; + bool IsValid() const override; private: class Impl; diff --git a/src/audio_core/hle/hle.cpp b/src/audio_core/hle/hle.cpp index 0bb7d4413..873d6a72e 100644 --- a/src/audio_core/hle/hle.cpp +++ b/src/audio_core/hle/hle.cpp @@ -87,15 +87,27 @@ DspHle::Impl::Impl(DspHle& parent_, Memory::MemorySystem& memory) : parent(paren source.SetMemory(memory); } -#ifdef HAVE_MF +#if defined(HAVE_MF) && defined(HAVE_FFMPEG) decoder = std::make_unique(memory); -#elif HAVE_FFMPEG + if (!decoder->IsValid()) { + LOG_WARNING(Audio_DSP, "Unable to load MediaFoundation. Attempting to load FFMPEG instead"); + decoder = std::make_unique(memory); + } +#elif defined(HAVE_MF) + decoder = std::make_unique(memory); +#elif defined(HAVE_FFMPEG) decoder = std::make_unique(memory); #else LOG_WARNING(Audio_DSP, "No decoder found, this could lead to missing audio"); decoder = std::make_unique(); #endif // HAVE_MF + if (!decoder->IsValid()) { + LOG_WARNING(Audio_DSP, + "Unable to load any decoders, this could cause missing audio in some games"); + decoder = std::make_unique(); + } + Core::Timing& timing = Core::System::GetInstance().CoreTiming(); tick_event = timing.RegisterEvent("AudioCore::DspHle::tick_event", [this](u64, s64 cycles_late) { diff --git a/src/audio_core/hle/wmf_decoder.cpp b/src/audio_core/hle/wmf_decoder.cpp index a1507991b..f188e5153 100644 --- a/src/audio_core/hle/wmf_decoder.cpp +++ b/src/audio_core/hle/wmf_decoder.cpp @@ -7,11 +7,16 @@ namespace AudioCore::HLE { +using namespace MFDecoder; + class WMFDecoder::Impl { public: explicit Impl(Memory::MemorySystem& memory); ~Impl(); std::optional ProcessRequest(const BinaryRequest& request); + bool IsValid() const { + return is_valid; + } private: std::optional Initalize(const BinaryRequest& request); @@ -28,21 +33,35 @@ private: unique_mfptr transform; DWORD in_stream_id = 0; DWORD out_stream_id = 0; + bool is_valid = false; + bool mf_started = false; + bool coinited = false; }; WMFDecoder::Impl::Impl(Memory::MemorySystem& memory) : memory(memory) { + // Attempt to load the symbols for mf.dll + if (!InitMFDLL()) { + LOG_CRITICAL(Audio_DSP, + "Unable to load mf.dll. AAC audio through media foundation unavailable"); + return; + } + HRESULT hr = S_OK; hr = CoInitialize(NULL); // S_FALSE will be returned when COM has already been initialized if (hr != S_OK && hr != S_FALSE) { ReportError("Failed to start COM components", hr); + } else { + coinited = true; } // lite startup is faster and all what we need is included - hr = MFStartup(MF_VERSION, MFSTARTUP_LITE); + hr = MFDecoder::MFStartup(MF_VERSION, MFSTARTUP_LITE); if (hr != S_OK) { // Do you know you can't initialize MF in test mode or safe mode? ReportError("Failed to initialize Media Foundation", hr); + } else { + mf_started = true; } LOG_INFO(Audio_DSP, "Media Foundation activated"); @@ -64,6 +83,7 @@ WMFDecoder::Impl::Impl(Memory::MemorySystem& memory) : memory(memory) { return; } transform_initialized = true; + is_valid = true; } WMFDecoder::Impl::~Impl() { @@ -73,8 +93,12 @@ WMFDecoder::Impl::~Impl() { // otherwise access violation will occur transform.reset(); } - MFShutdown(); - CoUninitialize(); + if (mf_started) { + MFDecoder::MFShutdown(); + } + if (coinited) { + CoUninitialize(); + } } std::optional WMFDecoder::Impl::ProcessRequest(const BinaryRequest& request) { @@ -271,4 +295,8 @@ std::optional WMFDecoder::ProcessRequest(const BinaryRequest& re return impl->ProcessRequest(request); } +bool WMFDecoder::IsValid() const { + return impl->IsValid(); +} + } // namespace AudioCore::HLE diff --git a/src/audio_core/hle/wmf_decoder.h b/src/audio_core/hle/wmf_decoder.h index 34e223740..a089f2322 100644 --- a/src/audio_core/hle/wmf_decoder.h +++ b/src/audio_core/hle/wmf_decoder.h @@ -13,6 +13,7 @@ public: explicit WMFDecoder(Memory::MemorySystem& memory); ~WMFDecoder() override; std::optional ProcessRequest(const BinaryRequest& request) override; + bool IsValid() const override; private: class Impl; diff --git a/src/audio_core/hle/wmf_decoder_utils.cpp b/src/audio_core/hle/wmf_decoder_utils.cpp index 21dd8a950..bcaf0b225 100644 --- a/src/audio_core/hle/wmf_decoder_utils.cpp +++ b/src/audio_core/hle/wmf_decoder_utils.cpp @@ -5,6 +5,8 @@ #include "common/string_util.h" #include "wmf_decoder_utils.h" +namespace MFDecoder { + // utility functions void ReportError(std::string msg, HRESULT hr) { if (SUCCEEDED(hr)) { @@ -26,6 +28,7 @@ void ReportError(std::string msg, HRESULT hr) { } unique_mfptr MFDecoderInit(GUID audio_format) { + HRESULT hr = S_OK; MFT_REGISTER_TYPE_INFO reg = {0}; GUID category = MFT_CATEGORY_AUDIO_DECODER; @@ -347,3 +350,112 @@ std::optional> CopySampleToBuffer(IMFSample* sample) { buffer->Unlock(); return output; } + +namespace { + +struct LibraryDeleter { + using pointer = HMODULE; + void operator()(HMODULE h) const { + if (h != nullptr) + FreeLibrary(h); + } +}; + +std::unique_ptr mf_dll{nullptr}; +std::unique_ptr mfplat_dll{nullptr}; + +} // namespace + +bool InitMFDLL() { + + mf_dll.reset(LoadLibrary(TEXT("mf.dll"))); + if (!mf_dll) { + DWORD error_message_id = GetLastError(); + LPSTR message_buffer = nullptr; + size_t size = + FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + nullptr, error_message_id, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + reinterpret_cast(&message_buffer), 0, nullptr); + + std::string message(message_buffer, size); + + LocalFree(message_buffer); + LOG_ERROR(Audio_DSP, "Could not load mf.dll: {}", message); + return false; + } + + mfplat_dll.reset(LoadLibrary(TEXT("mfplat.dll"))); + if (!mfplat_dll) { + DWORD error_message_id = GetLastError(); + LPSTR message_buffer = nullptr; + size_t size = + FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + nullptr, error_message_id, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + reinterpret_cast(&message_buffer), 0, nullptr); + + std::string message(message_buffer, size); + + LocalFree(message_buffer); + LOG_ERROR(Audio_DSP, "Could not load mfplat.dll: {}", message); + return false; + } + + MFStartup = Symbol(mfplat_dll.get(), "MFStartup"); + if (!MFStartup) { + LOG_ERROR(Audio_DSP, "Cannot load function MFStartup"); + return false; + } + + MFShutdown = Symbol(mfplat_dll.get(), "MFShutdown"); + if (!MFShutdown) { + LOG_ERROR(Audio_DSP, "Cannot load function MFShutdown"); + return false; + } + + MFShutdownObject = Symbol(mf_dll.get(), "MFShutdownObject"); + if (!MFShutdownObject) { + LOG_ERROR(Audio_DSP, "Cannot load function MFShutdownObject"); + return false; + } + + MFCreateAlignedMemoryBuffer = Symbol( + mfplat_dll.get(), "MFCreateAlignedMemoryBuffer"); + if (!MFCreateAlignedMemoryBuffer) { + LOG_ERROR(Audio_DSP, "Cannot load function MFCreateAlignedMemoryBuffer"); + return false; + } + + MFCreateSample = Symbol(mfplat_dll.get(), "MFCreateSample"); + if (!MFCreateSample) { + LOG_ERROR(Audio_DSP, "Cannot load function MFCreateSample"); + return false; + } + + MFTEnumEx = + Symbol(mfplat_dll.get(), "MFTEnumEx"); + if (!MFTEnumEx) { + LOG_ERROR(Audio_DSP, "Cannot load function MFTEnumEx"); + return false; + } + + MFCreateMediaType = Symbol(mfplat_dll.get(), "MFCreateMediaType"); + if (!MFCreateMediaType) { + LOG_ERROR(Audio_DSP, "Cannot load function MFCreateMediaType"); + return false; + } +} + +Symbol MFStartup; +Symbol MFShutdown; +Symbol MFShutdownObject; +Symbol MFCreateAlignedMemoryBuffer; +Symbol MFCreateSample; +Symbol + MFTEnumEx; +Symbol MFCreateMediaType; + +} // namespace MFDecoder diff --git a/src/audio_core/hle/wmf_decoder_utils.h b/src/audio_core/hle/wmf_decoder_utils.h index 26e1217a2..cdbde5f1f 100644 --- a/src/audio_core/hle/wmf_decoder_utils.h +++ b/src/audio_core/hle/wmf_decoder_utils.h @@ -18,6 +18,39 @@ #include "adts.h" +namespace MFDecoder { + +template +struct Symbol { + Symbol() = default; + Symbol(HMODULE dll, const char* name) { + if (dll) { + ptr_symbol = reinterpret_cast(GetProcAddress(dll, name)); + } + } + + operator T*() const { + return ptr_symbol; + } + + explicit operator bool() const { + return ptr_symbol != nullptr; + } + + T* ptr_symbol = nullptr; +}; + +// Runtime load the MF symbols to prevent mf.dll not found errors on citra load +extern Symbol MFStartup; +extern Symbol MFShutdown; +extern Symbol MFShutdownObject; +extern Symbol MFCreateAlignedMemoryBuffer; +extern Symbol MFCreateSample; +extern Symbol + MFTEnumEx; +extern Symbol MFCreateMediaType; + enum class MFOutputState { FatalError, OK, NeedMoreInput, NeedReconfig, HaveMoreData }; enum class MFInputState { FatalError, OK, NotAccepted }; @@ -73,6 +106,9 @@ struct ADTSMeta { }; // exported functions + +/// Loads the symbols from mf.dll at runtime. Returns false if the symbols can't be loaded +bool InitMFDLL(); unique_mfptr MFDecoderInit(GUID audio_format = MFAudioFormat_AAC); unique_mfptr CreateSample(const void* data, DWORD len, DWORD alignment = 1, LONGLONG duration = 0); @@ -87,3 +123,5 @@ MFInputState SendSample(IMFTransform* transform, DWORD in_stream_id, IMFSample* std::tuple> ReceiveSample(IMFTransform* transform, DWORD out_stream_id); std::optional> CopySampleToBuffer(IMFSample* sample); + +} // namespace MFDecoder