diff --git a/lib/sparsevec.dart b/lib/sparsevec.dart index 2d9de22..e5725fe 100644 --- a/lib/sparsevec.dart +++ b/lib/sparsevec.dart @@ -1,3 +1,5 @@ +import 'dart:typed_data'; + class SparseVector { final int dimensions; final List indices; @@ -18,6 +20,29 @@ class SparseVector { return SparseVector._(dimensions, indices, values); } + factory SparseVector.fromBinary(Uint8List bytes) { + var bdata = new ByteData.view(bytes.buffer, bytes.offsetInBytes); + var dimensions = bdata.getInt32(0); + var nnz = bdata.getInt32(4); + + var unused = bdata.getInt32(8); + if (unused != 0) { + throw FormatException('expected unused to be 0'); + } + + var indices = []; + for (var i = 0; i < nnz; i++) { + indices.add(bdata.getInt32(12 + i * 4)); + } + + var values = []; + for (var i = 0; i < nnz; i++) { + values.add(bdata.getFloat32(12 + 4 * nnz + i * 4)); + } + + return SparseVector._(dimensions, indices, values); + } + @override String toString() { var elements = [ diff --git a/test/postgres_test.dart b/test/postgres_test.dart index 8e5c294..1f36b23 100644 --- a/test/postgres_test.dart +++ b/test/postgres_test.dart @@ -36,13 +36,14 @@ void main() { List> results = await connection.execute( Sql.named( - "SELECT id, embedding FROM items ORDER BY embedding <-> @embedding LIMIT 5"), + "SELECT id, embedding, sparse_embedding FROM items ORDER BY embedding <-> @embedding LIMIT 5"), parameters: { "embedding": Vector([1, 1, 1]).toString() }); for (final row in results) { print(row[0]); print(Vector.fromBinary(row[1].bytes)); + print(SparseVector.fromBinary(row[2].bytes)); } await connection