diff --git a/ext/prism/extension.c b/ext/prism/extension.c index 93fa7b0989..47603fd9b4 100644 --- a/ext/prism/extension.c +++ b/ext/prism/extension.c @@ -23,6 +23,7 @@ VALUE rb_cPrismResult; VALUE rb_cPrismParseResult; VALUE rb_cPrismLexResult; VALUE rb_cPrismParseLexResult; +VALUE rb_cPrismStringQuery; VALUE rb_cPrismDebugEncoding; @@ -1133,6 +1134,67 @@ parse_file_failure_p(int argc, VALUE *argv, VALUE self) { return RTEST(parse_file_success_p(argc, argv, self)) ? Qfalse : Qtrue; } +/******************************************************************************/ +/* String query methods */ +/******************************************************************************/ + +/** + * Process the result of a call to a string query method and return an + * appropriate value. + */ +static VALUE +string_query(pm_string_query_t result) { + switch (result) { + case PM_STRING_QUERY_ERROR: + rb_raise(rb_eArgError, "Invalid or non ascii-compatible encoding"); + return Qfalse; + case PM_STRING_QUERY_FALSE: + return Qfalse; + case PM_STRING_QUERY_TRUE: + return Qtrue; + } +} + +/** + * call-seq: + * Prism::StringQuery::local?(string) -> bool + * + * Returns true if the string constitutes a valid local variable name. Note that + * this means the names that can be set through Binding#local_variable_set, not + * necessarily the ones that can be set through a local variable assignment. + */ +static VALUE +string_query_local_p(VALUE self, VALUE string) { + const uint8_t *source = (const uint8_t *) check_string(string); + return string_query(pm_string_query_local(source, RSTRING_LEN(string), rb_enc_get(string)->name)); +} + +/** + * call-seq: + * Prism::StringQuery::constant?(string) -> bool + * + * Returns true if the string constitutes a valid constant name. Note that this + * means the names that can be set through Module#const_set, not necessarily the + * ones that can be set through a constant assignment. + */ +static VALUE +string_query_constant_p(VALUE self, VALUE string) { + const uint8_t *source = (const uint8_t *) check_string(string); + return string_query(pm_string_query_constant(source, RSTRING_LEN(string), rb_enc_get(string)->name)); +} + +/** + * call-seq: + * Prism::StringQuery::method_name?(string) -> bool + * + * Returns true if the string constitutes a valid method name. + */ +static VALUE +string_query_method_name_p(VALUE self, VALUE string) { + const uint8_t *source = (const uint8_t *) check_string(string); + return string_query(pm_string_query_method_name(source, RSTRING_LEN(string), rb_enc_get(string)->name)); +} + /******************************************************************************/ /* Initialization of the extension */ /******************************************************************************/ @@ -1170,6 +1232,7 @@ Init_prism(void) { rb_cPrismParseResult = rb_define_class_under(rb_cPrism, "ParseResult", rb_cPrismResult); rb_cPrismLexResult = rb_define_class_under(rb_cPrism, "LexResult", rb_cPrismResult); rb_cPrismParseLexResult = rb_define_class_under(rb_cPrism, "ParseLexResult", rb_cPrismResult); + rb_cPrismStringQuery = rb_define_class_under(rb_cPrism, "StringQuery", rb_cObject); // Intern all of the IDs eagerly that we support so that we don't have to do // it every time we parse. @@ -1211,6 +1274,10 @@ Init_prism(void) { rb_define_singleton_method(rb_cPrism, "dump_file", dump_file, -1); #endif + rb_define_singleton_method(rb_cPrismStringQuery, "local?", string_query_local_p, 1); + rb_define_singleton_method(rb_cPrismStringQuery, "constant?", string_query_constant_p, 1); + rb_define_singleton_method(rb_cPrismStringQuery, "method_name?", string_query_method_name_p, 1); + // Next, initialize the other APIs. Init_prism_api_node(); Init_prism_pack(); diff --git a/include/prism.h b/include/prism.h index 755c38fca2..6f7b850a31 100644 --- a/include/prism.h +++ b/include/prism.h @@ -234,6 +234,53 @@ PRISM_EXPORTED_FUNCTION void pm_dump_json(pm_buffer_t *buffer, const pm_parser_t #endif +/** + * Represents the results of a slice query. + */ +typedef enum { + /** Returned if the encoding given to a slice query was invalid. */ + PM_STRING_QUERY_ERROR = -1, + + /** Returned if the result of the slice query is false. */ + PM_STRING_QUERY_FALSE, + + /** Returned if the result of the slice query is true. */ + PM_STRING_QUERY_TRUE +} pm_string_query_t; + +/** + * Check that the slice is a valid local variable name. + * + * @param source The source to check. + * @param length The length of the source. + * @param encoding_name The name of the encoding of the source. + * @return PM_STRING_QUERY_TRUE if the query is true, PM_STRING_QUERY_FALSE if + * the query is false, and PM_STRING_QUERY_ERROR if the encoding was invalid. + */ +PRISM_EXPORTED_FUNCTION pm_string_query_t pm_string_query_local(const uint8_t *source, size_t length, const char *encoding_name); + +/** + * Check that the slice is a valid constant name. + * + * @param source The source to check. + * @param length The length of the source. + * @param encoding_name The name of the encoding of the source. + * @return PM_STRING_QUERY_TRUE if the query is true, PM_STRING_QUERY_FALSE if + * the query is false, and PM_STRING_QUERY_ERROR if the encoding was invalid. + */ +PRISM_EXPORTED_FUNCTION pm_string_query_t pm_string_query_constant(const uint8_t *source, size_t length, const char *encoding_name); + +/** + * Check that the slice is a valid method name. + * + * @param source The source to check. + * @param length The length of the source. + * @param encoding_name The name of the encoding of the source. + * @return PM_STRING_QUERY_TRUE if the query is true, PM_STRING_QUERY_FALSE if + * the query is false, and PM_STRING_QUERY_ERROR if the encoding was invalid. + */ +PRISM_EXPORTED_FUNCTION pm_string_query_t pm_string_query_method_name(const uint8_t *source, size_t length, const char *encoding_name); + /** * @mainpage * diff --git a/lib/prism.rb b/lib/prism.rb index 66a64e7fd0..50b14a5486 100644 --- a/lib/prism.rb +++ b/lib/prism.rb @@ -25,6 +25,7 @@ module Prism autoload :Pattern, "prism/pattern" autoload :Reflection, "prism/reflection" autoload :Serialize, "prism/serialize" + autoload :StringQuery, "prism/string_query" autoload :Translation, "prism/translation" autoload :Visitor, "prism/visitor" @@ -75,13 +76,13 @@ def self.load(source, serialized) # it's going to require the built library. Otherwise, it's going to require a # module that uses FFI to call into the library. if RUBY_ENGINE == "ruby" and !ENV["PRISM_FFI_BACKEND"] - require "prism/prism" - # The C extension is the default backend on CRuby. Prism::BACKEND = :CEXT -else - require_relative "prism/ffi" + require "prism/prism" +else # The FFI backend is used on other Ruby implementations. Prism::BACKEND = :FFI + + require_relative "prism/ffi" end diff --git a/lib/prism/ffi.rb b/lib/prism/ffi.rb index 0520f7cdd2..a16d7f848f 100644 --- a/lib/prism/ffi.rb +++ b/lib/prism/ffi.rb @@ -73,6 +73,7 @@ def self.load_exported_functions_from(header, *functions, callbacks) callback :pm_parse_stream_fgets_t, [:pointer, :int, :pointer], :pointer enum :pm_string_init_result_t, %i[PM_STRING_INIT_SUCCESS PM_STRING_INIT_ERROR_GENERIC PM_STRING_INIT_ERROR_DIRECTORY] + enum :pm_string_query_t, [:PM_STRING_QUERY_ERROR, -1, :PM_STRING_QUERY_FALSE, :PM_STRING_QUERY_TRUE] load_exported_functions_from( "prism.h", @@ -83,6 +84,9 @@ def self.load_exported_functions_from(header, *functions, callbacks) "pm_serialize_lex", "pm_serialize_parse_lex", "pm_parse_success_p", + "pm_string_query_local", + "pm_string_query_constant", + "pm_string_query_method_name", [:pm_parse_stream_fgets_t] ) @@ -492,4 +496,36 @@ def dump_options(options) values.pack(template) end end + + # Here we are going to patch StringQuery to put in the class-level methods so + # that it can maintain a consistent interface + class StringQuery + class << self + # Mirrors the C extension's StringQuery::local? method. + def local?(string) + query(LibRubyParser.pm_string_query_local(string, string.bytesize, string.encoding.name)) + end + + # Mirrors the C extension's StringQuery::constant? method. + def constant?(string) + query(LibRubyParser.pm_string_query_constant(string, string.bytesize, string.encoding.name)) + end + + # Mirrors the C extension's StringQuery::method_name? method. + def method_name?(string) + query(LibRubyParser.pm_string_query_method_name(string, string.bytesize, string.encoding.name)) + end + + private + + # Parse the enum result and return an appropriate boolean. + def query(result) + case result + when :PM_STRING_QUERY_ERROR then raise ArgumentError, "Invalid or non ascii-compatible encoding" + when :PM_STRING_QUERY_FALSE then false + when :PM_STRING_QUERY_TRUE then true + end + end + end + end end diff --git a/lib/prism/string_query.rb b/lib/prism/string_query.rb new file mode 100644 index 0000000000..9011051d2b --- /dev/null +++ b/lib/prism/string_query.rb @@ -0,0 +1,30 @@ +# frozen_string_literal: true + +module Prism + # Query methods that allow categorizing strings based on their context for + # where they could be valid in a Ruby syntax tree. + class StringQuery + # The string that this query is wrapping. + attr_reader :string + + # Initialize a new query with the given string. + def initialize(string) + @string = string + end + + # Whether or not this string is a valid local variable name. + def local? + StringQuery.local?(string) + end + + # Whether or not this string is a valid constant name. + def constant? + StringQuery.constant?(string) + end + + # Whether or not this string is a valid method name. + def method_name? + StringQuery.method_name?(string) + end + end +end diff --git a/prism.gemspec b/prism.gemspec index 1a0547f5bc..c4efd8ae03 100644 --- a/prism.gemspec +++ b/prism.gemspec @@ -89,6 +89,7 @@ Gem::Specification.new do |spec| "lib/prism/polyfill/unpack1.rb", "lib/prism/reflection.rb", "lib/prism/serialize.rb", + "lib/prism/string_query.rb", "lib/prism/translation.rb", "lib/prism/translation/parser.rb", "lib/prism/translation/parser33.rb", diff --git a/src/prism.c b/src/prism.c index cfd1434ddd..ff5fd933f6 100644 --- a/src/prism.c +++ b/src/prism.c @@ -22643,3 +22643,166 @@ pm_serialize_parse_comments(pm_buffer_t *buffer, const uint8_t *source, size_t s } #endif + +/******************************************************************************/ +/* Slice queries for the Ruby API */ +/******************************************************************************/ + +/** The category of slice returned from pm_slice_type. */ +typedef enum { + /** Returned when the given encoding name is invalid. */ + PM_SLICE_TYPE_ERROR = -1, + + /** Returned when no other types apply to the slice. */ + PM_SLICE_TYPE_NONE, + + /** Returned when the slice is a valid local variable name. */ + PM_SLICE_TYPE_LOCAL, + + /** Returned when the slice is a valid constant name. */ + PM_SLICE_TYPE_CONSTANT, + + /** Returned when the slice is a valid method name. */ + PM_SLICE_TYPE_METHOD_NAME +} pm_slice_type_t; + +/** + * Check that the slice is a valid local variable name or constant. + */ +pm_slice_type_t +pm_slice_type(const uint8_t *source, size_t length, const char *encoding_name) { + // first, get the right encoding object + const pm_encoding_t *encoding = pm_encoding_find((const uint8_t *) encoding_name, (const uint8_t *) (encoding_name + strlen(encoding_name))); + if (encoding == NULL) return PM_SLICE_TYPE_ERROR; + + // check that there is at least one character + if (length == 0) return PM_SLICE_TYPE_NONE; + + size_t width; + if ((width = encoding->alpha_char(source, (ptrdiff_t) length)) != 0) { + // valid because alphabetical + } else if (*source == '_') { + // valid because underscore + width = 1; + } else if ((*source >= 0x80) && ((width = encoding->char_width(source, (ptrdiff_t) length)) > 0)) { + // valid because multibyte + } else { + // invalid because no match + return PM_SLICE_TYPE_NONE; + } + + // determine the type of the slice based on the first character + const uint8_t *end = source + length; + pm_slice_type_t result = encoding->isupper_char(source, end - source) ? PM_SLICE_TYPE_CONSTANT : PM_SLICE_TYPE_LOCAL; + + // next, iterate through all of the bytes of the string to ensure that they + // are all valid identifier characters + source += width; + + while (source < end) { + if ((width = encoding->alnum_char(source, end - source)) != 0) { + // valid because alphanumeric + source += width; + } else if (*source == '_') { + // valid because underscore + source++; + } else if ((*source >= 0x80) && ((width = encoding->char_width(source, end - source)) > 0)) { + // valid because multibyte + source += width; + } else { + // invalid because no match + break; + } + } + + // accept a ! or ? at the end of the slice as a method name + if (*source == '!' || *source == '?' || *source == '=') { + source++; + result = PM_SLICE_TYPE_METHOD_NAME; + } + + // valid if we are at the end of the slice + return source == end ? result : PM_SLICE_TYPE_NONE; +} + +/** + * Check that the slice is a valid local variable name. + */ +PRISM_EXPORTED_FUNCTION pm_string_query_t +pm_string_query_local(const uint8_t *source, size_t length, const char *encoding_name) { + switch (pm_slice_type(source, length, encoding_name)) { + case PM_SLICE_TYPE_ERROR: + return PM_STRING_QUERY_ERROR; + case PM_SLICE_TYPE_NONE: + case PM_SLICE_TYPE_CONSTANT: + case PM_SLICE_TYPE_METHOD_NAME: + return PM_STRING_QUERY_FALSE; + case PM_SLICE_TYPE_LOCAL: + return PM_STRING_QUERY_TRUE; + } + + assert(false && "unreachable"); + return PM_STRING_QUERY_FALSE; +} + +/** + * Check that the slice is a valid constant name. + */ +PRISM_EXPORTED_FUNCTION pm_string_query_t +pm_string_query_constant(const uint8_t *source, size_t length, const char *encoding_name) { + switch (pm_slice_type(source, length, encoding_name)) { + case PM_SLICE_TYPE_ERROR: + return PM_STRING_QUERY_ERROR; + case PM_SLICE_TYPE_NONE: + case PM_SLICE_TYPE_LOCAL: + case PM_SLICE_TYPE_METHOD_NAME: + return PM_STRING_QUERY_FALSE; + case PM_SLICE_TYPE_CONSTANT: + return PM_STRING_QUERY_TRUE; + } + + assert(false && "unreachable"); + return PM_STRING_QUERY_FALSE; +} + +/** + * Check that the slice is a valid method name. + */ +PRISM_EXPORTED_FUNCTION pm_string_query_t +pm_string_query_method_name(const uint8_t *source, size_t length, const char *encoding_name) { +#define B(p) ((p) ? PM_STRING_QUERY_TRUE : PM_STRING_QUERY_FALSE) +#define C1(c) (*source == c) +#define C2(s) (memcmp(source, s, 2) == 0) +#define C3(s) (memcmp(source, s, 3) == 0) + + switch (pm_slice_type(source, length, encoding_name)) { + case PM_SLICE_TYPE_ERROR: + return PM_STRING_QUERY_ERROR; + case PM_SLICE_TYPE_NONE: + break; + case PM_SLICE_TYPE_LOCAL: + // numbered parameters are not valid method names + return B((length != 2) || (source[0] != '_') || (source[1] == '0') || !pm_char_is_decimal_digit(source[1])); + case PM_SLICE_TYPE_CONSTANT: + // all constants are valid method names + case PM_SLICE_TYPE_METHOD_NAME: + // all method names are valid method names + return PM_STRING_QUERY_TRUE; + } + + switch (length) { + case 1: + return B(C1('&') || C1('`') || C1('!') || C1('^') || C1('>') || C1('<') || C1('-') || C1('%') || C1('|') || C1('+') || C1('/') || C1('*') || C1('~')); + case 2: + return B(C2("!=") || C2("!~") || C2("[]") || C2("==") || C2("=~") || C2(">=") || C2(">>") || C2("<=") || C2("<<") || C2("**")); + case 3: + return B(C3("===") || C3("<=>") || C3("[]=")); + default: + return PM_STRING_QUERY_FALSE; + } + +#undef B +#undef C1 +#undef C2 +#undef C3 +} diff --git a/test/prism/ruby/string_query_test.rb b/test/prism/ruby/string_query_test.rb new file mode 100644 index 0000000000..aa50c10ff3 --- /dev/null +++ b/test/prism/ruby/string_query_test.rb @@ -0,0 +1,60 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +module Prism + class StringQueryTest < TestCase + def test_local? + assert_predicate StringQuery.new("a"), :local? + assert_predicate StringQuery.new("a1"), :local? + assert_predicate StringQuery.new("self"), :local? + + assert_predicate StringQuery.new("_a"), :local? + assert_predicate StringQuery.new("_1"), :local? + + assert_predicate StringQuery.new("😀"), :local? + assert_predicate StringQuery.new("ア".encode("Windows-31J")), :local? + + refute_predicate StringQuery.new("1"), :local? + refute_predicate StringQuery.new("A"), :local? + end + + def test_constant? + assert_predicate StringQuery.new("A"), :constant? + assert_predicate StringQuery.new("A1"), :constant? + assert_predicate StringQuery.new("A_B"), :constant? + assert_predicate StringQuery.new("BEGIN"), :constant? + + assert_predicate StringQuery.new("À"), :constant? + assert_predicate StringQuery.new("A".encode("US-ASCII")), :constant? + + refute_predicate StringQuery.new("a"), :constant? + refute_predicate StringQuery.new("1"), :constant? + end + + def test_method_name? + assert_predicate StringQuery.new("a"), :method_name? + assert_predicate StringQuery.new("A"), :method_name? + assert_predicate StringQuery.new("__FILE__"), :method_name? + + assert_predicate StringQuery.new("a?"), :method_name? + assert_predicate StringQuery.new("a!"), :method_name? + assert_predicate StringQuery.new("a="), :method_name? + + assert_predicate StringQuery.new("+"), :method_name? + assert_predicate StringQuery.new("<<"), :method_name? + assert_predicate StringQuery.new("==="), :method_name? + + assert_predicate StringQuery.new("_0"), :method_name? + + refute_predicate StringQuery.new("1"), :method_name? + refute_predicate StringQuery.new("_1"), :method_name? + end + + def test_invalid_encoding + assert_raise ArgumentError do + StringQuery.new("A".encode("UTF-16LE")).local? + end + end + end +end