Skip to content

Commit

Permalink
feat: optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
olexnzarov committed Sep 20, 2023
1 parent 58b4739 commit 4a12946
Showing 1 changed file with 47 additions and 26 deletions.
73 changes: 47 additions & 26 deletions protomask.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ type FieldMask interface {
IsValid(protoreflect.ProtoMessage) bool
}

type property struct {
message *protoreflect.Message
field *protoreflect.FieldDescriptor
incomplete bool
}

func (p *property) set(value protoreflect.Value) {
(*p.message).Set(*p.field, value)
}

func (p *property) clear() {
(*p.message).Clear(*p.field)
}

// Update updates the targetMessage with values from the updateMessage, updateMask specifies which fields need to be updated.
//
// FieldMask should contain field names like in .proto file. Nested paths are supported (e.g. "foo.bar.xyz").
Expand All @@ -30,30 +44,46 @@ func Update[T protoreflect.ProtoMessage](targetMessage T, updateMessage T, updat

targetMessageRef, updateMessageRef := targetMessage.ProtoReflect(), updateMessage.ProtoReflect()
for _, path := range updateMask.GetPaths() {
target, targetField, err := populateMessageProperty(targetMessageRef, path)
fieldPath, err := splitPath(path)
if err != nil {
return err
}
updateProperty, err := populateMessageProperty(updateMessageRef, fieldPath, true)
if err != nil {
return err
}
update, updateField, err := populateMessageProperty(updateMessageRef, path)
value := (*updateProperty.message).Get(*updateProperty.field)
isNilValue := isNil(value)

targetProperty, err := populateMessageProperty(targetMessageRef, fieldPath, !isNilValue)
if err != nil {
return err
}

value := update.Get(updateField)
// If the search for the end property was aborted due to parent being nil.
// This will only happen if we're trying to also set nil (aka "value" is nil).
if targetProperty.incomplete {
continue
}

if isNil(value) {
target.Clear(targetField)
if isNilValue {
targetProperty.clear()
} else {
target.Set(
targetField,
value,
)
targetProperty.set(value)
}
}

return nil
}

func splitPath(path string) ([]string, error) {
paths := strings.Split(path, ".")
if len(paths) == 0 {
return nil, ErrInvalidPath
}
return paths, nil
}

func getFieldByName(message protoreflect.Message, fieldName string) (protoreflect.FieldDescriptor, error) {
field := message.Descriptor().Fields().ByName(protoreflect.Name(fieldName))
if field == nil {
Expand All @@ -80,36 +110,27 @@ func isNil(value protoreflect.Value) bool {
}
}

func populateMessageProperty(message protoreflect.Message, path string) (protoreflect.Message, protoreflect.FieldDescriptor, error) {
fields := strings.Split(path, ".")
switch len(fields) {
case 0:
return nil, nil, ErrInvalidPath
case 1:
field, err := getFieldByName(message, fields[0])
if err != nil {
return nil, nil, err
}
return message, field, nil
}

func populateMessageProperty(message protoreflect.Message, fields []string, recursive bool) (*property, error) {
for i := 0; i < len(fields)-1; i++ {
// These nil checks are redundant if we use Google's fieldmaskpb package.
// But it's better to be safe than sorry and account for other implementations not implementing their "IsValid" function correctly.
messageField := message.Descriptor().Fields().ByName(protoreflect.Name(fields[i]))
if messageField == nil {
return nil, nil, fmt.Errorf("unknown field: '%s'", strings.Join(fields[:i+1], "."))
return nil, fmt.Errorf("unknown field: '%s'", strings.Join(fields[:i+1], "."))
}

nextMessage := toMessage(message.Get(messageField))
if nextMessage == nil {
return nil, nil, fmt.Errorf("unsupported nested type: '%s'", strings.Join(fields[:i+1], "."))
return nil, fmt.Errorf("unsupported nested type: '%s'", strings.Join(fields[:i+1], "."))
}

// We need to make sure the message value is not nil.
// For example, we have a path of "a.b.c". It would be possible to get a "c" of "b" even if the value of "b" is nil.
// Therefore, before we can set a value of "c", we need to initialize "b", or we'll get a nil pointer exception.
if !nextMessage.IsValid() {
if !recursive {
return &property{incomplete: true}, nil
}
value := message.NewField(messageField)
message.Set(messageField, value)
nextMessage = toMessage(value)
Expand All @@ -120,7 +141,7 @@ func populateMessageProperty(message protoreflect.Message, path string) (protore

field, err := getFieldByName(message, fields[len(fields)-1])
if err != nil {
return nil, nil, err
return nil, err
}
return message, field, nil
return &property{message: &message, field: &field}, nil
}

0 comments on commit 4a12946

Please sign in to comment.