Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test cases for sparse vector #720

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ func (c *GrpcClient) validateSchema(sch *entity.Schema) error {
if field.DataType == entity.FieldTypeFloatVector ||
field.DataType == entity.FieldTypeBinaryVector ||
field.DataType == entity.FieldTypeBFloat16Vector ||
field.DataType == entity.FieldTypeFloat16Vector {
field.DataType == entity.FieldTypeFloat16Vector ||
field.DataType == entity.FieldTypeSparseVector {
vectors++
}
}
Expand Down
4 changes: 4 additions & 0 deletions entity/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ func AnyToColumns(rows []interface{}, schemas ...*Schema) ([]Column, error) {
}
col := NewColumnBFloat16Vector(field.Name, int(dim), data)
nameColumns[field.Name] = col
case FieldTypeSparseVector:
data := make([]SparseEmbedding, 0, rowsLen)
col := NewColumnSparseVectors(field.Name, data)
nameColumns[field.Name] = col
}
}

Expand Down
2 changes: 0 additions & 2 deletions test/common/response_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ func CheckOutputFields(t *testing.T, actualColumns []entity.Column, expFields []
for _, actualColumn := range actualColumns {
actualFields = append(actualFields, actualColumn.Name())
}
log.Printf("actual fields: %v", actualFields)
log.Printf("expected fields: %v", expFields)
require.ElementsMatchf(t, expFields, actualFields, fmt.Sprintf("Expected search output fields: %v, actual: %v", expFields, actualFields))
}

Expand Down
96 changes: 94 additions & 2 deletions test/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const (
DefaultBinaryVecFieldName = "binaryVec"
DefaultFloat16VecFieldName = "fp16Vec"
DefaultBFloat16VecFieldName = "bf16Vec"
DefaultSparseVecFieldName = "sparseVec"
DefaultDynamicNumberField = "dynamicNumber"
DefaultDynamicStringField = "dynamicString"
DefaultDynamicBoolField = "dynamicBool"
Expand Down Expand Up @@ -223,6 +224,21 @@ func GenBinaryVector(dim int64) []byte {
return vector
}

func GenSparseVector(maxLen int) entity.SparseEmbedding {
length := 1 + rand.Intn(1+maxLen)
positions := make([]uint32, length)
values := make([]float32, length)
for i := 0; i < length; i++ {
positions[i] = uint32(2*i + 1)
values[i] = rand.Float32()
}
vector, err := entity.NewSliceSparseEmbedding(positions, values)
if err != nil {
log.Fatalf("Generate vector failed %s", err)
}
return vector
}

// --- common utils ---

// --- gen fields ---
Expand Down Expand Up @@ -405,6 +421,13 @@ func GenColumnData(start int, nb int, fieldType entity.FieldType, fieldName stri
bf16Vectors = append(bf16Vectors, vec)
}
return entity.NewColumnBFloat16Vector(fieldName, int(opt.dim), bf16Vectors)
case entity.FieldTypeSparseVector:
vectors := make([]entity.SparseEmbedding, 0, nb)
for i := start; i < start+nb; i++ {
vec := GenSparseVector(opt.maxLenSparse)
vectors = append(vectors, vec)
}
return entity.NewColumnSparseVectors(fieldName, vectors)
default:
return nil
}
Expand Down Expand Up @@ -984,6 +1007,53 @@ func GenDefaultArrayRows(start int, nb int, dim int64, enableDynamicField bool,
return rows
}

