diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..c4d1853ee --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,43 @@ + +What +---- + + +Checklist +------------------ +- [ ] Contains customer facing changes? Including API/behavior changes +- [ ] Did you add sufficient unit test and/or integration test coverage for this PR? + - If not, please explain why it is not required + +References +---------- +JIRA: + + +Test & Review +------------ + + +Open questions / Follow-ups +-------------------------- + + + diff --git a/schemaregistry/config.go b/schemaregistry/config.go index 3826bff0e..39352a23c 100644 --- a/schemaregistry/config.go +++ b/schemaregistry/config.go @@ -44,7 +44,7 @@ func NewConfig(url string) *Config { c.ConnectionTimeoutMs = 10000 c.RequestTimeoutMs = 10000 - c.MaxRetries = 2 + c.MaxRetries = 3 c.RetriesWaitMs = 1000 c.RetriesMaxWaitMs = 20000 diff --git a/schemaregistry/internal/client_config.go b/schemaregistry/internal/client_config.go index 6578947e9..7016d803e 100644 --- a/schemaregistry/internal/client_config.go +++ b/schemaregistry/internal/client_config.go @@ -22,7 +22,7 @@ import ( // ClientConfig is used to pass multiple configuration options to the Schema Registry client. type ClientConfig struct { - // SchemaRegistryURL determines the URL of Schema Registry. + // SchemaRegistryURL is a comma-space separated list of URLs for the Schema Registry. SchemaRegistryURL string // BasicAuthUserInfo specifies the user info in the form of {username}:{password}. diff --git a/schemaregistry/internal/rest_service.go b/schemaregistry/internal/rest_service.go index 1daca2170..87f98393c 100644 --- a/schemaregistry/internal/rest_service.go +++ b/schemaregistry/internal/rest_service.go @@ -112,7 +112,7 @@ func NewRequest(method string, endpoint string, body interface{}, arguments ...i // RestService represents a REST client type RestService struct { - url *url.URL + urls []*url.URL headers http.Header maxRetries int retriesWaitMs int @@ -124,21 +124,22 @@ type RestService struct { // NewRestService returns a new REST client for the Confluent Schema Registry func NewRestService(conf *ClientConfig) (*RestService, error) { urlConf := conf.SchemaRegistryURL - u, err := url.Parse(urlConf) - - if err != nil { - return nil, err + urlStrs := strings.Split(urlConf, ",") + urls := make([]*url.URL, len(urlStrs)) + for i, urlStr := range urlStrs { + u, err := url.Parse(strings.TrimSpace(urlStr)) + if err != nil { + return nil, err + } + urls[i] = u } - headers, err := NewAuthHeader(u, conf) + headers, err := NewAuthHeader(urls[0], conf) if err != nil { return nil, err } headers.Add("Content-Type", "application/vnd.schemaregistry.v1+json") - if err != nil { - return nil, err - } if conf.HTTPClient == nil { transport, err := configureTransport(conf) @@ -155,7 +156,7 @@ func NewRestService(conf *ClientConfig) (*RestService, error) { } return &RestService{ - url: u, + urls: urls, headers: headers, maxRetries: conf.MaxRetries, retriesWaitMs: conf.RetriesWaitMs, @@ -337,19 +338,51 @@ func NewAuthHeader(service *url.URL, conf *ClientConfig) (http.Header, error) { return header, err } -// HandleRequest sends a HTTP(S) request to the Schema Registry, placing results into the response object +// HandleRequest sends a request to the Schema Registry, iterating over the list of URLs func (rs *RestService) HandleRequest(request *API, response interface{}) error { - urlPath := path.Join(rs.url.Path, fmt.Sprintf(request.endpoint, request.arguments...)) - endpoint, err := rs.url.Parse(urlPath) - if err != nil { + var resp *http.Response + var err error + for i, u := range rs.urls { + resp, err = rs.HandleHTTPRequest(u, request) + if err != nil { + if i == len(rs.urls)-1 { + return err + } + continue + } + if isSuccess(resp.StatusCode) || !isRetriable(resp.StatusCode) || i >= rs.maxRetries { + break + } + } + defer resp.Body.Close() + if isSuccess(resp.StatusCode) { + if err = json.NewDecoder(resp.Body).Decode(response); err != nil { + return err + } + return nil + } + + var failure rest.Error + if err = json.NewDecoder(resp.Body).Decode(&failure); err != nil { return err } + return &failure +} + +// HandleHTTPRequest sends a HTTP(S) request to the Schema Registry, placing results into the response object +func (rs *RestService) HandleHTTPRequest(url *url.URL, request *API) (*http.Response, error) { + urlPath := path.Join(url.Path, fmt.Sprintf(request.endpoint, request.arguments...)) + endpoint, err := url.Parse(urlPath) + if err != nil { + return nil, err + } + var readCloser io.ReadCloser if request.body != nil { outbuf, err := json.Marshal(request.body) if err != nil { - return err + return nil, err } readCloser = ioutil.NopCloser(bytes.NewBuffer(outbuf)) } @@ -365,30 +398,16 @@ func (rs *RestService) HandleRequest(request *API, response interface{}) error { for i := 0; i < rs.maxRetries+1; i++ { resp, err = rs.Do(req) if err != nil { - return err + return nil, err } if isSuccess(resp.StatusCode) || !isRetriable(resp.StatusCode) || i >= rs.maxRetries { - break + return resp, nil } time.Sleep(rs.fullJitter(i)) } - - defer resp.Body.Close() - if resp.StatusCode == 200 { - if err = json.NewDecoder(resp.Body).Decode(response); err != nil { - return err - } - return nil - } - - var failure rest.Error - if err := json.NewDecoder(resp.Body).Decode(&failure); err != nil { - return err - } - - return &failure + return nil, fmt.Errorf("failed to send request after %d retries", rs.maxRetries) } func (rs *RestService) fullJitter(retriesAttempted int) time.Duration { diff --git a/schemaregistry/rules/cel/cel_executor.go b/schemaregistry/rules/cel/cel_executor.go index e28c1dd2e..ed141c446 100644 --- a/schemaregistry/rules/cel/cel_executor.go +++ b/schemaregistry/rules/cel/cel_executor.go @@ -35,24 +35,18 @@ func init() { // Register registers the CEL rule executor func Register() { + serde.RegisterRuleExecutor(NewExecutor()) + serde.RegisterRuleExecutor(NewFieldExecutor()) +} + +// NewExecutor creates a new CEL rule executor +func NewExecutor() serde.RuleExecutor { env, _ := DefaultEnv() - e := &Executor{ + return &Executor{ env: env, cache: map[string]cel.Program{}, } - serde.RegisterRuleExecutor(e) - - a := &serde.AbstractFieldRuleExecutor{} - f := &FieldExecutor{ - AbstractFieldRuleExecutor: *a, - executor: Executor{ - env: env, - cache: map[string]cel.Program{}, - }, - } - f.FieldRuleExecutor = f - serde.RegisterRuleExecutor(f) } // Executor is a CEL rule executor @@ -199,19 +193,22 @@ func typeToCELType(arg interface{}) *cel.Type { func (c *Executor) newProgram(expr string, msg interface{}, decls []cel.EnvOption) (cel.Program, error) { typ := reflect.TypeOf(msg) - if typ.Kind() == reflect.Pointer { + if typ.Kind() == reflect.Pointer || typ.Kind() == reflect.Interface { typ = typ.Elem() } protoType, ok := msg.(proto.Message) var declType cel.EnvOption if ok { declType = cel.Types(protoType) - } else { + } else if typ.Kind() == reflect.Struct { declType = ext.NativeTypes(typ) } - envOptions := make([]cel.EnvOption, len(decls)) - copy(envOptions, decls) - envOptions = append(envOptions, declType) + envOptions := decls + if declType != nil { + envOptions = make([]cel.EnvOption, len(decls)) + copy(envOptions, decls) + envOptions = append(envOptions, declType) + } env, err := c.env.Extend(envOptions...) if err != nil { return nil, err diff --git a/schemaregistry/rules/cel/cel_field_executor.go b/schemaregistry/rules/cel/cel_field_executor.go index a2f1cf8b4..24a1112fa 100644 --- a/schemaregistry/rules/cel/cel_field_executor.go +++ b/schemaregistry/rules/cel/cel_field_executor.go @@ -19,8 +19,25 @@ package cel import ( "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry" "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/serde" + "github.com/google/cel-go/cel" ) +// NewFieldExecutor creates a new CEL field rule executor +func NewFieldExecutor() serde.RuleExecutor { + env, _ := DefaultEnv() + + a := &serde.AbstractFieldRuleExecutor{} + f := &FieldExecutor{ + AbstractFieldRuleExecutor: *a, + executor: Executor{ + env: env, + cache: map[string]cel.Program{}, + }, + } + f.FieldRuleExecutor = f + return f +} + // FieldExecutor is a CEL field rule executor type FieldExecutor struct { serde.AbstractFieldRuleExecutor diff --git a/schemaregistry/rules/encryption/awskms/aws_client.go b/schemaregistry/rules/encryption/awskms/aws_client.go index ad3ac2f5c..4cec7aed1 100644 --- a/schemaregistry/rules/encryption/awskms/aws_client.go +++ b/schemaregistry/rules/encryption/awskms/aws_client.go @@ -51,7 +51,7 @@ func NewClient(keyURI string, creds aws.CredentialsProvider) (registry.KMSClient // Supported returns true if keyURI starts with the URI prefix provided when // creating the client. func (c *awsClient) Supported(keyURI string) bool { - return strings.HasPrefix(keyURI, prefix) + return strings.HasPrefix(keyURI, c.keyURI) } // GetAEAD returns an implementation of the AEAD interface which performs @@ -64,7 +64,7 @@ func (c *awsClient) Supported(keyURI string) bool { // See https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html func (c *awsClient) GetAEAD(keyURI string) (tink.AEAD, error) { if !c.Supported(keyURI) { - return nil, fmt.Errorf("keyURI must start with prefix %s, but got %s", prefix, keyURI) + return nil, fmt.Errorf("keyURI must start with prefix %s, but got %s", c.keyURI, keyURI) } uri := strings.TrimPrefix(keyURI, prefix) return NewAEAD(uri, c.creds) diff --git a/schemaregistry/rules/encryption/awskms/aws_driver.go b/schemaregistry/rules/encryption/awskms/aws_driver.go index 2b28bdfed..c70a0454b 100644 --- a/schemaregistry/rules/encryption/awskms/aws_driver.go +++ b/schemaregistry/rules/encryption/awskms/aws_driver.go @@ -17,16 +17,26 @@ package awskms import ( + "context" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/rules/encryption" "github.com/tink-crypto/tink-go/v2/core/registry" + "os" + "strings" ) const ( prefix = "aws-kms://" accessKeyID = "access.key.id" secretAccessKey = "secret.access.key" + profile = "profile" + roleArn = "role.arn" + roleSessionName = "role.session.name" + roleExternalID = "role.external.id" ) func init() { @@ -51,13 +61,55 @@ func (l *awsDriver) NewKMSClient(conf map[string]string, keyURL *string) (regist if keyURL != nil { uriPrefix = *keyURL } + arn := conf[roleArn] + if arn == "" { + arn = os.Getenv("AWS_ROLE_ARN") + } + sessionName := conf[roleSessionName] + if sessionName == "" { + sessionName = os.Getenv("AWS_ROLE_SESSION_NAME") + } + externalID := conf[roleExternalID] + if externalID == "" { + externalID = os.Getenv("AWS_ROLE_EXTERNAL_ID") + } var creds aws.CredentialsProvider - key, ok := conf[accessKeyID] - if ok { - secret, ok := conf[secretAccessKey] - if ok { - creds = credentials.NewStaticCredentialsProvider(key, secret, "") + key := conf[accessKeyID] + secret := conf[secretAccessKey] + sourceProfile := conf[profile] + if key != "" && secret != "" { + creds = credentials.NewStaticCredentialsProvider(key, secret, "") + } else if sourceProfile != "" { + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithSharedConfigProfile(sourceProfile), + ) + if err != nil { + return nil, err + } + creds = cfg.Credentials + } + if arn != "" { + region, err := getRegion(strings.TrimPrefix(uriPrefix, prefix)) + if err != nil { + return nil, err + } + stsSvc := sts.New(sts.Options{ + Credentials: creds, + Region: region, + }) + if sessionName == "" { + sessionName = "confluent-encrypt" } + var extID *string + if externalID != "" { + extID = &externalID + } + creds = stscreds.NewAssumeRoleProvider(stsSvc, arn, func(o *stscreds.AssumeRoleOptions) { + o.RoleSessionName = sessionName + o.ExternalID = extID + }) + creds = aws.NewCredentialsCache(creds) } + return NewClient(uriPrefix, creds) } diff --git a/schemaregistry/rules/encryption/azurekms/azure_client.go b/schemaregistry/rules/encryption/azurekms/azure_client.go index 026bf8eba..5fc1a160e 100644 --- a/schemaregistry/rules/encryption/azurekms/azure_client.go +++ b/schemaregistry/rules/encryption/azurekms/azure_client.go @@ -51,14 +51,14 @@ func NewClient(keyURI string, creds azcore.TokenCredential, algorithm azkeys.Enc // Supported true if this client does support keyURI func (c *azureClient) Supported(keyURI string) bool { - return strings.HasPrefix(keyURI, prefix) + return strings.HasPrefix(keyURI, c.keyURI) } // GetAEAD gets an AEAD backend by keyURI. // keyURI must have the following format: 'azure-kms://https://{vaultURL}/keys/{keyName}/{keyVersion}" func (c *azureClient) GetAEAD(keyURI string) (tink.AEAD, error) { if !c.Supported(keyURI) { - return nil, fmt.Errorf("keyURI must start with prefix %s, but got %s", prefix, keyURI) + return nil, fmt.Errorf("keyURI must start with prefix %s, but got %s", c.keyURI, keyURI) } uri := strings.TrimPrefix(keyURI, prefix) return NewAEAD(uri, c.creds, c.algorithm) diff --git a/schemaregistry/rules/encryption/encrypt_executor.go b/schemaregistry/rules/encryption/encrypt_executor.go index b143e4c76..83740f668 100644 --- a/schemaregistry/rules/encryption/encrypt_executor.go +++ b/schemaregistry/rules/encryption/encrypt_executor.go @@ -42,16 +42,27 @@ func init() { // Register registers the encryption rule executor func Register() { - c := clock{} - RegisterWithClock(&c) + serde.RegisterRuleExecutor(NewExecutor()) } // RegisterWithClock registers the encryption rule executor with a given clock func RegisterWithClock(c Clock) *FieldEncryptionExecutor { + f := NewExecutorWithClock(c) + serde.RegisterRuleExecutor(f) + return f +} + +// NewExecutor creates a new encryption rule executor +func NewExecutor() serde.RuleExecutor { + c := clock{} + return NewExecutorWithClock(&c) +} + +// NewExecutorWithClock creates a new encryption rule executor with a given clock +func NewExecutorWithClock(c Clock) *FieldEncryptionExecutor { a := &serde.AbstractFieldRuleExecutor{} f := &FieldEncryptionExecutor{*a, nil, nil, c} f.FieldRuleExecutor = f - serde.RegisterRuleExecutor(f) return f } diff --git a/schemaregistry/rules/encryption/localkms/local_driver.go b/schemaregistry/rules/encryption/localkms/local_driver.go index 9d1d90e04..949a0d63a 100644 --- a/schemaregistry/rules/encryption/localkms/local_driver.go +++ b/schemaregistry/rules/encryption/localkms/local_driver.go @@ -24,6 +24,7 @@ import ( "github.com/tink-crypto/tink-go/v2/core/registry" "github.com/tink-crypto/tink-go/v2/subtle" "google.golang.org/protobuf/proto" + "os" "strings" agpb "github.com/tink-crypto/tink-go/v2/proto/aes_gcm_go_proto" @@ -57,8 +58,11 @@ func (l *localDriver) NewKMSClient(config map[string]string, keyURL *string) (re if keyURL != nil { uriPrefix = *keyURL } - secretKey, ok := config[secret] - if !ok { + secretKey := config[secret] + if secretKey == "" { + secretKey = os.Getenv("LOCAL_SECRET") + } + if secretKey == "" { return nil, errors.New("cannot load secret") } return NewLocalClient(uriPrefix, secretKey) diff --git a/schemaregistry/rules/jsonata/jsonata_executor.go b/schemaregistry/rules/jsonata/jsonata_executor.go index fa7e6448f..59e172c67 100644 --- a/schemaregistry/rules/jsonata/jsonata_executor.go +++ b/schemaregistry/rules/jsonata/jsonata_executor.go @@ -29,10 +29,14 @@ func init() { // Register registers the JSONata rule executor func Register() { - e := &Executor{ + serde.RegisterRuleExecutor(NewExecutor()) +} + +// NewExecutor creates a new JSONata rule executor +func NewExecutor() serde.RuleExecutor { + return &Executor{ cache: map[string]*jsonata.Expr{}, } - serde.RegisterRuleExecutor(e) } // Executor is a JSONata rule executor diff --git a/schemaregistry/serde/avro/avro_generic.go b/schemaregistry/serde/avro/avro_generic.go index 047c95154..8199fb50a 100644 --- a/schemaregistry/serde/avro/avro_generic.go +++ b/schemaregistry/serde/avro/avro_generic.go @@ -56,7 +56,7 @@ func (s *GenericSerializer) Serialize(topic string, msg interface{}) ([]byte, er return nil, nil } val := reflect.ValueOf(msg) - if val.Kind() == reflect.Ptr { + if val.Kind() == reflect.Pointer { // avro.TypeOf expects an interface containing a non-pointer msg = val.Elem().Interface() } diff --git a/schemaregistry/serde/avrov2/avro.go b/schemaregistry/serde/avrov2/avro.go index f58b72eca..0ff8d4574 100644 --- a/schemaregistry/serde/avrov2/avro.go +++ b/schemaregistry/serde/avrov2/avro.go @@ -44,6 +44,8 @@ type Deserializer struct { // Serde represents an Avro serde type Serde struct { + // we don't have a way to pass a resolver to the api, so we track both separately + api avro.API resolver *avro.TypeResolver schemaToTypeCache cache.Cache schemaToTypeCacheLock sync.RWMutex @@ -59,6 +61,7 @@ func NewSerializer(client schemaregistry.Client, serdeType serde.Type, conf *Ser return nil, err } ps := &Serde{ + api: avro.Config{}.Freeze(), resolver: avro.NewTypeResolver(), schemaToTypeCache: schemaToTypeCache, } @@ -122,7 +125,7 @@ func (s *Serializer) Serialize(topic string, msg interface{}) ([]byte, error) { } // Convert pointer to non-pointer msg = reflect.ValueOf(msg).Elem().Interface() - msgBytes, err := avro.Marshal(avroSchema, msg) + msgBytes, err := s.api.Marshal(avroSchema, msg) if err != nil { return nil, err } @@ -140,6 +143,7 @@ func NewDeserializer(client schemaregistry.Client, serdeType serde.Type, conf *D return nil, err } ps := &Serde{ + api: avro.Config{}.Freeze(), resolver: avro.NewTypeResolver(), schemaToTypeCache: schemaToTypeCache, } @@ -201,7 +205,7 @@ func (s *Deserializer) deserialize(topic string, payload []byte, result interfac } var msg interface{} if len(migrations) > 0 { - err = avro.Unmarshal(writer, payload[5:], &msg) + err = s.api.Unmarshal(writer, payload[5:], &msg) if err != nil { return nil, err } @@ -215,7 +219,7 @@ func (s *Deserializer) deserialize(topic string, payload []byte, result interfac return nil, err } var bytes []byte - bytes, err = avro.Marshal(reader, msg) + bytes, err = s.api.Marshal(reader, msg) if err != nil { return nil, err } @@ -227,7 +231,7 @@ func (s *Deserializer) deserialize(topic string, payload []byte, result interfac } else { msg = result } - err = avro.Unmarshal(reader, bytes, msg) + err = s.api.Unmarshal(reader, bytes, msg) if err != nil { return nil, err } @@ -254,12 +258,12 @@ func (s *Deserializer) deserialize(topic string, payload []byte, result interfac return nil, err } } - err = avro.Unmarshal(reader, payload[5:], msg) + err = s.api.Unmarshal(reader, payload[5:], msg) if err != nil { return nil, err } } else { - err = avro.Unmarshal(writer, payload[5:], msg) + err = s.api.Unmarshal(writer, payload[5:], msg) if err != nil { return nil, err } @@ -280,6 +284,7 @@ func (s *Deserializer) deserialize(topic string, payload []byte, result interfac // RegisterType registers a type with the Avro Serde func (s *Serde) RegisterType(name string, msgType interface{}) { + s.api.Register(name, msgType) s.resolver.Register(name, msgType) } diff --git a/schemaregistry/serde/avrov2/avro_test.go b/schemaregistry/serde/avrov2/avro_test.go index b1bd72b19..ab5b1363f 100644 --- a/schemaregistry/serde/avrov2/avro_test.go +++ b/schemaregistry/serde/avrov2/avro_test.go @@ -18,6 +18,7 @@ package avrov2 import ( "errors" + "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/rules/cel" "reflect" "testing" "time" @@ -504,6 +505,107 @@ func TestAvroSerdeWithReferences(t *testing.T) { serde.MaybeFail("deserialization", err, serde.Expect(msg, &obj)) } +func TestAvroSerdeUnionWithReferences(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + _ = ser.RegisterTypeFromMessageFactory("DemoSchema", testMessageFactory) + _ = ser.RegisterTypeFromMessageFactory("ComplexSchema", testMessageFactory) + + info := schemaregistry.SchemaInfo{ + Schema: string(demoSchema), + SchemaType: "AVRO", + } + + id, err := client.Register("demo-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + info = schemaregistry.SchemaInfo{ + Schema: string(complexSchema), + SchemaType: "AVRO", + } + + id, err = client.Register("complex-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + info = schemaregistry.SchemaInfo{ + Schema: `[ "DemoSchema", "ComplexSchema" ]`, + SchemaType: "AVRO", + References: []schemaregistry.Reference{ + { + Name: "DemoSchema", + Subject: "demo-value", + Version: 1, + }, + { + Name: "Complexchema", + Subject: "complex-value", + Version: 1, + }, + }, + } + + id, err = client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := DemoSchema{} + obj.IntField = 123 + obj.DoubleField = 45.67 + obj.StringField = "hi" + obj.BoolField = true + obj.BytesField = []byte{1, 2} + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deser, err := NewDeserializer(client, serde.ValueSerde, NewDeserializerConfig()) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + deser.MessageFactory = testMessageFactory + _ = deser.RegisterTypeFromMessageFactory("DemoSchema", testMessageFactory) + _ = deser.RegisterTypeFromMessageFactory("ComplexSchema", testMessageFactory) + + oldmap := map[string]interface{}{ + "DemoSchema": map[string]interface{}{ + "IntField": 123, + "DoubleField": 45.67, + "StringField": "hi", + "BoolField": true, + "BytesField": []byte{1, 2}, + }, + } + + // deserialize into map + var newmap map[string]interface{} + err = deser.DeserializeInto("topic1", bytes, &newmap) + serde.MaybeFail("deserialization into", err, serde.Expect(newmap, oldmap)) + + // deserialize into interface{} + var newany interface{} + err = deser.DeserializeInto("topic1", bytes, &newany) + var newobj = newany.(DemoSchema) + serde.MaybeFail("deserialization into", err, serde.Expect(newobj, obj)) +} + func TestAvroSchemaEvolution(t *testing.T) { serde.MaybeFail = serde.InitFailFunc(t) var err error @@ -795,6 +897,74 @@ func TestAvroSerdeWithCELConditionIgnoreFail(t *testing.T) { serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) } +func TestAvroSerdeWithCELFieldTransformDisable(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-cel", + Kind: "TRANSFORM", + Mode: "WRITE", + Type: "CEL_FIELD", + Expr: "name == 'StringField' ; value + '-suffix'", + } + ruleSet := schemaregistry.RuleSet{ + DomainRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: demoSchema, + SchemaType: "AVRO", + RuleSet: &ruleSet, + } + + registry := serde.NewRuleRegistry() + registry.RegisterExecutor(cel.NewFieldExecutor()) + registry.RegisterOverride(&serde.RuleOverride{ + Type: "CEL_FIELD", + OnSuccess: nil, + OnFailure: nil, + Disabled: &[]bool{true}[0], + }) + ser.RuleRegistry = ®istry + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := DemoSchema{} + obj.IntField = 123 + obj.DoubleField = 45.67 + obj.StringField = "hi" + obj.BoolField = true + obj.BytesField = []byte{1, 2} + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deserConfig := NewDeserializerConfig() + deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + deser.MessageFactory = testMessageFactory + + newobj, err := deser.Deserialize("topic1", bytes) + serde.MaybeFail("deserialization", err, serde.Expect(newobj.(*DemoSchema).StringField, "hi")) +} + func TestAvroSerdeWithCELFieldTransform(t *testing.T) { serde.MaybeFail = serde.InitFailFunc(t) var err error @@ -1707,7 +1877,7 @@ func TestAvroSerdeEncryptionWithPointerReferences(t *testing.T) { } ser, err := NewSerializer(client, serde.ValueSerde, serConfig) serde.MaybeFail("Serializer configuration", err) - ser.RegisterTypeFromMessageFactory("DemoSchema", testMessageFactory) + _ = ser.RegisterTypeFromMessageFactory("DemoSchema", testMessageFactory) info := schemaregistry.SchemaInfo{ Schema: string(demoSchema), @@ -1771,7 +1941,7 @@ func TestAvroSerdeEncryptionWithPointerReferences(t *testing.T) { serde.MaybeFail("Deserializer configuration", err) deser.Client = ser.Client deser.MessageFactory = testMessageFactory - deser.RegisterTypeFromMessageFactory("DemoSchema", testMessageFactory) + _ = deser.RegisterTypeFromMessageFactory("DemoSchema", testMessageFactory) newobj, err := deser.Deserialize("topic1", bytes) serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) diff --git a/schemaregistry/serde/avrov2/avro_util.go b/schemaregistry/serde/avrov2/avro_util.go index 5075b83a7..74498f50c 100644 --- a/schemaregistry/serde/avrov2/avro_util.go +++ b/schemaregistry/serde/avrov2/avro_util.go @@ -270,7 +270,7 @@ func resolveUnion(resolver *avro.TypeResolver, schema avro.Schema, msg *reflect. } func deref(val *reflect.Value) *reflect.Value { - if val.Kind() == reflect.Pointer { + if val.Kind() == reflect.Pointer || val.Kind() == reflect.Interface { v := val.Elem() return &v } diff --git a/schemaregistry/serde/jsonschema/json_schema_test.go b/schemaregistry/serde/jsonschema/json_schema_test.go index 35ccc2c51..487581e00 100644 --- a/schemaregistry/serde/jsonschema/json_schema_test.go +++ b/schemaregistry/serde/jsonschema/json_schema_test.go @@ -19,7 +19,6 @@ package jsonschema import ( "encoding/base64" "errors" - _ "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/rules/cel" _ "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/rules/encryption/awskms" _ "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/rules/encryption/azurekms" @@ -91,6 +90,30 @@ const ( } } } +` + demoSchemaNested = ` +{ + "type": "object", + "properties": { + "OtherField": { + "type": "object", + "properties": { + "IntField": { "type": "integer" }, + "DoubleField": { "type": "number" }, + "StringField": { + "type": "string", + "confluent:tags": [ "PII" ] + }, + "BoolField": { "type": "boolean" }, + "BytesField": { + "type": "string", + "contentEncoding": "base64", + "confluent:tags": [ "PII" ] + } + } + } + } +} ` complexSchema = ` { @@ -160,6 +183,38 @@ const ( "version": { "type": "integer" } } } +` + defSchema = ` +{ + "$schema" : "http://json-schema.org/draft-07/schema#", + "additionalProperties" : false, + "definitions" : { + "Address" : { + "additionalProperties" : false, + "properties" : { + "doornumber" : { + "type" : "integer" + }, + "doorpin" : { + "confluent:tags" : [ "PII" ], + "type" : "string" + } + }, + "type" : "object" + } + }, + "properties" : { + "address" : { + "$ref" : "#/definitions/Address" + }, + "name" : { + "confluent:tags" : [ "PII" ], + "type" : "string" + } + }, + "title" : "Sample Event", + "type" : "object" +} ` ) @@ -565,6 +620,204 @@ func TestJSONSchemaSerdeWithCELFieldTransform(t *testing.T) { serde.MaybeFail("deserialization", err, serde.Expect(&newobj, &obj2)) } +func TestJSONSchemaSerdeWithCELFieldTransformWithDef(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-cel", + Kind: "TRANSFORM", + Mode: "WRITE", + Type: "CEL_FIELD", + Tags: []string{"PII"}, + Expr: "value + '-suffix'", + } + ruleSet := schemaregistry.RuleSet{ + DomainRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: defSchema, + SchemaType: "JSON", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + addr := Address{} + addr.DoorNumber = 123 + addr.DoorPin = "1234" + obj := JSONPerson{} + obj.Name = "bob" + obj.Address = addr + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deserConfig := NewDeserializerConfig() + deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + + addr2 := Address{} + addr2.DoorNumber = 123 + addr2.DoorPin = "1234-suffix" + obj2 := JSONPerson{} + obj2.Name = "bob-suffix" + obj2.Address = addr2 + + var newobj JSONPerson + err = deser.DeserializeInto("topic1", bytes, &newobj) + serde.MaybeFail("deserialization", err, serde.Expect(&newobj, &obj2)) +} + +func TestJSONSchemaSerdeWithCELFieldTransformWithSimpleMap(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-cel", + Kind: "TRANSFORM", + Mode: "WRITE", + Type: "CEL_FIELD", + Expr: "name == 'StringField' ; value + '-suffix'", + } + ruleSet := schemaregistry.RuleSet{ + DomainRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: demoSchema, + SchemaType: "JSON", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := make(map[string]interface{}) + obj["IntField"] = 123 + obj["DoubleField"] = 45.67 + obj["StringField"] = "hi" + obj["BoolField"] = true + obj["BytesField"] = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 1}) + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deser, err := NewDeserializer(client, serde.ValueSerde, NewDeserializerConfig()) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + + obj2 := JSONDemoSchema{} + // JSON decoding produces floats + obj2.IntField = 123.0 + obj2.DoubleField = 45.67 + obj2.StringField = "hi-suffix" + obj2.BoolField = true + obj2.BytesField = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 1}) + + var newobj JSONDemoSchema + err = deser.DeserializeInto("topic1", bytes, &newobj) + serde.MaybeFail("deserialization", err, serde.Expect(&newobj, &obj2)) +} + +func TestJSONSchemaSerdeWithCELFieldTransformWithNestedMap(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-cel", + Kind: "TRANSFORM", + Mode: "WRITE", + Type: "CEL_FIELD", + Expr: "name == 'StringField' ; value + '-suffix'", + } + ruleSet := schemaregistry.RuleSet{ + DomainRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: demoSchemaNested, + SchemaType: "JSON", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + nested := make(map[string]interface{}) + nested["IntField"] = 123 + nested["DoubleField"] = 45.67 + nested["StringField"] = "hi" + nested["BoolField"] = true + nested["BytesField"] = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 1}) + obj := make(map[string]interface{}) + obj["OtherField"] = nested + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deser, err := NewDeserializer(client, serde.ValueSerde, NewDeserializerConfig()) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + + nested2 := JSONDemoSchema{} + // JSON decoding produces floats + nested2.IntField = 123.0 + nested2.DoubleField = 45.67 + nested2.StringField = "hi-suffix" + nested2.BoolField = true + nested2.BytesField = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 1}) + obj2 := JSONNestedTestRecord{nested2} + + var newobj JSONNestedTestRecord + err = deser.DeserializeInto("topic1", bytes, &newobj) + serde.MaybeFail("deserialization", err, serde.Expect(&newobj, &obj2)) +} + func TestJSONSchemaSerdeWithCELFieldTransformComplex(t *testing.T) { serde.MaybeFail = serde.InitFailFunc(t) var err error @@ -1421,3 +1674,15 @@ type NewerWidget struct { Version int `json:"version"` } + +type Address struct { + DoorNumber int `json:"doornumber"` + + DoorPin string `json:"doorpin"` +} + +type JSONPerson struct { + Name string `json:"name"` + + Address Address `json:"address"` +} diff --git a/schemaregistry/serde/jsonschema/json_schema_util.go b/schemaregistry/serde/jsonschema/json_schema_util.go index 2316f8eca..6400d5107 100644 --- a/schemaregistry/serde/jsonschema/json_schema_util.go +++ b/schemaregistry/serde/jsonschema/json_schema_util.go @@ -298,7 +298,7 @@ func validate(schema *jsonschema2.Schema, msg *reflect.Value) (bool, error) { } func deref(val *reflect.Value) *reflect.Value { - if val.Kind() == reflect.Pointer { + if val.Kind() == reflect.Pointer || val.Kind() == reflect.Interface { v := val.Elem() return &v } diff --git a/schemaregistry/serde/protobuf/protobuf.go b/schemaregistry/serde/protobuf/protobuf.go index c88008ad8..a021e61d3 100644 --- a/schemaregistry/serde/protobuf/protobuf.go +++ b/schemaregistry/serde/protobuf/protobuf.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "log" + "reflect" "strings" "sync" @@ -518,7 +519,13 @@ func (s *Deserializer) Deserialize(topic string, payload []byte) (interface{}, e // DeserializeInto implements deserialization of Protobuf data to the given object func (s *Deserializer) DeserializeInto(topic string, payload []byte, msg interface{}) error { - _, err := s.deserialize(topic, payload, msg) + result, err := s.deserialize(topic, payload, msg) + // Copy the result into the target since we may have created a clone during transformations + value := reflect.ValueOf(msg) + if value.Kind() == reflect.Pointer { + rv := value.Elem() + rv.Set(reflect.ValueOf(result).Elem()) + } return err } diff --git a/schemaregistry/serde/protobuf/protobuf_test.go b/schemaregistry/serde/protobuf/protobuf_test.go index bb39f990d..d09ae423d 100644 --- a/schemaregistry/serde/protobuf/protobuf_test.go +++ b/schemaregistry/serde/protobuf/protobuf_test.go @@ -137,6 +137,9 @@ func TestProtobufSerdeWithSimple(t *testing.T) { newobj, err = deser.Deserialize("topic1", bytes) serde.MaybeFail("deserialization", err, serde.Expect(newobj.(proto.Message).ProtoReflect(), obj.ProtoReflect())) + + err = deser.DeserializeInto("topic1", bytes, newobj) + serde.MaybeFail("deserialization", err, serde.Expect(newobj.(proto.Message).ProtoReflect(), obj.ProtoReflect())) } func TestProtobufSerdeWithSecondMessage(t *testing.T) { @@ -647,7 +650,10 @@ func TestProtobufSerdeEncryption(t *testing.T) { serde.MaybeFail("register message", err) newobj, err := deser.Deserialize("topic1", bytes) - serde.MaybeFail("deserialization", err, serde.Expect(newobj.(proto.Message).ProtoReflect(), obj.ProtoReflect())) + serde.MaybeFail("deserialization", err, serde.Expect(newobj.(*test.Author).Name, obj.Name)) + + err = deser.DeserializeInto("topic1", bytes, newobj) + serde.MaybeFail("deserialization", err, serde.Expect(newobj.(*test.Author).Name, obj.Name)) } func TestProtobufSerdeJSONataFullyCompatible(t *testing.T) { diff --git a/schemaregistry/serde/rule_registry.go b/schemaregistry/serde/rule_registry.go index 8d374da8b..adc16d2fe 100644 --- a/schemaregistry/serde/rule_registry.go +++ b/schemaregistry/serde/rule_registry.go @@ -21,18 +21,38 @@ import ( ) var ( - globalInstance = RuleRegistry{ - ruleExecutors: make(map[string]RuleExecutor), - ruleActions: make(map[string]RuleAction), - } + globalInstance = NewRuleRegistry() ) +// RuleOverride represents a rule override +type RuleOverride struct { + // Rule type + Type string + // Rule action on success + OnSuccess *string + // Rule action on failure + OnFailure *string + // Whether the rule is disabled + Disabled *bool +} + // RuleRegistry is used to store all registered rule executors and actions. type RuleRegistry struct { ruleExecutorsMu sync.RWMutex ruleExecutors map[string]RuleExecutor ruleActionsMu sync.RWMutex ruleActions map[string]RuleAction + ruleOverridesMu sync.RWMutex + ruleOverrides map[string]*RuleOverride +} + +// NewRuleRegistry creates a Rule Registry instance. +func NewRuleRegistry() RuleRegistry { + return RuleRegistry{ + ruleExecutors: make(map[string]RuleExecutor), + ruleActions: make(map[string]RuleAction), + ruleOverrides: make(map[string]*RuleOverride), + } } // RegisterExecutor is used to register a new rule executor. @@ -85,8 +105,38 @@ func (r *RuleRegistry) GetActions() []RuleAction { return result } +// RegisterOverride is used to register a new global rule override. +func (r *RuleRegistry) RegisterOverride(ruleOverride *RuleOverride) { + r.ruleOverridesMu.Lock() + defer r.ruleOverridesMu.Unlock() + r.ruleOverrides[ruleOverride.Type] = ruleOverride +} + +// GetOverride fetches a rule override by a given name. +func (r *RuleRegistry) GetOverride(name string) *RuleOverride { + r.ruleOverridesMu.RLock() + defer r.ruleOverridesMu.RUnlock() + return r.ruleOverrides[name] +} + +// GetOverrides fetches all rule overrides +func (r *RuleRegistry) GetOverrides() []*RuleOverride { + r.ruleOverridesMu.RLock() + defer r.ruleOverridesMu.RUnlock() + var result []*RuleOverride + for _, v := range r.ruleOverrides { + result = append(result, v) + } + return result +} + // Clear clears all registered rules func (r *RuleRegistry) Clear() { + r.ruleOverridesMu.Lock() + defer r.ruleOverridesMu.Unlock() + for k := range r.ruleOverrides { + delete(r.ruleOverrides, k) + } r.ruleActionsMu.Lock() defer r.ruleActionsMu.Unlock() for k, v := range r.ruleActions { @@ -115,3 +165,8 @@ func RegisterRuleExecutor(ruleExecutor RuleExecutor) { func RegisterRuleAction(ruleAction RuleAction) { globalInstance.RegisterAction(ruleAction) } + +// RegisterRuleOverride is used to register a new global rule override. +func RegisterRuleOverride(ruleOverride *RuleOverride) { + globalInstance.RegisterOverride(ruleOverride) +} diff --git a/schemaregistry/serde/serde.go b/schemaregistry/serde/serde.go index a2c8df12f..ddccefe72 100644 --- a/schemaregistry/serde/serde.go +++ b/schemaregistry/serde/serde.go @@ -593,7 +593,7 @@ func (s *Serde) ExecuteRules(subject string, topic string, ruleMode schemaregist } } for i, rule := range rules { - if rule.Disabled { + if s.isDisabled(rule) { continue } mode, ok := schemaregistry.ParseMode(rule.Mode) @@ -628,7 +628,7 @@ func (s *Serde) ExecuteRules(subject string, topic string, ruleMode schemaregist } ruleExecutor := s.RuleRegistry.GetExecutor(rule.Type) if ruleExecutor == nil { - err := s.runAction(ctx, ruleMode, rule, rule.OnFailure, msg, + err := s.runAction(ctx, ruleMode, rule, s.getOnFailure(rule), msg, fmt.Errorf("could not find rule executor of type %s", rule.Type), "ERROR") if err != nil { return nil, err @@ -638,7 +638,7 @@ func (s *Serde) ExecuteRules(subject string, topic string, ruleMode schemaregist var err error result, err := ruleExecutor.Transform(ctx, msg) if result == nil || err != nil { - err = s.runAction(ctx, ruleMode, rule, rule.OnFailure, msg, err, "ERROR") + err = s.runAction(ctx, ruleMode, rule, s.getOnFailure(rule), msg, err, "ERROR") if err != nil { return nil, err } @@ -647,7 +647,7 @@ func (s *Serde) ExecuteRules(subject string, topic string, ruleMode schemaregist case "CONDITION": condResult, ok2 := result.(bool) if ok2 && !condResult { - err = s.runAction(ctx, ruleMode, rule, rule.OnFailure, msg, err, "ERROR") + err = s.runAction(ctx, ruleMode, rule, s.getOnFailure(rule), msg, err, "ERROR") if err != nil { return nil, RuleConditionErr{ Rule: ctx.Rule, @@ -659,12 +659,36 @@ func (s *Serde) ExecuteRules(subject string, topic string, ruleMode schemaregist msg = result } // ignore error, since rule succeeded - _ = s.runAction(ctx, ruleMode, rule, rule.OnSuccess, msg, nil, "NONE") + _ = s.runAction(ctx, ruleMode, rule, s.getOnSuccess(rule), msg, nil, "NONE") } } return msg, nil } +func (s *Serde) getOnSuccess(rule schemaregistry.Rule) string { + override := s.RuleRegistry.GetOverride(rule.Type) + if override != nil && override.OnSuccess != nil { + return *override.OnSuccess + } + return rule.OnSuccess +} + +func (s *Serde) getOnFailure(rule schemaregistry.Rule) string { + override := s.RuleRegistry.GetOverride(rule.Type) + if override != nil && override.OnFailure != nil { + return *override.OnFailure + } + return rule.OnFailure +} + +func (s *Serde) isDisabled(rule schemaregistry.Rule) bool { + override := s.RuleRegistry.GetOverride(rule.Type) + if override != nil && override.Disabled != nil { + return *override.Disabled + } + return rule.Disabled +} + func reverseRules(rules []schemaregistry.Rule) []schemaregistry.Rule { newRules := make([]schemaregistry.Rule, len(rules)) copy(newRules, rules) diff --git a/service.yml b/service.yml index 80f95640b..545183fe5 100644 --- a/service.yml +++ b/service.yml @@ -8,3 +8,5 @@ github: repo_name: confluentinc/confluent-kafka-go semaphore: enable: true +sonarqube: + enable: true