diff --git a/spec/std/hash_spec.cr b/spec/std/hash_spec.cr index 94d7b98b5d52..e241f09404a0 100644 --- a/spec/std/hash_spec.cr +++ b/spec/std/hash_spec.cr @@ -1187,4 +1187,42 @@ describe "Hash" do end end end + + describe "compare_by_identity" do + it "small hash" do + string = "foo" + h = {string => 1} + h.compare_by_identity?.should be_false + h.compare_by_identity + h.compare_by_identity?.should be_true + h[string]?.should eq(1) + h["fo" + "o"]?.should be_nil + end + + it "big hash" do + h = {} of String => Int32 + nums = (100..116).to_a + strings = nums.map(&.to_s) + strings.zip(nums) do |string, num| + h[string] = num + end + h.compare_by_identity + nums.each do |num| + h[num.to_s]?.should be_nil + end + strings.zip(nums) do |string, num| + h[string]?.should eq(num) + end + end + + it "retains compare_by_identity on dup" do + h = ({} of String => Int32).compare_by_identity + h.dup.compare_by_identity?.should be_true + end + + it "retains compare_by_identity on clone" do + h = ({} of String => Int32).compare_by_identity + h.clone.compare_by_identity?.should be_true + end + end end diff --git a/spec/std/set_spec.cr b/spec/std/set_spec.cr index 57ebed1a6958..d998ffcbde60 100644 --- a/spec/std/set_spec.cr +++ b/spec/std/set_spec.cr @@ -396,4 +396,29 @@ describe "Set" do end typeof(Set(Int32).new(initial_capacity: 1234)) + + describe "compare_by_identity" do + it "compares by identity" do + string = "foo" + set = Set{string, "bar", "baz"} + set.compare_by_identity?.should be_false + set.includes?(string).should be_true + + set.compare_by_identity + set.compare_by_identity?.should be_true + + set.includes?("fo" + "o").should be_false + set.includes?(string).should be_true + end + + it "retains compare_by_identity on dup" do + set = Set(String).new.compare_by_identity + set.dup.compare_by_identity?.should be_true + end + + it "retains compare_by_identity on clone" do + set = Set(String).new.compare_by_identity + set.clone.compare_by_identity?.should be_true + end + end end diff --git a/src/compiler/crystal/codegen/exception.cr b/src/compiler/crystal/codegen/exception.cr index 0d8e1782a257..90003c8bf0de 100644 --- a/src/compiler/crystal/codegen/exception.cr +++ b/src/compiler/crystal/codegen/exception.cr @@ -1,7 +1,7 @@ require "./codegen" class Crystal::CodeGenVisitor - @node_ensure_exception_handlers = {} of UInt64 => Handler + @node_ensure_exception_handlers : Hash(ASTNode, Handler) = ({} of ASTNode => Handler).compare_by_identity def visit(node : ExceptionHandler) # In this codegen, we assume that LLVM only provides us with a basic try/catch abstraction with no @@ -289,7 +289,7 @@ class Crystal::CodeGenVisitor end def execute_ensures_until(node) - stop_exception_handler = @node_ensure_exception_handlers[node.object_id]?.try &.node + stop_exception_handler = @node_ensure_exception_handlers[node]?.try &.node @ensure_exception_handlers.try &.reverse_each do |exception_handler| break if exception_handler.node.same?(stop_exception_handler) @@ -305,7 +305,7 @@ class Crystal::CodeGenVisitor def set_ensure_exception_handler(node) if eh = @ensure_exception_handlers.try &.last? - @node_ensure_exception_handlers[node.object_id] = eh + @node_ensure_exception_handlers[node] = eh end end diff --git a/src/compiler/crystal/program.cr b/src/compiler/crystal/program.cr index 6171ff1c0fe8..63954b1f1158 100644 --- a/src/compiler/crystal/program.cr +++ b/src/compiler/crystal/program.cr @@ -46,7 +46,7 @@ module Crystal # associated to a def's object id (the UInt64), and on an instantiation # we compare the new type with the previous one and check if it contains # the previous type. - getter splat_expansions = {} of UInt64 => Type + getter splat_expansions : Hash(Def, Type) = ({} of Def => Type).compare_by_identity # All FileModules indexed by their filename. # These store file-private defs, and top-level variables in files other diff --git a/src/compiler/crystal/semantic/bindings.cr b/src/compiler/crystal/semantic/bindings.cr index a1c7eb898156..1db25f097c6a 100644 --- a/src/compiler/crystal/semantic/bindings.cr +++ b/src/compiler/crystal/semantic/bindings.cr @@ -217,16 +217,16 @@ module Crystal owner_trace = [] of ASTNode node = self - visited = Set(typeof(object_id)).new + visited = Set(ASTNode).new.compare_by_identity owner_trace << node if node.type?.try &.includes_type?(owner) - visited.add node.object_id + visited.add node while deps = node.dependencies? - dependencies = deps.select { |dep| dep.type? && dep.type.includes_type?(owner) && !visited.includes?(dep.object_id) } + dependencies = deps.select { |dep| dep.type? && dep.type.includes_type?(owner) && !visited.includes?(dep) } if dependencies.size > 0 node = dependencies.first nil_reason = node.nil_reason if node.is_a?(MetaTypeVar) owner_trace << node if node - visited.add node.object_id + visited.add node else break end diff --git a/src/compiler/crystal/semantic/call_error.cr b/src/compiler/crystal/semantic/call_error.cr index fb1bbcf3e8dc..365a94850a1d 100644 --- a/src/compiler/crystal/semantic/call_error.cr +++ b/src/compiler/crystal/semantic/call_error.cr @@ -638,14 +638,14 @@ class Crystal::Call def check_recursive_splat_call(a_def, args) if a_def.splat_index current_splat_type = args.values.last.type - if previous_splat_type = program.splat_expansions[a_def.object_id]? + if previous_splat_type = program.splat_expansions[a_def]? if current_splat_type.has_in_type_vars?(previous_splat_type) raise "recursive splat expansion: #{previous_splat_type}, #{current_splat_type}, ..." end end - program.splat_expansions[a_def.object_id] = current_splat_type + program.splat_expansions[a_def] = current_splat_type yield - program.splat_expansions.delete a_def.object_id + program.splat_expansions.delete a_def else yield end diff --git a/src/compiler/crystal/semantic/cleanup_transformer.cr b/src/compiler/crystal/semantic/cleanup_transformer.cr index 00ef39f42912..267cceb6b6d0 100644 --- a/src/compiler/crystal/semantic/cleanup_transformer.cr +++ b/src/compiler/crystal/semantic/cleanup_transformer.cr @@ -57,8 +57,10 @@ module Crystal # idea on how to generate code for unreachable branches, because they have no type, # and for now the codegen only deals with typed nodes. class CleanupTransformer < Transformer + @transformed : Set(Def) + def initialize(@program : Program) - @transformed = Set(UInt64).new + @transformed = Set(Def).new.compare_by_identity @def_nest_count = 0 @last_is_truthy = false @last_is_falsey = false @@ -361,9 +363,7 @@ module Crystal end target_defs.each do |target_def| - unless @transformed.includes?(target_def.object_id) - @transformed.add(target_def.object_id) - + if @transformed.add?(target_def) node.bubbling_exception do @def_nest_count += 1 target_def.body = target_def.body.transform(self) diff --git a/src/compiler/crystal/semantic/fix_missing_types.cr b/src/compiler/crystal/semantic/fix_missing_types.cr index 55d66fdcb1f5..2f9f5d9ac678 100644 --- a/src/compiler/crystal/semantic/fix_missing_types.cr +++ b/src/compiler/crystal/semantic/fix_missing_types.cr @@ -2,11 +2,11 @@ require "../semantic" class Crystal::FixMissingTypes < Crystal::Visitor @program : Program - @fixed : Set(UInt64) + @fixed : Set(Def) def initialize(mod) @program = mod - @fixed = Set(typeof(object_id)).new + @fixed = Set(Def).new.compare_by_identity end def visit(node : Def) @@ -71,8 +71,7 @@ class Crystal::FixMissingTypes < Crystal::Visitor end node.target_defs.try &.each do |target_def| - unless @fixed.includes?(target_def.object_id) - @fixed.add(target_def.object_id) + if @fixed.add?(target_def) target_def.type = @program.no_return unless target_def.type? target_def.accept_children self end diff --git a/src/compiler/crystal/semantic/main_visitor.cr b/src/compiler/crystal/semantic/main_visitor.cr index d3134f0a11c2..4094e6c097e1 100644 --- a/src/compiler/crystal/semantic/main_visitor.cr +++ b/src/compiler/crystal/semantic/main_visitor.cr @@ -1633,7 +1633,7 @@ module Crystal @scope : Type @in_super : Int32 @callstack : Array(ASTNode) - @visited : Set(UInt64)? + @visited : Set(Def)? @vars : MetaVars def initialize(a_def, @scope, @vars) @@ -1687,10 +1687,10 @@ module Crystal node.target_defs.try &.each do |target_def| if target_def.owner == @scope - next if visited.try &.includes?(target_def.object_id) + next if visited.try &.includes?(target_def) - visited = @visited ||= Set(typeof(object_id)).new - visited << target_def.object_id + visited = @visited ||= Set(Def).new.compare_by_identity + visited << target_def @callstack.push(node) target_def.body.accept self diff --git a/src/compiler/crystal/tools/context.cr b/src/compiler/crystal/tools/context.cr index 2885a28f0364..9d9bf973b444 100644 --- a/src/compiler/crystal/tools/context.cr +++ b/src/compiler/crystal/tools/context.cr @@ -72,8 +72,10 @@ module Crystal end class RechableVisitor < Visitor + @visited_typed_defs : Set(Def) + def initialize(@context_visitor : Crystal::ContextVisitor) - @visited_typed_defs = Set(UInt64).new + @visited_typed_defs = Set(Def).new.compare_by_identity end def visit(node : Call) @@ -91,9 +93,7 @@ module Crystal end def visit(node : Def) - should_visit = !@visited_typed_defs.includes?(node.object_id) - @visited_typed_defs << node.object_id if should_visit - return should_visit + @visited_typed_defs.add?(node) end def visit(node) diff --git a/src/compiler/crystal/tools/formatter.cr b/src/compiler/crystal/tools/formatter.cr index 1801866b10c9..8161cbae13f4 100644 --- a/src/compiler/crystal/tools/formatter.cr +++ b/src/compiler/crystal/tools/formatter.cr @@ -82,7 +82,7 @@ module Crystal @assign_infos : Array(AlignInfo) @doc_comments : Array(CommentInfo) @current_doc_comment : CommentInfo? - @hash_in_same_line : Set(UInt64) + @hash_in_same_line : Set(ASTNode) @shebang : Bool @heredoc_fixes : Array(HeredocFix) @assign_length : Int32? @@ -136,7 +136,7 @@ module Crystal @assign_infos = [] of AlignInfo @doc_comments = [] of CommentInfo @current_doc_comment = nil - @hash_in_same_line = Set(UInt64).new + @hash_in_same_line = Set(ASTNode).new.compare_by_identity @shebang = @token.type == :COMMENT && @token.value.to_s.starts_with?("#!") @heredoc_fixes = [] of HeredocFix @last_is_heredoc = false @@ -896,7 +896,7 @@ module Crystal format_hash_entry nil, node_of end - if @hash_in_same_line.includes? node.object_id + if @hash_in_same_line.includes? node @hash_infos.reject! { |info| info.id == node.object_id } end @@ -930,8 +930,8 @@ module Crystal skip_space_or_newline accept entry.value - if found_in_same_line - @hash_in_same_line << hash.object_id + if hash && found_in_same_line + @hash_in_same_line << hash end end @@ -941,7 +941,7 @@ module Crystal format_literal_elements node.entries, :"{", :"}" @current_hash = old_hash - if @hash_in_same_line.includes? node.object_id + if @hash_in_same_line.includes? node @hash_infos.reject! { |info| info.id == node.object_id } end @@ -965,7 +965,7 @@ module Crystal accept entry.value if found_in_same_line - @hash_in_same_line << hash.object_id + @hash_in_same_line << hash end end diff --git a/src/hash.cr b/src/hash.cr index a03b21a87b45..eb7185606240 100644 --- a/src/hash.cr +++ b/src/hash.cr @@ -187,6 +187,9 @@ class Hash(K, V) # Otherwise guaranteed to be at least 3. @indices_size_pow2 : UInt8 + # Whether to compare objects using `object_id`. + @compare_by_identity : Bool = false + # The optional block that triggers on non-existing keys. @block : (self, K -> V)? @@ -386,7 +389,7 @@ class Hash(K, V) # We found a non-empty slot, let's see if the key we have matches entry = get_entry(entry_index) - if entry.matches?(hash, key) + if entry_matches?(entry, hash, key) # If it does we just update the entry set_entry(entry_index, Entry(K, V).new(hash, key, value)) return entry @@ -402,7 +405,7 @@ class Hash(K, V) private def update_linear_scan(key, value, hash) : Entry(K, V)? # Just do a linear scan... each_entry_with_index do |entry, index| - if entry.matches?(hash, key) + if entry_matches?(entry, hash, key) set_entry(index, Entry(K, V).new(entry.hash, entry.key, value)) return entry end @@ -438,7 +441,7 @@ class Hash(K, V) # We found a non-empty slot, let's see if the key we have matches entry = get_entry(entry_index) - if entry.matches?(hash, key) + if entry_matches?(entry, hash, key) delete_entry_and_update_counts(entry_index) return entry else @@ -452,7 +455,7 @@ class Hash(K, V) # Returns the deleted Entry, if it existed, `nil` otherwise. private def delete_linear_scan(key, hash) : Entry(K, V)? each_entry_with_index do |entry, index| - if entry.matches?(hash, key) + if entry_matches?(entry, hash, key) delete_entry_and_update_counts(index) return entry end @@ -487,7 +490,7 @@ class Hash(K, V) # We found a non-empty slot, let's see if the key we have matches entry = get_entry(entry_index) - if entry.matches?(hash, key) + if entry_matches?(entry, hash, key) # It does! return entry else @@ -504,12 +507,12 @@ class Hash(K, V) # computing a hash code of a complex structure). if entries_size <= 8 each_entry_with_index do |entry| - return entry if entry.key == key + return entry if entry_matches?(entry, key) end else hash = key_hash(key) each_entry_with_index do |entry| - return entry if entry.matches?(hash, key) + return entry if entry_matches?(entry, hash, key) end end @@ -636,6 +639,8 @@ class Hash(K, V) # Initializes a `dup` copy from the contents of `other`. protected def initialize_dup(other) + initialize_compare_by_identity(other) + return if other.empty? initialize_dup_entries(other) @@ -644,12 +649,18 @@ class Hash(K, V) # Initializes a `clone` copy from the contents of `other`. protected def initialize_clone(other) + initialize_compare_by_identity(other) + return if other.empty? initialize_clone_entries(other) initialize_copy_non_entries_vars(other) end + private def initialize_compare_by_identity(other) + compare_by_identity if other.compare_by_identity? + end + # Initializes `@entries` for a dup copy. # Here we only need tu duplicate the buffer. private def initialize_dup_entries(other) @@ -892,10 +903,51 @@ class Hash(K, V) # Computes the hash of a key. private def key_hash(key) - hash = key.hash.to_u32! + if @compare_by_identity && key.responds_to?(:object_id) + hash = key.object_id.hash.to_u32! + else + hash = key.hash.to_u32! + end hash == 0 ? UInt32::MAX : hash end + private def entry_matches?(entry, hash, key) + # Tiny optimization: for these primitive types it's faster to just + # compare the key instead of comparing the hash and the key. + # We still have to skip hashes with value 0 (means deleted). + {% if K == Bool || + K == Char || + K == Symbol || + K < Int::Primitive || + K < Float::Primitive || + K < Enum %} + entry.key == key && entry.hash != 0_u32 + {% else %} + entry.hash == hash && entry_matches?(entry, key) + {% end %} + end + + private def entry_matches?(entry, key) + entry_key = entry.key + + if @compare_by_identity + if entry_key.responds_to?(:object_id) + if key.responds_to?(:object_id) + entry_key.object_id == key.object_id + else + false + end + elsif key.responds_to?(:object_id) + # because entry_key doesn't respond to :object_id + false + else + entry_key == key + end + else + entry_key == key + end + end + # =========================================================================== # Internal implementation ends # =========================================================================== @@ -903,6 +955,29 @@ class Hash(K, V) # Returns the number of elements in this Hash. getter size : Int32 + # Makes this hash compare keys using their object identity (`object_id)` + # for types that define such method (`Reference` types, but also structs that + # might wrap other `Reference` types and delegate the `object_id` method to them). + # + # ``` + # h1 = {"foo" => 1, "bar" => 2} + # h1["fo" + "o"]? # => 1 + # + # h1.compare_by_identity + # h1.compare_by_identity? # => true + # h1["fo" + "o"]? # => nil # not the same String instance + # ``` + def compare_by_identity + @compare_by_identity = true + rehash + self + end + + # Returns `true` of this Hash is comparing keys by `object_id`. + # + # See `compare_by_identity`. + getter? compare_by_identity + # Sets the value of *key* to the given *value*. # # ``` @@ -1837,22 +1912,6 @@ class Hash(K, V) @hash == 0_u32 end - def matches?(hash, key) - # Tiny optimization: for these primitive types it's faster to just - # compare the key instead of comparing the hash and the key. - # We still have to skip hashes with value 0 (means deleted). - {% if K == Bool || - K == Char || - K == Symbol || - K < Int::Primitive || - K < Float::Primitive || - K < Enum %} - @key == key && @hash != 0_u32 - {% else %} - @hash == hash && @key == key - {% end %} - end - def clone Entry(K, V).new(hash, key, value.clone) end diff --git a/src/set.cr b/src/set.cr index 90514a94b076..5309613738bc 100644 --- a/src/set.cr +++ b/src/set.cr @@ -53,6 +53,30 @@ struct Set(T) Set(T).new.concat(enumerable) end + # Makes this set compare objects using their object identity (`object_id)` + # for types that define such method (`Reference` types, but also structs that + # might wrap other `Reference` types and delegate the `object_id` method to them). + # + # ``` + # s = Set{"foo", "bar"} + # s.includes?("fo" + "o") # => true + # + # s.compare_by_identity + # s.compare_by_identity? # => true + # s.includes?("fo" + "o") # => false # not the same String instance + # ``` + def compare_by_identity + @hash.compare_by_identity + self + end + + # Returns `true` of this Set is comparing objects by `object_id`. + # + # See `compare_by_identity`. + def compare_by_identity? + @hash.compare_by_identity? + end + # Alias for `add` def <<(object : T) add object @@ -323,12 +347,15 @@ struct Set(T) # Returns a new `Set` with all of the same elements. def dup - Set.new(self) + set = Set.new(self) + set.compare_by_identity if compare_by_identity? + set end # Returns a new `Set` with all of the elements cloned. def clone clone = Set(T).new(self.size) + clone.compare_by_identity if compare_by_identity? each do |element| clone << element.clone end