func GenDefaultSparseRows(start int, nb int, dim int64, maxLenSparse int, enableDynamicField bool) []interface{} {
rows := make([]interface{}, 0, nb)
type BaseRow struct {
Int64 int64 `json:"int64" milvus:"name:int64"`
Varchar string `json:"varchar" milvus:"name:varchar"`
FloatVec []float32 `json:"floatVec" milvus:"name:floatVec"`
SparseVec entity.SparseEmbedding `json:"sparseVec" milvus:"name:sparseVec"`
}

type DynamicRow struct {
Int64 int64 `json:"int64" milvus:"name:int64"`
Varchar string `json:"varchar" milvus:"name:varchar"`
FloatVec []float32 `json:"floatVec" milvus:"name:floatVec"`
SparseVec entity.SparseEmbedding `json:"sparseVec" milvus:"name:sparseVec"`
Dynamic Dynamic `json:"dynamic" milvus:"name:dynamic"`
}

for i := start; i < start+nb; i++ {
baseRow := BaseRow{
Int64: int64(i),
Varchar: strconv.Itoa(i),
FloatVec: GenFloatVector(dim),
SparseVec: GenSparseVector(maxLenSparse),
}
// json and dynamic field
dynamicJSON := Dynamic{
Number: int32(i),
String: strconv.Itoa(i),
Bool: i%2 == 0,
List: []int64{int64(i), int64(i + 1)},
}
if enableDynamicField {
dynamicRow := DynamicRow{
Int64: baseRow.Int64,
Varchar: baseRow.Varchar,
FloatVec: baseRow.FloatVec,
SparseVec: baseRow.SparseVec,
Dynamic: dynamicJSON,
}
rows = append(rows, dynamicRow)
} else {
rows = append(rows, &baseRow)
}
}
return rows
}

