Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Base64::Decode performance #11467

Merged
8 commits merged into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/actions/spelling/excludes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ SUMS$
^src/terminal/parser/ft_fuzzer/run\.bat$
^src/terminal/parser/ft_fuzzer/VTCommandFuzzer\.cpp$
^src/terminal/parser/ft_fuzzwrapper/run\.bat$
^src/terminal/parser/ut_parser/Base64Test.cpp$
^src/terminal/parser/ut_parser/run\.bat$
^src/tools/integrity/packageuwp/ConsoleUWP\.appxSources$
^src/tools/lnkd/lnkd\.bat$
Expand Down
26 changes: 13 additions & 13 deletions src/terminal/parser/OutputStateMachineEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,22 +1101,22 @@ bool OutputStateMachineEngine::_GetOscSetClipboard(const std::wstring_view strin
std::wstring& content,
bool& queryClipboard) const noexcept
{
const size_t pos = string.find(';');
if (pos != std::wstring_view::npos)
const auto pos = string.find(L';');
if (pos == std::wstring_view::npos)
{
const std::wstring_view substr = string.substr(pos + 1);
if (substr == L"?")
{
queryClipboard = true;
return true;
}
else
{
return Base64::s_Decode(substr, content);
}
return false;
}

return false;
const auto substr = string.substr(pos + 1);
if (substr == L"?")
{
queryClipboard = true;
return true;
}

// Log_IfFailed has the following description: "Should be decorated WI_NOEXCEPT, but conflicts with forceinline."
#pragma warning(suppress : 26447) // The function is declared 'noexcept' but calls function 'Log_IfFailed()' which may throw exceptions (f.6).
return SUCCEEDED_LOG(Base64::Decode(substr, content));
}

// Method Description:
Expand Down
289 changes: 127 additions & 162 deletions src/terminal/parser/base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,190 +4,155 @@
#include "precomp.h"
#include "base64.hpp"

using namespace Microsoft::Console::VirtualTerminal;

