From 13397911e207d6f7f7ef134a0bb19cee2cec0f3c Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 9 Sep 2024 20:49:13 -0700 Subject: [PATCH] Added SparseVector class --- CHANGELOG.md | 2 +- lib/pgvector.dart | 2 ++ lib/sparsevec.dart | 28 ++++++++++++++++++++++++++++ test/postgres_test.dart | 9 ++++++--- 4 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 lib/sparsevec.dart diff --git a/CHANGELOG.md b/CHANGELOG.md index e2328df..571e279 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ ## 0.1.1 (unreleased) -- Added `Vector` and `HalfVector` classes +- Added `Vector`, `HalfVector`, and `SparseVector` classes ## 0.1.0 (2023-10-17) diff --git a/lib/pgvector.dart b/lib/pgvector.dart index 9b28eea..13c319d 100644 --- a/lib/pgvector.dart +++ b/lib/pgvector.dart @@ -2,9 +2,11 @@ import 'dart:convert'; import 'dart:typed_data'; import 'halfvec.dart'; +import 'sparsevec.dart'; import 'vector.dart'; export 'halfvec.dart' show HalfVector; +export 'sparsevec.dart' show SparseVector; export 'vector.dart' show Vector; class Pgvector { diff --git a/lib/sparsevec.dart b/lib/sparsevec.dart new file mode 100644 index 0000000..2d9de22 --- /dev/null +++ b/lib/sparsevec.dart @@ -0,0 +1,28 @@ +class SparseVector { + final int dimensions; + final List indices; + final List values; + + SparseVector._(this.dimensions, this.indices, this.values); + + factory SparseVector(List value) { + var dimensions = value.length; + var indices = []; + var values = []; + for (var i = 0; i < value.length; i++) { + if (value[i] != 0) { + indices.add(i); + values.add(value[i]); + } + } + return SparseVector._(dimensions, indices, values); + } + + @override + String toString() { + var elements = [ + for (var i = 0; i < indices.length; i++) "${indices[i] + 1}:${values[i]}" + ].join(","); + return "{${elements}}/${this.dimensions}"; + } +} diff --git a/test/postgres_test.dart b/test/postgres_test.dart index 34407ef..8e5c294 100644 --- a/test/postgres_test.dart +++ b/test/postgres_test.dart @@ -17,18 +17,21 @@ void main() { await connection.execute("DROP TABLE IF EXISTS items"); await connection.execute( - "CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3))"); + "CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), sparse_embedding sparsevec(3))"); await connection.execute( Sql.named( - "INSERT INTO items (embedding, half_embedding) VALUES (@a, @d), (@b, @e), (@c, @f)"), + "INSERT INTO items (embedding, half_embedding, sparse_embedding) VALUES (@a, @d, @g), (@b, @e, @h), (@c, @f, @i)"), parameters: { "a": Vector([1, 1, 1]).toString(), "b": Vector([2, 2, 2]).toString(), "c": Vector([1, 1, 2]).toString(), "d": HalfVector([1, 1, 1]).toString(), "e": HalfVector([2, 2, 2]).toString(), - "f": HalfVector([1, 1, 2]).toString() + "f": HalfVector([1, 1, 2]).toString(), + "g": SparseVector([1, 1, 1]).toString(), + "h": SparseVector([2, 2, 2]).toString(), + "i": SparseVector([1, 1, 2]).toString() }); List> results = await connection.execute(