From 51996c54f03c554b8126cd7242801879be292f7c Mon Sep 17 00:00:00 2001 From: SachinVin <26602104+SachinVin@users.noreply.github.com> Date: Sat, 29 Jul 2023 00:45:58 +0530 Subject: [PATCH] audio_core\hle\adts_reader.cpp: Use BitField to parse ADTS header (#6719) --- src/audio_core/hle/adts.h | 24 ++++--- src/audio_core/hle/adts_reader.cpp | 64 +++++++++++------- src/audio_core/hle/audiotoolbox_decoder.cpp | 11 +-- src/audio_core/hle/mediandk_decoder.cpp | 13 ++-- src/audio_core/hle/wmf_decoder.cpp | 11 +-- src/audio_core/hle/wmf_decoder_utils.cpp | 13 ++-- src/audio_core/hle/wmf_decoder_utils.h | 10 +-- src/tests/CMakeLists.txt | 1 + src/tests/audio_core/hle/adts_reader.cpp | 75 +++++++++++++++++++++ 9 files changed, 161 insertions(+), 61 deletions(-) create mode 100644 src/tests/audio_core/hle/adts_reader.cpp diff --git a/src/audio_core/hle/adts.h b/src/audio_core/hle/adts.h index cc602e125..3729a81da 100644 --- a/src/audio_core/hle/adts.h +++ b/src/audio_core/hle/adts.h @@ -5,20 +5,24 @@ #include "common/common_types.h" +namespace AudioCore { + struct ADTSData { - u8 header_length; - bool MPEG2; - u8 profile; - u8 channels; - u8 channel_idx; - u8 framecount; - u8 samplerate_idx; - u32 length; - u32 samplerate; + u8 header_length = 0; + bool mpeg2 = false; + u8 profile = 0; + u8 channels = 0; + u8 channel_idx = 0; + u8 framecount = 0; + u8 samplerate_idx = 0; + u32 length = 0; + u32 samplerate = 0; }; -ADTSData ParseADTS(const char* buffer); +ADTSData ParseADTS(const u8* buffer); // last two bytes of MF AAC decoder user data // see https://docs.microsoft.com/en-us/windows/desktop/medfound/aac-decoder#example-media-types u16 MFGetAACTag(const ADTSData& input); + +} // namespace AudioCore diff --git a/src/audio_core/hle/adts_reader.cpp b/src/audio_core/hle/adts_reader.cpp index 417e02176..d3dc7942e 100644 --- a/src/audio_core/hle/adts_reader.cpp +++ b/src/audio_core/hle/adts_reader.cpp @@ -3,44 +3,59 @@ // Refer to the license.txt file included. #include #include "adts.h" +#include "common/bit_field.h" +namespace AudioCore { constexpr std::array freq_table = {96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050, 16000, 12000, 11025, 8000, 7350, 0, 0, 0}; constexpr std::array channel_table = {0, 1, 2, 3, 4, 5, 6, 8}; -ADTSData ParseADTS(const char* buffer) { - u32 tmp = 0; - ADTSData out; +struct ADTSHeader { + union { + std::array raw{}; + BitFieldBE<52, 12, u64> sync_word; + BitFieldBE<51, 1, u64> mpeg2; + BitFieldBE<49, 2, u64> layer; + BitFieldBE<48, 1, u64> protection_absent; + BitFieldBE<46, 2, u64> profile; + BitFieldBE<42, 4, u64> samplerate_idx; + BitFieldBE<41, 1, u64> private_bit; + BitFieldBE<38, 3, u64> channel_idx; + BitFieldBE<37, 1, u64> originality; + BitFieldBE<36, 1, u64> home; + BitFieldBE<35, 1, u64> copyright_id; + BitFieldBE<34, 1, u64> copyright_id_start; + BitFieldBE<21, 13, u64> frame_length; + BitFieldBE<10, 11, u64> buffer_fullness; + BitFieldBE<8, 2, u64> frame_count; + }; +}; + +ADTSData ParseADTS(const u8* buffer) { + ADTSHeader header; + memcpy(header.raw.data(), buffer, sizeof(header.raw)); // sync word 0xfff - tmp = (buffer[0] << 8) | (buffer[1] & 0xf0); - if ((tmp & 0xffff) != 0xfff0) { - out.length = 0; - return out; + if (header.sync_word != 0xfff) { + return {}; } + + ADTSData out{}; // bit 16 = no CRC - out.header_length = (buffer[1] & 0x1) ? 7 : 9; - out.MPEG2 = (buffer[1] >> 3) & 0x1; + out.header_length = header.protection_absent ? 7 : 9; + out.mpeg2 = static_cast(header.mpeg2); // bit 17 to 18 - out.profile = (buffer[2] >> 6) + 1; + out.profile = static_cast(header.profile) + 1; // bit 19 to 22 - tmp = (buffer[2] >> 2) & 0xf; - out.samplerate_idx = tmp; - out.samplerate = (tmp > 15) ? 0 : freq_table[tmp]; + out.samplerate_idx = static_cast(header.samplerate_idx); + out.samplerate = header.samplerate_idx > 15 ? 0 : freq_table[header.samplerate_idx]; // bit 24 to 26 - tmp = ((buffer[2] & 0x1) << 2) | ((buffer[3] >> 6) & 0x3); - out.channel_idx = tmp; - out.channels = (tmp > 7) ? 0 : channel_table[tmp]; - + out.channel_idx = static_cast(header.channel_idx); + out.channels = (header.channel_idx > 7) ? 0 : channel_table[header.channel_idx]; // bit 55 to 56 - out.framecount = (buffer[6] & 0x3) + 1; - + out.framecount = static_cast(header.frame_count + 1); // bit 31 to 43 - tmp = (buffer[3] & 0x3) << 11; - tmp |= (buffer[4] << 3) & 0x7f8; - tmp |= (buffer[5] >> 5) & 0x7; - - out.length = tmp; + out.length = static_cast(header.frame_length); return out; } @@ -61,3 +76,4 @@ u16 MFGetAACTag(const ADTSData& input) { return tag; } +} // namespace AudioCore diff --git a/src/audio_core/hle/audiotoolbox_decoder.cpp b/src/audio_core/hle/audiotoolbox_decoder.cpp index 83d4a41ae..122b4f21a 100644 --- a/src/audio_core/hle/audiotoolbox_decoder.cpp +++ b/src/audio_core/hle/audiotoolbox_decoder.cpp @@ -24,7 +24,7 @@ private: std::optional Decode(const BinaryMessage& request); void Clear(); - bool InitializeDecoder(ADTSData& adts_header); + bool InitializeDecoder(AudioCore::ADTSData& adts_header); static OSStatus DataFunc(AudioConverterRef in_audio_converter, u32* io_number_data_packets, AudioBufferList* io_data, @@ -33,7 +33,7 @@ private: Memory::MemorySystem& memory; - ADTSData adts_config; + AudioCore::ADTSData adts_config; AudioStreamBasicDescription output_format = {}; AudioConverterRef converter = nullptr; @@ -101,7 +101,7 @@ std::optional AudioToolboxDecoder::Impl::ProcessRequest( } } -bool AudioToolboxDecoder::Impl::InitializeDecoder(ADTSData& adts_header) { +bool AudioToolboxDecoder::Impl::InitializeDecoder(AudioCore::ADTSData& adts_header) { if (converter) { if (adts_config.channels == adts_header.channels && adts_config.samplerate == adts_header.samplerate) { @@ -183,8 +183,9 @@ std::optional AudioToolboxDecoder::Impl::Decode(const BinaryMessa return {}; } - auto data = memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR); - auto adts_header = ParseADTS(reinterpret_cast(data)); + const auto data = + memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR); + auto adts_header = AudioCore::ParseADTS(data); curr_data = data + adts_header.header_length; curr_data_len = request.decode_aac_request.size - adts_header.header_length; diff --git a/src/audio_core/hle/mediandk_decoder.cpp b/src/audio_core/hle/mediandk_decoder.cpp index 7c63b3f75..a6b39f7e5 100644 --- a/src/audio_core/hle/mediandk_decoder.cpp +++ b/src/audio_core/hle/mediandk_decoder.cpp @@ -27,7 +27,7 @@ public: ~Impl(); std::optional ProcessRequest(const BinaryMessage& request); - bool SetMediaType(const ADTSData& adts_data); + bool SetMediaType(const AudioCore::ADTSData& adts_data); private: std::optional Initalize(const BinaryMessage& request); @@ -36,8 +36,8 @@ private: Memory::MemorySystem& memory; std::unique_ptr decoder; // default: 2 channles, 48000 samplerate - ADTSData mADTSData{ - /*header_length*/ 7, /*MPEG2*/ false, /*profile*/ 2, + AudioCore::ADTSData mADTSData{ + /*header_length*/ 7, /*mpeg2*/ false, /*profile*/ 2, /*channels*/ 2, /*channel_idx*/ 2, /*framecount*/ 0, /*samplerate_idx*/ 3, /*length*/ 0, /*samplerate*/ 48000}; }; @@ -54,7 +54,7 @@ std::optional MediaNDKDecoder::Impl::Initalize(const BinaryMessag return response; } -bool MediaNDKDecoder::Impl::SetMediaType(const ADTSData& adts_data) { +bool MediaNDKDecoder::Impl::SetMediaType(const AudioCore::ADTSData& adts_data) { const char* mime = "audio/mp4a-latm"; if (decoder && mADTSData.profile == adts_data.profile && mADTSData.channel_idx == adts_data.channel_idx && @@ -141,8 +141,9 @@ std::optional MediaNDKDecoder::Impl::Decode(const BinaryMessage& return response; } - u8* data = memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR); - ADTSData adts_data = ParseADTS(reinterpret_cast(data)); + const u8* data = + memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR); + ADTSData adts_data = AudioCore::ParseADTS(data); SetMediaType(adts_data); response.decode_aac_response.sample_rate = GetSampleRateEnum(adts_data.samplerate); response.decode_aac_response.num_channels = adts_data.channels; diff --git a/src/audio_core/hle/wmf_decoder.cpp b/src/audio_core/hle/wmf_decoder.cpp index e49f86eb0..09043793b 100644 --- a/src/audio_core/hle/wmf_decoder.cpp +++ b/src/audio_core/hle/wmf_decoder.cpp @@ -23,7 +23,8 @@ private: std::optional Decode(const BinaryMessage& request); - MFOutputState DecodingLoop(ADTSData adts_header, std::array, 2>& out_streams); + MFOutputState DecodingLoop(AudioCore::ADTSData adts_header, + std::array, 2>& out_streams); bool transform_initialized = false; bool format_selected = false; @@ -139,7 +140,7 @@ std::optional WMFDecoder::Impl::Initalize(const BinaryMessage& re return response; } -MFOutputState WMFDecoder::Impl::DecodingLoop(ADTSData adts_header, +MFOutputState WMFDecoder::Impl::DecodingLoop(AudioCore::ADTSData adts_header, std::array, 2>& out_streams) { std::optional> output_buffer; @@ -210,14 +211,14 @@ std::optional WMFDecoder::Impl::Decode(const BinaryMessage& reque request.decode_aac_request.src_addr); return std::nullopt; } - u8* data = memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR); + const u8* data = + memory.GetFCRAMPointer(request.decode_aac_request.src_addr - Memory::FCRAM_PADDR); std::array, 2> out_streams; unique_mfptr sample; MFInputState input_status = MFInputState::OK; MFOutputState output_status = MFOutputState::OK; - std::optional adts_meta = - DetectMediaType((char*)data, request.decode_aac_request.size); + std::optional adts_meta = DetectMediaType(data, request.decode_aac_request.size); if (!adts_meta) { LOG_ERROR(Audio_DSP, "Unable to deduce decoding parameters from ADTS stream"); diff --git a/src/audio_core/hle/wmf_decoder_utils.cpp b/src/audio_core/hle/wmf_decoder_utils.cpp index e71bb2f19..6cdf73f69 100644 --- a/src/audio_core/hle/wmf_decoder_utils.cpp +++ b/src/audio_core/hle/wmf_decoder_utils.cpp @@ -110,8 +110,9 @@ unique_mfptr CreateSample(const void* data, DWORD len, DWORD alignmen return sample; } -bool SelectInputMediaType(IMFTransform* transform, int in_stream_id, const ADTSData& adts, - const UINT8* user_data, UINT32 user_data_len, GUID audio_format) { +bool SelectInputMediaType(IMFTransform* transform, int in_stream_id, + const AudioCore::ADTSData& adts, const UINT8* user_data, + UINT32 user_data_len, GUID audio_format) { HRESULT hr = S_OK; unique_mfptr t; @@ -190,12 +191,12 @@ bool SelectOutputMediaType(IMFTransform* transform, int out_stream_id, GUID audi return false; } -std::optional DetectMediaType(char* buffer, std::size_t len) { +std::optional DetectMediaType(const u8* buffer, std::size_t len) { if (len < 7) { return std::nullopt; } - ADTSData tmp; + AudioCore::ADTSData tmp; ADTSMeta result; // see https://docs.microsoft.com/en-us/windows/desktop/api/mmreg/ns-mmreg-heaacwaveinfo_tag // for the meaning of the byte array below @@ -207,7 +208,7 @@ std::optional DetectMediaType(char* buffer, std::size_t len) { UINT8 aac_tmp[] = {0x01, 0x00, 0xfe, 00, 00, 00, 00, 00, 00, 00, 00, 00, 0x00, 0x00}; uint16_t tag = 0; - tmp = ParseADTS(buffer); + tmp = AudioCore::ParseADTS(buffer); if (tmp.length == 0) { return std::nullopt; } @@ -215,7 +216,7 @@ std::optional DetectMediaType(char* buffer, std::size_t len) { tag = MFGetAACTag(tmp); aac_tmp[12] |= (tag & 0xff00) >> 8; aac_tmp[13] |= (tag & 0x00ff); - std::memcpy(&(result.ADTSHeader), &tmp, sizeof(ADTSData)); + std::memcpy(&(result.ADTSHeader), &tmp, sizeof(AudioCore::ADTSData)); std::memcpy(&(result.AACTag), aac_tmp, 14); return result; } diff --git a/src/audio_core/hle/wmf_decoder_utils.h b/src/audio_core/hle/wmf_decoder_utils.h index 77a12bef5..e515528d0 100644 --- a/src/audio_core/hle/wmf_decoder_utils.h +++ b/src/audio_core/hle/wmf_decoder_utils.h @@ -99,7 +99,7 @@ void ReportError(std::string msg, HRESULT hr); // data type for transferring ADTS metadata between functions struct ADTSMeta { - ADTSData ADTSHeader; + AudioCore::ADTSData ADTSHeader; u8 AACTag[14]; }; @@ -110,10 +110,10 @@ bool InitMFDLL(); unique_mfptr MFDecoderInit(GUID audio_format = MFAudioFormat_AAC); unique_mfptr CreateSample(const void* data, DWORD len, DWORD alignment = 1, LONGLONG duration = 0); -bool SelectInputMediaType(IMFTransform* transform, int in_stream_id, const ADTSData& adts, - const UINT8* user_data, UINT32 user_data_len, - GUID audio_format = MFAudioFormat_AAC); -std::optional DetectMediaType(char* buffer, std::size_t len); +bool SelectInputMediaType(IMFTransform* transform, int in_stream_id, + const AudioCore::ADTSData& adts, const UINT8* user_data, + UINT32 user_data_len, GUID audio_format = MFAudioFormat_AAC); +std::optional DetectMediaType(const u8* buffer, std::size_t len); bool SelectOutputMediaType(IMFTransform* transform, int out_stream_id, GUID audio_format = MFAudioFormat_PCM); void MFFlush(IMFTransform* transform); diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 026d0cf23..42adecd28 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -12,6 +12,7 @@ add_executable(tests core/memory/vm_manager.cpp precompiled_headers.h audio_core/hle/hle.cpp + audio_core/hle/adts_reader.cpp audio_core/lle/lle.cpp audio_core/audio_fixures.h audio_core/decoder_tests.cpp diff --git a/src/tests/audio_core/hle/adts_reader.cpp b/src/tests/audio_core/hle/adts_reader.cpp new file mode 100644 index 000000000..d4d3de2d8 --- /dev/null +++ b/src/tests/audio_core/hle/adts_reader.cpp @@ -0,0 +1,75 @@ +// Copyright 2023 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include +#include "audio_core/hle/adts.h" + +namespace { +constexpr std::array freq_table = {96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050, + 16000, 12000, 11025, 8000, 7350, 0, 0, 0}; +constexpr std::array channel_table = {0, 1, 2, 3, 4, 5, 6, 8}; + +AudioCore::ADTSData ParseADTS_Old(const unsigned char* buffer) { + u32 tmp = 0; + AudioCore::ADTSData out{}; + + // sync word 0xfff + tmp = (buffer[0] << 8) | (buffer[1] & 0xf0); + if ((tmp & 0xffff) != 0xfff0) { + out.length = 0; + return out; + } + // bit 16 = no CRC + out.header_length = (buffer[1] & 0x1) ? 7 : 9; + out.mpeg2 = (buffer[1] >> 3) & 0x1; + // bit 17 to 18 + out.profile = (buffer[2] >> 6) + 1; + // bit 19 to 22 + tmp = (buffer[2] >> 2) & 0xf; + out.samplerate_idx = tmp; + out.samplerate = (tmp > 15) ? 0 : freq_table[tmp]; + // bit 24 to 26 + tmp = ((buffer[2] & 0x1) << 2) | ((buffer[3] >> 6) & 0x3); + out.channel_idx = tmp; + out.channels = (tmp > 7) ? 0 : channel_table[tmp]; + + // bit 55 to 56 + out.framecount = (buffer[6] & 0x3) + 1; + + // bit 31 to 43 + tmp = (buffer[3] & 0x3) << 11; + tmp |= (buffer[4] << 3) & 0x7f8; + tmp |= (buffer[5] >> 5) & 0x7; + + out.length = tmp; + + return out; +} +} // namespace + +TEST_CASE("ParseADTS fuzz", "[audio_core][hle]") { + for (u32 i = 0; i < 0x10000; i++) { + std::array adts_header; + std::string adts_header_string = "ADTS Header: "; + for (auto& it : adts_header) { + it = static_cast(rand()); + adts_header_string.append(fmt::format("{:2X} ", it)); + } + INFO(adts_header_string); + + AudioCore::ADTSData out_old_impl = + ParseADTS_Old(reinterpret_cast(adts_header.data())); + AudioCore::ADTSData out = AudioCore::ParseADTS(adts_header.data()); + + REQUIRE(out_old_impl.length == out.length); + REQUIRE(out_old_impl.channels == out.channels); + REQUIRE(out_old_impl.channel_idx == out.channel_idx); + REQUIRE(out_old_impl.framecount == out.framecount); + REQUIRE(out_old_impl.header_length == out.header_length); + REQUIRE(out_old_impl.mpeg2 == out.mpeg2); + REQUIRE(out_old_impl.profile == out.profile); + REQUIRE(out_old_impl.samplerate == out.samplerate); + REQUIRE(out_old_impl.samplerate_idx == out.samplerate_idx); + } +}