From 0a8885dbae2661b28ad0d1b53bc3ee84f08aacde Mon Sep 17 00:00:00 2001 From: Aiden Storey Date: Mon, 8 Jul 2024 05:53:35 +0000 Subject: [PATCH] Fix missing array signature for ActiveRecordRelation #create --- .../dsl/compilers/active_record_relations.rb | 57 +++++++++++++------ .../compilers/active_record_relations_spec.rb | 18 ++++++ 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/lib/tapioca/dsl/compilers/active_record_relations.rb b/lib/tapioca/dsl/compilers/active_record_relations.rb index 4198d7290..d89cd4fce 100644 --- a/lib/tapioca/dsl/compilers/active_record_relations.rb +++ b/lib/tapioca/dsl/compilers/active_record_relations.rb @@ -220,7 +220,7 @@ def gather_constants [:find_or_create_by, :find_or_create_by!, :find_or_initialize_by, :create_or_find_by, :create_or_find_by!], T::Array[Symbol], ) - BUILDER_METHODS = T.let([:new, :build, :create, :create!], T::Array[Symbol]) + BUILDER_METHODS = T.let([:new, :create, :create!, :build], T::Array[Symbol]) TO_ARRAY_METHODS = T.let([:to_ary, :to_a], T::Array[Symbol]) private @@ -991,25 +991,50 @@ def create_common_methods end FIND_OR_CREATE_METHODS.each do |method_name| - block_type = "T.nilable(T.proc.params(object: #{constant_name}).void)" - create_common_method( - method_name, - parameters: [ - create_param("attributes", type: "T.untyped"), - create_block_param("block", type: block_type), - ], - return_type: constant_name, + sigs = [ + common_relation_methods_module.create_sig( + parameters: { + attributes: "T::Array[T.untyped]", + block: "T.nilable(T.proc.params(objects: #{constant_name}).void)", + }, + return_type: "T::Array[#{constant_name}]", + ), + common_relation_methods_module.create_sig( + parameters: { + attributes: "T.untyped", + block: "T.nilable(T.proc.params(object: #{constant_name}).void)", + }, + return_type: constant_name, + ), + ] + common_relation_methods_module.create_method_with_sigs( + method_name.to_s, + sigs: sigs, + parameters: [RBI::ReqParam.new("attributes"), RBI::BlockParam.new("block")], ) end BUILDER_METHODS.each do |method_name| - create_common_method( - method_name, - parameters: [ - create_opt_param("attributes", type: "T.untyped", default: "nil"), - create_block_param("block", type: "T.nilable(T.proc.params(object: #{constant_name}).void)"), - ], - return_type: constant_name, + sigs = [ + common_relation_methods_module.create_sig( + parameters: { + attributes: "T::Array[T.untyped]", + block: "T.nilable(T.proc.params(objects: #{constant_name}).void)", + }, + return_type: "T::Array[#{constant_name}]", + ), + common_relation_methods_module.create_sig( + parameters: { + attributes: "T.untyped", + block: "T.nilable(T.proc.params(object: #{constant_name}).void)", + }, + return_type: constant_name, + ), + ] + common_relation_methods_module.create_method_with_sigs( + method_name.to_s, + sigs: sigs, + parameters: [RBI::OptParam.new("attributes", "nil"), RBI::BlockParam.new("block")], ) end end diff --git a/spec/tapioca/dsl/compilers/active_record_relations_spec.rb b/spec/tapioca/dsl/compilers/active_record_relations_spec.rb index 8aee085ce..c7b8fec15 100644 --- a/spec/tapioca/dsl/compilers/active_record_relations_spec.rb +++ b/spec/tapioca/dsl/compilers/active_record_relations_spec.rb @@ -97,6 +97,7 @@ def any?(&block); end sig { params(column_name: T.any(String, Symbol)).returns(T.any(Integer, Float, BigDecimal)) } def average(column_name); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def build(attributes = nil, &block); end @@ -107,15 +108,19 @@ def calculate(operation, column_name); end sig { params(column_name: NilClass, block: T.proc.params(object: ::Post).void).returns(Integer) } def count(column_name = nil, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def create(attributes = nil, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def create!(attributes = nil, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def create_or_find_by(attributes, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def create_or_find_by!(attributes, &block); end @@ -150,12 +155,15 @@ def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, o sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) } def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def find_or_create_by(attributes, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def find_or_create_by!(attributes, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def find_or_initialize_by(attributes, &block); end @@ -224,6 +232,7 @@ def member?(record); end sig { params(column_name: T.any(String, Symbol)).returns(T.untyped) } def minimum(column_name); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def new(attributes = nil, &block); end @@ -790,6 +799,7 @@ def any?(&block); end sig { params(column_name: T.any(String, Symbol)).returns(T.any(Integer, Float, BigDecimal)) } def average(column_name); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def build(attributes = nil, &block); end @@ -800,15 +810,19 @@ def calculate(operation, column_name); end sig { params(column_name: NilClass, block: T.proc.params(object: ::Post).void).returns(Integer) } def count(column_name = nil, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def create(attributes = nil, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def create!(attributes = nil, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def create_or_find_by(attributes, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def create_or_find_by!(attributes, &block); end @@ -847,12 +861,15 @@ def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, o sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) } def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def find_or_create_by(attributes, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def find_or_create_by!(attributes, &block); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def find_or_initialize_by(attributes, &block); end @@ -921,6 +938,7 @@ def member?(record); end sig { params(column_name: T.any(String, Symbol)).returns(T.untyped) } def minimum(column_name); end + sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) } sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) } def new(attributes = nil, &block); end