Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Make utf8 validation ~10-13x faster on neon and sse4 #3401

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
47c4175
[stdlib] Add new method `SIMD.dynamic_shuffle()`
gabrieldemarmiesse Aug 19, 2024
aa7ffa0
Typos and format
gabrieldemarmiesse Aug 19, 2024
6b81b8c
Add support for neon
gabrieldemarmiesse Aug 20, 2024
f5c1127
Make it work for any size with recursivity
gabrieldemarmiesse Aug 20, 2024
cbe544b
Fix some things, better tests
gabrieldemarmiesse Aug 20, 2024
471a5e0
Add a working implementation of full simd utf8 validation
gabrieldemarmiesse Aug 20, 2024
8a4fc36
Merge branch 'nightly' into add_dynamic_shuffle
gabrieldemarmiesse Aug 22, 2024
1fb3b11
Add nodebug and use rebind
gabrieldemarmiesse Aug 22, 2024
2d7064a
Merge branch 'add_dynamic_shuffle' into faster_utf8_validation
gabrieldemarmiesse Aug 22, 2024
e58d55c
Added no_inline for a 10% speedup
gabrieldemarmiesse Aug 22, 2024
9adfa5d
[stdlib] Add more utf-8 validation unit tests
gabrieldemarmiesse Aug 22, 2024
efdef33
Add todo
gabrieldemarmiesse Aug 22, 2024
84b4f60
Merge branch 'add_more_tests_for_validating_utf8' into faster_utf8_va…
gabrieldemarmiesse Aug 22, 2024
a53fc63
Better formatting and some notes
gabrieldemarmiesse Aug 22, 2024
275eb3e
Some more notes
gabrieldemarmiesse Aug 22, 2024
d3c607c
Add todos about sse3
gabrieldemarmiesse Aug 22, 2024
9587d29
Merge branch 'add_dynamic_shuffle' into faster_utf8_validation
gabrieldemarmiesse Aug 22, 2024
1423dde
Use the C# implementation as reference for validation
gabrieldemarmiesse Aug 23, 2024
d607d60
Some renaming to conform in Mojo style guide
gabrieldemarmiesse Aug 23, 2024
0fc014e
Move references to the top of the file
gabrieldemarmiesse Aug 23, 2024
2513e10
Merge branch 'nightly' into faster_utf8_validation
gabrieldemarmiesse Aug 25, 2024
c483e1c
Merge branch 'nightly' into faster_utf8_validation
gabrieldemarmiesse Sep 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,7 @@ struct SIMD[type: DType, size: Int](

# Not an overload of shuffle because there is ambiguity
# with fn shuffle[*mask: Int](self, other: Self) -> Self:
# TODO: move closer to UTF-8 String validation code - see https://github.com/modularml/mojo/issues/3477
# TODO: move to the utils directory - see https://github.com/modularml/mojo/issues/3477
@always_inline
fn _dynamic_shuffle[
mask_size: Int, //
Expand Down
183 changes: 183 additions & 0 deletions stdlib/src/utils/_utf8_validation.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

gabrieldemarmiesse marked this conversation as resolved.
Show resolved Hide resolved
"""Implement fast utf-8 validation using SIMD instructions.

References for this algorithm:
J. Keiser, D. Lemire, Validating UTF-8 In Less Than One Instruction Per Byte,
Software: Practice and Experience 51 (5), 2021
https://arxiv.org/abs/2010.03090

Blog post:
https://lemire.me/blog/2018/10/19/validating-utf-8-bytes-using-only-0-45-cycles-per-byte-avx-edition/

Code adapted from:
https://github.com/simdutf/SimdUnicode/blob/main/src/UTF8.cs
"""

alias TOO_SHORT: UInt8 = 1 << 0
alias TOO_LONG: UInt8 = 1 << 1
alias OVERLONG_3: UInt8 = 1 << 2
alias SURROGATE: UInt8 = 1 << 4
alias OVERLONG_2: UInt8 = 1 << 5
alias TWO_CONTS: UInt8 = 1 << 7
alias TOO_LARGE: UInt8 = 1 << 3
alias TOO_LARGE_1000: UInt8 = 1 << 6
alias OVERLONG_4: UInt8 = 1 << 6
alias CARRY: UInt8 = TOO_SHORT | TOO_LONG | TWO_CONTS


# fmt: off
alias shuf1 = SIMD[DType.uint8, 16](
TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS,
TOO_SHORT | OVERLONG_2,
TOO_SHORT,
TOO_SHORT | OVERLONG_3 | SURROGATE,
TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4
)

alias shuf2 = SIMD[DType.uint8, 16](
CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
CARRY | OVERLONG_2,
CARRY,
CARRY,
CARRY | TOO_LARGE,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000
)
alias shuf3 = SIMD[DType.uint8, 16](
TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4,
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE,
TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE,
TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE,
TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT
)
# fmt: on


@always_inline
fn extract_vector[
simd_size: Int, //, offset: Int
](a: SIMD[DType.uint8, simd_size], b: SIMD[DType.uint8, simd_size]) -> SIMD[
DType.uint8, simd_size
]:
"""This can be a single instruction on some architectures."""
concatenated = a.join(b)
return concatenated.slice[simd_size, offset=offset]()


@always_inline
fn _subtract_with_saturation[
simd_size: Int, //, b: Int
](a: SIMD[DType.uint8, simd_size]) -> SIMD[DType.uint8, simd_size]:
"""The equivalent of https://doc.rust-lang.org/core/arch/x86_64/fn._mm_subs_epu8.html .
This can be a single instruction on some architectures.
"""
alias b_as_vector = SIMD[DType.uint8, simd_size](b)
return max(a, b_as_vector) - b_as_vector


fn validate_chunk[
simd_size: Int
](
current_block: SIMD[DType.uint8, simd_size],
previous_input_block: SIMD[DType.uint8, simd_size],
) -> SIMD[DType.uint8, simd_size]:
alias v0f = SIMD[DType.uint8, simd_size](0x0F)
alias v80 = SIMD[DType.uint8, simd_size](0x80)
alias third_byte = 0b11100000 - 0x80
alias fourth_byte = 0b11110000 - 0x80
var prev1 = extract_vector[simd_size - 1](
previous_input_block, current_block
)
var byte_1_high = shuf1._dynamic_shuffle(prev1 >> 4)
var byte_1_low = shuf2._dynamic_shuffle(prev1 & v0f)
var byte_2_high = shuf3._dynamic_shuffle(current_block >> 4)
var sc = byte_1_high & byte_1_low & byte_2_high

var prev2 = extract_vector[simd_size - 2](
previous_input_block, current_block
)
var prev3 = extract_vector[simd_size - 3](
previous_input_block, current_block
)
var is_third_byte = _subtract_with_saturation[third_byte](prev2)
var is_fourth_byte = _subtract_with_saturation[fourth_byte](prev3)
var must23 = is_third_byte | is_fourth_byte
var must23_as_80 = must23 & v80
return must23_as_80 ^ sc


fn _is_valid_utf8(ptr: UnsafePointer[UInt8], length: Int) -> Bool:
"""Verify that the bytes are valid UTF-8.

Args:
ptr: The pointer to the data.
length: The length of the items pointed to.

Returns:
Whether the data is valid UTF-8.

#### UTF-8 coding format
[Table 3-7 page 94](http://www.unicode.org/versions/Unicode6.0.0/ch03.pdf).
Well-Formed UTF-8 Byte Sequences

Code Points | First Byte | Second Byte | Third Byte | Fourth Byte |
:---------- | :--------- | :---------- | :--------- | :---------- |
U+0000..U+007F | 00..7F | | | |
U+0080..U+07FF | C2..DF | 80..BF | | |
U+0800..U+0FFF | E0 | ***A0***..BF| 80..BF | |
U+1000..U+CFFF | E1..EC | 80..BF | 80..BF | |
U+D000..U+D7FF | ED | 80..***9F***| 80..BF | |
U+E000..U+FFFF | EE..EF | 80..BF | 80..BF | |
U+10000..U+3FFFF | F0 | ***90***..BF| 80..BF | 80..BF |
U+40000..U+FFFFF | F1..F3 | 80..BF | 80..BF | 80..BF |
U+100000..U+10FFFF | F4 | 80..***8F***| 80..BF | 80..BF |
"""
alias simd_size = sys.simdbytewidth()
var i: Int = 0
var previous = SIMD[DType.uint8, simd_size]()

while i + simd_size <= length:
var current_bytes = (ptr + i).load[width=simd_size]()
var has_error = validate_chunk(current_bytes, previous)
previous = current_bytes
if any(has_error != 0):
return False
i += simd_size

var has_error = SIMD[DType.uint8, simd_size]()
# last incomplete chunk
if i != length:
var buffer = SIMD[DType.uint8, simd_size](0)
for j in range(i, length):
buffer[j - i] = (ptr + j)[]
has_error = validate_chunk(buffer, previous)
else:
# Add a chunk of 0s to the end to validate continuations bytes
has_error = validate_chunk(SIMD[DType.uint8, simd_size](), previous)

return all(has_error == 0)
146 changes: 0 additions & 146 deletions stdlib/src/utils/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -48,152 +48,6 @@ fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b):
return count_leading_zeros(~(b & UInt8(0b1111_0000)))


fn _validate_utf8_simd_slice[
width: Int, remainder: Bool = False
](ptr: UnsafePointer[UInt8], length: Int, owned iter_len: Int) -> Int:
"""Internal method to validate utf8, use _is_valid_utf8.

Parameters:
width: The width of the SIMD vector to build for validation.
remainder: Whether it is computing the remainder that doesn't fit in the
SIMD vector.

Args:
ptr: Pointer to the data.
length: The length of the items in the pointer.
iter_len: The amount of items to still iterate through.

Returns:
The new amount of items to iterate through that don't fit in the
specified width of SIMD vector. If -1 then it is invalid.
"""
# TODO: implement a faster algorithm like https://github.com/cyb70289/utf8
# and benchmark the difference.
var idx = length - iter_len
while iter_len >= width or remainder:
var d: SIMD[DType.uint8, width] # use a vector of the specified width

@parameter
if not remainder:
d = ptr.load[width=width](idx)
else:
debug_assert(iter_len > -1, "iter_len must be > -1")
d = SIMD[DType.uint8, width](0)
for i in range(iter_len):
d[i] = ptr[idx + i]

var is_ascii = d < 0b1000_0000
if is_ascii.reduce_and(): # skip all ASCII bytes

@parameter
if not remainder:
idx += width
iter_len -= width
continue
else:
return 0
elif is_ascii[0]:
for i in range(1, width):
if is_ascii[i]:
continue
idx += i
iter_len -= i
break
continue

var byte_types = _utf8_byte_type(d)
var first_byte_type = byte_types[0]

# byte_type has to match against the amount of continuation bytes
alias Vec = SIMD[DType.uint8, 4]
alias n4_byte_types = Vec(4, 1, 1, 1)
alias n3_byte_types = Vec(3, 1, 1, 0)
alias n3_mask = Vec(0b111, 0b111, 0b111, 0)
alias n2_byte_types = Vec(2, 1, 0, 0)
alias n2_mask = Vec(0b111, 0b111, 0, 0)
var byte_types_4 = byte_types.slice[4]()
var valid_n4 = (byte_types_4 == n4_byte_types).reduce_and()
var valid_n3 = ((byte_types_4 & n3_mask) == n3_byte_types).reduce_and()
var valid_n2 = ((byte_types_4 & n2_mask) == n2_byte_types).reduce_and()
if not (valid_n4 or valid_n3 or valid_n2):
return -1

# special unicode ranges
var b0 = d[0]
var b1 = d[1]
if first_byte_type == 2 and b0 < UInt8(0b1100_0010):
return -1
elif b0 == 0xE0 and not (UInt8(0xA0) <= b1 <= UInt8(0xBF)):
return -1
elif b0 == 0xED and not (UInt8(0x80) <= b1 <= UInt8(0x9F)):
return -1
elif b0 == 0xF0 and not (UInt8(0x90) <= b1 <= UInt8(0xBF)):
return -1
elif b0 == 0xF4 and not (UInt8(0x80) <= b1 <= UInt8(0x8F)):
return -1

# amount of bytes evaluated
idx += int(first_byte_type)
iter_len -= int(first_byte_type)

@parameter
if remainder:
break
return iter_len


fn _is_valid_utf8(ptr: UnsafePointer[UInt8], length: Int) -> Bool:
"""Verify that the bytes are valid UTF-8.

Args:
ptr: The pointer to the data.
length: The length of the items pointed to.

Returns:
Whether the data is valid UTF-8.

#### UTF-8 coding format
[Table 3-7 page 94](http://www.unicode.org/versions/Unicode6.0.0/ch03.pdf).
Well-Formed UTF-8 Byte Sequences

Code Points | First Byte | Second Byte | Third Byte | Fourth Byte |
:---------- | :--------- | :---------- | :--------- | :---------- |
U+0000..U+007F | 00..7F | | | |
U+0080..U+07FF | C2..DF | 80..BF | | |
U+0800..U+0FFF | E0 | ***A0***..BF| 80..BF | |
U+1000..U+CFFF | E1..EC | 80..BF | 80..BF | |
U+D000..U+D7FF | ED | 80..***9F***| 80..BF | |
U+E000..U+FFFF | EE..EF | 80..BF | 80..BF | |
U+10000..U+3FFFF | F0 | ***90***..BF| 80..BF | 80..BF |
U+40000..U+FFFFF | F1..F3 | 80..BF | 80..BF | 80..BF |
U+100000..U+10FFFF | F4 | 80..***8F***| 80..BF | 80..BF |
.
"""

var iter_len = length
if iter_len >= 64 and simdwidthof[DType.uint8]() >= 64:
iter_len = _validate_utf8_simd_slice[64](ptr, length, iter_len)
if iter_len < 0:
return False
if iter_len >= 32 and simdwidthof[DType.uint8]() >= 32:
iter_len = _validate_utf8_simd_slice[32](ptr, length, iter_len)
if iter_len < 0:
return False
if iter_len >= 16 and simdwidthof[DType.uint8]() >= 16:
iter_len = _validate_utf8_simd_slice[16](ptr, length, iter_len)
if iter_len < 0:
return False
if iter_len >= 8:
iter_len = _validate_utf8_simd_slice[8](ptr, length, iter_len)
if iter_len < 0:
return False
if iter_len >= 4:
iter_len = _validate_utf8_simd_slice[4](ptr, length, iter_len)
if iter_len < 0:
return False
return _validate_utf8_simd_slice[4, True](ptr, length, iter_len) == 0


fn _is_newline_start(
ptr: UnsafePointer[UInt8], read_ahead: Int = 1
) -> (Bool, Int):
Expand Down
24 changes: 12 additions & 12 deletions stdlib/test/builtin/test_simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -888,13 +888,13 @@ def test_shuffle_dynamic_size_32_uint8():
)
# fmt: off
var indices = SIMD[DType.uint8, 32](
3 , 3 , 5 , 5 , 7 , 7 , 9 , 9 ,
11, 11, 13, 13, 15, 15, 0 , 1 ,
0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ,
3 , 3 , 5 , 5 , 7 , 7 , 9 , 9 ,
11, 11, 13, 13, 15, 15, 0 , 1 ,
0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ,
8 , 9 , 10, 11, 12, 13, 14, 15,
)
result = table_lookup._dynamic_shuffle(indices)

