diff --git a/.github/actions/spelling/excludes.txt b/.github/actions/spelling/excludes.txt index 81bfde27718..8d751a187ff 100644 --- a/.github/actions/spelling/excludes.txt +++ b/.github/actions/spelling/excludes.txt @@ -67,6 +67,7 @@ SUMS$ ^src/terminal/parser/ft_fuzzer/run\.bat$ ^src/terminal/parser/ft_fuzzer/VTCommandFuzzer\.cpp$ ^src/terminal/parser/ft_fuzzwrapper/run\.bat$ +^src/terminal/parser/ut_parser/Base64Test.cpp$ ^src/terminal/parser/ut_parser/run\.bat$ ^src/tools/integrity/packageuwp/ConsoleUWP\.appxSources$ ^src/tools/lnkd/lnkd\.bat$ diff --git a/src/terminal/parser/OutputStateMachineEngine.cpp b/src/terminal/parser/OutputStateMachineEngine.cpp index bc9299d0db3..948ffaa4ee7 100644 --- a/src/terminal/parser/OutputStateMachineEngine.cpp +++ b/src/terminal/parser/OutputStateMachineEngine.cpp @@ -1101,22 +1101,22 @@ bool OutputStateMachineEngine::_GetOscSetClipboard(const std::wstring_view strin std::wstring& content, bool& queryClipboard) const noexcept { - const size_t pos = string.find(';'); - if (pos != std::wstring_view::npos) + const auto pos = string.find(L';'); + if (pos == std::wstring_view::npos) { - const std::wstring_view substr = string.substr(pos + 1); - if (substr == L"?") - { - queryClipboard = true; - return true; - } - else - { - return Base64::s_Decode(substr, content); - } + return false; } - return false; + const auto substr = string.substr(pos + 1); + if (substr == L"?") + { + queryClipboard = true; + return true; + } + +// Log_IfFailed has the following description: "Should be decorated WI_NOEXCEPT, but conflicts with forceinline." +#pragma warning(suppress : 26447) // The function is declared 'noexcept' but calls function 'Log_IfFailed()' which may throw exceptions (f.6). + return SUCCEEDED_LOG(Base64::Decode(substr, content)); } // Method Description: diff --git a/src/terminal/parser/base64.cpp b/src/terminal/parser/base64.cpp index 48e22833ece..83a85028278 100644 --- a/src/terminal/parser/base64.cpp +++ b/src/terminal/parser/base64.cpp @@ -4,190 +4,155 @@ #include "precomp.h" #include "base64.hpp" -using namespace Microsoft::Console::VirtualTerminal; - -static const char base64Chars[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -static const char padChar = '='; - -#pragma warning(disable : 26446 26447 26482 26485 26493 26494) - -// Routine Description: -// - Encode a string using base64. When there are not enough characters -// for one quantum, paddings are added. -// Arguments: -// - src - String to base64 encode. -// Return Value: -// - the encoded string. -std::wstring Base64::s_Encode(const std::wstring_view src) noexcept -{ - std::wstring dst; - wchar_t input[3]; - - const auto len = (src.size() + 2) / 3 * 4; - if (len == 0) - { - return dst; - } - dst.reserve(len); - - auto iter = src.cbegin(); - // Encode each three chars into one quantum (four chars). - while (iter < src.cend() - 2) - { - input[0] = *iter++; - input[1] = *iter++; - input[2] = *iter++; - dst.push_back(base64Chars[input[0] >> 2]); - dst.push_back(base64Chars[(input[0] & 0x03) << 4 | input[1] >> 4]); - dst.push_back(base64Chars[(input[1] & 0x0f) << 2 | input[2] >> 6]); - dst.push_back(base64Chars[(input[2] & 0x3f)]); - } - - // Here only zero, or one, or two chars are left. We may need to add paddings. - if (iter < src.cend()) - { - input[0] = *iter++; - dst.push_back(base64Chars[input[0] >> 2]); - if (iter < src.cend()) // Two chars left. - { - input[1] = *iter++; - dst.push_back(base64Chars[(input[0] & 0x03) << 4 | input[1] >> 4]); - dst.push_back(base64Chars[(input[1] & 0x0f) << 2]); - } - else // Only one char left. - { - dst.push_back(base64Chars[(input[0] & 0x03) << 4]); - dst.push_back(padChar); - } - dst.push_back(padChar); - } +#pragma warning(disable : 26446) // Prefer to use gsl::at() instead of unchecked subscript operator (bounds.4). +// I didn't want to handle out of memory errors. There's no reasonable mode of +// operation for this application without the ability to allocate memory anyways. +#pragma warning(disable : 26447) // The function is declared 'noexcept' but calls function '...' which may throw exceptions (f.6). +#pragma warning(disable : 26481) // Don't use pointer arithmetic. Use span instead (bounds.1). +#pragma warning(disable : 26482) // Only index into arrays using constant expressions (bounds.2). - return dst; -} +using namespace Microsoft::Console::VirtualTerminal; -// Routine Description: -// - Decode a base64 string. This requires the base64 string is properly padded. -// Otherwise, false will be returned. -// Arguments: -// - src - String to decode. -// - dst - Destination to decode into. -// Return Value: -// - true if decoding successfully, otherwise false. -bool Base64::s_Decode(const std::wstring_view src, std::wstring& dst) noexcept +// clang-format off +static constexpr uint8_t decodeTable[128] = { + 255 /* NUL */, 255 /* SOH */, 255 /* STX */, 255 /* ETX */, 255 /* EOT */, 255 /* ENQ */, 255 /* ACK */, 255 /* BEL */, 255 /* BS */, 255 /* HT */, 255 /* LF */, 255 /* VT */, 255 /* FF */, 255 /* CR */, 255 /* SO */, 255 /* SI */, + 255 /* DLE */, 255 /* DC1 */, 255 /* DC2 */, 255 /* DC3 */, 255 /* DC4 */, 255 /* NAK */, 255 /* SYN */, 255 /* ETB */, 255 /* CAN */, 255 /* EM */, 255 /* SUB */, 255 /* ESC */, 255 /* FS */, 255 /* GS */, 255 /* RS */, 255 /* US */, + 255 /* SP */, 255 /* ! */, 255 /* " */, 255 /* # */, 255 /* $ */, 255 /* % */, 255 /* & */, 255 /* ' */, 255 /* ( */, 255 /* ) */, 255 /* * */, 62 /* + */, 255 /* , */, 62 /* - */, 255 /* . */, 63 /* / */, + 52 /* 0 */, 53 /* 1 */, 54 /* 2 */, 55 /* 3 */, 56 /* 4 */, 57 /* 5 */, 58 /* 6 */, 59 /* 7 */, 60 /* 8 */, 61 /* 9 */, 255 /* : */, 255 /* ; */, 255 /* < */, 255 /* = */, 255 /* > */, 255 /* ? */, + 255 /* @ */, 0 /* A */, 1 /* B */, 2 /* C */, 3 /* D */, 4 /* E */, 5 /* F */, 6 /* G */, 7 /* H */, 8 /* I */, 9 /* J */, 10 /* K */, 11 /* L */, 12 /* M */, 13 /* N */, 14 /* O */, + 15 /* P */, 16 /* Q */, 17 /* R */, 18 /* S */, 19 /* T */, 20 /* U */, 21 /* V */, 22 /* W */, 23 /* X */, 24 /* Y */, 25 /* Z */, 255 /* [ */, 255 /* \ */, 255 /* ] */, 255 /* ^ */, 63 /* _ */, + 255 /* ` */, 26 /* a */, 27 /* b */, 28 /* c */, 29 /* d */, 30 /* e */, 31 /* f */, 32 /* g */, 33 /* h */, 34 /* i */, 35 /* j */, 36 /* k */, 37 /* l */, 38 /* m */, 39 /* n */, 40 /* o */, + 41 /* p */, 42 /* q */, 43 /* r */, 44 /* s */, 45 /* t */, 46 /* u */, 47 /* v */, 48 /* w */, 49 /* x */, 50 /* y */, 51 /* z */, 255 /* { */, 255 /* | */, 255 /* } */, 255 /* ~ */, 255 /* DEL */, +}; +// clang-format on + +// Decodes an UTF8 string encoded with RFC 4648 (Base64) and returns it as UTF16 in dst. +// It supports both variants of the RFC (base64 and base64url), but +// throws an error for non-alphabet characters, including newlines. +// * Throws an exception for all invalid base64 inputs. +// * Doesn't support whitespace and will throw an exception for such strings. +// * Doesn't validate the number of trailing "=". Those are basically ignored. +// Strings like "YQ===" will be accepted as valid input and simply result in "a". +HRESULT Base64::Decode(const std::wstring_view& src, std::wstring& dst) noexcept { - std::string mbStr; - int state = 0; - char tmp; - - const auto len = src.size() / 4 * 3; - if (len == 0) + std::string result; + result.resize(((src.size() + 3) / 4) * 3); + + // in and inEnd may be nullptr if src.empty(). + // The remaining code in this function ensures not to read from in if src.empty(). +#pragma warning(suppress : 26429) // Symbol 'in' is never tested for nullness, it can be marked as not_null (f.23). + auto in = src.data(); + const auto inEnd = in + src.size(); + // Sometimes in programming you have to ask yourself what the right offset for a pointer is. + // Is 4 enough? Certainly not. 6 on the other hand is just way too much. Clearly 5 is just right. + // + // In all seriousness however the offset is 5, because the batched loop reads 4 characters at a time, + // a base64 string can end with two "=" and the batched loop doesn't handle any such "=". + // Additionally the while() condition of the batched loop would make a lot more sense if it were using <=, + // but for reasons outlined below it needs to use < so we need to add 1 back again. + // We thus get -4-2+1 which is -5. + // + // There's a special reason we need to use < and not <= for the loop: + // In C++ it's undefined behavior to perform any pointer arithmetic that leads to unallocated memory, + // which is why we can't just write `inEnd - 6` as that might be UB if `src.size()` is less than 6. + // We thus would need write `inEnd - min(6, src.size())` in combination with `<=` for the batched loop. + // But if `src.size()` is actually less than 6 then `inEnd` is equal to the initial `in`, aka: an empty range. + // In such cases we'd enter the batched loop and read from `in` despite us not wanting to enter the loop. + // We can fix the issue by using < instead and adding +1 to the offset. + // + // Yes this works. + const auto inEndBatched = inEnd - std::min(5, src.size()); + + // outBeg and out may be nullptr if src.empty(). + // The remaining code in this function ensures not to write to out if src.empty(). + const auto outBeg = result.data(); +#pragma warning(suppress : 26429) // Symbol 'out' is never tested for nullness, it can be marked as not_null (f.23). + auto out = outBeg; + + // r is just a generic "remainder" we use to accumulate 4 base64 chars into 3 output bytes. + uint_fast32_t r = 0; + // error is treated as a boolean. If it's not 0 we had an invalid input character. + uint_fast16_t error = 0; + + // Capturing r/error by reference produces less optimal assembly. + static constexpr auto accumulate = [](auto& r, auto& error, auto ch) { + // n will be in the range [0, 0x3f] for valid ch + // and exactly 0xff for invalid ch. + const auto n = decodeTable[ch & 0x7f]; + // Both ch > 0x7f, as well as n > 0x7f are invalid values and count as an error. + // We can add the error state by checking if any bits ~0x7f are set (which is 0xff80). + error |= (ch | n) & 0xff80; + r = r << 6 | n; + }; + + // If src.empty() then `in == inEndBatched == nullptr` and this is skipped. + while (in < inEndBatched) { - return false; + const auto ch0 = *in++; + const auto ch1 = *in++; + const auto ch2 = *in++; + const auto ch3 = *in++; + + // Most other base64 libraries do something like this: + // const auto n0 = decodeTable[a]; + // const auto n1 = decodeTable[b]; + // const auto n2 = decodeTable[c]; + // const auto n3 = decodeTable[d]; + // *out++ = n0 << 2 | n1 >> 4; + // *out++ = (n1 & 0xf) << 4 | n2 >> 2; + // *out++ = (n2 & 0x3) << 6 | n3; + // + // But on all modern CPUs I tested (well even those 10 years old at this point) shifting base64 + // characters into a single register (here: r) is faster than the traditional approach. + // I believe this is due to reducing the dependency of instructions on prior calculations. + accumulate(r, error, ch0); + accumulate(r, error, ch1); + accumulate(r, error, ch2); + accumulate(r, error, ch3); + + *out++ = gsl::narrow_cast(r >> 16); + *out++ = gsl::narrow_cast(r >> 8); + *out++ = gsl::narrow_cast(r >> 0); } - mbStr.reserve(len); - auto iter = src.cbegin(); - while (iter < src.cend()) { - if (s_IsSpace(*iter)) // Skip whitespace anywhere. - { - iter++; - continue; - } + uint_fast8_t ri = 0; - if (*iter == padChar) + // If src.empty() then `in == inEnd == nullptr` and this is skipped. + for (; in < inEnd; ++in) { - break; - } - - auto pos = strchr(base64Chars, *iter); - if (!pos) // A non-base64 character found. - { - return false; + if (const auto ch = *in; ch != '=') + { + accumulate(r, error, ch); + ri++; + } } - switch (state) + switch (ri) { - case 0: - tmp = (char)(pos - base64Chars) << 2; - state = 1; - break; - case 1: - tmp |= (char)(pos - base64Chars) >> 4; - mbStr += tmp; - tmp = (char)((pos - base64Chars) & 0x0f) << 4; - state = 2; - break; case 2: - tmp |= (char)(pos - base64Chars) >> 2; - mbStr += tmp; - tmp = (char)((pos - base64Chars) & 0x03) << 6; - state = 3; + *out++ = gsl::narrow_cast(r >> 4); break; case 3: - tmp |= pos - base64Chars; - mbStr += tmp; - state = 0; + *out++ = gsl::narrow_cast(r >> 10); + *out++ = gsl::narrow_cast(r >> 2); break; - default: - break; - } - - iter++; - } - - if (iter < src.cend()) // Padding char is met. - { - iter++; - switch (state) - { - // Invalid when state is 0 or 1. - case 0: - case 1: - return false; - case 2: - // Skip any number of spaces. - while (iter < src.cend() && s_IsSpace(*iter)) - { - iter++; - } - // Make sure there is another trailing padding character. - if (iter == src.cend() || *iter != padChar) - { - return false; - } - iter++; // Skip the padding character and fallthrough to "single trailing padding character" case. - [[fallthrough]]; - case 3: - while (iter < src.cend()) - { - if (!s_IsSpace(*iter)) - { - return false; - } - iter++; - } + case 4: + *out++ = gsl::narrow_cast(r >> 16); + *out++ = gsl::narrow_cast(r >> 8); + *out++ = gsl::narrow_cast(r >> 0); break; default: + error |= ri; break; } } - else if (state != 0) // When no padding, we must be in state 0. + + if (error) { - return false; + return HRESULT_FROM_WIN32(ERROR_INVALID_DATA); } - return SUCCEEDED(til::u8u16(mbStr, dst)); -} - -// Routine Description: -// - Check if parameter is a base64 whitespace. Only carriage return or line feed -// is valid whitespace. -// Arguments: -// - ch - Character to check. -// Return Value: -// - true iff ch is a carriage return or line feed. -constexpr bool Base64::s_IsSpace(const wchar_t ch) noexcept -{ - return ch == L'\r' || ch == L'\n'; + result.resize(out - outBeg); + return til::u8u16(result, dst); } diff --git a/src/terminal/parser/base64.hpp b/src/terminal/parser/base64.hpp index 976bcc5209a..da8f22ec509 100644 --- a/src/terminal/parser/base64.hpp +++ b/src/terminal/parser/base64.hpp @@ -16,10 +16,6 @@ namespace Microsoft::Console::VirtualTerminal class Base64 { public: - static std::wstring s_Encode(const std::wstring_view src) noexcept; - static bool s_Decode(const std::wstring_view src, std::wstring& dst) noexcept; - - private: - static constexpr bool s_IsSpace(const wchar_t ch) noexcept; + static HRESULT Decode(const std::wstring_view& src, std::wstring& dst) noexcept; }; } diff --git a/src/terminal/parser/ut_parser/Base64Test.cpp b/src/terminal/parser/ut_parser/Base64Test.cpp index e3d9e5f29fc..5374a303ac9 100644 --- a/src/terminal/parser/ut_parser/Base64Test.cpp +++ b/src/terminal/parser/ut_parser/Base64Test.cpp @@ -3,7 +3,8 @@ #include "precomp.h" #include "WexTestClass.h" -#include "../../inc/consoletaeftemplates.hpp" + +#include #include "base64.hpp" @@ -28,74 +29,90 @@ class Microsoft::Console::VirtualTerminal::Base64Test { TEST_CLASS(Base64Test); - TEST_METHOD(TestBase64Encode) - { - VERIFY_ARE_EQUAL(L"Zm9v", Base64::s_Encode(L"foo")); - VERIFY_ARE_EQUAL(L"Zm9vYg==", Base64::s_Encode(L"foob")); - VERIFY_ARE_EQUAL(L"Zm9vYmE=", Base64::s_Encode(L"fooba")); - VERIFY_ARE_EQUAL(L"Zm9vYmFy", Base64::s_Encode(L"foobar")); - VERIFY_ARE_EQUAL(L"Zm9vYmFyDQo=", Base64::s_Encode(L"foobar\r\n")); - } - - TEST_METHOD(TestBase64Decode) + TEST_METHOD(DecodeFuzz) { - std::wstring result; - bool success; - - success = Base64::s_Decode(L"Zm9v", result); - VERIFY_ARE_EQUAL(true, success); - VERIFY_ARE_EQUAL(L"foo", result); - - result = L""; - success = Base64::s_Decode(L"Zm9vYg==", result); - VERIFY_ARE_EQUAL(true, success); - VERIFY_ARE_EQUAL(L"foob", result); + // NOTE: Modify testRounds to get the feeling of running a fuzz test on Base64::Decode. + static constexpr auto testRounds = 8; + pcg_engines::oneseq_dxsm_64_32 rng{ til::gen_random() }; - result = L""; - success = Base64::s_Decode(L"Zm9vYmE=", result); - VERIFY_ARE_EQUAL(true, success); - VERIFY_ARE_EQUAL(L"fooba", result); - - result = L""; - success = Base64::s_Decode(L"Zm9vYmFy", result); - VERIFY_ARE_EQUAL(true, success); - VERIFY_ARE_EQUAL(L"foobar", result); + // Fills referenceData with random ASCII characters. + // We use ASCII as Base64::Decode uses til:u8u16 internally and I don't want to test that. + char referenceData[128]; + { + uint32_t randomData[sizeof(referenceData) / sizeof(uint32_t)]; + for (auto& i : randomData) + { + i = rng(); + } - result = L""; - success = Base64::s_Decode(L"Zm9vYmFyDQo=", result); - VERIFY_ARE_EQUAL(true, success); - VERIFY_ARE_EQUAL(L"foobar\r\n", result); + const std::string_view randomDataView{ reinterpret_cast(randomData), sizeof(randomData) }; + auto out = std::begin(referenceData); - result = L""; - success = Base64::s_Decode(L"Zm9v\rYmFy", result); - VERIFY_ARE_EQUAL(true, success); - VERIFY_ARE_EQUAL(L"foobar", result); + for (const auto& ch : randomDataView) + { + *out++ = static_cast(ch & 0x7f); + } + } - result = L""; - success = Base64::s_Decode(L"Zm9v\r\nYmFy\n", result); - VERIFY_ARE_EQUAL(true, success); - VERIFY_ARE_EQUAL(L"foobar", result); + wchar_t wideReferenceData[std::size(referenceData)]; + std::copy_n(std::begin(referenceData), std::size(referenceData), std::begin(wideReferenceData)); - success = Base64::s_Decode(L"Z", result); - VERIFY_ARE_EQUAL(false, success); + std::wstring encoded; + std::wstring decoded; - success = Base64::s_Decode(L"Zm9vYg", result); - VERIFY_ARE_EQUAL(false, success); + for (auto i = 0; i < testRounds; ++i) + { + const auto referenceLength = rng(static_cast(std::size(referenceData))); + const std::wstring_view wideReference{ std::begin(wideReferenceData), referenceLength }; + + if (!referenceLength) + { + encoded.clear(); + } + else + { + const auto reference = reinterpret_cast(std::begin(referenceData)); + DWORD encodedLen; + THROW_IF_WIN32_BOOL_FALSE(CryptBinaryToStringW(reference, referenceLength, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, nullptr, &encodedLen)); + + // encodedLen is returned by CryptBinaryToStringW including the trailing null byte. + encoded.resize(encodedLen - 1); + + THROW_IF_WIN32_BOOL_FALSE(CryptBinaryToStringW(reference, referenceLength, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, encoded.data(), &encodedLen)); + } + + // Test whether Decode() handles strings with and without trailing "=". + if (rng(2)) + { + while (!encoded.empty() && encoded.back() == '=') + { + encoded.pop_back(); + } + } + + // Test whether Decode() handles null-pointer arguments correctly. + std::wstring_view encodedView{ encoded }; + if (encodedView.empty() && rng(2)) + { + encodedView = {}; + } + + Base64::Decode(encodedView, decoded); + VERIFY_ARE_EQUAL(wideReference, decoded); + } + } - success = Base64::s_Decode(L"Zm9vYg=", result); - VERIFY_ARE_EQUAL(false, success); + TEST_METHOD(DecodeUTF8) + { + std::wstring result; // U+306b U+307b U+3093 U+3054 U+6c49 U+8bed U+d55c U+ad6d - result = L""; - success = Base64::s_Decode(L"44Gr44G744KT44GU5rGJ6K+t7ZWc6rWt", result); - VERIFY_ARE_EQUAL(true, success); + Base64::Decode(L"44Gr44G744KT44GU5rGJ6K+t7ZWc6rWt", result); VERIFY_ARE_EQUAL(L"にほんご汉语한국", result); // U+d83d U+dc4d U+d83d U+dc4d U+d83c U+dffb U+d83d U+dc4d U+d83c U+dffc U+d83d // U+dc4d U+d83c U+dffd U+d83d U+dc4d U+d83c U+dffe U+d83d U+dc4d U+d83c U+dfff - result = L""; - success = Base64::s_Decode(L"8J+RjfCfkY3wn4+78J+RjfCfj7zwn5GN8J+PvfCfkY3wn4++8J+RjfCfj78=", result); - VERIFY_ARE_EQUAL(true, success); + Base64::Decode(L"8J+RjfCfkY3wn4+78J+RjfCfj7zwn5GN8J+PvfCfkY3wn4++8J+RjfCfj78=", result); VERIFY_ARE_EQUAL(L"👍👍🏻👍🏼👍🏽👍🏾👍🏿", result); } }; diff --git a/src/terminal/parser/ut_parser/OutputEngineTest.cpp b/src/terminal/parser/ut_parser/OutputEngineTest.cpp index c5b9647312d..f9471f47fa2 100644 --- a/src/terminal/parser/ut_parser/OutputEngineTest.cpp +++ b/src/terminal/parser/ut_parser/OutputEngineTest.cpp @@ -3255,7 +3255,7 @@ class StateMachineExternalTest final pDispatch->_copyContent = L"UNCHANGED"; // Passing a non-base64 `Pd` param is illegal, won't change the content. - mach.ProcessString(L"\x1b]52;;foo\x07"); + mach.ProcessString(L"\x1b]52;;???\x07"); VERIFY_ARE_EQUAL(L"UNCHANGED", pDispatch->_copyContent); pDispatch->ClearState();