Skip to content

Commit

Permalink
Support upsert
Browse files Browse the repository at this point in the history
Signed-off-by: lixinguo <xinguo.li@zilliz.com>
  • Loading branch information
lixinguo committed Feb 20, 2023
1 parent f5d28d7 commit 3a07ca5
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 1 deletion.
2 changes: 2 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ type Client interface {
Flush(ctx context.Context, collName string, async bool) error
// DeleteByPks deletes entries related to provided primary keys
DeleteByPks(ctx context.Context, collName string, partitionName string, ids entity.Column) error
// Upsert column-based data of collection, returns id column values
Upsert(ctx context.Context, collName string, partitionName string, columns ...entity.Column) (entity.Column, error)
// Search search with bool expression
Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error)
Expand Down
92 changes: 92 additions & 0 deletions client/client_grpc_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,98 @@ func (c *GrpcClient) DeleteByPks(ctx context.Context, collName string, partition
return nil
}

// Upsert Index into collection with column-based format
// collName is the collection name
// partitionName is the partition to upsert, if not specified(empty), default partition will be used
// columns are slice of the column-based data
func (c *GrpcClient) Upsert(ctx context.Context, collName string, partitionName string, columns ...entity.Column) (entity.Column, error) {
if c.Service == nil {
return nil, ErrClientNotReady
}
// 1. validation for all input params
// collection
if err := c.checkCollectionExists(ctx, collName); err != nil {
return nil, err
}
if partitionName != "" {
err := c.checkPartitionExists(ctx, collName, partitionName)
if err != nil {
return nil, err
}
}
// fields
var rowSize int
coll, err := c.DescribeCollection(ctx, collName)
if err != nil {
return nil, err
}
mNameField := make(map[string]*entity.Field)
for _, field := range coll.Schema.Fields {
mNameField[field.Name] = field
}
mNameColumn := make(map[string]entity.Column)
for _, column := range columns {
mNameColumn[column.Name()] = column
l := column.Len()
if rowSize == 0 {
rowSize = l
} else {
if rowSize != l {
return nil, errors.New("column size not match")
}
}
field, has := mNameField[column.Name()]
if !has {
return nil, fmt.Errorf("field %s does not exist in collection %s", column.Name(), collName)
}
if column.Type() != field.DataType {
return nil, fmt.Errorf("param column %s has type %v but collection field definition is %v", column.Name(), column.FieldData(), field.DataType)
}
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector {
dim := 0
switch column := column.(type) {
case *entity.ColumnFloatVector:
dim = column.Dim()
case *entity.ColumnBinaryVector:
dim = column.Dim()
}
if fmt.Sprintf("%d", dim) != field.TypeParams[entity.TypeParamDim] {
return nil, fmt.Errorf("params column %s vector dim %d not match collection definition, which has dim of %s", field.Name, dim, field.TypeParams[entity.TypeParamDim])
}
}
}
for _, field := range coll.Schema.Fields {
_, has := mNameColumn[field.Name]
if !has && !field.AutoID {
return nil, fmt.Errorf("field %s not passed", field.Name)
}
}

// 2. do upsert request
req := &server.UpsertRequest{
DbName: "", // reserved
CollectionName: collName,
PartitionName: partitionName,
}
if req.PartitionName == "" {
req.PartitionName = "_default" // use default partition
}
req.NumRows = uint32(rowSize)
for _, column := range columns {
req.FieldsData = append(req.FieldsData, column.FieldData())
}
resp, err := c.Service.Upsert(ctx, req)
if err != nil {
return nil, err
}
if err := handleRespStatus(resp.GetStatus()); err != nil {
return nil, err
}
MetaCache.setSessionTs(collName, resp.Timestamp)
// 3. parse id column
return entity.IDColumns(resp.GetIDs(), 0, -1)
}

// Search with bool expression
func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) {
Expand Down
88 changes: 88 additions & 0 deletions client/client_grpc_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,94 @@ func TestGrpcClientFlush(t *testing.T) {
})
}

