diff --git a/pgx_test.go b/pgx_test.go index 76e7e39..56221e1 100644 --- a/pgx_test.go +++ b/pgx_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/pgvector/pgvector-go" pgxvector "github.com/pgvector/pgvector-go/pgx" ) @@ -70,23 +71,25 @@ func TestPgx(t *testing.T) { CreatePgxItems(ctx, conn) - rows, err := conn.Query(ctx, "SELECT id, embedding, half_embedding, sparse_embedding, embedding <-> $1 FROM pgx_items ORDER BY embedding <-> $1 LIMIT 5", pgvector.NewVector([]float32{1, 1, 1})) + rows, err := conn.Query(ctx, "SELECT id, embedding, half_embedding, binary_embedding, sparse_embedding, embedding <-> $1 FROM pgx_items ORDER BY embedding <-> $1 LIMIT 5", pgvector.NewVector([]float32{1, 1, 1})) if err != nil { panic(err) } defer rows.Close() var items []PgxItem + var binaryEmbeddings []pgtype.Bits var distances []float64 for rows.Next() { var item PgxItem + var binaryEmbedding pgtype.Bits var distance float64 - // TODO scan BinaryEmbedding - err = rows.Scan(&item.Id, &item.Embedding, &item.HalfEmbedding, &item.SparseEmbedding, &distance) + err = rows.Scan(&item.Id, &item.Embedding, &item.HalfEmbedding, &binaryEmbedding, &item.SparseEmbedding, &distance) if err != nil { panic(err) } items = append(items, item) + binaryEmbeddings = append(binaryEmbeddings, binaryEmbedding) distances = append(distances, distance) } @@ -103,6 +106,9 @@ func TestPgx(t *testing.T) { if !reflect.DeepEqual(items[1].HalfEmbedding.Slice(), []float32{1, 1, 2}) { t.Error() } + if binaryEmbeddings[1].Bytes[0] != (7<<5) || binaryEmbeddings[1].Len != 3 { + t.Error() + } if !reflect.DeepEqual(items[1].SparseEmbedding.Slice(), []float32{1, 1, 2}) { t.Error() }