Skip to content

Commit

Permalink
Add test cases for sparse vector
Browse files Browse the repository at this point in the history
Signed-off-by: ThreadDao <yufen.zong@zilliz.com>
  • Loading branch information
ThreadDao committed Apr 18, 2024
1 parent e697167 commit 7745e56
Show file tree
Hide file tree
Showing 15 changed files with 964 additions and 20 deletions.
3 changes: 2 additions & 1 deletion client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ func (c *GrpcClient) validateSchema(sch *entity.Schema) error {
if field.DataType == entity.FieldTypeFloatVector ||
field.DataType == entity.FieldTypeBinaryVector ||
field.DataType == entity.FieldTypeBFloat16Vector ||
field.DataType == entity.FieldTypeFloat16Vector {
field.DataType == entity.FieldTypeFloat16Vector ||
field.DataType == entity.FieldTypeSparseVector {
vectors++
}
}
Expand Down
4 changes: 4 additions & 0 deletions entity/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ func AnyToColumns(rows []interface{}, schemas ...*Schema) ([]Column, error) {
}
col := NewColumnBFloat16Vector(field.Name, int(dim), data)
nameColumns[field.Name] = col
case FieldTypeSparseVector:
data := make([]SparseEmbedding, 0, rowsLen)
col := NewColumnSparseVectors(field.Name, data)
nameColumns[field.Name] = col
}
}

Expand Down
2 changes: 0 additions & 2 deletions test/common/response_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ func CheckOutputFields(t *testing.T, actualColumns []entity.Column, expFields []
for _, actualColumn := range actualColumns {
actualFields = append(actualFields, actualColumn.Name())
}
log.Printf("actual fields: %v", actualFields)
log.Printf("expected fields: %v", expFields)
require.ElementsMatchf(t, expFields, actualFields, fmt.Sprintf("Expected search output fields: %v, actual: %v", expFields, actualFields))
}

Expand Down
96 changes: 94 additions & 2 deletions test/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const (
DefaultBinaryVecFieldName = "binaryVec"
DefaultFloat16VecFieldName = "fp16Vec"
DefaultBFloat16VecFieldName = "bf16Vec"
DefaultSparseVecFieldName = "sparseVec"
DefaultDynamicNumberField = "dynamicNumber"
DefaultDynamicStringField = "dynamicString"
DefaultDynamicBoolField = "dynamicBool"
Expand Down Expand Up @@ -223,6 +224,21 @@ func GenBinaryVector(dim int64) []byte {
return vector
}

func GenSparseVector(maxLen int) entity.SparseEmbedding {
length := 1 + rand.Intn(1+maxLen)
positions := make([]uint32, length)
values := make([]float32, length)
for i := 0; i < length; i++ {
positions[i] = uint32(2*i + 1)
values[i] = rand.Float32()
}
vector, err := entity.NewSliceSparseEmbedding(positions, values)
if err != nil {
log.Fatalf("Generate vector failed %s", err)
}
return vector
}

// --- common utils ---

// --- gen fields ---
Expand Down Expand Up @@ -405,6 +421,13 @@ func GenColumnData(start int, nb int, fieldType entity.FieldType, fieldName stri
bf16Vectors = append(bf16Vectors, vec)
}
return entity.NewColumnBFloat16Vector(fieldName, int(opt.dim), bf16Vectors)
case entity.FieldTypeSparseVector:
vectors := make([]entity.SparseEmbedding, 0, nb)
for i := start; i < start+nb; i++ {
vec := GenSparseVector(opt.maxLenSparse)
vectors = append(vectors, vec)
}
return entity.NewColumnSparseVectors(fieldName, vectors)
default:
return nil
}
Expand Down Expand Up @@ -984,6 +1007,53 @@ func GenDefaultArrayRows(start int, nb int, dim int64, enableDynamicField bool,
return rows
}

