From 3e0197a13c7f69a3996612904c2496557c6220c2 Mon Sep 17 00:00:00 2001 From: Likith B Date: Fri, 2 Feb 2024 11:06:28 +0530 Subject: [PATCH 1/8] MB-59616: Adding vector_base64 field - Added a new field type called vector_base64. - Acts similar to vector in most cases. - When a new document arrives in the bleve layer, during the parsing of all its fields in processProperty, if the field mapping type is vector-base64, then its value is decoded into a vector field and processed like a vector. - The standard golang base64 library is used for the decode operation. --- document/field_vector_base64.go | 128 +++++++++++++++++++++++ mapping/document.go | 2 + mapping/index.go | 1 - mapping/mapping_no_vectors.go | 9 ++ mapping/mapping_vectors.go | 47 ++++++++- mapping_vector.go | 4 + search_knn_test.go | 179 ++++++++++++++++++++++++++++++++ 7 files changed, 368 insertions(+), 2 deletions(-) create mode 100644 document/field_vector_base64.go diff --git a/document/field_vector_base64.go b/document/field_vector_base64.go new file mode 100644 index 000000000..a55c4758b --- /dev/null +++ b/document/field_vector_base64.go @@ -0,0 +1,128 @@ +// Copyright (c) 2024 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package document + +import ( + "encoding/base64" + "encoding/json" + "fmt" + + index "github.com/blevesearch/bleve_index_api" +) + +type VectorBase64Field struct { + vectorField *VectorField + encodedValue string +} + +func (n *VectorBase64Field) Size() int { + return n.vectorField.Size() +} + +func (n *VectorBase64Field) Name() string { + return n.vectorField.Name() +} + +func (n *VectorBase64Field) ArrayPositions() []uint64 { + return n.vectorField.ArrayPositions() +} + +func (n *VectorBase64Field) Options() index.FieldIndexingOptions { + return n.vectorField.Options() +} + +func (n *VectorBase64Field) NumPlainTextBytes() uint64 { + return n.vectorField.NumPlainTextBytes() +} + +func (n *VectorBase64Field) AnalyzedLength() int { + return n.vectorField.AnalyzedLength() +} + +func (n *VectorBase64Field) EncodedFieldType() byte { + return 'e' // CHECK +} + +func (n *VectorBase64Field) AnalyzedTokenFrequencies() index.TokenFrequencies { + return n.vectorField.AnalyzedTokenFrequencies() +} + +func (n *VectorBase64Field) Analyze() { + // CHECK +} + +func (n *VectorBase64Field) Value() []byte { + return n.vectorField.Value() +} + +func (n *VectorBase64Field) GoString() string { + return fmt.Sprintf("&document.vectorFieldBase64Field{Name:%s, Options: %s, "+ + "Value: %+v}", n.vectorField.Name(), n.vectorField.Options(), n.vectorField.Value()) +} + +// For the sake of not polluting the API, we are keeping arrayPositions as a +// parameter, but it is not used. +func NewVectorBase64Field(name string, arrayPositions []uint64, encodedValue string, + dims int, similarity, vectorIndexOptimizedFor string) (*VectorBase64Field, error) { + + vector, err := decodeVector(encodedValue) + if err != nil { + return nil, err + } + + return &VectorBase64Field{ + vectorField: NewVectorFieldWithIndexingOptions(name, arrayPositions, + vector, dims, similarity, + vectorIndexOptimizedFor, DefaultVectorIndexingOptions), + + encodedValue: encodedValue, + }, nil +} + +func decodeVector(encodedValue string) ([]float32, error) { + decodedString, err := base64.StdEncoding.DecodeString(encodedValue) + if err != nil { + fmt.Println("Error decoding string:", err) + return nil, err + } + + var decodedVector []float32 + err = json.Unmarshal(decodedString, decodedVector) + if err != nil { + fmt.Println("Error decoding string:", err) + return nil, err + } + + return decodedVector, nil +} + +func (n *VectorBase64Field) Vector() []float32 { + return n.vectorField.Vector() +} + +func (n *VectorBase64Field) Dims() int { + return n.vectorField.Dims() +} + +func (n *VectorBase64Field) Similarity() string { + return n.vectorField.Similarity() +} + +func (n *VectorBase64Field) IndexOptimizedFor() string { + return n.vectorField.IndexOptimizedFor() +} diff --git a/mapping/document.go b/mapping/document.go index 73bb124db..fe72f0802 100644 --- a/mapping/document.go +++ b/mapping/document.go @@ -443,6 +443,8 @@ func (dm *DocumentMapping) processProperty(property interface{}, path []string, fieldMapping.processGeoShape(property, pathString, path, indexes, context) } else if fieldMapping.Type == "geopoint" { fieldMapping.processGeoPoint(property, pathString, path, indexes, context) + } else if fieldMapping.Type == "vector-base64" { + fieldMapping.processVectorBase64(property, pathString, path, indexes, context) } else { fieldMapping.processString(propertyValueString, pathString, path, indexes, context) } diff --git a/mapping/index.go b/mapping/index.go index 171ee1a72..6ff229ae8 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -320,7 +320,6 @@ func (im *IndexMappingImpl) determineType(data interface{}) string { return im.DefaultType } - func (im *IndexMappingImpl) MapDocument(doc *document.Document, data interface{}) error { docType := im.determineType(data) docMapping := im.mappingForType(docType) diff --git a/mapping/mapping_no_vectors.go b/mapping/mapping_no_vectors.go index f9f35f57c..90cb1e225 100644 --- a/mapping/mapping_no_vectors.go +++ b/mapping/mapping_no_vectors.go @@ -21,11 +21,20 @@ func NewVectorFieldMapping() *FieldMapping { return nil } +func NewVectorBase64FieldMapping() *FieldMapping { + return nil +} + func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, pathString string, path []string, indexes []uint64, context *walkContext) bool { return false } +func (fm *FieldMapping) processVectorBase64(propertyMightBeVector interface{}, + pathString string, path []string, indexes []uint64, context *walkContext) { + +} + // ----------------------------------------------------------------------------- // document validation functions diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index a0b712608..640dcf5e8 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -18,6 +18,8 @@ package mapping import ( + "encoding/base64" + "encoding/json" "fmt" "reflect" @@ -43,6 +45,17 @@ func NewVectorFieldMapping() *FieldMapping { } } +func NewVectorBase64FieldMapping() *FieldMapping { + return &FieldMapping{ + Type: "vector-base64", + Store: false, + Index: true, + IncludeInAll: false, + DocValues: false, + SkipFreqNorm: true, + } +} + // validate and process a flat vector func processFlatVector(vecV reflect.Value, dims int) ([]float32, bool) { if vecV.Len() != dims { @@ -121,6 +134,27 @@ func processVector(vecI interface{}, dims int) ([]float32, bool) { return rv, true } +func processVectorBase64(vecBase64 interface{}) (interface{}, bool) { + + vecEncoded, ok := vecBase64.(string) + if !ok { + return nil, false + } + + vecData, err := base64.StdEncoding.DecodeString(vecEncoded) + if err != nil { + return nil, false + } + + var vector interface{} + err = json.Unmarshal(vecData, &vector) + if err != nil { + return nil, false + } + + return vector, true +} + func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, pathString string, path []string, indexes []uint64, context *walkContext) bool { vector, ok := processVector(propertyMightBeVector, fm.Dims) @@ -140,13 +174,24 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, return true } +func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interface{}, + pathString string, path []string, indexes []uint64, context *walkContext) { + + propertyMightBeVector, ok := processVectorBase64(propertyMightBeVectorBase64) + if !ok { + return + } + + fm.processVector(propertyMightBeVector, pathString, path, indexes, context) +} + // ----------------------------------------------------------------------------- // document validation functions func validateFieldMapping(field *FieldMapping, parentName string, fieldAliasCtx map[string]*FieldMapping) error { switch field.Type { - case "vector": + case "vector", "vector-base64": return validateVectorFieldAlias(field, parentName, fieldAliasCtx) default: // non-vector field return validateFieldType(field) diff --git a/mapping_vector.go b/mapping_vector.go index 594313861..c73dac9e5 100644 --- a/mapping_vector.go +++ b/mapping_vector.go @@ -22,3 +22,7 @@ import "github.com/blevesearch/bleve/v2/mapping" func NewVectorFieldMapping() *mapping.FieldMapping { return mapping.NewVectorFieldMapping() } + +func NewVectorBase64FieldMapping() *mapping.FieldMapping { + return mapping.NewVectorBase64FieldMapping() +} diff --git a/search_knn_test.go b/search_knn_test.go index b54ce5a93..3da8bda09 100644 --- a/search_knn_test.go +++ b/search_knn_test.go @@ -19,6 +19,7 @@ package bleve import ( "archive/zip" + "encoding/base64" "encoding/json" "fmt" "math" @@ -397,6 +398,162 @@ func min(a, b int) int { return b } +func TestVectorBase64Index(t *testing.T) { + + dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents := makeDatasetIntoDocuments(dataset) + + _, searchRequestsCopy, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + + err = encodeVectors(documents) + if err != nil { + t.Fatal(err) + } + + modifySearchRequests(searchRequestsCopy) + + contentFM := NewTextFieldMapping() + contentFM.Analyzer = en.AnalyzerName + + vecFML2 := mapping.NewVectorFieldMapping() + vecFML2.Dims = testDatasetDims + vecFML2.Similarity = index.EuclideanDistance + + vecBFML2 := mapping.NewVectorBase64FieldMapping() + vecBFML2.Dims = testDatasetDims + vecBFML2.Similarity = index.EuclideanDistance + + vecFMDot := mapping.NewVectorFieldMapping() + vecFMDot.Dims = testDatasetDims + vecFMDot.Similarity = index.CosineSimilarity + + vecBFMDot := mapping.NewVectorBase64FieldMapping() + vecBFMDot.Dims = testDatasetDims + vecBFMDot.Similarity = index.CosineSimilarity + + indexMappingL2 := NewIndexMapping() + indexMappingL2.DefaultMapping.AddFieldMappingsAt("content", contentFM) + indexMappingL2.DefaultMapping.AddFieldMappingsAt("vector", vecFML2) + indexMappingL2.DefaultMapping.AddFieldMappingsAt("vectorEncoded", vecBFML2) + + indexMappingDot := NewIndexMapping() + indexMappingDot.DefaultMapping.AddFieldMappingsAt("content", contentFM) + indexMappingDot.DefaultMapping.AddFieldMappingsAt("vector", vecFMDot) + indexMappingDot.DefaultMapping.AddFieldMappingsAt("vectorEncoded", vecBFMDot) + + tmpIndexPathL2 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPathL2) + + tmpIndexPathDot := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPathDot) + + indexL2, err := New(tmpIndexPathL2, indexMappingL2) + if err != nil { + t.Fatal(err) + } + defer func() { + err := indexL2.Close() + if err != nil { + t.Fatal(err) + } + }() + + indexDot, err := New(tmpIndexPathDot, indexMappingDot) + if err != nil { + t.Fatal(err) + } + defer func() { + err := indexDot.Close() + if err != nil { + t.Fatal(err) + } + }() + + batchL2 := indexL2.NewBatch() + batchDot := indexDot.NewBatch() + + for _, doc := range documents { + err = batchL2.Index(doc["id"].(string), doc) + if err != nil { + t.Fatal(err) + } + err = batchDot.Index(doc["id"].(string), doc) + if err != nil { + t.Fatal(err) + } + } + + err = indexL2.Batch(batchL2) + if err != nil { + t.Fatal(err) + } + + err = indexDot.Batch(batchDot) + if err != nil { + t.Fatal(err) + } + + for i, _ := range searchRequests { + for _, operator := range knnOperators { + controlQuery := searchRequests[i] + testQuery := searchRequestsCopy[i] + + controlQuery.AddKNNOperator(operator) + testQuery.AddKNNOperator(operator) + + controlResultL2, err := indexL2.Search(controlQuery) + if err != nil { + t.Fatal(err) + } + testResultL2, err := indexL2.Search(testQuery) + if err != nil { + t.Fatal(err) + } + + if controlResultL2 != nil && testResultL2 != nil { + if len(controlResultL2.Hits) == len(testResultL2.Hits) { + for j, _ := range controlResultL2.Hits { + if controlResultL2.Hits[j].ID != testResultL2.Hits[j].ID { + t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, controlResultL2.Hits[j].ID, testResultL2.Hits[j].ID) + } + } + } + } else if (controlResultL2 == nil && testResultL2 != nil) || + (controlResultL2 != nil && testResultL2 == nil) { + t.Fatalf("testcase %d failed: expected result %s, got result %s", i, controlResultL2, testResultL2) + } + + controlResultDot, err := indexDot.Search(controlQuery) + if err != nil { + t.Fatal(err) + } + testResultDot, err := indexDot.Search(testQuery) + if err != nil { + t.Fatal(err) + } + + if controlResultDot != nil && testResultDot != nil { + if len(controlResultDot.Hits) == len(testResultDot.Hits) { + for j, _ := range controlResultDot.Hits { + if controlResultDot.Hits[j].ID != testResultDot.Hits[j].ID { + t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, controlResultDot.Hits[j].ID, testResultDot.Hits[j].ID) + } + } + } + } else if (controlResultDot == nil && testResultDot != nil) || + (controlResultDot != nil && testResultDot == nil) { + t.Fatalf("testcase %d failed: expected result %s, got result %s", i, controlResultDot, testResultDot) + } + } + } +} + type testDocument struct { ID string `json:"id"` Content string `json:"content"` @@ -434,6 +591,28 @@ func readDatasetAndQueries(fileName string) ([]testDocument, []*SearchRequest, e return dataset, queries, nil } +func encodeVectors(docs []map[string]interface{}) error { + + for _, doc := range docs { + vec, err := json.Marshal(doc["vector"]) + if err != nil { + return err + } + doc["vectorEncoded"] = base64.StdEncoding.EncodeToString(vec) + } + + return nil +} + +func modifySearchRequests(srs []*SearchRequest) { + + for _, sr := range srs { + for _, kr := range sr.KNN { + kr.Field = "vectorEncoded" + } + } +} + func makeDatasetIntoDocuments(dataset []testDocument) []map[string]interface{} { documents := make([]map[string]interface{}, len(dataset)) for i := 0; i < len(dataset); i++ { From d84e8249d725d97583b04db0867040a3c9edaac3 Mon Sep 17 00:00:00 2001 From: Likith B Date: Wed, 27 Mar 2024 15:09:21 +0530 Subject: [PATCH 2/8] MB-59616: VectorBase64 Queries - Added VectorBase64 of type string to the KNNRequest struct - ValidateKNN will handle the decoding of the encoded string and fill the vector field with the vector value --- document/field_vector_base64.go | 6 +++--- mapping/document.go | 2 +- mapping/index.go | 1 + mapping/mapping_vectors.go | 4 ++-- search_knn.go | 20 ++++++++++++++++---- 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/document/field_vector_base64.go b/document/field_vector_base64.go index a55c4758b..be844eb66 100644 --- a/document/field_vector_base64.go +++ b/document/field_vector_base64.go @@ -80,7 +80,7 @@ func (n *VectorBase64Field) GoString() string { func NewVectorBase64Field(name string, arrayPositions []uint64, encodedValue string, dims int, similarity, vectorIndexOptimizedFor string) (*VectorBase64Field, error) { - vector, err := decodeVector(encodedValue) + vector, err := DecodeVector(encodedValue) if err != nil { return nil, err } @@ -94,7 +94,7 @@ func NewVectorBase64Field(name string, arrayPositions []uint64, encodedValue str }, nil } -func decodeVector(encodedValue string) ([]float32, error) { +func DecodeVector(encodedValue string) ([]float32, error) { decodedString, err := base64.StdEncoding.DecodeString(encodedValue) if err != nil { fmt.Println("Error decoding string:", err) @@ -102,7 +102,7 @@ func decodeVector(encodedValue string) ([]float32, error) { } var decodedVector []float32 - err = json.Unmarshal(decodedString, decodedVector) + err = json.Unmarshal(decodedString, &decodedVector) if err != nil { fmt.Println("Error decoding string:", err) return nil, err diff --git a/mapping/document.go b/mapping/document.go index fe72f0802..3131f33bf 100644 --- a/mapping/document.go +++ b/mapping/document.go @@ -443,7 +443,7 @@ func (dm *DocumentMapping) processProperty(property interface{}, path []string, fieldMapping.processGeoShape(property, pathString, path, indexes, context) } else if fieldMapping.Type == "geopoint" { fieldMapping.processGeoPoint(property, pathString, path, indexes, context) - } else if fieldMapping.Type == "vector-base64" { + } else if fieldMapping.Type == "vector_base64" { fieldMapping.processVectorBase64(property, pathString, path, indexes, context) } else { fieldMapping.processString(propertyValueString, pathString, path, indexes, context) diff --git a/mapping/index.go b/mapping/index.go index 6ff229ae8..171ee1a72 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -320,6 +320,7 @@ func (im *IndexMappingImpl) determineType(data interface{}) string { return im.DefaultType } + func (im *IndexMappingImpl) MapDocument(doc *document.Document, data interface{}) error { docType := im.determineType(data) docMapping := im.mappingForType(docType) diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index 640dcf5e8..789c7797b 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -47,7 +47,7 @@ func NewVectorFieldMapping() *FieldMapping { func NewVectorBase64FieldMapping() *FieldMapping { return &FieldMapping{ - Type: "vector-base64", + Type: "vector_base64", Store: false, Index: true, IncludeInAll: false, @@ -191,7 +191,7 @@ func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interfac func validateFieldMapping(field *FieldMapping, parentName string, fieldAliasCtx map[string]*FieldMapping) error { switch field.Type { - case "vector", "vector-base64": + case "vector", "vector_base64": return validateVectorFieldAlias(field, parentName, fieldAliasCtx) default: // non-vector field return validateFieldType(field) diff --git a/search_knn.go b/search_knn.go index 683771418..28b62e550 100644 --- a/search_knn.go +++ b/search_knn.go @@ -23,6 +23,7 @@ import ( "fmt" "sort" + "github.com/blevesearch/bleve/v2/document" "github.com/blevesearch/bleve/v2/search" "github.com/blevesearch/bleve/v2/search/collector" "github.com/blevesearch/bleve/v2/search/query" @@ -67,10 +68,11 @@ type SearchRequest struct { } type KNNRequest struct { - Field string `json:"field"` - Vector []float32 `json:"vector"` - K int64 `json:"k"` - Boost *query.Boost `json:"boost,omitempty"` + Field string `json:"field"` + Vector []float32 `json:"vector"` + VectorBase64 string `json:"vectorbase64"` + K int64 `json:"k"` + Boost *query.Boost `json:"boost,omitempty"` } func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost float64) { @@ -230,6 +232,16 @@ func validateKNN(req *SearchRequest) error { if q == nil { return fmt.Errorf("knn query cannot be nil") } + if q.VectorBase64 != "" { + if q.Vector == nil { + vec, err := document.DecodeVector(q.VectorBase64) + if err != nil { + return err + } + + q.Vector = vec + } + } if q.K <= 0 || len(q.Vector) == 0 { return fmt.Errorf("k must be greater than 0 and vector must be non-empty") } From 297f7020a63847d68a4bc2d49ae62fe2a08c4179 Mon Sep 17 00:00:00 2001 From: Likith B Date: Thu, 11 Apr 2024 12:37:13 +0530 Subject: [PATCH 3/8] Addressing Review Comments --- document/field_vector_base64.go | 13 ++++++----- mapping/mapping_vectors.go | 1 - search_knn.go | 3 ++- search_knn_test.go | 38 ++++++++++----------------------- 4 files changed, 19 insertions(+), 36 deletions(-) diff --git a/document/field_vector_base64.go b/document/field_vector_base64.go index be844eb66..d41688f2a 100644 --- a/document/field_vector_base64.go +++ b/document/field_vector_base64.go @@ -26,8 +26,8 @@ import ( ) type VectorBase64Field struct { - vectorField *VectorField - encodedValue string + vectorField *VectorField + base64Encoding string } func (n *VectorBase64Field) Size() int { @@ -55,7 +55,7 @@ func (n *VectorBase64Field) AnalyzedLength() int { } func (n *VectorBase64Field) EncodedFieldType() byte { - return 'e' // CHECK + return 'e' } func (n *VectorBase64Field) AnalyzedTokenFrequencies() index.TokenFrequencies { @@ -63,7 +63,6 @@ func (n *VectorBase64Field) AnalyzedTokenFrequencies() index.TokenFrequencies { } func (n *VectorBase64Field) Analyze() { - // CHECK } func (n *VectorBase64Field) Value() []byte { @@ -77,10 +76,10 @@ func (n *VectorBase64Field) GoString() string { // For the sake of not polluting the API, we are keeping arrayPositions as a // parameter, but it is not used. -func NewVectorBase64Field(name string, arrayPositions []uint64, encodedValue string, +func NewVectorBase64Field(name string, arrayPositions []uint64, vectorBase64 string, dims int, similarity, vectorIndexOptimizedFor string) (*VectorBase64Field, error) { - vector, err := DecodeVector(encodedValue) + vector, err := DecodeVector(vectorBase64) if err != nil { return nil, err } @@ -90,7 +89,7 @@ func NewVectorBase64Field(name string, arrayPositions []uint64, encodedValue str vector, dims, similarity, vectorIndexOptimizedFor, DefaultVectorIndexingOptions), - encodedValue: encodedValue, + base64Encoding: vectorBase64, }, nil } diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index 789c7797b..d236f8972 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -135,7 +135,6 @@ func processVector(vecI interface{}, dims int) ([]float32, bool) { } func processVectorBase64(vecBase64 interface{}) (interface{}, bool) { - vecEncoded, ok := vecBase64.(string) if !ok { return nil, false diff --git a/search_knn.go b/search_knn.go index 28b62e550..7f43e9988 100644 --- a/search_knn.go +++ b/search_knn.go @@ -67,10 +67,11 @@ type SearchRequest struct { sortFunc func(sort.Interface) } +// Vector takes precedence over vectorBase64 in case both fields are given type KNNRequest struct { Field string `json:"field"` Vector []float32 `json:"vector"` - VectorBase64 string `json:"vectorbase64"` + VectorBase64 string `json:"vector_base64"` K int64 `json:"k"` Boost *query.Boost `json:"boost,omitempty"` } diff --git a/search_knn_test.go b/search_knn_test.go index 3da8bda09..c1629f427 100644 --- a/search_knn_test.go +++ b/search_knn_test.go @@ -399,7 +399,6 @@ func min(a, b int) int { } func TestVectorBase64Index(t *testing.T) { - dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile) if err != nil { t.Fatal(err) @@ -411,12 +410,19 @@ func TestVectorBase64Index(t *testing.T) { t.Fatal(err) } - err = encodeVectors(documents) - if err != nil { - t.Fatal(err) + for _, doc := range documents { + vec, err := json.Marshal(doc["vector"]) + if err != nil { + t.Fatal(err) + } + doc["vectorEncoded"] = base64.StdEncoding.EncodeToString(vec) } - modifySearchRequests(searchRequestsCopy) + for _, sr := range searchRequestsCopy { + for _, kr := range sr.KNN { + kr.Field = "vectorEncoded" + } + } contentFM := NewTextFieldMapping() contentFM.Analyzer = en.AnalyzerName @@ -591,28 +597,6 @@ func readDatasetAndQueries(fileName string) ([]testDocument, []*SearchRequest, e return dataset, queries, nil } -func encodeVectors(docs []map[string]interface{}) error { - - for _, doc := range docs { - vec, err := json.Marshal(doc["vector"]) - if err != nil { - return err - } - doc["vectorEncoded"] = base64.StdEncoding.EncodeToString(vec) - } - - return nil -} - -func modifySearchRequests(srs []*SearchRequest) { - - for _, sr := range srs { - for _, kr := range sr.KNN { - kr.Field = "vectorEncoded" - } - } -} - func makeDatasetIntoDocuments(dataset []testDocument) []map[string]interface{} { documents := make([]map[string]interface{}, len(dataset)) for i := 0; i < len(dataset); i++ { From ec3cfee16a4bacac43a10aabf6e093da9d24ab11 Mon Sep 17 00:00:00 2001 From: Likith B Date: Fri, 12 Apr 2024 11:14:17 +0530 Subject: [PATCH 4/8] Changed decode algorithm --- document/field_vector_base64.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/document/field_vector_base64.go b/document/field_vector_base64.go index d41688f2a..1f05d04ed 100644 --- a/document/field_vector_base64.go +++ b/document/field_vector_base64.go @@ -19,8 +19,9 @@ package document import ( "encoding/base64" - "encoding/json" + "encoding/binary" "fmt" + "math" index "github.com/blevesearch/bleve_index_api" ) @@ -100,11 +101,12 @@ func DecodeVector(encodedValue string) ([]float32, error) { return nil, err } - var decodedVector []float32 - err = json.Unmarshal(decodedString, &decodedVector) - if err != nil { - fmt.Println("Error decoding string:", err) - return nil, err + dims := int(len(decodedString) / 4) + decodedVector := make([]float32, dims) + + for i := 0; i < dims; i++ { + bytes := decodedString[i*4 : (i+1)*4] + decodedVector[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes)) } return decodedVector, nil From 6f7cc986412dac8813e5039e2f72f3b4c3fc1b3b Mon Sep 17 00:00:00 2001 From: Abhinav Dangeti Date: Fri, 12 Apr 2024 11:14:05 -0600 Subject: [PATCH 5/8] Drop empty line --- mapping/mapping_vectors.go | 1 - 1 file changed, 1 deletion(-) diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index d236f8972..b1ce19568 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -175,7 +175,6 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interface{}, pathString string, path []string, indexes []uint64, context *walkContext) { - propertyMightBeVector, ok := processVectorBase64(propertyMightBeVectorBase64) if !ok { return From e65010c4f1b6ba3f3a63fbb6ed0706586630265d Mon Sep 17 00:00:00 2001 From: Likith B Date: Mon, 15 Apr 2024 15:51:30 +0530 Subject: [PATCH 6/8] Addressing Review Comments --- document/field_vector_base64.go | 12 +++++++++++- mapping/mapping_vectors.go | 30 +++++++----------------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/document/field_vector_base64.go b/document/field_vector_base64.go index 1f05d04ed..7ef12bf37 100644 --- a/document/field_vector_base64.go +++ b/document/field_vector_base64.go @@ -94,16 +94,26 @@ func NewVectorBase64Field(name string, arrayPositions []uint64, vectorBase64 str }, nil } +// This function takes a base64 encoded string and decodes it into +// a vector. func DecodeVector(encodedValue string) ([]float32, error) { + + // We first decode the encoded string into a byte array. decodedString, err := base64.StdEncoding.DecodeString(encodedValue) if err != nil { - fmt.Println("Error decoding string:", err) return nil, err } + // The array is expected to be divisible by 4 because each float32 + // should occupy 4 bytes + if len(decodedString)%4 != 0 { + return nil, fmt.Errorf("Decoded byte array not divisible by 4") + } dims := int(len(decodedString) / 4) decodedVector := make([]float32, dims) + // We iterate through the array 4 bytes at a time and convert each of + // them to a float32 value by reading them in a little endian notation for i := 0; i < dims; i++ { bytes := decodedString[i*4 : (i+1)*4] decodedVector[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes)) diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index b1ce19568..cda41f304 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -18,8 +18,6 @@ package mapping import ( - "encoding/base64" - "encoding/json" "fmt" "reflect" @@ -134,26 +132,6 @@ func processVector(vecI interface{}, dims int) ([]float32, bool) { return rv, true } -func processVectorBase64(vecBase64 interface{}) (interface{}, bool) { - vecEncoded, ok := vecBase64.(string) - if !ok { - return nil, false - } - - vecData, err := base64.StdEncoding.DecodeString(vecEncoded) - if err != nil { - return nil, false - } - - var vector interface{} - err = json.Unmarshal(vecData, &vector) - if err != nil { - return nil, false - } - - return vector, true -} - func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, pathString string, path []string, indexes []uint64, context *walkContext) bool { vector, ok := processVector(propertyMightBeVector, fm.Dims) @@ -175,11 +153,17 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interface{}, pathString string, path []string, indexes []uint64, context *walkContext) { - propertyMightBeVector, ok := processVectorBase64(propertyMightBeVectorBase64) + + encodedString, ok := propertyMightBeVectorBase64.(string) if !ok { return } + propertyMightBeVector, err := document.DecodeVector(encodedString) + if err != nil { + return + } + fm.processVector(propertyMightBeVector, pathString, path, indexes, context) } From 1d102da7df0270f97df65e34e9390105c04c3051 Mon Sep 17 00:00:00 2001 From: Likith B Date: Tue, 16 Apr 2024 11:10:51 +0530 Subject: [PATCH 7/8] Adding Testcases --- document/field_vector_base64.go | 9 ++- document/field_vector_base64_test.go | 110 +++++++++++++++++++++++++++ mapping/mapping_vectors.go | 1 - 3 files changed, 115 insertions(+), 5 deletions(-) create mode 100644 document/field_vector_base64_test.go diff --git a/document/field_vector_base64.go b/document/field_vector_base64.go index 7ef12bf37..e62dbe0a2 100644 --- a/document/field_vector_base64.go +++ b/document/field_vector_base64.go @@ -23,6 +23,7 @@ import ( "fmt" "math" + "github.com/blevesearch/bleve/v2/size" index "github.com/blevesearch/bleve_index_api" ) @@ -106,16 +107,16 @@ func DecodeVector(encodedValue string) ([]float32, error) { // The array is expected to be divisible by 4 because each float32 // should occupy 4 bytes - if len(decodedString)%4 != 0 { - return nil, fmt.Errorf("Decoded byte array not divisible by 4") + if len(decodedString)%size.SizeOfFloat32 != 0 { + return nil, fmt.Errorf("Decoded byte array not divisible by %d", size.SizeOfFloat32) } - dims := int(len(decodedString) / 4) + dims := int(len(decodedString) / size.SizeOfFloat32) decodedVector := make([]float32, dims) // We iterate through the array 4 bytes at a time and convert each of // them to a float32 value by reading them in a little endian notation for i := 0; i < dims; i++ { - bytes := decodedString[i*4 : (i+1)*4] + bytes := decodedString[i*size.SizeOfFloat32 : (i+1)*size.SizeOfFloat32] decodedVector[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes)) } diff --git a/document/field_vector_base64_test.go b/document/field_vector_base64_test.go new file mode 100644 index 000000000..b23cc7fc6 --- /dev/null +++ b/document/field_vector_base64_test.go @@ -0,0 +1,110 @@ +// Copyright (c) 2024 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package document + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "fmt" + "math/rand" + "testing" +) + +func TestDecodeVector(t *testing.T) { + vec := make([]float32, 2048) + for i := range vec { + vec[i] = rand.Float32() + } + + vecBytes := bytifyVec(vec) + encodedVec := base64.StdEncoding.EncodeToString(vecBytes) + + decodedVec, err := DecodeVector(encodedVec) + if err != nil { + t.Error(err) + } + if len(decodedVec) != len(vec) { + t.Errorf("Decoded vector dimensions not same as original vector dimensions") + } + + for i := range vec { + if vec[i] != decodedVec[i] { + t.Errorf("Decoded vector not the same as original vector") + } + } +} + +func BenchmarkDecodeVector128(b *testing.B) { + vec := make([]float32, 128) + for i := range vec { + vec[i] = rand.Float32() + } + + vecBytes := bytifyVec(vec) + encodedVec := base64.StdEncoding.EncodeToString(vecBytes) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = DecodeVector(encodedVec) + } +} + +func BenchmarkDecodeVector784(b *testing.B) { + vec := make([]float32, 784) + for i := range vec { + vec[i] = rand.Float32() + } + + vecBytes := bytifyVec(vec) + encodedVec := base64.StdEncoding.EncodeToString(vecBytes) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = DecodeVector(encodedVec) + } +} + +func BenchmarkDecodeVector1536(b *testing.B) { + vec := make([]float32, 1536) + for i := range vec { + vec[i] = rand.Float32() + } + + vecBytes := bytifyVec(vec) + encodedVec := base64.StdEncoding.EncodeToString(vecBytes) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = DecodeVector(encodedVec) + } +} + +func bytifyVec(vec []float32) []byte { + + buf := new(bytes.Buffer) + + for _, v := range vec { + err := binary.Write(buf, binary.LittleEndian, v) + if err != nil { + fmt.Println(err) + } + } + + return buf.Bytes() +} diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index cda41f304..0ec7c0f9f 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -153,7 +153,6 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interface{}, pathString string, path []string, indexes []uint64, context *walkContext) { - encodedString, ok := propertyMightBeVectorBase64.(string) if !ok { return From 8877da8251a7ec8200ab718c5ea935d152fc1353 Mon Sep 17 00:00:00 2001 From: Abhinav Dangeti Date: Tue, 16 Apr 2024 16:51:46 -0600 Subject: [PATCH 8/8] Test file under go:build vectors --- document/field_vector_base64_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/document/field_vector_base64_test.go b/document/field_vector_base64_test.go index b23cc7fc6..ac4bd8d4e 100644 --- a/document/field_vector_base64_test.go +++ b/document/field_vector_base64_test.go @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build vectors +// +build vectors + package document import (