From 13c9ba81cb04c02d144f927e26065869712a7a9e Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 20 Jun 2024 17:31:02 -0700 Subject: [PATCH] Added more functions to SparseVector --- sparsevec.go | 15 +++++++++++++++ sparsevec_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/sparsevec.go b/sparsevec.go index 808fe97..e574603 100644 --- a/sparsevec.go +++ b/sparsevec.go @@ -42,6 +42,21 @@ func NewSparseVectorFromMap(elements map[int32]float32, dim int32) SparseVector return SparseVector{dim: dim, indices: indices, values: values} } +// Dimensions returns the number of dimensions. +func (v SparseVector) Dimensions() int32 { + return v.dim +} + +// Indices returns the non-zero indices. +func (v SparseVector) Indices() []int32 { + return v.indices +} + +// Values returns the non-zero values. +func (v SparseVector) Values() []float32 { + return v.values +} + // Slice returns a slice of float32. func (v SparseVector) Slice() []float32 { vec := make([]float32, v.dim) diff --git a/sparsevec_test.go b/sparsevec_test.go index 8512e56..dcbf8e4 100644 --- a/sparsevec_test.go +++ b/sparsevec_test.go @@ -8,6 +8,13 @@ import ( "github.com/pgvector/pgvector-go" ) +func TestNewSparseVector(t *testing.T) { + vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0}) + if !reflect.DeepEqual(vec.Slice(), []float32{1, 0, 2, 0, 3, 0}) { + t.Error() + } +} + func TestNewSparseVectorFromMap(t *testing.T) { vec := pgvector.NewSparseVectorFromMap(map[int32]float32{0: 1, 2: 2, 4: 3}, 6) if !reflect.DeepEqual(vec.Slice(), []float32{1, 0, 2, 0, 3, 0}) { @@ -15,6 +22,27 @@ func TestNewSparseVectorFromMap(t *testing.T) { } } +func TestDimensions(t *testing.T) { + vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0}) + if vec.Dimensions() != 6 { + t.Error() + } +} + +func TestIndices(t *testing.T) { + vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0}) + if !reflect.DeepEqual(vec.Indices(), []int32{0, 2, 4}) { + t.Error() + } +} + +func TestValues(t *testing.T) { + vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0}) + if !reflect.DeepEqual(vec.Values(), []float32{1, 2, 3}) { + t.Error() + } +} + func TestSparseVectorSlice(t *testing.T) { vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0}) if !reflect.DeepEqual(vec.Slice(), []float32{1, 0, 2, 0, 3, 0}) {