expected_result = SIMD[DType.uint8, 32](
30 , 30 , 50 , 50 , 70 , 70 , 90 , 90 ,
110, 110, 130, 130, 150, 150, 0 , 10 ,
Expand All @@ -911,13 +911,13 @@ def test_shuffle_dynamic_size_64_uint8():
)
# fmt: off
var indices = SIMD[DType.uint8, 32](
3 , 3 , 5 , 5 , 7 , 7 , 9 , 9 ,
11, 11, 13, 13, 15, 15, 0 , 1 ,
0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ,
3 , 3 , 5 , 5 , 7 , 7 , 9 , 9 ,
11, 11, 13, 13, 15, 15, 0 , 1 ,
0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ,
8 , 9 , 10, 11, 12, 13, 14, 15,
)
result = table_lookup._dynamic_shuffle(indices.join(indices))

expected_result = SIMD[DType.uint8, 32](
30 , 30 , 50 , 50 , 70 , 70 , 90 , 90 ,
110, 110, 130, 130, 150, 150, 0 , 10 ,
Expand All @@ -935,13 +935,13 @@ def test_shuffle_dynamic_size_32_float():
80.0, 90.0, 100.0, 110.0, 120.0, 130.0, 140.0, 150.0,
)
var indices = SIMD[DType.uint8, 32](
3 , 3 , 5 , 5 , 7 , 7 , 9 , 9 ,
11, 11, 13, 13, 15, 15, 0 , 1 ,
0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ,
3 , 3 , 5 , 5 , 7 , 7 , 9 , 9 ,
11, 11, 13, 13, 15, 15, 0 , 1 ,
0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ,
8 , 9 , 10, 11, 12, 13, 14, 15,
)
result = table_lookup._dynamic_shuffle(indices)

expected_result = SIMD[DType.float64, 32](
30. , 30. , 50. , 50. , 70. , 70. , 90. , 90. ,
110., 110., 130., 130., 150., 150., 0. , 10. ,
Expand Down
Loading