From 4f82c59e3d1a96be82c4a62b063c803fa181302e Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 9 Sep 2024 23:04:40 -0700 Subject: [PATCH] Added == and hashCode to type classes --- lib/src/halfvec.dart | 9 +++++++++ lib/src/sparsevec.dart | 11 +++++++++++ lib/src/utils.dart | 13 +++++++++++++ lib/src/vector.dart | 8 ++++++++ test/halfvec_test.dart | 9 +++++++++ test/postgres_test.dart | 6 +++--- test/sparsevec_test.dart | 9 +++++++++ test/vector_test.dart | 9 +++++++++ 8 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 lib/src/utils.dart diff --git a/lib/src/halfvec.dart b/lib/src/halfvec.dart index 85c7afc..1ab798d 100644 --- a/lib/src/halfvec.dart +++ b/lib/src/halfvec.dart @@ -1,3 +1,5 @@ +import 'utils.dart'; + class HalfVector { final List _vec; @@ -11,4 +13,11 @@ class HalfVector { String toString() { return _vec.toString(); } + + @override + bool operator ==(Object other) => + other is HalfVector && listEquals(other._vec, _vec); + + @override + int get hashCode => _vec.hashCode; } diff --git a/lib/src/sparsevec.dart b/lib/src/sparsevec.dart index 1d8bb42..27d137c 100644 --- a/lib/src/sparsevec.dart +++ b/lib/src/sparsevec.dart @@ -1,4 +1,5 @@ import 'dart:typed_data'; +import 'utils.dart'; class SparseVector { final int dimensions; @@ -70,4 +71,14 @@ class SparseVector { ].join(','); return '{${elements}}/${dimensions}'; } + + @override + bool operator ==(Object other) => + other is SparseVector && + other.dimensions == dimensions && + listEquals(other.indices, indices) && + listEquals(other.values, values); + + @override + int get hashCode => Object.hash(dimensions, indices, values); } diff --git a/lib/src/utils.dart b/lib/src/utils.dart new file mode 100644 index 0000000..bb9c186 --- /dev/null +++ b/lib/src/utils.dart @@ -0,0 +1,13 @@ +bool listEquals(List a, List b) { + if (a.length != b.length) { + return false; + } + + for (var i = 0; i < a.length; i++) { + if (a[i] != b[i]) { + return false; + } + } + + return true; +} diff --git a/lib/src/vector.dart b/lib/src/vector.dart index e1b8bda..f9518f1 100644 --- a/lib/src/vector.dart +++ b/lib/src/vector.dart @@ -1,4 +1,5 @@ import 'dart:typed_data'; +import 'utils.dart'; class Vector { final List _vec; @@ -30,4 +31,11 @@ class Vector { String toString() { return _vec.toString(); } + + @override + bool operator ==(Object other) => + other is Vector && listEquals(other._vec, _vec); + + @override + int get hashCode => _vec.hashCode; } diff --git a/test/halfvec_test.dart b/test/halfvec_test.dart index 45279f8..8f6b9e0 100644 --- a/test/halfvec_test.dart +++ b/test/halfvec_test.dart @@ -7,4 +7,13 @@ void main() { expect(vec.toString(), equals('[1.0, 2.0, 3.0]')); expect(vec.toList(), equals([1, 2, 3])); }); + + test('equals', () { + var a = HalfVector([1, 2, 3]); + var b = HalfVector([1, 2, 3]); + var c = HalfVector([1, 2, 4]); + + expect(a, equals(b)); + expect(a, isNot(equals(c))); + }); } diff --git a/test/postgres_test.dart b/test/postgres_test.dart index c17e648..12dc1ce 100644 --- a/test/postgres_test.dart +++ b/test/postgres_test.dart @@ -44,9 +44,9 @@ void main() { "embedding": Vector([1, 1, 1]).toString() }); expect(results.map((r) => r[0]), equals([1, 3, 2])); - expect(Vector.fromBinary(results[1][1].bytes).toList(), equals([1, 1, 2])); - expect(SparseVector.fromBinary(results[1][2].bytes).toList(), - equals([1, 1, 2])); + expect(Vector.fromBinary(results[1][1].bytes), equals(Vector([1, 1, 2]))); + expect(SparseVector.fromBinary(results[1][2].bytes), + equals(SparseVector([1, 1, 2]))); await connection .execute("CREATE INDEX ON items USING hnsw (embedding vector_l2_ops)"); diff --git a/test/sparsevec_test.dart b/test/sparsevec_test.dart index 2016f1d..af8cbbf 100644 --- a/test/sparsevec_test.dart +++ b/test/sparsevec_test.dart @@ -19,4 +19,13 @@ void main() { expect(vec.indices, equals([0, 2, 4])); expect(vec.values, equals([1, 2, 3])); }); + + test('equals', () { + var a = SparseVector([1, 2, 3]); + var b = SparseVector([1, 2, 3]); + var c = SparseVector([1, 2, 4]); + + expect(a, equals(b)); + expect(a, isNot(equals(c))); + }); } diff --git a/test/vector_test.dart b/test/vector_test.dart index c7e3809..a7618ec 100644 --- a/test/vector_test.dart +++ b/test/vector_test.dart @@ -7,4 +7,13 @@ void main() { expect(vec.toString(), equals('[1.0, 2.0, 3.0]')); expect(vec.toList(), equals([1, 2, 3])); }); + + test('equals', () { + var a = Vector([1, 2, 3]); + var b = Vector([1, 2, 3]); + var c = Vector([1, 2, 4]); + + expect(a, equals(b)); + expect(a, isNot(equals(c))); + }); }