func GenDefaultSparseRows(start int, nb int, dim int64, maxLenSparse int, enableDynamicField bool) []interface{} {
rows := make([]interface{}, 0, nb)
type BaseRow struct {
Int64 int64 `json:"int64" milvus:"name:int64"`
Varchar string `json:"varchar" milvus:"name:varchar"`
FloatVec []float32 `json:"floatVec" milvus:"name:floatVec"`
SparseVec entity.SparseEmbedding `json:"sparseVec" milvus:"name:sparseVec"`
}

type DynamicRow struct {
Int64 int64 `json:"int64" milvus:"name:int64"`
Varchar string `json:"varchar" milvus:"name:varchar"`
FloatVec []float32 `json:"floatVec" milvus:"name:floatVec"`
SparseVec entity.SparseEmbedding `json:"sparseVec" milvus:"name:sparseVec"`
Dynamic Dynamic `json:"dynamic" milvus:"name:dynamic"`
}

for i := start; i < start+nb; i++ {
baseRow := BaseRow{
Int64: int64(i),
Varchar: strconv.Itoa(i),
FloatVec: GenFloatVector(dim),
SparseVec: GenSparseVector(maxLenSparse),
}
// json and dynamic field
dynamicJSON := Dynamic{
Number: int32(i),
String: strconv.Itoa(i),
Bool: i%2 == 0,
List: []int64{int64(i), int64(i + 1)},
}
if enableDynamicField {
dynamicRow := DynamicRow{
Int64: baseRow.Int64,
Varchar: baseRow.Varchar,
FloatVec: baseRow.FloatVec,
SparseVec: baseRow.SparseVec,
Dynamic: dynamicJSON,
}
rows = append(rows, dynamicRow)
} else {
rows = append(rows, &baseRow)
}
}
return rows
}

