Skip to content

Commit

Permalink
Added NormalizedAttribute
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 2, 2024
1 parent 3fb189b commit b0a0c8c
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 5 deletions.
1 change: 1 addition & 0 deletions lib/neighbor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def register_vector_type(m)
ActiveSupport.on_load(:active_record) do
require_relative "neighbor/attribute"
require_relative "neighbor/model"
require_relative "neighbor/normalized_attribute"
require_relative "neighbor/type/cube"
require_relative "neighbor/type/halfvec"
require_relative "neighbor/type/mysql_vector"
Expand Down
2 changes: 1 addition & 1 deletion lib/neighbor/model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def self.neighbor_attributes
else
attribute_names.each do |attribute_name|
attribute attribute_name do |cast_type|
raise "todo"
Neighbor::NormalizedAttribute.new(cast_type: cast_type, model: self, attribute_name: attribute_name)
end
end
end
Expand Down
21 changes: 21 additions & 0 deletions lib/neighbor/normalized_attribute.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module Neighbor
class NormalizedAttribute < ActiveRecord::Type::Value
delegate :type, :serialize, :deserialize, to: :@cast_type

def initialize(cast_type:, model:, attribute_name:)
@cast_type = cast_type
@model = model
@attribute_name = attribute_name.to_s
end

def cast(value)
Neighbor::Utils.normalize(@cast_type.cast(value), column_info: @model.columns_hash[@attribute_name])
end

private

def cast_value(...)
@cast_type.send(:cast_value, ...)
end
end
end
12 changes: 10 additions & 2 deletions test/cube_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,19 @@ def test_normalize

def test_insert
CosineItem.insert!({cube_embedding: [0, 3, 4]})
assert_elements_in_delta [0, 0.6, 0.8], Item.last.cube_embedding
if supports_normalizes?
assert_elements_in_delta [0, 0.6, 0.8], Item.last.cube_embedding
else
assert_elements_in_delta [0, 3, 4], Item.last.cube_embedding
end
end

def test_insert_all
CosineItem.insert_all!([{cube_embedding: [0, 3, 4]}])
assert_elements_in_delta [0, 0.6, 0.8], Item.last.cube_embedding
if supports_normalizes?
assert_elements_in_delta [0, 0.6, 0.8], Item.last.cube_embedding
else
assert_elements_in_delta [0, 3, 4], Item.last.cube_embedding
end
end
end
12 changes: 10 additions & 2 deletions test/mariadb_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,19 @@ def test_normalize

def test_insert
MariadbCosineItem.insert!({embedding: [0, 3, 4]})
assert_elements_in_delta [0, 0.6, 0.8], MariadbItem.last.embedding
if supports_normalizes?
assert_elements_in_delta [0, 0.6, 0.8], MariadbItem.last.embedding
else
assert_elements_in_delta [0, 3, 4], MariadbItem.last.embedding
end
end

def test_insert_all
MariadbCosineItem.insert_all!([{embedding: [0, 3, 4]}])
assert_elements_in_delta [0, 0.6, 0.8], MariadbItem.last.embedding
if supports_normalizes?
assert_elements_in_delta [0, 0.6, 0.8], MariadbItem.last.embedding
else
assert_elements_in_delta [0, 3, 4], MariadbItem.last.embedding
end
end
end
4 changes: 4 additions & 0 deletions test/test_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,8 @@ def assert_index_scan(relation)
assert_match "Index Scan", relation.limit(5).explain.inspect
end
end

def supports_normalizes?
ActiveRecord::VERSION::STRING.to_f >= 7.1
end
end

0 comments on commit b0a0c8c

Please sign in to comment.