Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Vectorize ASCII helper functions #3859

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions mojo/stdlib/src/collections/string/_ascii.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

# ===----------------------------------------------------------------------=== #
# isdigit
# ===----------------------------------------------------------------------=== #


@always_inline
fn _isdigit_vec[w: Int](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias `0` = SIMD[DType.uint8, w](Byte(ord("0")))
alias `9` = SIMD[DType.uint8, w](Byte(ord("9")))
return (`0` <= v) & (v <= `9`)


# ===----------------------------------------------------------------------=== #
# isprintable
# ===----------------------------------------------------------------------=== #


@always_inline
fn _is_ascii_printable_vec[
w: Int
](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias ` ` = SIMD[DType.uint8, w](Byte(ord(" ")))
alias `~` = SIMD[DType.uint8, w](Byte(ord("~")))
return (` ` <= v) & (v <= `~`)


@always_inline
fn _nonprintable_ascii[w: Int](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
return (~_is_ascii_printable_vec(v)) & (v < 0b1000_0000)


@always_inline
fn _is_python_printable_vec[
w: Int
](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias `\\` = SIMD[DType.uint8, w](Byte(ord(" ")))
return (v != `\\`) & _is_ascii_printable_vec(v)


@always_inline
fn _nonprintable_python[w: Int](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
return (~_is_python_printable_vec(v)) & (v < 0b1000_0000)


# ===----------------------------------------------------------------------=== #
# isupper
# ===----------------------------------------------------------------------=== #


@always_inline
fn _is_ascii_uppercase_vec[
w: Int
](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias `A` = SIMD[DType.uint8, w](Byte(ord("A")))
alias `Z` = SIMD[DType.uint8, w](Byte(ord("Z")))
return (`A` <= v) & (v <= `Z`)


@always_inline
fn _is_ascii_uppercase(c: Byte) -> Bool:
return _is_ascii_uppercase_vec(c)


# ===----------------------------------------------------------------------=== #
# islower
# ===----------------------------------------------------------------------=== #


@always_inline
fn _is_ascii_lowercase_vec[
w: Int
](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias `a` = SIMD[DType.uint8, w](Byte(ord("a")))
alias `z` = SIMD[DType.uint8, w](Byte(ord("z")))
return (`a` <= v) & (v <= `z`)


@always_inline
fn _is_ascii_lowercase(c: Byte) -> Bool:
return _is_ascii_lowercase_vec(c)


# ===----------------------------------------------------------------------=== #
# toggle_case
# ===----------------------------------------------------------------------=== #


@always_inline
fn _ascii_toggle_case[
w: Int
](value: SIMD[DType.uint8, w]) -> SIMD[DType.uint8, w]:
alias `a` = Byte(ord("a"))
alias `A` = Byte(ord("A"))
# ASCII only has a 1 upper bit bifference in uppercase and lowercase letters
return value ^ (`A` ^ `a`)


# ===----------------------------------------------------------------------=== #
# isspace
# ===----------------------------------------------------------------------=== #


fn _is_ascii_space(c: Byte) -> Bool:
"""Determines whether the given character is an ASCII whitespace character:
`" \\t\\n\\v\\f\\r\\x1c\\x1d\\x1e"`.

Args:
c: The character to check.

Returns:
True if the character is one of the ASCII whitespace characters.
"""

# NOTE: a global LUT doesn't work at compile time so we can't use it here.
alias ` ` = Byte(ord(" "))
alias `\t` = Byte(ord("\t"))
alias `\n` = Byte(ord("\n"))
alias `\r` = Byte(ord("\r"))
alias `\f` = Byte(ord("\f"))
alias `\v` = Byte(ord("\v"))
alias `\x1c` = Byte(ord("\x1c"))
alias `\x1d` = Byte(ord("\x1d"))
alias `\x1e` = Byte(ord("\x1e"))

# This compiles to something very clever that's even faster than a LUT.
return (
c == ` `
or c == `\t`
or c == `\n`
or c == `\r`
or c == `\f`
or c == `\v`
or c == `\x1c`
or c == `\x1d`
or c == `\x1e`
)


# ===----------------------------------------------------------------------=== #
# ascii
# ===----------------------------------------------------------------------=== #


fn _repr_ascii(c: UInt8) -> String:
"""Returns a printable representation of the given ASCII code point.

Args:
c: An integer that represents a code point.

Returns:
A string containing a representation of the given code point.
"""
alias ord_tab = ord("\t")
alias ord_new_line = ord("\n")
alias ord_carriage_return = ord("\r")
alias ord_back_slash = ord("\\")

if c == ord_back_slash:
return r"\\"
elif _is_ascii_printable_vec(c):
return String(String._buffer_type(c, 0))
elif c == ord_tab:
return r"\t"
elif c == ord_new_line:
return r"\n"
elif c == ord_carriage_return:
return r"\r"
else:
var uc = c.cast[DType.uint8]()
if uc < 16:
return hex(uc, prefix=r"\x0")
else:
return hex(uc, prefix=r"\x")


@always_inline
fn ascii(value: StringSlice) -> String:
"""Get the ASCII representation of the object.

Args:
value: The object to get the ASCII representation of.

Returns:
A string containing the ASCII representation of the object.
"""
alias ord_squote = ord("'")
var result = String()
var use_dquote = False

for idx in range(len(value._slice)):
var char = value._slice[idx]
result += _repr_ascii(char)
use_dquote = use_dquote or (char == ord_squote)

if use_dquote:
return '"' + result + '"'
else:
return "'" + result + "'"
75 changes: 1 addition & 74 deletions mojo/stdlib/src/collections/string/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import KeyElement, List, Optional
from collections._index_normalization import normalize_index
from collections.string import CodepointsIter
from collections.string._ascii import ascii
from collections.string.format import _CurlyEntryFormattable, _FormatCurlyEntry
from collections.string.string_slice import (
StaticString,
Expand Down Expand Up @@ -102,80 +103,6 @@ fn chr(c: Int) -> String:
return String(char)


# ===----------------------------------------------------------------------=== #
# ascii
# ===----------------------------------------------------------------------=== #


fn _chr_ascii(c: UInt8) -> String:
"""Returns a string based on the given ASCII code point.

Args:
c: An integer that represents a code point.

Returns:
A string containing a single character based on the given code point.
"""
return String(String._buffer_type(c, 0))


fn _repr_ascii(c: UInt8) -> String:
"""Returns a printable representation of the given ASCII code point.

Args:
c: An integer that represents a code point.

Returns:
A string containing a representation of the given code point.
"""
alias ord_tab = ord("\t")
alias ord_new_line = ord("\n")
alias ord_carriage_return = ord("\r")
alias ord_back_slash = ord("\\")

if c == ord_back_slash:
return r"\\"
elif Codepoint(c).is_ascii_printable():
return _chr_ascii(c)
elif c == ord_tab:
return r"\t"
elif c == ord_new_line:
return r"\n"
elif c == ord_carriage_return:
return r"\r"
else:
var uc = c.cast[DType.uint8]()
if uc < 16:
return hex(uc, prefix=r"\x0")
else:
return hex(uc, prefix=r"\x")


@always_inline
fn ascii(value: StringSlice) -> String:
"""Get the ASCII representation of the object.

Args:
value: The object to get the ASCII representation of.

Returns:
A string containing the ASCII representation of the object.
"""
alias ord_squote = ord("'")
var result = String()
var use_dquote = False

for idx in range(len(value._slice)):
var char = value._slice[idx]
result += _repr_ascii(char)
use_dquote = use_dquote or (char == ord_squote)

if use_dquote:
return '"' + result + '"'
else:
return "'" + result + "'"


# ===----------------------------------------------------------------------=== #
# strtol
# ===----------------------------------------------------------------------=== #
Expand Down
2 changes: 1 addition & 1 deletion mojo/stdlib/test/collections/string/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1455,9 +1455,9 @@ def main():
test_find()
test_replace()
test_rfind()
test_isspace()
test_split()
test_splitlines()
test_isspace()
test_ascii_aliases()
test_rstrip()
test_lstrip()
Expand Down