func GenAllVectorsRows(start int, nb int, dim int64, enableDynamicField bool) []interface{} {
rows := make([]interface{}, 0, nb)
type BaseRow struct {
Expand Down Expand Up @@ -1234,11 +1304,28 @@ var SupportBinIvfFlatMetricType = []entity.MetricType{
entity.HAMMING,
}

var UnsupportedSparseVecMetricsType = []entity.MetricType{
entity.L2,
entity.COSINE,
entity.JACCARD,
entity.HAMMING,
entity.SUBSTRUCTURE,
entity.SUPERSTRUCTURE,
}

// GenAllFloatIndex gen all float vector index
func GenAllFloatIndex() []entity.Index {
func GenAllFloatIndex(metricTypes ...entity.MetricType) []entity.Index {
nlist := 128
var allFloatIndex []entity.Index
for _, metricType := range SupportFloatMetricType {
var allMetricTypes []entity.MetricType
log.Println(metricTypes)
if len(metricTypes) == 0 {
allMetricTypes = SupportFloatMetricType
} else {
allMetricTypes = metricTypes
}
for _, metricType := range allMetricTypes {
log.Println(metricType)
idxFlat, _ := entity.NewIndexFlat(metricType)
idxIvfFlat, _ := entity.NewIndexIvfFlat(metricType, nlist)
idxIvfSq8, _ := entity.NewIndexIvfSQ8(metricType, nlist)
Expand Down Expand Up @@ -1279,6 +1366,11 @@ func GenSearchVectors(nq int, dim int64, dataType entity.FieldType) []entity.Vec
vector := GenBFloat16Vector(dim)
vectors = append(vectors, entity.BFloat16Vector(vector))
}
case entity.FieldTypeSparseVector:
for i := 0; i < nq; i++ {
vec := GenSparseVector(int(dim))
vectors = append(vectors, vec)
}
}
return vectors
}
Expand Down
13 changes: 10 additions & 3 deletions test/common/utils_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ func GenSchema(name string, autoID bool, fields []*entity.Field, opts ...CreateS
// GenColumnDataOption -- create column data --
type GenColumnDataOption func(opt *genDataOpt)
type genDataOpt struct {
dim int64
ElementType entity.FieldType
capacity int64
dim int64
ElementType entity.FieldType
capacity int64
maxLenSparse int
}

func WithVectorDim(dim int64) GenColumnDataOption {
Expand All @@ -137,4 +138,10 @@ func WithArrayCapacity(capacity int64) GenColumnDataOption {
}
}

func WithSparseVectorLen(length int) GenColumnDataOption {
return func(opt *genDataOpt) {
opt.maxLenSparse = length
}
}

// -- create column data --
44 changes: 44 additions & 0 deletions test/testcases/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,50 @@ func TestCreateMultiVectorExceed(t *testing.T) {
common.CheckErr(t, errCreateCollection, false, "maximum vector field's number should be limited to 4")
}

// specify dim for sparse vector -> error
func TestCreateCollectionSparseVectorWithDim(t *testing.T) {
ctx := createContext(t, time.Second*common.DefaultTimeout)
mc := createMilvusClient(ctx, t)
allFields := []*entity.Field{
common.GenField(common.DefaultIntFieldName, entity.FieldTypeInt64, common.WithIsPrimaryKey(true), common.WithAutoID(false)),
common.GenField(common.DefaultSparseVecFieldName, entity.FieldTypeSparseVector, common.WithDim(common.DefaultDim)),
}
collName := common.GenRandomString(6)
schema := common.GenSchema(collName, false, allFields)

// create collection
errCreateCollection := mc.CreateCollection(ctx, schema, common.DefaultShards)
common.CheckErr(t, errCreateCollection, false, "dim should not be specified for sparse vector field sparseVec(0)")
}

// create collection with sparse vector
func TestCreateCollectionSparseVector(t *testing.T) {
ctx := createContext(t, time.Second*common.DefaultTimeout)
mc := createMilvusClient(ctx, t)
allFields := []*entity.Field{
common.GenField(common.DefaultIntFieldName, entity.FieldTypeInt64, common.WithIsPrimaryKey(true), common.WithAutoID(false)),
common.GenField(common.DefaultVarcharFieldName, entity.FieldTypeVarChar, common.WithMaxLength(common.TestMaxLen)),
common.GenField(common.DefaultSparseVecFieldName, entity.FieldTypeSparseVector),
}
collName := common.GenRandomString(6)
schema := common.GenSchema(collName, false, allFields)

// create collection
errCreateCollection := mc.CreateCollection(ctx, schema, common.DefaultShards)
common.CheckErr(t, errCreateCollection, true)

// describe collection
collection, err := mc.DescribeCollection(ctx, collName)
common.CheckErr(t, err, true)
common.CheckCollection(t, collection, collName, common.DefaultShards, schema, common.DefaultConsistencyLevel)
require.Len(t, collection.Schema.Fields, 3)
for _, field := range collection.Schema.Fields {
if field.DataType == entity.FieldTypeSparseVector {
require.Equal(t, common.DefaultSparseVecFieldName, field.Name)
}
}
}

// -- Get Collection Statistics --

func TestGetStaticsCollectionNotExisted(t *testing.T) {
Expand Down
64 changes: 64 additions & 0 deletions test/testcases/groupby_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,70 @@ func TestSearchGroupByFloatDefault(t *testing.T) {
}
}

// test groupBy search sparse vector
func TestGroupBySearchSparseVector(t *testing.T) {
t.Parallel()
idxInverted, _ := entity.NewIndexSparseInverted(entity.IP, 0.3)
idxWand, _ := entity.NewIndexSparseWAND(entity.IP, 0.2)
for _, idx := range []entity.Index{idxInverted, idxWand} {
ctx := createContext(t, time.Second*common.DefaultTimeout*2)
// connect
mc := createMilvusClient(ctx, t)

// create -> insert [0, 3000) -> flush -> index -> load
cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true,
ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen}
collName := createCollection(ctx, t, mc, cp, client.WithConsistencyLevel(entity.ClStrong))

// insert data
dp := DataParams{DoInsert: true, CollectionName: collName, CollectionFieldsType: Int64VarcharSparseVec, start: 0,
nb: 200, dim: common.DefaultDim, EnableDynamicField: true}
for i := 0; i < 100; i++ {
_, _ = insertData(ctx, t, mc, dp)
}
mc.Flush(ctx, collName, false)

// index and load
idxHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96)
mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idxHnsw, false)
mc.CreateIndex(ctx, collName, common.DefaultSparseVecFieldName, idx, false)
mc.LoadCollection(ctx, collName, false)

// groupBy search
queryVec := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector)
sp, _ := entity.NewIndexSparseInvertedSearchParam(0.2)
resGroupBy, _ := mc.Search(ctx, collName, []string{}, "", []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName}, queryVec,
common.DefaultSparseVecFieldName, entity.IP, common.DefaultTopK, sp, client.WithGroupByField(common.DefaultVarcharFieldName))

// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for _, rs := range resGroupBy {
for i := 0; i < rs.ResultCount; i++ {
groupByValue, _ := rs.GroupByValue.Get(i)
pkValue, _ := rs.IDs.GetAsInt64(i)
expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue)

// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName,
common.DefaultVarcharFieldName}, queryVec, common.DefaultSparseVecFieldName, entity.IP, 1, sp)
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}

// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
common.DefaultSparseVecFieldName, common.DefaultNq, common.DefaultTopK, hitsNum, hitsRate)
log.Println(_str)
require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str)
}
}

// binary vector -> not supported
func TestSearchGroupByBinaryDefault(t *testing.T) {
t.Parallel()
Expand Down
51 changes: 50 additions & 1 deletion test/testcases/hybrid_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ func TestHybridSearchDefault(t *testing.T) {

// hybrid search default -> verify success
func TestHybridSearchMultiVectorsDefault(t *testing.T) {
t.Skip("https://github.com/milvus-io/milvus/issues/32222")
t.Parallel()
ctx := createContext(t, time.Second*common.DefaultTimeout*3)
// connect
Expand Down Expand Up @@ -311,3 +310,53 @@ func TestHybridSearchMultiVectorsRangeSearch(t *testing.T) {
}
}
}

func TestHybridSearchSparseVector(t *testing.T) {
t.Parallel()
idxInverted := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_INVERTED_INDEX", map[string]string{"drop_ratio_build": "0.2", "metric_type": "IP"})
idxWand := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_WAND", map[string]string{"drop_ratio_build": "0.3", "metric_type": "IP"})
for _, idx := range []entity.Index{idxInverted, idxWand} {
ctx := createContext(t, time.Second*common.DefaultTimeout*2)
// connect
mc := createMilvusClient(ctx, t)

// create -> insert [0, 3000) -> flush -> index -> load
cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true,
ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen}

dp := DataParams{DoInsert: true, CollectionFieldsType: Int64VarcharSparseVec, start: 0, nb: common.DefaultNb * 3,
dim: common.DefaultDim, EnableDynamicField: true}

// index params
idxHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96)
ips := []IndexParams{
{BuildIndex: true, Index: idx, FieldName: common.DefaultSparseVecFieldName, async: false},
{BuildIndex: true, Index: idxHnsw, FieldName: common.DefaultFloatVecFieldName, async: false},
}
collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))

// search
queryVec1 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim*2, entity.FieldTypeSparseVector)
queryVec2 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
sp1, _ := entity.NewIndexSparseInvertedSearchParam(0.2)
sp2, _ := entity.NewIndexHNSWSearchParam(20)
expr := fmt.Sprintf("%s > 1", common.DefaultIntFieldName)
sReqs := []*client.ANNSearchRequest{
client.NewANNSearchRequest(common.DefaultSparseVecFieldName, entity.IP, expr, queryVec1, sp1, common.DefaultTopK),
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec2, sp2, common.DefaultTopK),
}
for _, reranker := range []client.Reranker{
client.NewRRFReranker(),
client.NewWeightedReranker([]float64{0.5, 0.6}),
} {
// hybrid search
searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{"*"}, reranker, sReqs)
common.CheckErr(t, errSearch, true)
common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultTopK)
common.CheckErr(t, errSearch, true)
outputFields := []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName, common.DefaultFloatVecFieldName,
common.DefaultSparseVecFieldName, common.DefaultDynamicFieldName}
common.CheckOutputFields(t, searchRes[0].Fields, outputFields)
}
}
}
Loading

0 comments on commit 7745e56

Please sign in to comment.