static const char base64Chars[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static const char padChar = '=';

#pragma warning(disable : 26446 26447 26482 26485 26493 26494)

// Routine Description:
// - Encode a string using base64. When there are not enough characters
// for one quantum, paddings are added.
// Arguments:
// - src - String to base64 encode.
// Return Value:
// - the encoded string.
std::wstring Base64::s_Encode(const std::wstring_view src) noexcept
{
std::wstring dst;
wchar_t input[3];

const auto len = (src.size() + 2) / 3 * 4;
if (len == 0)
{
return dst;
}
dst.reserve(len);

auto iter = src.cbegin();
// Encode each three chars into one quantum (four chars).
while (iter < src.cend() - 2)
{
input[0] = *iter++;
input[1] = *iter++;
input[2] = *iter++;
dst.push_back(base64Chars[input[0] >> 2]);
dst.push_back(base64Chars[(input[0] & 0x03) << 4 | input[1] >> 4]);
dst.push_back(base64Chars[(input[1] & 0x0f) << 2 | input[2] >> 6]);
dst.push_back(base64Chars[(input[2] & 0x3f)]);
}

// Here only zero, or one, or two chars are left. We may need to add paddings.
if (iter < src.cend())
{
input[0] = *iter++;
dst.push_back(base64Chars[input[0] >> 2]);
if (iter < src.cend()) // Two chars left.
{
input[1] = *iter++;
dst.push_back(base64Chars[(input[0] & 0x03) << 4 | input[1] >> 4]);
dst.push_back(base64Chars[(input[1] & 0x0f) << 2]);
}
else // Only one char left.
{
dst.push_back(base64Chars[(input[0] & 0x03) << 4]);
dst.push_back(padChar);
}
dst.push_back(padChar);
}
#pragma warning(disable : 26446) // Prefer to use gsl::at() instead of unchecked subscript operator (bounds.4).
// I didn't want to handle out of memory errors. There's no reasonable mode of
// operation for this application without the ability to allocate memory anyways.
#pragma warning(disable : 26447) // The function is declared 'noexcept' but calls function '...' which may throw exceptions (f.6).
#pragma warning(disable : 26481) // Don't use pointer arithmetic. Use span instead (bounds.1).
#pragma warning(disable : 26482) // Only index into arrays using constant expressions (bounds.2).

return dst;
}
using namespace Microsoft::Console::VirtualTerminal;

// Routine Description:
// - Decode a base64 string. This requires the base64 string is properly padded.
// Otherwise, false will be returned.
// Arguments:
// - src - String to decode.
// - dst - Destination to decode into.
// Return Value:
// - true if decoding successfully, otherwise false.
bool Base64::s_Decode(const std::wstring_view src, std::wstring& dst) noexcept
// clang-format off
static constexpr uint8_t decodeTable[128] = {
255 /* NUL */, 255 /* SOH */, 255 /* STX */, 255 /* ETX */, 255 /* EOT */, 255 /* ENQ */, 255 /* ACK */, 255 /* BEL */, 255 /* BS */, 255 /* HT */, 255 /* LF */, 255 /* VT */, 255 /* FF */, 255 /* CR */, 255 /* SO */, 255 /* SI */,
255 /* DLE */, 255 /* DC1 */, 255 /* DC2 */, 255 /* DC3 */, 255 /* DC4 */, 255 /* NAK */, 255 /* SYN */, 255 /* ETB */, 255 /* CAN */, 255 /* EM */, 255 /* SUB */, 255 /* ESC */, 255 /* FS */, 255 /* GS */, 255 /* RS */, 255 /* US */,
255 /* SP */, 255 /* ! */, 255 /* " */, 255 /* # */, 255 /* $ */, 255 /* % */, 255 /* & */, 255 /* ' */, 255 /* ( */, 255 /* ) */, 255 /* * */, 62 /* + */, 255 /* , */, 62 /* - */, 255 /* . */, 63 /* / */,
52 /* 0 */, 53 /* 1 */, 54 /* 2 */, 55 /* 3 */, 56 /* 4 */, 57 /* 5 */, 58 /* 6 */, 59 /* 7 */, 60 /* 8 */, 61 /* 9 */, 255 /* : */, 255 /* ; */, 255 /* < */, 255 /* = */, 255 /* > */, 255 /* ? */,
255 /* @ */, 0 /* A */, 1 /* B */, 2 /* C */, 3 /* D */, 4 /* E */, 5 /* F */, 6 /* G */, 7 /* H */, 8 /* I */, 9 /* J */, 10 /* K */, 11 /* L */, 12 /* M */, 13 /* N */, 14 /* O */,
15 /* P */, 16 /* Q */, 17 /* R */, 18 /* S */, 19 /* T */, 20 /* U */, 21 /* V */, 22 /* W */, 23 /* X */, 24 /* Y */, 25 /* Z */, 255 /* [ */, 255 /* \ */, 255 /* ] */, 255 /* ^ */, 63 /* _ */,
255 /* ` */, 26 /* a */, 27 /* b */, 28 /* c */, 29 /* d */, 30 /* e */, 31 /* f */, 32 /* g */, 33 /* h */, 34 /* i */, 35 /* j */, 36 /* k */, 37 /* l */, 38 /* m */, 39 /* n */, 40 /* o */,
41 /* p */, 42 /* q */, 43 /* r */, 44 /* s */, 45 /* t */, 46 /* u */, 47 /* v */, 48 /* w */, 49 /* x */, 50 /* y */, 51 /* z */, 255 /* { */, 255 /* | */, 255 /* } */, 255 /* ~ */, 255 /* DEL */,
};
// clang-format on

// Decodes an UTF8 string encoded with RFC 4648 (Base64) and returns it as UTF16 in dst.
// It supports both variants of the RFC (base64 and base64url), but
// throws an error for non-alphabet characters, including newlines.
// * Throws an exception for all invalid base64 inputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well,

// * Doesn't support whitespace and will throw an exception for such strings.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i bet that this will come bite us later, but i am willing to take that risk

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it’s "return ERROR_INVALID_DATA" now. I‘ll update the comment.
Is there anything else you’re worried about?

// * Doesn't validate the number of trailing "=". Those are basically ignored.
// Strings like "YQ===" will be accepted as valid input and simply result in "a".
HRESULT Base64::Decode(const std::wstring_view& src, std::wstring& dst) noexcept
{
std::string mbStr;
int state = 0;
char tmp;

const auto len = src.size() / 4 * 3;
if (len == 0)
std::string result;
result.resize(((src.size() + 3) / 4) * 3);

// in and inEnd may be nullptr if src.empty().
// The remaining code in this function ensures not to read from in if src.empty().
#pragma warning(suppress : 26429) // Symbol 'in' is never tested for nullness, it can be marked as not_null (f.23).
auto in = src.data();
const auto inEnd = in + src.size();
// Sometimes in programming you have to ask yourself what the right offset for a pointer is.
// Is 4 enough? Certainly not. 6 on the other hand is just way too much. Clearly 5 is just right.
//
// In all seriousness however the offset is 5, because the batched loop reads 4 characters at a time,
// a base64 string can end with two "=" and the batched loop doesn't handle any such "=".
// Additionally the while() condition of the batched loop would make a lot more sense if it were using <=,
// but for reasons outlined below it needs to use < so we need to add 1 back again.
// We thus get -4-2+1 which is -5.
//
// There's a special reason we need to use < and not <= for the loop:
// In C++ it's undefined behavior to perform any pointer arithmetic that leads to unallocated memory,
// which is why we can't just write `inEnd - 6` as that might be UB if `src.size()` is less than 6.
// We thus would need write `inEnd - min(6, src.size())` in combination with `<=` for the batched loop.
// But if `src.size()` is actually less than 6 then `inEnd` is equal to the initial `in`, aka: an empty range.
// In such cases we'd enter the batched loop and read from `in` despite us not wanting to enter the loop.
// We can fix the issue by using < instead and adding +1 to the offset.
//
// Yes this works.
const auto inEndBatched = inEnd - std::min<size_t>(5, src.size());

// outBeg and out may be nullptr if src.empty().
// The remaining code in this function ensures not to write to out if src.empty().
const auto outBeg = result.data();
#pragma warning(suppress : 26429) // Symbol 'out' is never tested for nullness, it can be marked as not_null (f.23).
auto out = outBeg;

// r is just a generic "remainder" we use to accumulate 4 base64 chars into 3 output bytes.
uint_fast32_t r = 0;
// error is treated as a boolean. If it's not 0 we had an invalid input character.
uint_fast16_t error = 0;

// Capturing r/error by reference produces less optimal assembly.
static constexpr auto accumulate = [](auto& r, auto& error, auto ch) {
// n will be in the range [0, 0x3f] for valid ch
// and exactly 0xff for invalid ch.
const auto n = decodeTable[ch & 0x7f];
// Both ch > 0x7f, as well as n > 0x7f are invalid values and count as an error.
// We can add the error state by checking if any bits ~0x7f are set (which is 0xff80).
error |= (ch | n) & 0xff80;
r = r << 6 | n;
};

// If src.empty() then `in == inEndBatched == nullptr` and this is skipped.
while (in < inEndBatched)
{
return false;
const auto ch0 = *in++;
const auto ch1 = *in++;
const auto ch2 = *in++;
const auto ch3 = *in++;

// Most other base64 libraries do something like this:
// const auto n0 = decodeTable[a];
// const auto n1 = decodeTable[b];
// const auto n2 = decodeTable[c];
// const auto n3 = decodeTable[d];
// *out++ = n0 << 2 | n1 >> 4;
// *out++ = (n1 & 0xf) << 4 | n2 >> 2;
// *out++ = (n2 & 0x3) << 6 | n3;
//
// But on all modern CPUs I tested (well even those 10 years old at this point) shifting base64
// characters into a single register (here: r) is faster than the traditional approach.
// I believe this is due to reducing the dependency of instructions on prior calculations.
accumulate(r, error, ch0);
accumulate(r, error, ch1);
accumulate(r, error, ch2);
accumulate(r, error, ch3);

*out++ = gsl::narrow_cast<char>(r >> 16);
*out++ = gsl::narrow_cast<char>(r >> 8);
*out++ = gsl::narrow_cast<char>(r >> 0);
}
mbStr.reserve(len);

auto iter = src.cbegin();
while (iter < src.cend())
{
if (s_IsSpace(*iter)) // Skip whitespace anywhere.
{
iter++;
continue;
}
uint_fast8_t ri = 0;

if (*iter == padChar)
// If src.empty() then `in == inEnd == nullptr` and this is skipped.
for (; in < inEnd; ++in)
{
break;
}

auto pos = strchr(base64Chars, *iter);
if (!pos) // A non-base64 character found.
{
return false;
if (const auto ch = *in; ch != '=')
{
accumulate(r, error, ch);
ri++;
}
}

switch (state)
switch (ri)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ri => remainder index?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I called r because it's the accumulation "register" (since on a CPU level r will almost certainly live inside a register most of the time).
And since i is usually a counter for something I called it ri, the "register index-counter-thingy".

"remainder index" however is much better. I'll steal that. 😄

{
case 0:
tmp = (char)(pos - base64Chars) << 2;
state = 1;
break;
case 1:
tmp |= (char)(pos - base64Chars) >> 4;
mbStr += tmp;
tmp = (char)((pos - base64Chars) & 0x0f) << 4;
state = 2;
break;
case 2:
tmp |= (char)(pos - base64Chars) >> 2;
mbStr += tmp;
tmp = (char)((pos - base64Chars) & 0x03) << 6;
state = 3;
*out++ = gsl::narrow_cast<char>(r >> 4);
break;
case 3:
tmp |= pos - base64Chars;
mbStr += tmp;
state = 0;
*out++ = gsl::narrow_cast<char>(r >> 10);
*out++ = gsl::narrow_cast<char>(r >> 2);
break;
default:
break;
}

iter++;
}

if (iter < src.cend()) // Padding char is met.
{
iter++;
switch (state)
{
// Invalid when state is 0 or 1.
case 0:
case 1:
return false;
case 2:
// Skip any number of spaces.
while (iter < src.cend() && s_IsSpace(*iter))
{
iter++;
}
// Make sure there is another trailing padding character.
if (iter == src.cend() || *iter != padChar)
{
return false;
}
iter++; // Skip the padding character and fallthrough to "single trailing padding character" case.
[[fallthrough]];
case 3:
while (iter < src.cend())
{
if (!s_IsSpace(*iter))
{
return false;
}
iter++;
}
case 4:
*out++ = gsl::narrow_cast<char>(r >> 16);
*out++ = gsl::narrow_cast<char>(r >> 8);
*out++ = gsl::narrow_cast<char>(r >> 0);
break;
default:
error |= ri;
break;
}
}
else if (state != 0) // When no padding, we must be in state 0.

if (error)
{
return false;
return HRESULT_FROM_WIN32(ERROR_INVALID_DATA);
}

return SUCCEEDED(til::u8u16(mbStr, dst));
}

// Routine Description:
// - Check if parameter is a base64 whitespace. Only carriage return or line feed
// is valid whitespace.
// Arguments:
// - ch - Character to check.
// Return Value:
// - true iff ch is a carriage return or line feed.
constexpr bool Base64::s_IsSpace(const wchar_t ch) noexcept
{
return ch == L'\r' || ch == L'\n';
result.resize(out - outBeg);
return til::u8u16(result, dst);
}
6 changes: 1 addition & 5 deletions src/terminal/parser/base64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ namespace Microsoft::Console::VirtualTerminal
class Base64
{
public:
static std::wstring s_Encode(const std::wstring_view src) noexcept;
static bool s_Decode(const std::wstring_view src, std::wstring& dst) noexcept;

private:
static constexpr bool s_IsSpace(const wchar_t ch) noexcept;
static HRESULT Decode(const std::wstring_view& src, std::wstring& dst) noexcept;
};
}
Loading