From 005689065c65886a3402de5f902b8768c01e40f4 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Wed, 9 Oct 2024 14:40:35 -0400 Subject: [PATCH] Prism::CodeUnitsCache Calculating code unit offsets for a source can be very expensive, especially when the source is large. This commit introduces a new class that wraps the source and desired encoding into a cache that reuses pre-computed offsets. It performs quite a bit better. There are still some problems with this approach, namely character boundaries and the fact that the cache is unbounded, but both of these may be addressed in subsequent commits. --- lib/prism/parse_result.rb | 112 ++++++++++++++++++++++++++++ rbi/prism/parse_result.rbi | 29 +++++++ sig/prism/_private/parse_result.rbs | 12 +++ sig/prism/parse_result.rbs | 20 +++++ test/prism/ruby/location_test.rb | 46 ++++++++++++ 5 files changed, 219 insertions(+) diff --git a/lib/prism/parse_result.rb b/lib/prism/parse_result.rb index e3ba7e7c8e9..46bd33d1db4 100644 --- a/lib/prism/parse_result.rb +++ b/lib/prism/parse_result.rb @@ -120,6 +120,12 @@ def code_units_offset(byte_offset, encoding) end end + # Generate a cache that targets a specific encoding for calculating code + # unit offsets. + def code_units_cache(encoding) + CodeUnitsCache.new(source, encoding) + end + # Returns the column number in code units for the given encoding for the # given byte offset. def code_units_column(byte_offset, encoding) @@ -149,6 +155,76 @@ def find_line(byte_offset) end end + # A cache that can be used to quickly compute code unit offsets from byte + # offsets. It purposefully provides only a single #[] method to access the + # cache in order to minimize surface area. + # + # Note that there are some known issues here that may or may not be addressed + # in the future: + # + # * The first is that there are issues when the cache computes values that are + # not on character boundaries. This can result in subsequent computations + # being off by one or more code units. + # * The second is that this cache is currently unbounded. In theory we could + # introduce some kind of LRU cache to limit the number of entries, but this + # has not yet been implemented. + # + class CodeUnitsCache + class UTF16Counter # :nodoc: + def initialize(source, encoding) + @source = source + @encoding = encoding + end + + def count(byte_offset, byte_length) + @source.byteslice(byte_offset, byte_length).encode(@encoding, invalid: :replace, undef: :replace).bytesize / 2 + end + end + + class LengthCounter # :nodoc: + def initialize(source, encoding) + @source = source + @encoding = encoding + end + + def count(byte_offset, byte_length) + @source.byteslice(byte_offset, byte_length).encode(@encoding, invalid: :replace, undef: :replace).length + end + end + + private_constant :UTF16Counter, :LengthCounter + + # Initialize a new cache with the given source and encoding. + def initialize(source, encoding) + @source = source + @counter = + if encoding == Encoding::UTF_16LE || encoding == Encoding::UTF_16BE + UTF16Counter.new(source, encoding) + else + LengthCounter.new(source, encoding) + end + + @cache = {} + @offsets = [] + end + + # Retrieve the code units offset from the given byte offset. + def [](byte_offset) + @cache[byte_offset] ||= + if (index = @offsets.bsearch_index { |offset| offset > byte_offset }).nil? + @offsets << byte_offset + @counter.count(0, byte_offset) + elsif index == 0 + @offsets.unshift(byte_offset) + @counter.count(0, byte_offset) + else + @offsets.insert(index, byte_offset) + offset = @offsets[index - 1] + @cache[offset] + @counter.count(offset, byte_offset - offset) + end + end + end + # Specialized version of Prism::Source for source code that includes ASCII # characters only. This class is used to apply performance optimizations that # cannot be applied to sources that include multibyte characters. @@ -178,6 +254,13 @@ def code_units_offset(byte_offset, encoding) byte_offset end + # Returns a cache that is the identity function in order to maintain the + # same interface. We can do this because code units are always equivalent to + # byte offsets for ASCII-only sources. + def code_units_cache(encoding) + ->(byte_offset) { byte_offset } + end + # Specialized version of `code_units_column` that does not depend on # `code_units_offset`, which is a more expensive operation. This is # essentially the same as `Prism::Source#column`. @@ -287,6 +370,12 @@ def start_code_units_offset(encoding = Encoding::UTF_16LE) source.code_units_offset(start_offset, encoding) end + # The start offset from the start of the file in code units using the given + # cache to fetch or calculate the value. + def cached_start_code_units_offset(cache) + cache[start_offset] + end + # The byte offset from the beginning of the source where this location ends. def end_offset start_offset + length @@ -303,6 +392,12 @@ def end_code_units_offset(encoding = Encoding::UTF_16LE) source.code_units_offset(end_offset, encoding) end + # The end offset from the start of the file in code units using the given + # cache to fetch or calculate the value. + def cached_end_code_units_offset(cache) + cache[end_offset] + end + # The line number where this location starts. def start_line source.line(start_offset) @@ -337,6 +432,12 @@ def start_code_units_column(encoding = Encoding::UTF_16LE) source.code_units_column(start_offset, encoding) end + # The start column in code units using the given cache to fetch or calculate + # the value. + def cached_start_code_units_column(cache) + cache[start_offset] - cache[source.line_start(start_offset)] + end + # The column number in bytes where this location ends from the start of the # line. def end_column @@ -355,6 +456,12 @@ def end_code_units_column(encoding = Encoding::UTF_16LE) source.code_units_column(end_offset, encoding) end + # The end column in code units using the given cache to fetch or calculate + # the value. + def cached_end_code_units_column(cache) + cache[end_offset] - cache[source.line_start(end_offset)] + end + # Implement the hash pattern matching interface for Location. def deconstruct_keys(keys) { start_offset: start_offset, end_offset: end_offset } @@ -604,6 +711,11 @@ def success? def failure? !success? end + + # Create a code units cache for the given encoding. + def code_units_cache(encoding) + source.code_units_cache(encoding) + end end # This is a result specific to the `parse` and `parse_file` methods. diff --git a/rbi/prism/parse_result.rbi b/rbi/prism/parse_result.rbi index ef47e93bd1a..7cc1e775021 100644 --- a/rbi/prism/parse_result.rbi +++ b/rbi/prism/parse_result.rbi @@ -40,10 +40,21 @@ class Prism::Source sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) } def code_units_offset(byte_offset, encoding); end + sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) } + def code_units_cache(encoding); end + sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) } def code_units_column(byte_offset, encoding); end end +class Prism::CodeUnitsCache + sig { params(source: Source, encoding: Encoding).void } + def initialize(source, encoding); end + + sig { params(byte_offset: Integer).returns(Integer) } + def [](byte_offset); end +end + class Prism::ASCIISource < Prism::Source sig { params(byte_offset: Integer).returns(Integer) } def character_offset(byte_offset); end @@ -54,6 +65,9 @@ class Prism::ASCIISource < Prism::Source sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) } def code_units_offset(byte_offset, encoding); end + sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) } + def code_units_cache(encoding); end + sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) } def code_units_column(byte_offset, encoding); end end @@ -107,6 +121,9 @@ class Prism::Location sig { params(encoding: Encoding).returns(Integer) } def start_code_units_offset(encoding = Encoding::UTF_16LE); end + sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) } + def cached_start_code_units_offset(cache); end + sig { returns(Integer) } def end_offset; end @@ -116,6 +133,9 @@ class Prism::Location sig { params(encoding: Encoding).returns(Integer) } def end_code_units_offset(encoding = Encoding::UTF_16LE); end + sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) } + def cached_end_code_units_offset(cache); end + sig { returns(Integer) } def start_line; end @@ -134,6 +154,9 @@ class Prism::Location sig { params(encoding: Encoding).returns(Integer) } def start_code_units_column(encoding = Encoding::UTF_16LE); end + sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) } + def cached_start_code_units_column(cache); end + sig { returns(Integer) } def end_column; end @@ -143,6 +166,9 @@ class Prism::Location sig { params(encoding: Encoding).returns(Integer) } def end_code_units_column(encoding = Encoding::UTF_16LE); end + sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) } + def cached_end_code_units_column(cache); end + sig { params(keys: T.nilable(T::Array[Symbol])).returns(T::Hash[Symbol, T.untyped]) } def deconstruct_keys(keys); end @@ -296,6 +322,9 @@ class Prism::Result sig { returns(T::Boolean) } def failure?; end + + sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) } + def code_units_cache(encoding); end end class Prism::ParseResult < Prism::Result diff --git a/sig/prism/_private/parse_result.rbs b/sig/prism/_private/parse_result.rbs index 62e0cdc9177..659bedcfe34 100644 --- a/sig/prism/_private/parse_result.rbs +++ b/sig/prism/_private/parse_result.rbs @@ -5,6 +5,18 @@ module Prism def find_line: (Integer) -> Integer end + class CodeUnitsCache + class UTF16Counter + def initialize: (String source, Encoding encoding) -> void + def count: (Integer byte_offset, Integer byte_length) -> Integer + end + + class LengthCounter + def initialize: (String source, Encoding encoding) -> void + def count: (Integer byte_offset, Integer byte_length) -> Integer + end + end + class Location private diff --git a/sig/prism/parse_result.rbs b/sig/prism/parse_result.rbs index d5b9767a01b..d81fe90966b 100644 --- a/sig/prism/parse_result.rbs +++ b/sig/prism/parse_result.rbs @@ -1,4 +1,8 @@ module Prism + interface _CodeUnitsCache + def []: (Integer byte_offset) -> Integer + end + class Source attr_reader source: String attr_reader start_line: Integer @@ -16,15 +20,22 @@ module Prism def character_offset: (Integer byte_offset) -> Integer def character_column: (Integer byte_offset) -> Integer def code_units_offset: (Integer byte_offset, Encoding encoding) -> Integer + def code_units_cache: (Encoding encoding) -> _CodeUnitsCache def code_units_column: (Integer byte_offset, Encoding encoding) -> Integer def self.for: (String source) -> Source end + class CodeUnitsCache + def initialize: (String source, Encoding encoding) -> void + def []: (Integer byte_offset) -> Integer + end + class ASCIISource < Source def character_offset: (Integer byte_offset) -> Integer def character_column: (Integer byte_offset) -> Integer def code_units_offset: (Integer byte_offset, Encoding encoding) -> Integer + def code_units_cache: (Encoding encoding) -> _CodeUnitsCache def code_units_column: (Integer byte_offset, Encoding encoding) -> Integer end @@ -45,15 +56,23 @@ module Prism def slice: () -> String def slice_lines: () -> String def start_character_offset: () -> Integer + def start_code_units_offset: (Encoding encoding) -> Integer + def cached_start_code_units_offset: (_CodeUnitsCache cache) -> Integer def end_offset: () -> Integer def end_character_offset: () -> Integer + def end_code_units_offset: (Encoding encoding) -> Integer + def cached_end_code_units_offset: (_CodeUnitsCache cache) -> Integer def start_line: () -> Integer def start_line_slice: () -> String def end_line: () -> Integer def start_column: () -> Integer def start_character_column: () -> Integer + def start_code_units_column: (Encoding encoding) -> Integer + def cached_start_code_units_column: (_CodeUnitsCache cache) -> Integer def end_column: () -> Integer def end_character_column: () -> Integer + def end_code_units_column: (Encoding encoding) -> Integer + def cached_end_code_units_column: (_CodeUnitsCache cache) -> Integer def deconstruct_keys: (Array[Symbol]? keys) -> Hash[Symbol, untyped] def pretty_print: (untyped q) -> untyped def join: (Location other) -> Location @@ -125,6 +144,7 @@ module Prism def deconstruct_keys: (Array[Symbol]? keys) -> Hash[Symbol, untyped] def success?: () -> bool def failure?: () -> bool + def code_units_cache: (Encoding encoding) -> _CodeUnitsCache end class ParseResult < Result diff --git a/test/prism/ruby/location_test.rb b/test/prism/ruby/location_test.rb index 3d3e7dd5623..33f844243c0 100644 --- a/test/prism/ruby/location_test.rb +++ b/test/prism/ruby/location_test.rb @@ -140,6 +140,52 @@ def test_code_units assert_equal 7, location.end_code_units_column(Encoding::UTF_32LE) end + def test_cached_code_units + result = Prism.parse("šŸ˜€ + šŸ˜€\nšŸ˜ ||= šŸ˜") + + utf8_cache = result.code_units_cache(Encoding::UTF_8) + utf16_cache = result.code_units_cache(Encoding::UTF_16LE) + utf32_cache = result.code_units_cache(Encoding::UTF_32LE) + + # first šŸ˜€ + location = result.value.statements.body.first.receiver.location + + assert_equal 0, location.cached_start_code_units_offset(utf8_cache) + assert_equal 0, location.cached_start_code_units_offset(utf16_cache) + assert_equal 0, location.cached_start_code_units_offset(utf32_cache) + + assert_equal 1, location.cached_end_code_units_offset(utf8_cache) + assert_equal 2, location.cached_end_code_units_offset(utf16_cache) + assert_equal 1, location.cached_end_code_units_offset(utf32_cache) + + assert_equal 0, location.cached_start_code_units_column(utf8_cache) + assert_equal 0, location.cached_start_code_units_column(utf16_cache) + assert_equal 0, location.cached_start_code_units_column(utf32_cache) + + assert_equal 1, location.cached_end_code_units_column(utf8_cache) + assert_equal 2, location.cached_end_code_units_column(utf16_cache) + assert_equal 1, location.cached_end_code_units_column(utf32_cache) + + # second šŸ˜€ + location = result.value.statements.body.first.arguments.arguments.first.location + + assert_equal 4, location.cached_start_code_units_offset(utf8_cache) + assert_equal 5, location.cached_start_code_units_offset(utf16_cache) + assert_equal 4, location.cached_start_code_units_offset(utf32_cache) + + assert_equal 5, location.cached_end_code_units_offset(utf8_cache) + assert_equal 7, location.cached_end_code_units_offset(utf16_cache) + assert_equal 5, location.cached_end_code_units_offset(utf32_cache) + + assert_equal 4, location.cached_start_code_units_column(utf8_cache) + assert_equal 5, location.cached_start_code_units_column(utf16_cache) + assert_equal 4, location.cached_start_code_units_column(utf32_cache) + + assert_equal 5, location.cached_end_code_units_column(utf8_cache) + assert_equal 7, location.cached_end_code_units_column(utf16_cache) + assert_equal 5, location.cached_end_code_units_column(utf32_cache) + end + def test_code_units_binary_valid_utf8 program = Prism.parse(<<~RUBY).value # -*- encoding: binary -*-