func TestGrpcClientUpsert(t *testing.T) {
ctx := context.Background()

c := testClient(ctx, t)

t.Run("test create failure due to meta", func(t *testing.T) {
mock.DelInjection(MHasCollection) // collection does not exist
ids, err := c.Upsert(ctx, testCollectionName, "")
assert.Nil(t, ids)
assert.NotNil(t, err)

// partition not exists
mock.SetInjection(MHasCollection, hasCollectionDefault)
ids, err = c.Upsert(ctx, testCollectionName, "_part_not_exists")
assert.Nil(t, ids)
assert.NotNil(t, err)

// field not in collection
mock.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))
vectors := generateFloatVector(10, testVectorDim)
ids, err = c.Upsert(ctx, testCollectionName, "",
entity.NewColumnInt64("extra_field", []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), entity.NewColumnFloatVector(testVectorField, testVectorDim, vectors))
assert.Nil(t, ids)
assert.NotNil(t, err)

// field type not match
mock.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))
ids, err = c.Upsert(ctx, testCollectionName, "",
entity.NewColumnInt32("int64", []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), entity.NewColumnFloatVector(testVectorField, testVectorDim, vectors))
assert.Nil(t, ids)
assert.NotNil(t, err)

// missing field
ids, err = c.Upsert(ctx, testCollectionName, "")
assert.Nil(t, ids)
assert.NotNil(t, err)

// column len not match
ids, err = c.Upsert(ctx, testCollectionName, "", entity.NewColumnInt64("int64", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}),
entity.NewColumnFloatVector(testVectorField, testVectorDim, vectors))
assert.Nil(t, ids)
assert.NotNil(t, err)

// column len not match
ids, err = c.Upsert(ctx, testCollectionName, "", entity.NewColumnInt64("int64", []int64{1, 2, 3}),
entity.NewColumnFloatVector(testVectorField, testVectorDim, vectors))
assert.Nil(t, ids)
assert.NotNil(t, err)

// dim not match
ids, err = c.Upsert(ctx, testCollectionName, "", entity.NewColumnInt64("int64", []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
entity.NewColumnFloatVector(testVectorField, testVectorDim*2, vectors))
assert.Nil(t, ids)
assert.NotNil(t, err)
})

mock.SetInjection(MHasCollection, hasCollectionDefault)
mock.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))

vector := generateFloatVector(4096, testVectorDim)
mock.SetInjection(MUpsert, func(_ context.Context, raw proto.Message) (proto.Message, error) {
req, ok := raw.(*server.UpsertRequest)
resp := &server.MutationResult{}
if !ok {
s, err := BadRequestStatus()
resp.Status = s
return resp, err
}
assert.EqualValues(t, 4096, req.GetNumRows())
assert.Equal(t, testCollectionName, req.GetCollectionName())
intIds := &schema.IDs_IntId{
IntId: &schema.LongArray{
Data: make([]int64, 4096),
},
}
resp.IDs = &schema.IDs{
IdField: intIds,
}
s, err := SuccessStatus()
resp.Status = s
return resp, err
})
_, err := c.Upsert(ctx, testCollectionName, "", // use default partition
entity.NewColumnFloatVector(testVectorField, testVectorDim, vector))
assert.Nil(t, err)
mock.DelInjection(MUpsert)
}

func TestGrpcDeleteByPks(t *testing.T) {
ctx := context.Background()

Expand Down
2 changes: 1 addition & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func TestGrpcClientNil(t *testing.T) {
m.Name == "Search" || // type alias MetricType treated as string
m.Name == "CalcDistance" ||
m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect
m.Name == "Insert" { // complex methods with ...
m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ...
t.Skip("method", m.Name, "skipped")
}
ins := make([]reflect.Value, 0, mt.NumIn())
Expand Down

0 comments on commit 3a07ca5

Please sign in to comment.