diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ee2225..c5cb3e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.2.1 (unreleased) - Added `EncodeBinary` and `DecodeBinary` methods to `Vector` +- Added `EncodeBinary` and `DecodeBinary` methods to `SparseVector` ## 0.2.0 (2024-06-25) diff --git a/sparsevec.go b/sparsevec.go index 2346c5f..e6565cb 100644 --- a/sparsevec.go +++ b/sparsevec.go @@ -3,7 +3,9 @@ package pgvector import ( "database/sql" "database/sql/driver" + "encoding/binary" "fmt" + "math" "strconv" "strings" ) @@ -120,6 +122,47 @@ func (v *SparseVector) Parse(s string) error { return nil } +// EncodeBinary encodes a binary representation of a sparse vector. +func (v SparseVector) EncodeBinary(buf []byte) (newBuf []byte, err error) { + buf = binary.BigEndian.AppendUint32(buf, uint32(v.dim)) + buf = binary.BigEndian.AppendUint32(buf, uint32(len(v.indices))) + buf = binary.BigEndian.AppendUint32(buf, 0) + for _, v := range v.indices { + buf = binary.BigEndian.AppendUint32(buf, uint32(v)) + } + for _, v := range v.values { + buf = binary.BigEndian.AppendUint32(buf, math.Float32bits(v)) + } + return buf, nil +} + +// DecodeBinary decodes a binary representation of a sparse vector. +func (v *SparseVector) DecodeBinary(buf []byte) error { + dim := int(binary.BigEndian.Uint32(buf[0:4])) + nnz := int(binary.BigEndian.Uint32(buf[4:8])) + + unused := int(binary.BigEndian.Uint32(buf[8:12])) + if unused != 0 { + return fmt.Errorf("expected unused to be 0") + } + + v.dim = int32(dim) + v.indices = make([]int32, 0, dim) + v.values = make([]float32, 0, dim) + + for i := 0; i < nnz; i++ { + offset := 12 + 4*i + v.indices = append(v.indices, int32(binary.BigEndian.Uint32(buf[offset:offset+4]))) + } + + for i := 0; i < nnz; i++ { + offset := 12 + 4*nnz + 4*i + v.values = append(v.values, math.Float32frombits(binary.BigEndian.Uint32(buf[offset:offset+4]))) + } + + return nil +} + // statically assert that SparseVector implements sql.Scanner. var _ sql.Scanner = (*SparseVector)(nil)