From 96c88ca3c8a788770d81addcc3b5cc167c79aeff Mon Sep 17 00:00:00 2001 From: Alexandre Terrasa Date: Thu, 11 Jul 2024 13:30:44 -0400 Subject: [PATCH 1/2] Introduce Type builder interface This interface is used to manually build RBI types. Signed-off-by: Alexandre Terrasa Co-authored-by: Ufuk Kayserilioglu --- lib/rbi.rb | 1 + lib/rbi/type.rb | 767 ++++++++++++++++++++++++++++++++++++++++++ test/rbi/type_test.rb | 441 ++++++++++++++++++++++++ 3 files changed, 1209 insertions(+) create mode 100644 lib/rbi/type.rb create mode 100644 test/rbi/type_test.rb diff --git a/lib/rbi.rb b/lib/rbi.rb index 641fa57e..4c5a7bb9 100644 --- a/lib/rbi.rb +++ b/lib/rbi.rb @@ -12,6 +12,7 @@ class Error < StandardError require "rbi/loc" require "rbi/model" +require "rbi/type" require "rbi/visitor" require "rbi/index" require "rbi/rewriters/add_sig_templates" diff --git a/lib/rbi/type.rb b/lib/rbi/type.rb new file mode 100644 index 00000000..0e29c470 --- /dev/null +++ b/lib/rbi/type.rb @@ -0,0 +1,767 @@ +# typed: strict +# frozen_string_literal: true + +module RBI + # The base class for all RBI types. + class Type + extend T::Sig + extend T::Helpers + + abstract! + + # Simple + + # A type that represents a simple class name like `String` or `Foo`. + # + # It can also be a qualified name like `::Foo` or `Foo::Bar`. + class Simple < Type + extend T::Sig + + sig { returns(String) } + attr_reader :name + + sig { params(name: String).void } + def initialize(name) + super() + @name = name + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Simple === other && @name == other.name + end + + sig { override.returns(String) } + def to_rbi + @name + end + end + + # Literals + + # `T.anything`. + class Anything < Type + extend T::Sig + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Anything === other + end + + sig { override.returns(String) } + def to_rbi + "T.anything" + end + end + + # `T.attached_class`. + class AttachedClass < Type + extend T::Sig + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + AttachedClass === other + end + + sig { override.returns(String) } + def to_rbi + "T.attached_class" + end + end + + # `T::Boolean`. + class Boolean < Type + extend T::Sig + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Boolean === other + end + + sig { override.returns(String) } + def to_rbi + "T::Boolean" + end + end + + # `T.noreturn`. + class NoReturn < Type + extend T::Sig + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + NoReturn === other + end + + sig { override.returns(String) } + def to_rbi + "T.noreturn" + end + end + + # `T.self_type`. + class SelfType < Type + extend T::Sig + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + SelfType === other + end + + sig { override.returns(String) } + def to_rbi + "T.self_type" + end + end + + # `T.untyped`. + class Untyped < Type + extend T::Sig + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Untyped === other + end + + sig { override.returns(String) } + def to_rbi + "T.untyped" + end + end + + # `void`. + class Void < Type + extend T::Sig + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Void === other + end + + sig { override.returns(String) } + def to_rbi + "void" + end + end + + # Composites + + # The class of another type like `T::Class[Foo]`. + class Class < Type + extend T::Sig + + sig { returns(Type) } + attr_reader :type + + sig { params(type: Type).void } + def initialize(type) + super() + @type = type + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Class === other && @type == other.type + end + + sig { override.returns(String) } + def to_rbi + "T::Class[#{@type}]" + end + end + + # The singleton class of another type like `T.class_of(Foo)`. + class ClassOf < Type + extend T::Sig + + sig { returns(Simple) } + attr_reader :type + + sig { returns(T.nilable(Type)) } + attr_reader :type_parameter + + sig { params(type: Simple, type_parameter: T.nilable(Type)).void } + def initialize(type, type_parameter = nil) + super() + @type = type + @type_parameter = type_parameter + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + ClassOf === other && @type == other.type + end + + sig { override.returns(String) } + def to_rbi + if @type_parameter + "T.class_of(#{@type.to_rbi})[#{@type_parameter.to_rbi}]" + else + "T.class_of(#{@type.to_rbi})" + end + end + end + + # A type that can be `nil` like `T.nilable(String)`. + class Nilable < Type + extend T::Sig + + sig { returns(Type) } + attr_reader :type + + sig { params(type: Type).void } + def initialize(type) + super() + @type = type + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Nilable === other && @type == other.type + end + + sig { override.returns(String) } + def to_rbi + "T.nilable(#{@type.to_rbi})" + end + end + + # A type that is composed of multiple types like `T.all(String, Integer)`. + class Composite < Type + extend T::Sig + extend T::Helpers + + abstract! + + sig { returns(T::Array[Type]) } + attr_reader :types + + sig { params(types: T::Array[Type]).void } + def initialize(types) + super() + @types = types + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + self.class === other && @types.sort_by(&:to_rbi) == other.types.sort_by(&:to_rbi) + end + end + + # A type that is intersection of multiple types like `T.all(String, Integer)`. + class All < Composite + extend T::Sig + + sig { override.returns(String) } + def to_rbi + "T.all(#{@types.map(&:to_rbi).join(", ")})" + end + end + + # A type that is union of multiple types like `T.any(String, Integer)`. + class Any < Composite + extend T::Sig + + sig { override.returns(String) } + def to_rbi + "T.any(#{@types.map(&:to_rbi).join(", ")})" + end + + sig { returns(T::Boolean) } + def nilable? + @types.any? { |type| type.nilable? || (type.is_a?(Simple) && type.name == "NilClass") } + end + end + + # Generics + + # A generic type like `T::Array[String]` or `T::Hash[Symbol, Integer]`. + class Generic < Type + extend T::Sig + + sig { returns(String) } + attr_reader :name + + sig { returns(T::Array[Type]) } + attr_reader :params + + sig { params(name: String, params: Type).void } + def initialize(name, *params) + super() + @name = name + @params = T.let(params, T::Array[Type]) + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Generic === other && @name == other.name && @params == other.params + end + + sig { override.returns(String) } + def to_rbi + "#{@name}[#{@params.map(&:to_rbi).join(", ")}]" + end + end + + # A type parameter like `T.type_parameter(:U)`. + class TypeParameter < Type + extend T::Sig + + sig { returns(Symbol) } + attr_reader :name + + sig { params(name: Symbol).void } + def initialize(name) + super() + @name = name + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + TypeParameter === other && @name == other.name + end + + sig { override.returns(String) } + def to_rbi + "T.type_parameter(#{@name.inspect})" + end + end + + # Tuples and shapes + + # A tuple type like `[String, Integer]`. + class Tuple < Type + extend T::Sig + + sig { returns(T::Array[Type]) } + attr_reader :types + + sig { params(types: T::Array[Type]).void } + def initialize(types) + super() + @types = types + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Tuple === other && @types == other.types + end + + sig { override.returns(String) } + def to_rbi + "[#{@types.map(&:to_rbi).join(", ")}]" + end + end + + # A shape type like `{name: String, age: Integer}`. + class Shape < Type + extend T::Sig + + sig { returns(T::Hash[T.any(String, Symbol), Type]) } + attr_reader :types + + sig { params(types: T::Hash[T.any(String, Symbol), Type]).void } + def initialize(types) + super() + @types = types + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + Shape === other && @types.sort_by { |t| t.first.to_s } == other.types.sort_by { |t| t.first.to_s } + end + + sig { override.returns(String) } + def to_rbi + if @types.empty? + "{}" + else + "{ " + @types.map { |name, type| "#{name}: #{type.to_rbi}" }.join(", ") + " }" + end + end + end + + # Proc + + # A proc type like `T.proc.void`. + class Proc < Type + extend T::Sig + + sig { returns(T::Hash[Symbol, Type]) } + attr_reader :proc_params + + sig { returns(Type) } + attr_reader :proc_returns + + sig { returns(T.nilable(Type)) } + attr_reader :proc_bind + + sig { void } + def initialize + super + @proc_params = T.let({}, T::Hash[Symbol, Type]) + @proc_returns = T.let(Type.void, Type) + @proc_bind = T.let(nil, T.nilable(Type)) + end + + sig { override.params(other: BasicObject).returns(T::Boolean) } + def ==(other) + return false unless Proc === other + return false unless @proc_params == other.proc_params + return false unless @proc_returns == other.proc_returns + return false unless @proc_bind == other.proc_bind + + true + end + + sig { params(params: Type).returns(T.self_type) } + def params(**params) + @proc_params = params + self + end + + sig { params(type: T.untyped).returns(T.self_type) } + def returns(type) + @proc_returns = type + self + end + + sig { returns(T.self_type) } + def void + @proc_returns = RBI::Type.void + self + end + + sig { params(type: T.untyped).returns(T.self_type) } + def bind(type) + @proc_bind = type + self + end + + sig { override.returns(String) } + def to_rbi + rbi = +"T.proc" + + if @proc_bind + rbi << ".bind(#{@proc_bind})" + end + + unless @proc_params.empty? + rbi << ".params(" + rbi << @proc_params.map { |name, type| "#{name}: #{type.to_rbi}" }.join(", ") + rbi << ")" + end + + rbi << case @proc_returns + when Void + ".void" + else + ".returns(#{@proc_returns})" + end + + rbi + end + end + + # Type builder + + class << self + extend T::Sig + + # Simple + + # Builds a simple type like `String` or `::Foo::Bar`. + # + # It raises a `NameError` if the name is not a valid Ruby class identifier. + sig { params(name: String).returns(Simple) } + def simple(name) + # TODO: should we allow creating the instance anyway and move this to a `validate!` method? + raise NameError, "Invalid type name: `#{name}`" unless valid_identifier?(name) + + Simple.new(name) + end + + # Literals + + # Builds a type that represents `T.anything`. + sig { returns(Anything) } + def anything + Anything.new + end + + # Builds a type that represents `T.attached_class`. + sig { returns(AttachedClass) } + def attached_class + AttachedClass.new + end + + # Builds a type that represents `T::Boolean`. + sig { returns(Boolean) } + def boolean + Boolean.new + end + + # Builds a type that represents `T.noreturn`. + sig { returns(NoReturn) } + def noreturn + NoReturn.new + end + + # Builds a type that represents `T.self_type`. + sig { returns(SelfType) } + def self_type + SelfType.new + end + + # Builds a type that represents `T.untyped`. + sig { returns(Untyped) } + def untyped + Untyped.new + end + + # Builds a type that represents `void`. + sig { returns(Void) } + def void + Void.new + end + + # Composites + + # Builds a type that represents the class of another type like `T::Class[Foo]`. + sig { params(type: Type).returns(Class) } + def t_class(type) + Class.new(type) + end + + # Builds a type that represents the singleton class of another type like `T.class_of(Foo)`. + sig { params(type: Simple, type_parameter: T.nilable(Type)).returns(ClassOf) } + def class_of(type, type_parameter = nil) + ClassOf.new(type, type_parameter) + end + + # Builds a type that represents a nilable of another type like `T.nilable(String)`. + # + # Note that this method transforms types such as `T.nilable(T.untyped)` into `T.untyped`, so + # it may return something other than a `RBI::Type::Nilable`. + sig { params(type: Type).returns(Type) } + def nilable(type) + # TODO: should we move this logic to a `flatten!`, `normalize!` or `simplify!` method? + return type if type.is_a?(Untyped) + + if type.nilable? + type + else + Nilable.new(type) + end + end + + # Builds a type that represents an intersection of multiple types like `T.all(String, Integer)`. + # + # Note that this method transforms types such as `T.all(String, String)` into `String`, so + # it may return something other than a `All`. + sig { params(type1: Type, type2: Type, types: Type).returns(Type) } + def all(type1, type2, *types) + types = [type1, type2, *types] + + # TODO: should we move this logic to a `flatten!`, `normalize!` or `simplify!` method? + flattened = types.flatten.flat_map do |type| + case type + when All + type.types + else + type + end + end.uniq + + if flattened.size == 1 + T.must(flattened.first) + else + raise ArgumentError, "RBI::Type.all should have at least 2 types supplied" if flattened.size < 2 + + All.new(flattened) + end + end + + # Builds a type that represents a union of multiple types like `T.any(String, Integer)`. + # + # Note that this method transforms types such as `T.any(String, NilClass)` into `T.nilable(String)`, so + # it may return something other than a `Any`. + sig { params(type1: Type, type2: Type, types: Type).returns(Type) } + def any(type1, type2, *types) + types = [type1, type2, *types] + + # TODO: should we move this logic to a `flatten!`, `normalize!` or `simplify!` method? + flattened = types.flatten.flat_map do |type| + case type + when Any + type.types + else + type + end + end + + is_nilable = T.let(false, T::Boolean) + + types = flattened.filter_map do |type| + case type + when Simple + if type.name == "NilClass" + is_nilable = true + nil + else + type + end + when Nilable + is_nilable = true + type.type + else + type + end + end.uniq + + has_true_class = types.any? { |type| type.is_a?(Simple) && type.name == "TrueClass" } + has_false_class = types.any? { |type| type.is_a?(Simple) && type.name == "FalseClass" } + + if has_true_class && has_false_class + types = types.reject { |type| type.is_a?(Simple) && (type.name == "TrueClass" || type.name == "FalseClass") } + types << boolean + end + + type = case types.size + when 0 + if is_nilable + is_nilable = false + simple("NilClass") + else + raise ArgumentError, "RBI::Type.any should have at least 2 types supplied" + end + when 1 + T.must(types.first) + else + Any.new(types) + end + + if is_nilable + nilable(type) + else + type + end + end + + # Generics + + # Builds a type that represents a generic type like `T::Array[String]` or `T::Hash[Symbol, Integer]`. + sig { params(name: String, params: T.any(Type, T::Array[Type])).returns(Generic) } + def generic(name, *params) + T.unsafe(Generic).new(name, *params.flatten) + end + + # Builds a type that represents a type parameter like `T.type_parameter(:U)`. + sig { params(name: Symbol).returns(TypeParameter) } + def type_parameter(name) + TypeParameter.new(name) + end + + # Tuples and shapes + + # Builds a type that represents a tuple type like `[String, Integer]`. + sig { params(types: T.any(Type, T::Array[Type])).returns(Tuple) } + def tuple(*types) + Tuple.new(types.flatten) + end + + # Builds a type that represents a shape type like `{name: String, age: Integer}`. + sig { params(hash_types: T::Hash[T.any(String, Symbol), Type], types: Type).returns(Shape) } + def shape(hash_types = {}, **types) + types = hash_types.merge(types) + + Shape.new(types) + end + + # Proc + + # Builds a type that represents a proc type like `T.proc.void`. + sig { returns(Proc) } + def proc + Proc.new + end + + # We mark the constructor as `protected` because we want to force the use of factories on `Type` to create types + protected :new + + private + + sig { params(name: String).returns(T::Boolean) } + def valid_identifier?(name) + Prism.parse("class self::#{name.delete_prefix("::")}; end").success? + end + end + + sig { void } + def initialize + @nilable = T.let(false, T::Boolean) + end + + # Returns a new type that is `nilable` if it is not already. + # + # If the type is already nilable, it returns itself. + # ```ruby + # type = RBI::Type.simple("String") + # type.to_rbi # => "String" + # type.nilable.to_rbi # => "T.nilable(String)" + # type.nilable.nilable.to_rbi # => "T.nilable(String)" + # ``` + sig { returns(Type) } + def nilable + Type.nilable(self) + end + + # Returns the non-nilable version of the type. + # If the type is already non-nilable, it returns itself. + # If the type is nilable, it returns the inner type. + # + # ```ruby + # type = RBI::Type.nilable(RBI::Type.simple("String")) + # type.to_rbi # => "T.nilable(String)" + # type.non_nilable.to_rbi # => "String" + # type.non_nilable.non_nilable.to_rbi # => "String" + # ``` + sig { returns(Type) } + def non_nilable + # TODO: Should this logic be moved into a builder method? + case self + when Nilable + type + else + self + end + end + + # Returns whether the type is nilable. + sig { returns(T::Boolean) } + def nilable? + is_a?(Nilable) + end + + sig { abstract.params(other: BasicObject).returns(T::Boolean) } + def ==(other); end + + sig { params(other: BasicObject).returns(T::Boolean) } + def eql?(other) + self == other + end + + sig { override.returns(Integer) } + def hash + to_rbi.hash + end + + sig { abstract.returns(String) } + def to_rbi; end + + sig { override.returns(String) } + def to_s + to_rbi + end + end +end diff --git a/test/rbi/type_test.rb b/test/rbi/type_test.rb new file mode 100644 index 00000000..8ebed9af --- /dev/null +++ b/test/rbi/type_test.rb @@ -0,0 +1,441 @@ +# typed: true +# frozen_string_literal: true + +require "test_helper" + +module RBI + class TypeTest < Minitest::Test + def test_build_cant_call_new + assert_raises(NoMethodError) do + Type::Simple.new("String") + end + end + + def test_build_type_simple_raises_if_incorrect_name + Type.simple("String") + Type.simple("::String") + Type.simple("String::String") + Type.simple("S1_1::S1_1") + + exception = assert_raises(NameError) do + Type.simple("T.nilable(String)") + end + + assert_equal("Invalid type name: `T.nilable(String)`", exception.message.lines.first.strip) + + exception = assert_raises(NameError) do + Type.simple("String[Integer]") + end + + assert_equal("Invalid type name: `String[Integer]`", exception.message.lines.first.strip) + + exception = assert_raises(NameError) do + Type.simple("<< String") + end + + assert_equal("Invalid type name: `<< String`", exception.message.lines.first.strip) + end + + def test_build_type_string + type = Type.simple("String") + refute_predicate(type, :nilable?) + assert_equal("String", type.to_rbi) + end + + def test_build_type_anything + type = Type.anything + assert_equal("T.anything", type.to_rbi) + end + + def test_build_type_void + type = Type.void + assert_equal("void", type.to_rbi) + end + + def test_build_type_nilable + type = Type.simple("String") + refute_predicate(type, :nilable?) + assert_equal("String", type.to_rbi) + + type = type.nilable + assert_predicate(type, :nilable?) + assert_equal("T.nilable(String)", type.to_rbi) + end + + def test_build_type_nilable_of_untyped + type = Type.nilable(Type.untyped) + assert_instance_of(Type::Untyped, type) + assert_equal("T.untyped", type.to_rbi) + end + + def test_build_type_nilable_of_nilable + type = Type.nilable(Type.nilable(Type.simple("String"))) + assert_predicate(type, :nilable?) + assert_equal("T.nilable(String)", type.to_rbi) + end + + def test_build_non_nilable_of_simple_type + type = Type.simple("String").non_nilable + refute_predicate(type, :nilable?) + assert_equal("String", type.to_rbi) + end + + def test_build_non_nilable_of_nilable_type + type = Type.simple("String").nilable.non_nilable + refute_predicate(type, :nilable?) + assert_equal("String", type.to_rbi) + end + + def test_build_type_all + type = Type.all( + Type.simple("String"), + Type.simple("Integer"), + ) + refute_predicate(type, :nilable?) + assert_equal("T.all(String, Integer)", type.to_rbi) + end + + def test_build_type_all_of_all + type = Type.all( + Type.simple("String"), + Type.simple("Integer"), + Type.all( + Type.simple("Numeric"), + Type.simple("Integer"), + ), + ) + assert_instance_of(Type::All, type) + assert_equal("T.all(String, Integer, Numeric)", type.to_rbi) + end + + def test_build_type_all_of_dup + type = Type.all( + Type.simple("String"), + Type.simple("String"), + ) + assert_instance_of(Type::Simple, type) + assert_equal("String", type.to_rbi) + end + + def test_build_type_any + type = Type.any( + Type.simple("String"), + Type.simple("Integer"), + ) + refute_predicate(type, :nilable?) + assert_equal("T.any(String, Integer)", type.to_rbi) + + type = Type.any( + Type.simple("String"), + Type.simple("String"), + Type.simple("Integer"), + Type.simple("Integer"), + ) + refute_predicate(type, :nilable?) + assert_equal("T.any(String, Integer)", type.to_rbi) + end + + def test_build_type_any_of_any + type = Type.any( + Type.any( + Type.simple("String"), + Type.simple("Integer"), + ), + Type.any( + Type.simple("String"), + Type.simple("Symbol"), + ), + ) + + assert_instance_of(Type::Any, type) + assert_equal("T.any(String, Integer, Symbol)", type.to_rbi) + end + + def test_build_type_any_of_any_of_any + type = Type.any( + Type.any( + Type.simple("String"), + Type.simple("Integer"), + ), + Type.any( + Type.simple("Numeric"), + Type.any( + Type.simple("String"), + Type.simple("Symbol"), + ), + ), + ) + + assert_instance_of(Type::Any, type) + assert_equal("T.any(String, Integer, Numeric, Symbol)", type.to_rbi) + end + + def test_build_type_any_of_uniq + type = Type.any( + Type.simple("String"), + Type.simple("String"), + ) + assert_instance_of(Type::Simple, type) + assert_equal("String", type.to_rbi) + end + + def test_build_type_any_of_uniq_and_nilable + type = Type.any( + Type.simple("String"), + Type.simple("String"), + Type.nilable(Type.simple("String")), + ) + assert_predicate(type, :nilable?) + assert_equal("T.nilable(String)", type.to_rbi) + end + + def test_build_type_any_of_nilclass + type = Type.any( + Type.simple("String"), + Type.simple("NilClass"), + ) + assert_predicate(type, :nilable?) + assert_equal("T.nilable(String)", type.to_rbi) + end + + def test_build_type_any_of_nilable + type = Type.any( + Type.simple("String"), + Type.nilable(Type.simple("Integer")), + ) + assert_predicate(type, :nilable?) + assert_equal("T.nilable(T.any(String, Integer))", type.to_rbi) + end + + def test_build_type_any_of_trueclass_and_falseclass + type = Type.any( + Type.simple("TrueClass"), + Type.simple("String"), + Type.simple("FalseClass"), + ) + assert_equal("T.any(String, T::Boolean)", type.to_rbi) + end + + def test_build_type_any_of_trueclass_and_falseclass_and_nilclass + type = Type.any( + Type.simple("TrueClass"), + Type.simple("NilClass"), + Type.simple("FalseClass"), + ) + assert_predicate(type, :nilable?) + assert_equal("T.nilable(T::Boolean)", type.to_rbi) + end + + def test_build_type_any_of_trueclass_and_falseclass_with_nilable + type = Type.any( + Type.simple("TrueClass"), + Type.nilable(Type.simple("FalseClass")), + ) + assert_predicate(type, :nilable?) + assert_equal("T.nilable(T::Boolean)", type.to_rbi) + end + + def test_build_type_tuple + type = Type.tuple( + Type.simple("String"), + Type.simple("Integer"), + ) + refute_predicate(type, :nilable?) + assert_equal("[String, Integer]", type.to_rbi) + end + + def test_build_type_empty_tuple + type = Type.tuple + refute_predicate(type, :nilable?) + assert_equal("[]", type.to_rbi) + end + + def test_build_type_shape + type = Type.shape( + foo: Type.simple("String"), + bar: Type.simple("Integer"), + ) + refute_predicate(type, :nilable?) + assert_equal("{ foo: String, bar: Integer }", type.to_rbi) + end + + def test_build_type_void_proc + type = Type.proc + refute_predicate(type, :nilable?) + assert_equal("T.proc.void", type.to_rbi) + end + + def test_build_type_void_proc_with_explicit_void_return_type + type = Type.proc.void + refute_predicate(type, :nilable?) + assert_equal("T.proc.void", type.to_rbi) + end + + def test_build_type_void_proc_with_explicit_returns_with_void + type = Type.proc.returns(Type.void) + refute_predicate(type, :nilable?) + assert_equal("T.proc.void", type.to_rbi) + end + + def test_build_type_void_proc_with_multiple_returns_specified + type = Type.proc.returns(Type.simple("Integer")).void + refute_predicate(type, :nilable?) + assert_equal("T.proc.void", type.to_rbi) + end + + def test_build_type_void_nilable_proc + type = Type.proc.nilable + assert_predicate(type, :nilable?) + assert_equal("T.nilable(T.proc.void)", type.to_rbi) + end + + def test_build_type_void_proc_with_params + type = Type.proc.params(foo: Type.simple("String"), bar: Type.simple("Integer")) + refute_predicate(type, :nilable?) + assert_equal("T.proc.params(foo: String, bar: Integer).void", type.to_rbi) + end + + def test_build_type_void_nilable_proc_with_params + type = Type.proc.params(foo: Type.simple("String"), bar: Type.simple("Integer")).nilable + assert_predicate(type, :nilable?) + assert_equal("T.nilable(T.proc.params(foo: String, bar: Integer).void)", type.to_rbi) + end + + def test_build_type_symbol_returning_proc_with_params + type = Type.proc.params(foo: Type.simple("String"), bar: Type.simple("Integer")).returns(Type.simple("Symbol")) + refute_predicate(type, :nilable?) + assert_equal("T.proc.params(foo: String, bar: Integer).returns(Symbol)", type.to_rbi) + end + + def test_build_type_symbol_returning_proc_with_params_and_bind + type = Type.proc + .params( + foo: Type.simple("String"), + bar: Type.simple("Integer"), + ) + .returns(Type.simple("Symbol")) + .bind(Type.class_of(Type.simple("Base"))) + refute_predicate(type, :nilable?) + assert_equal("T.proc.bind(T.class_of(Base)).params(foo: String, bar: Integer).returns(Symbol)", type.to_rbi) + end + + def test_build_type_void_proc_with_bind + type = Type.proc + .bind(Type.class_of(Type.simple("Base"))) + refute_predicate(type, :nilable?) + assert_equal("T.proc.bind(T.class_of(Base)).void", type.to_rbi) + end + + def test_build_type_empty_shape + type = Type.shape + refute_predicate(type, :nilable?) + assert_equal("{}", type.to_rbi) + end + + def test_build_type_generic + type = Type.generic("T::Array", Type.simple("String")) + refute_predicate(type, :nilable?) + assert_equal("T::Array[String]", type.to_rbi) + + type = Type.generic("T::Hash", Type.simple("Integer"), Type.simple("String")) + refute_predicate(type, :nilable?) + assert_equal("T::Hash[Integer, String]", type.to_rbi) + end + + def test_build_type_parameter + type = Type.type_parameter(:U) + assert_equal("T.type_parameter(:U)", type.to_rbi) + + type = Type.type_parameter(:" !") + assert_equal("T.type_parameter(:\" !\")", type.to_rbi) + end + + def test_build_type_class_of + type = Type.class_of(Type.simple("String")) + assert_equal("T.class_of(String)", type.to_rbi) + end + + def test_build_type_self_type + type = Type.self_type + assert_equal("T.self_type", type.to_rbi) + end + + def test_build_type_attached_class + type = Type.attached_class + assert_equal("T.attached_class", type.to_rbi) + end + + def test_build_type_untyped + type = Type.untyped + assert_equal("T.untyped", type.to_rbi) + end + + def test_types_comparison + type1 = Type.simple("String") + type2 = Type.simple("String") + assert_equal(type1, type2) + + type3 = Type.simple("Integer") + refute_equal(type1, type3) + + type4 = Type.nilable(Type.simple("String")) + refute_equal(type1, type4) + + type5 = Type.nilable(Type.simple("String")) + assert_equal(type4, type5) + + type6 = Type.generic("Foo", Type.simple("String")) + type7 = Type.generic("Foo", Type.simple("String")) + assert_equal(type6, type7) + + type8 = Type.generic("Foo", Type.simple("Integer")) + refute_equal(type6, type8) + + type9 = Type.generic("Bar", Type.simple("String")) + refute_equal(type6, type9) + + type10 = Type.any(Type.simple("String"), Type.simple("NilClass")) + assert_equal(type4, type10) + + type11 = Type.any(Type.simple("String"), Type.simple("NilClass")) + assert_equal(type10, type11) + + type12 = Type.any(Type.simple("String"), Type.simple("Integer")) + refute_equal(type10, type12) + + type13 = Type.any(Type.simple("Integer"), Type.simple("String")) + assert_equal(type12, type13) + + type15 = Type.untyped + refute_equal(type1, type15) + + type16 = Type.boolean + type17 = Type.any(Type.simple("TrueClass"), Type.simple("FalseClass")) + assert_equal(type16, type17) + + type18 = Type.nilable(Type.untyped) + assert_equal(type15, type18) + + type19 = Type.any(Type.simple("String"), Type.simple("String")) + assert_equal(type19, type2) + + type20 = Type.all(Type.simple("String"), Type.simple("Integer")) + type21 = Type.all(Type.simple("Integer"), Type.simple("String")) + assert_equal(type20, type21) + + type22 = Type.shape(foo: Type.simple("String"), bar: Type.simple("Integer")) + type23 = Type.shape(bar: Type.simple("Integer"), foo: Type.simple("String")) + assert_equal(type22, type23) + + type24 = Type.shape(foo: Type.simple("Integer"), bar: Type.simple("String")) + refute_equal(type22, type24) + + type25 = Type.tuple(Type.simple("String"), Type.simple("Integer")) + type26 = Type.tuple(Type.simple("String"), Type.simple("Integer")) + assert_equal(type25, type26) + + type27 = Type.tuple(Type.simple("Integer"), Type.simple("String")) + refute_equal(type25, type27) + end + end +end From 1597d40f40ef45ed91722290ff23e18b68fcb9d1 Mon Sep 17 00:00:00 2001 From: Alexandre Terrasa Date: Wed, 31 Jul 2024 11:26:19 -0400 Subject: [PATCH 2/2] Model accept either String or Type when a RBI type is expected Signed-off-by: Alexandre Terrasa Co-authored-by: Ufuk Kayserilioglu --- lib/rbi/model.rb | 38 ++++++++++++++---------- lib/rbi/parser.rb | 2 +- lib/rbi/printer.rb | 14 ++++----- lib/rbi/rewriters/attr_to_methods.rb | 4 +-- test/rbi/model_test.rb | 43 ++++++++++++++++++++++++++++ 5 files changed, 76 insertions(+), 25 deletions(-) diff --git a/lib/rbi/model.rb b/lib/rbi/model.rb index f44cbbad..a8195540 100644 --- a/lib/rbi/model.rb +++ b/lib/rbi/model.rb @@ -574,7 +574,7 @@ def add_block_param(name) sig do params( params: T::Array[SigParam], - return_type: T.nilable(String), + return_type: T.any(String, Type), is_abstract: T::Boolean, is_override: T::Boolean, is_overridable: T::Boolean, @@ -586,7 +586,7 @@ def add_block_param(name) end def add_sig( params: [], - return_type: nil, + return_type: "void", is_abstract: false, is_override: false, is_overridable: false, @@ -928,8 +928,10 @@ def initialize(visibility, loc: nil, comments: []) @visibility = visibility end - sig { params(other: Visibility).returns(T::Boolean) } + sig { params(other: T.nilable(Object)).returns(T::Boolean) } def ==(other) + return false unless other.is_a?(Visibility) + visibility == other.visibility end @@ -1105,7 +1107,7 @@ class Sig < Node sig { returns(T::Array[SigParam]) } attr_reader :params - sig { returns(T.nilable(String)) } + sig { returns(T.any(Type, String)) } attr_accessor :return_type sig { returns(T::Boolean) } @@ -1120,7 +1122,7 @@ class Sig < Node sig do params( params: T::Array[SigParam], - return_type: T.nilable(String), + return_type: T.any(Type, String), is_abstract: T::Boolean, is_override: T::Boolean, is_overridable: T::Boolean, @@ -1133,7 +1135,7 @@ class Sig < Node end def initialize( params: [], - return_type: nil, + return_type: "void", is_abstract: false, is_override: false, is_overridable: false, @@ -1160,7 +1162,7 @@ def <<(param) @params << param end - sig { params(name: String, type: String).void } + sig { params(name: String, type: T.any(Type, String)).void } def add_param(name, type) @params << SigParam.new(name, type) end @@ -1169,7 +1171,7 @@ def add_param(name, type) def ==(other) return false unless other.is_a?(Sig) - params == other.params && return_type == other.return_type && is_abstract == other.is_abstract && + params == other.params && return_type.to_s == other.return_type.to_s && is_abstract == other.is_abstract && is_override == other.is_override && is_overridable == other.is_overridable && is_final == other.is_final && type_params == other.type_params && checked == other.checked end @@ -1179,12 +1181,15 @@ class SigParam < NodeWithComments extend T::Sig sig { returns(String) } - attr_reader :name, :type + attr_reader :name + + sig { returns(T.any(Type, String)) } + attr_reader :type sig do params( name: String, - type: String, + type: T.any(Type, String), loc: T.nilable(Loc), comments: T::Array[Comment], block: T.nilable(T.proc.params(node: SigParam).void), @@ -1199,7 +1204,7 @@ def initialize(name, type, loc: nil, comments: [], &block) sig { params(other: Object).returns(T::Boolean) } def ==(other) - other.is_a?(SigParam) && name == other.name && type == other.type + other.is_a?(SigParam) && name == other.name && type.to_s == other.type.to_s end end @@ -1229,7 +1234,10 @@ class TStructField < NodeWithComments abstract! sig { returns(String) } - attr_accessor :name, :type + attr_accessor :name + + sig { returns(T.any(Type, String)) } + attr_accessor :type sig { returns(T.nilable(String)) } attr_accessor :default @@ -1237,7 +1245,7 @@ class TStructField < NodeWithComments sig do params( name: String, - type: String, + type: T.any(Type, String), default: T.nilable(String), loc: T.nilable(Loc), comments: T::Array[Comment], @@ -1260,7 +1268,7 @@ class TStructConst < TStructField sig do params( name: String, - type: String, + type: T.any(Type, String), default: T.nilable(String), loc: T.nilable(Loc), comments: T::Array[Comment], @@ -1290,7 +1298,7 @@ class TStructProp < TStructField sig do params( name: String, - type: String, + type: T.any(Type, String), default: T.nilable(String), loc: T.nilable(Loc), comments: T::Array[Comment], diff --git a/lib/rbi/parser.rb b/lib/rbi/parser.rb index 541ed78e..1d311d40 100644 --- a/lib/rbi/parser.rb +++ b/lib/rbi/parser.rb @@ -839,7 +839,7 @@ def visit_call_node(node) end end when "void" - @current.return_type = nil + @current.return_type = "void" end visit(node.receiver) diff --git a/lib/rbi/printer.rb b/lib/rbi/printer.rb index 21e505cc..b41c78e1 100644 --- a/lib/rbi/printer.rb +++ b/lib/rbi/printer.rb @@ -611,7 +611,7 @@ def print_param_comment_leading_space(node, last:) def print_sig_param_comment_leading_space(node, last:) printn printt - print(" " * (node.name.size + node.type.size + 3)) + print(" " * (node.name.size + node.type.to_s.size + 3)) print(" ") unless last end @@ -654,10 +654,10 @@ def print_sig_as_line(node) print(").") end return_type = node.return_type - if node.return_type && node.return_type != "void" - print("returns(#{return_type})") - else + if node.return_type.to_s == "void" print("void") + else + print("returns(#{return_type})") end printn(" }") end @@ -707,10 +707,10 @@ def print_sig_as_block(node) print(".") if modifiers.any? || params.any? return_type = node.return_type - if return_type && return_type != "void" - print("returns(#{return_type})") - else + if return_type.to_s == "void" print("void") + else + print("returns(#{return_type})") end printn dedent diff --git a/lib/rbi/rewriters/attr_to_methods.rb b/lib/rbi/rewriters/attr_to_methods.rb index 0bf4257a..02ced33e 100644 --- a/lib/rbi/rewriters/attr_to_methods.rb +++ b/lib/rbi/rewriters/attr_to_methods.rb @@ -62,7 +62,7 @@ def convert_to_methods; end private - sig(:final) { returns([T.nilable(Sig), T.nilable(String)]) } + sig(:final) { returns([T.nilable(Sig), T.nilable(T.any(Type, String))]) } def parse_sig raise UnexpectedMultipleSigsError, self if 1 < sigs.count @@ -101,7 +101,7 @@ def create_getter_method(name, sig, visibility, loc, comments) params( name: String, sig: T.nilable(Sig), - attribute_type: T.nilable(String), + attribute_type: T.nilable(T.any(Type, String)), visibility: Visibility, loc: T.nilable(Loc), comments: T::Array[Comment], diff --git a/test/rbi/model_test.rb b/test/rbi/model_test.rb index ff8430f3..9d75752a 100644 --- a/test/rbi/model_test.rb +++ b/test/rbi/model_test.rb @@ -420,5 +420,48 @@ def test_model_nodes_as_strings mod << helper assert_equal("::Foo.foo!", helper.to_s) end + + # types + + def test_model_sig_builder_with_types + rbi = Tree.new do |tree| + tree << Method.new("foo") do |node| + node.add_param("x") + + node.add_sig do |sig| + sig.add_param("x", Type.untyped) + sig.return_type = Type.void + end + end + end + + assert_equal(<<~RBI, rbi.string) + sig { params(x: T.untyped).void } + def foo(x); end + RBI + end + + def test_model_sig_with_types + node = Sig.new + node << SigParam.new("x", Type.untyped) + node.return_type = Type.simple("Integer") + + assert_equal(<<~RBI, node.string) + sig { params(x: T.untyped).returns(Integer) } + RBI + end + + def test_t_struct_with_types + node = TStruct.new("MyStruct") + node << TStructConst.new("foo", Type.simple("Foo")) + node << TStructProp.new("bar", Type.simple("Bar")) + + assert_equal(<<~RBI, node.string) + class MyStruct < ::T::Struct + const :foo, Foo + prop :bar, Bar + end + RBI + end end end