func GenAllVectorsRows(start int, nb int, dim int64, enableDynamicField bool) []interface{} {
rows := make([]interface{}, 0, nb)
type BaseRow struct {
Expand Down Expand Up @@ -1234,11 +1304,28 @@ var SupportBinIvfFlatMetricType = []entity.MetricType{
entity.HAMMING,
}

var UnsupportedSparseVecMetricsType = []entity.MetricType{
entity.L2,
entity.COSINE,
entity.JACCARD,
entity.HAMMING,
entity.SUBSTRUCTURE,
entity.SUPERSTRUCTURE,
}

// GenAllFloatIndex gen all float vector index
func GenAllFloatIndex() []entity.Index {
func GenAllFloatIndex(metricTypes ...entity.MetricType) []entity.Index {
nlist := 128
var allFloatIndex []entity.Index
for _, metricType := range SupportFloatMetricType {
var allMetricTypes []entity.MetricType
log.Println(metricTypes)
if len(metricTypes) == 0 {
allMetricTypes = SupportFloatMetricType
} else {
allMetricTypes = metricTypes
}
for _, metricType := range allMetricTypes {
log.Println(metricType)
idxFlat, _ := entity.NewIndexFlat(metricType)
idxIvfFlat, _ := entity.NewIndexIvfFlat(metricType, nlist)
idxIvfSq8, _ := entity.NewIndexIvfSQ8(metricType, nlist)
Expand Down Expand Up @@ -1279,6 +1366,11 @@ func GenSearchVectors(nq int, dim int64, dataType entity.FieldType) []entity.Vec
vector := GenBFloat16Vector(dim)
vectors = append(vectors, entity.BFloat16Vector(vector))
}
case entity.FieldTypeSparseVector:
for i := 0; i < nq; i++ {
vec := GenSparseVector(int(dim))
vectors = append(vectors, vec)
}
}
return vectors
}
Expand Down
13 changes: 10 additions & 3 deletions test/common/utils_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ func GenSchema(name string, autoID bool, fields []*entity.Field, opts ...CreateS
// GenColumnDataOption -- create column data --
type GenColumnDataOption func(opt *genDataOpt)
type genDataOpt struct {
dim int64
ElementType entity.FieldType
capacity int64
dim int64
ElementType entity.FieldType
capacity int64
maxLenSparse int
}

func WithVectorDim(dim int64) GenColumnDataOption {
Expand All @@ -137,4 +138,10 @@ func WithArrayCapacity(capacity int64) GenColumnDataOption {
}
}

func WithSparseVectorLen(length int) GenColumnDataOption {
return func(opt *genDataOpt) {
opt.maxLenSparse = length
}
}

// -- create column data --
44 changes: 44 additions & 0 deletions test/testcases/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,50 @@ func TestCreateMultiVectorExceed(t *testing.T) {
common.CheckErr(t, errCreateCollection, false, "maximum vector field's number should be limited to 4")
}

// specify dim for sparse vector -> error
func TestCreateCollectionSparseVectorWithDim(t *testing.T) {
ctx := createContext(t, time.Second*common.DefaultTimeout)
mc := createMilvusClient(ctx, t)
allFields := []*entity.Field{
common.GenField(common.DefaultIntFieldName, entity.FieldTypeInt64, common.WithIsPrimaryKey(true), common.WithAutoID(false)),
common.GenField(common.DefaultSparseVecFieldName, entity.FieldTypeSparseVector, common.WithDim(common.DefaultDim)),
}
collName := common.GenRandomString(6)
schema := common.GenSchema(collName, false, allFields)

// create collection
errCreateCollection := mc.CreateCollection(ctx, schema, common.DefaultShards)
common.CheckErr(t, errCreateCollection, false, "dim should not be specified for sparse vector field sparseVec(0)")
}

// create collection with sparse vector
func TestCreateCollectionSparseVector(t *testing.T) {
ctx := createContext(t, time.Second*common.DefaultTimeout)
mc := createMilvusClient(ctx, t)
allFields := []*entity.Field{
common.GenField(common.DefaultIntFieldName, entity.FieldTypeInt64, common.WithIsPrimaryKey(true), common.WithAutoID(false)),
common.GenField(common.DefaultVarcharFieldName, entity.FieldTypeVarChar, common.WithMaxLength(common.TestMaxLen)),
common.GenField(common.DefaultSparseVecFieldName, entity.FieldTypeSparseVector),
}
collName := common.GenRandomString(6)
schema := common.GenSchema(collName, false, allFields)

// create collection
errCreateCollection := mc.CreateCollection(ctx, schema, common.DefaultShards)
common.CheckErr(t, errCreateCollection, true)

// describe collection
collection, err := mc.DescribeCollection(ctx, collName)
common.CheckErr(t, err, true)
common.CheckCollection(t, collection, collName, common.DefaultShards, schema, common.DefaultConsistencyLevel)
require.Len(t, collection.Schema.Fields, 3)
for _, field := range collection.Schema.Fields {
if field.DataType == entity.FieldTypeSparseVector {
require.Equal(t, common.DefaultSparseVecFieldName, field.Name)
}
}
}

// -- Get Collection Statistics --

func TestGetStaticsCollectionNotExisted(t *testing.T) {
Expand Down
64 changes: 64 additions & 0 deletions test/testcases/groupby_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,70 @@ func TestSearchGroupByFloatDefault(t *testing.T) {
}
}

// test groupBy search sparse vector
func TestGroupBySearchSparseVector(t *testing.T) {
t.Parallel()
idxInverted, _ := entity.NewIndexSparseInverted(entity.IP, 0.3)
idxWand, _ := entity.NewIndexSparseWAND(entity.IP, 0.2)
for _, idx := range []entity.Index{idxInverted, idxWand} {
ctx := createContext(t, time.Second*common.DefaultTimeout*2)
// connect
mc := createMilvusClient(ctx, t)

// create -> insert [0, 3000) -> flush -> index -> load
cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true,
ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen}
collName := createCollection(ctx, t, mc, cp, client.WithConsistencyLevel(entity.ClStrong))

// insert data
dp := DataParams{DoInsert: true, CollectionName: collName, CollectionFieldsType: Int64VarcharSparseVec, start: 0,
nb: 200, dim: common.DefaultDim, EnableDynamicField: true}
for i := 0; i < 100; i++ {
_, _ = insertData(ctx, t, mc, dp)
}
mc.Flush(ctx, collName, false)

// index and load
idxHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96)
mc.CreateIndex(ctx, collName, common.DefaultFloatVecFieldName, idxHnsw, false)
mc.CreateIndex(ctx, collName, common.DefaultSparseVecFieldName, idx, false)
mc.LoadCollection(ctx, collName, false)

// groupBy search
queryVec := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector)
sp, _ := entity.NewIndexSparseInvertedSearchParam(0.2)
resGroupBy, _ := mc.Search(ctx, collName, []string{}, "", []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName}, queryVec,
common.DefaultSparseVecFieldName, entity.IP, common.DefaultTopK, sp, client.WithGroupByField(common.DefaultVarcharFieldName))

// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for _, rs := range resGroupBy {
for i := 0; i < rs.ResultCount; i++ {
groupByValue, _ := rs.GroupByValue.Get(i)
pkValue, _ := rs.IDs.GetAsInt64(i)
expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue)

// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName,
common.DefaultVarcharFieldName}, queryVec, common.DefaultSparseVecFieldName, entity.IP, 1, sp)
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}

// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
common.DefaultSparseVecFieldName, common.DefaultNq, common.DefaultTopK, hitsNum, hitsRate)
log.Println(_str)
require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str)
}
}

// binary vector -> not supported
func TestSearchGroupByBinaryDefault(t *testing.T) {
t.Parallel()
Expand Down
51 changes: 50 additions & 1 deletion test/testcases/hybrid_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ func TestHybridSearchDefault(t *testing.T) {

// hybrid search default -> verify success
func TestHybridSearchMultiVectorsDefault(t *testing.T) {
t.Skip("https://github.com/milvus-io/milvus/issues/32222")
t.Parallel()
ctx := createContext(t, time.Second*common.DefaultTimeout*3)
// connect
Expand Down Expand Up @@ -311,3 +310,53 @@ func TestHybridSearchMultiVectorsRangeSearch(t *testing.T) {
}
}
}

func TestHybridSearchSparseVector(t *testing.T) {
t.Parallel()
idxInverted := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_INVERTED_INDEX", map[string]string{"drop_ratio_build": "0.2", "metric_type": "IP"})
idxWand := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_WAND", map[string]string{"drop_ratio_build": "0.3", "metric_type": "IP"})
for _, idx := range []entity.Index{idxInverted, idxWand} {
ctx := createContext(t, time.Second*common.DefaultTimeout*2)
// connect
mc := createMilvusClient(ctx, t)

// create -> insert [0, 3000) -> flush -> index -> load
cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true,
ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen}

dp := DataParams{DoInsert: true, CollectionFieldsType: Int64VarcharSparseVec, start: 0, nb: common.DefaultNb * 3,
dim: common.DefaultDim, EnableDynamicField: true}

// index params
idxHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96)
ips := []IndexParams{
{BuildIndex: true, Index: idx, FieldName: common.DefaultSparseVecFieldName, async: false},
{BuildIndex: true, Index: idxHnsw, FieldName: common.DefaultFloatVecFieldName, async: false},
}
collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))

// search
queryVec1 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim*2, entity.FieldTypeSparseVector)
queryVec2 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
sp1, _ := entity.NewIndexSparseInvertedSearchParam(0.2)
sp2, _ := entity.NewIndexHNSWSearchParam(20)
expr := fmt.Sprintf("%s > 1", common.DefaultIntFieldName)
sReqs := []*client.ANNSearchRequest{
client.NewANNSearchRequest(common.DefaultSparseVecFieldName, entity.IP, expr, queryVec1, sp1, common.DefaultTopK),
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec2, sp2, common.DefaultTopK),
}
for _, reranker := range []client.Reranker{
client.NewRRFReranker(),
client.NewWeightedReranker([]float64{0.5, 0.6}),
} {
// hybrid search
searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{"*"}, reranker, sReqs)
common.CheckErr(t, errSearch, true)
common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultTopK)
common.CheckErr(t, errSearch, true)
outputFields := []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName, common.DefaultFloatVecFieldName,
common.DefaultSparseVecFieldName, common.DefaultDynamicFieldName}
common.CheckOutputFields(t, searchRes[0].Fields, outputFields)
}
}
}
Loading
Loading