From 3a07ca54ad3b80e7918e904426ad328a4d769428 Mon Sep 17 00:00:00 2001 From: lixinguo Date: Wed, 15 Feb 2023 19:16:24 +0800 Subject: [PATCH] Support upsert Signed-off-by: lixinguo --- client/client.go | 2 + client/client_grpc_data.go | 92 +++++++++++++++++++++++++++++++++ client/client_grpc_data_test.go | 88 +++++++++++++++++++++++++++++++ client/client_test.go | 2 +- 4 files changed, 183 insertions(+), 1 deletion(-) diff --git a/client/client.go b/client/client.go index f9c0ff3c2..4441386aa 100644 --- a/client/client.go +++ b/client/client.go @@ -115,6 +115,8 @@ type Client interface { Flush(ctx context.Context, collName string, async bool) error // DeleteByPks deletes entries related to provided primary keys DeleteByPks(ctx context.Context, collName string, partitionName string, ids entity.Column) error + // Upsert column-based data of collection, returns id column values + Upsert(ctx context.Context, collName string, partitionName string, columns ...entity.Column) (entity.Column, error) // Search search with bool expression Search(ctx context.Context, collName string, partitions []string, expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) diff --git a/client/client_grpc_data.go b/client/client_grpc_data.go index d5085a1b1..ed21a2270 100644 --- a/client/client_grpc_data.go +++ b/client/client_grpc_data.go @@ -229,6 +229,98 @@ func (c *GrpcClient) DeleteByPks(ctx context.Context, collName string, partition return nil } +// Upsert Index into collection with column-based format +// collName is the collection name +// partitionName is the partition to upsert, if not specified(empty), default partition will be used +// columns are slice of the column-based data +func (c *GrpcClient) Upsert(ctx context.Context, collName string, partitionName string, columns ...entity.Column) (entity.Column, error) { + if c.Service == nil { + return nil, ErrClientNotReady + } + // 1. validation for all input params + // collection + if err := c.checkCollectionExists(ctx, collName); err != nil { + return nil, err + } + if partitionName != "" { + err := c.checkPartitionExists(ctx, collName, partitionName) + if err != nil { + return nil, err + } + } + // fields + var rowSize int + coll, err := c.DescribeCollection(ctx, collName) + if err != nil { + return nil, err + } + mNameField := make(map[string]*entity.Field) + for _, field := range coll.Schema.Fields { + mNameField[field.Name] = field + } + mNameColumn := make(map[string]entity.Column) + for _, column := range columns { + mNameColumn[column.Name()] = column + l := column.Len() + if rowSize == 0 { + rowSize = l + } else { + if rowSize != l { + return nil, errors.New("column size not match") + } + } + field, has := mNameField[column.Name()] + if !has { + return nil, fmt.Errorf("field %s does not exist in collection %s", column.Name(), collName) + } + if column.Type() != field.DataType { + return nil, fmt.Errorf("param column %s has type %v but collection field definition is %v", column.Name(), column.FieldData(), field.DataType) + } + if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector { + dim := 0 + switch column := column.(type) { + case *entity.ColumnFloatVector: + dim = column.Dim() + case *entity.ColumnBinaryVector: + dim = column.Dim() + } + if fmt.Sprintf("%d", dim) != field.TypeParams[entity.TypeParamDim] { + return nil, fmt.Errorf("params column %s vector dim %d not match collection definition, which has dim of %s", field.Name, dim, field.TypeParams[entity.TypeParamDim]) + } + } + } + for _, field := range coll.Schema.Fields { + _, has := mNameColumn[field.Name] + if !has && !field.AutoID { + return nil, fmt.Errorf("field %s not passed", field.Name) + } + } + + // 2. do upsert request + req := &server.UpsertRequest{ + DbName: "", // reserved + CollectionName: collName, + PartitionName: partitionName, + } + if req.PartitionName == "" { + req.PartitionName = "_default" // use default partition + } + req.NumRows = uint32(rowSize) + for _, column := range columns { + req.FieldsData = append(req.FieldsData, column.FieldData()) + } + resp, err := c.Service.Upsert(ctx, req) + if err != nil { + return nil, err + } + if err := handleRespStatus(resp.GetStatus()); err != nil { + return nil, err + } + MetaCache.setSessionTs(collName, resp.Timestamp) + // 3. parse id column + return entity.IDColumns(resp.GetIDs(), 0, -1) +} + // Search with bool expression func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []string, expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) { diff --git a/client/client_grpc_data_test.go b/client/client_grpc_data_test.go index f110c2603..98f17c23a 100644 --- a/client/client_grpc_data_test.go +++ b/client/client_grpc_data_test.go @@ -177,6 +177,94 @@ func TestGrpcClientFlush(t *testing.T) { }) } +func TestGrpcClientUpsert(t *testing.T) { + ctx := context.Background() + + c := testClient(ctx, t) + + t.Run("test create failure due to meta", func(t *testing.T) { + mock.DelInjection(MHasCollection) // collection does not exist + ids, err := c.Upsert(ctx, testCollectionName, "") + assert.Nil(t, ids) + assert.NotNil(t, err) + + // partition not exists + mock.SetInjection(MHasCollection, hasCollectionDefault) + ids, err = c.Upsert(ctx, testCollectionName, "_part_not_exists") + assert.Nil(t, ids) + assert.NotNil(t, err) + + // field not in collection + mock.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema())) + vectors := generateFloatVector(10, testVectorDim) + ids, err = c.Upsert(ctx, testCollectionName, "", + entity.NewColumnInt64("extra_field", []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), entity.NewColumnFloatVector(testVectorField, testVectorDim, vectors)) + assert.Nil(t, ids) + assert.NotNil(t, err) + + // field type not match + mock.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema())) + ids, err = c.Upsert(ctx, testCollectionName, "", + entity.NewColumnInt32("int64", []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), entity.NewColumnFloatVector(testVectorField, testVectorDim, vectors)) + assert.Nil(t, ids) + assert.NotNil(t, err) + + // missing field + ids, err = c.Upsert(ctx, testCollectionName, "") + assert.Nil(t, ids) + assert.NotNil(t, err) + + // column len not match + ids, err = c.Upsert(ctx, testCollectionName, "", entity.NewColumnInt64("int64", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}), + entity.NewColumnFloatVector(testVectorField, testVectorDim, vectors)) + assert.Nil(t, ids) + assert.NotNil(t, err) + + // column len not match + ids, err = c.Upsert(ctx, testCollectionName, "", entity.NewColumnInt64("int64", []int64{1, 2, 3}), + entity.NewColumnFloatVector(testVectorField, testVectorDim, vectors)) + assert.Nil(t, ids) + assert.NotNil(t, err) + + // dim not match + ids, err = c.Upsert(ctx, testCollectionName, "", entity.NewColumnInt64("int64", []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + entity.NewColumnFloatVector(testVectorField, testVectorDim*2, vectors)) + assert.Nil(t, ids) + assert.NotNil(t, err) + }) + + mock.SetInjection(MHasCollection, hasCollectionDefault) + mock.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema())) + + vector := generateFloatVector(4096, testVectorDim) + mock.SetInjection(MUpsert, func(_ context.Context, raw proto.Message) (proto.Message, error) { + req, ok := raw.(*server.UpsertRequest) + resp := &server.MutationResult{} + if !ok { + s, err := BadRequestStatus() + resp.Status = s + return resp, err + } + assert.EqualValues(t, 4096, req.GetNumRows()) + assert.Equal(t, testCollectionName, req.GetCollectionName()) + intIds := &schema.IDs_IntId{ + IntId: &schema.LongArray{ + Data: make([]int64, 4096), + }, + } + resp.IDs = &schema.IDs{ + IdField: intIds, + } + s, err := SuccessStatus() + resp.Status = s + return resp, err + }) + _, err := c.Upsert(ctx, testCollectionName, "", // use default partition + entity.NewColumnFloatVector(testVectorField, testVectorDim, vector)) + assert.Nil(t, err) + mock.DelInjection(MUpsert) +} + func TestGrpcDeleteByPks(t *testing.T) { ctx := context.Background() diff --git a/client/client_test.go b/client/client_test.go index 4db7a262f..b358d4b47 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -180,7 +180,7 @@ func TestGrpcClientNil(t *testing.T) { m.Name == "Search" || // type alias MetricType treated as string m.Name == "CalcDistance" || m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect - m.Name == "Insert" { // complex methods with ... + m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ... t.Skip("method", m.Name, "skipped") } ins := make([]reflect.Value, 0, mt.NumIn())