diff --git a/libs/dyn/merge/override.go b/libs/dyn/merge/override.go new file mode 100644 index 0000000000..97e8f10098 --- /dev/null +++ b/libs/dyn/merge/override.go @@ -0,0 +1,198 @@ +package merge + +import ( + "fmt" + + "github.com/databricks/cli/libs/dyn" +) + +// OverrideVisitor is visiting the changes during the override process +// and allows to control what changes are allowed, or update the effective +// value. +// +// For instance, it can disallow changes outside the specific path(s), or update +// the location of the effective value. +// +// 'VisitDelete' is called when a value is removed from mapping or sequence +// 'VisitInsert' is called when a new value is added to mapping or sequence +// 'VisitUpdate' is called when a leaf value is updated +type OverrideVisitor struct { + VisitDelete func(valuePath dyn.Path, left dyn.Value) error + VisitInsert func(valuePath dyn.Path, right dyn.Value) (dyn.Value, error) + VisitUpdate func(valuePath dyn.Path, left dyn.Value, right dyn.Value) (dyn.Value, error) +} + +// Override overrides value 'leftRoot' with 'rightRoot', keeping 'location' if values +// haven't changed. Preserving 'location' is important to preserve the original source of the value +// for error reporting. +func Override(leftRoot dyn.Value, rightRoot dyn.Value, visitor OverrideVisitor) (dyn.Value, error) { + return override(dyn.EmptyPath, leftRoot, rightRoot, visitor) +} + +func override(basePath dyn.Path, left dyn.Value, right dyn.Value, visitor OverrideVisitor) (dyn.Value, error) { + if left == dyn.NilValue && right == dyn.NilValue { + return dyn.NilValue, nil + } + + if left.Kind() != right.Kind() { + return visitor.VisitUpdate(basePath, left, right) + } + + // NB: we only call 'VisitUpdate' on leaf values, and for sequences and mappings + // we don't know if value was updated or not + + switch left.Kind() { + case dyn.KindMap: + merged, err := overrideMapping(basePath, left.MustMap(), right.MustMap(), visitor) + + if err != nil { + return dyn.InvalidValue, err + } + + return dyn.NewValue(merged, left.Location()), nil + + case dyn.KindSequence: + // some sequences are keyed, and we can detect which elements are added/removed/updated, + // but we don't have this information + merged, err := overrideSequence(basePath, left.MustSequence(), right.MustSequence(), visitor) + + if err != nil { + return dyn.InvalidValue, err + } + + return dyn.NewValue(merged, left.Location()), nil + + case dyn.KindString: + if left.MustString() == right.MustString() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + + case dyn.KindFloat: + // TODO consider comparison with epsilon if normalization doesn't help, where do we use floats? + + if left.MustFloat() == right.MustFloat() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + + case dyn.KindBool: + if left.MustBool() == right.MustBool() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + + case dyn.KindTime: + if left.MustTime() == right.MustTime() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + + case dyn.KindInt: + if left.MustInt() == right.MustInt() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + } + + return dyn.InvalidValue, fmt.Errorf("unexpected kind %s", left.Kind()) +} + +func overrideMapping(basePath dyn.Path, leftMapping dyn.Mapping, rightMapping dyn.Mapping, visitor OverrideVisitor) (dyn.Mapping, error) { + out := dyn.NewMapping() + + for _, leftPair := range leftMapping.Pairs() { + // detect if key was removed + if _, ok := rightMapping.GetPair(leftPair.Key); !ok { + path := basePath.Append(dyn.Key(leftPair.Key.MustString())) + + err := visitor.VisitDelete(path, leftPair.Value) + + if err != nil { + return dyn.NewMapping(), err + } + } + } + + // iterating only right mapping will remove keys not present anymore + // and insert new keys + + for _, rightPair := range rightMapping.Pairs() { + if leftPair, ok := leftMapping.GetPair(rightPair.Key); ok { + path := basePath.Append(dyn.Key(rightPair.Key.MustString())) + newValue, err := override(path, leftPair.Value, rightPair.Value, visitor) + + if err != nil { + return dyn.NewMapping(), err + } + + // key was there before, so keep its location + err = out.Set(leftPair.Key, newValue) + + if err != nil { + return dyn.NewMapping(), err + } + } else { + path := basePath.Append(dyn.Key(rightPair.Key.MustString())) + + newValue, err := visitor.VisitInsert(path, rightPair.Value) + + if err != nil { + return dyn.NewMapping(), err + } + + err = out.Set(rightPair.Key, newValue) + + if err != nil { + return dyn.NewMapping(), err + } + } + } + + return out, nil +} + +func overrideSequence(basePath dyn.Path, left []dyn.Value, right []dyn.Value, visitor OverrideVisitor) ([]dyn.Value, error) { + minLen := min(len(left), len(right)) + var values []dyn.Value + + for i := 0; i < minLen; i++ { + path := basePath.Append(dyn.Index(i)) + merged, err := override(path, left[i], right[i], visitor) + + if err != nil { + return nil, err + } + + values = append(values, merged) + } + + if len(right) > len(left) { + for i := minLen; i < len(right); i++ { + path := basePath.Append(dyn.Index(i)) + newValue, err := visitor.VisitInsert(path, right[i]) + + if err != nil { + return nil, err + } + + values = append(values, newValue) + } + } else if len(left) > len(right) { + for i := minLen; i < len(left); i++ { + path := basePath.Append(dyn.Index(i)) + err := visitor.VisitDelete(path, left[i]) + + if err != nil { + return nil, err + } + } + } + + return values, nil +} diff --git a/libs/dyn/merge/override_test.go b/libs/dyn/merge/override_test.go new file mode 100644 index 0000000000..dbf249d12d --- /dev/null +++ b/libs/dyn/merge/override_test.go @@ -0,0 +1,434 @@ +package merge + +import ( + "testing" + "time" + + "github.com/databricks/cli/libs/dyn" + assert "github.com/databricks/cli/libs/dyn/dynassert" +) + +type overrideTestCase struct { + name string + left dyn.Value + right dyn.Value + state visitorState + expected dyn.Value +} + +func TestOverride_Primitive(t *testing.T) { + leftLocation := dyn.Location{File: "left.yml", Line: 1, Column: 1} + rightLocation := dyn.Location{File: "right.yml", Line: 1, Column: 1} + + modifiedTestCases := []overrideTestCase{ + { + name: "string (updated)", + state: visitorState{updated: []string{"root"}}, + left: dyn.NewValue("a", leftLocation), + right: dyn.NewValue("b", rightLocation), + expected: dyn.NewValue("b", rightLocation), + }, + { + name: "string (not updated)", + state: visitorState{}, + left: dyn.NewValue("a", leftLocation), + right: dyn.NewValue("a", rightLocation), + expected: dyn.NewValue("a", leftLocation), + }, + { + name: "bool (updated)", + state: visitorState{updated: []string{"root"}}, + left: dyn.NewValue(true, leftLocation), + right: dyn.NewValue(false, rightLocation), + expected: dyn.NewValue(false, rightLocation), + }, + { + name: "bool (not updated)", + state: visitorState{}, + left: dyn.NewValue(true, leftLocation), + right: dyn.NewValue(true, rightLocation), + expected: dyn.NewValue(true, leftLocation), + }, + { + name: "int (updated)", + state: visitorState{updated: []string{"root"}}, + left: dyn.NewValue(1, leftLocation), + right: dyn.NewValue(2, rightLocation), + expected: dyn.NewValue(2, rightLocation), + }, + { + name: "int (not updated)", + state: visitorState{}, + left: dyn.NewValue(int32(1), leftLocation), + right: dyn.NewValue(int64(1), rightLocation), + expected: dyn.NewValue(int32(1), leftLocation), + }, + { + name: "float (updated)", + state: visitorState{updated: []string{"root"}}, + left: dyn.NewValue(1.0, leftLocation), + right: dyn.NewValue(2.0, rightLocation), + expected: dyn.NewValue(2.0, rightLocation), + }, + { + name: "float (not updated)", + state: visitorState{}, + left: dyn.NewValue(float32(1.0), leftLocation), + right: dyn.NewValue(float64(1.0), rightLocation), + expected: dyn.NewValue(float32(1.0), leftLocation), + }, + { + name: "time (updated)", + state: visitorState{updated: []string{"root"}}, + left: dyn.NewValue(time.UnixMilli(10000), leftLocation), + right: dyn.NewValue(time.UnixMilli(10001), rightLocation), + expected: dyn.NewValue(time.UnixMilli(10001), rightLocation), + }, + { + name: "time (not updated)", + state: visitorState{}, + left: dyn.NewValue(time.UnixMilli(10000), leftLocation), + right: dyn.NewValue(time.UnixMilli(10000), rightLocation), + expected: dyn.NewValue(time.UnixMilli(10000), leftLocation), + }, + { + name: "different types (updated)", + state: visitorState{updated: []string{"root"}}, + left: dyn.NewValue("a", leftLocation), + right: dyn.NewValue(42, rightLocation), + expected: dyn.NewValue(42, rightLocation), + }, + { + name: "map - remove 'a', update 'b'", + state: visitorState{ + removed: []string{"root.a"}, + updated: []string{"root.b"}, + }, + left: dyn.NewValue( + map[string]dyn.Value{ + "a": dyn.NewValue(42, leftLocation), + "b": dyn.NewValue(10, leftLocation), + }, + leftLocation, + ), + right: dyn.NewValue( + map[string]dyn.Value{ + "b": dyn.NewValue(20, rightLocation), + }, + rightLocation, + ), + expected: dyn.NewValue( + map[string]dyn.Value{ + "b": dyn.NewValue(20, rightLocation), + }, + leftLocation, + ), + }, + { + name: "map - add 'a'", + state: visitorState{ + added: []string{"root.a"}, + }, + left: dyn.NewValue( + map[string]dyn.Value{ + "b": dyn.NewValue(10, leftLocation), + }, + leftLocation, + ), + right: dyn.NewValue( + map[string]dyn.Value{ + "a": dyn.NewValue(42, rightLocation), + "b": dyn.NewValue(10, rightLocation), + }, + leftLocation, + ), + expected: dyn.NewValue( + map[string]dyn.Value{ + "a": dyn.NewValue(42, rightLocation), + // location hasn't changed because value hasn't changed + "b": dyn.NewValue(10, leftLocation), + }, + leftLocation, + ), + }, + { + name: "map - remove 'a'", + state: visitorState{ + removed: []string{"root.a"}, + }, + left: dyn.NewValue( + map[string]dyn.Value{ + "a": dyn.NewValue(42, leftLocation), + "b": dyn.NewValue(10, leftLocation), + }, + leftLocation, + ), + right: dyn.NewValue( + map[string]dyn.Value{ + "b": dyn.NewValue(10, rightLocation), + }, + leftLocation, + ), + expected: dyn.NewValue( + map[string]dyn.Value{ + // location hasn't changed because value hasn't changed + "b": dyn.NewValue(10, leftLocation), + }, + leftLocation, + ), + }, + { + name: "map - add 'jobs.job_1'", + state: visitorState{ + added: []string{"root.jobs.job_1"}, + }, + left: dyn.NewValue( + map[string]dyn.Value{ + "jobs": dyn.NewValue( + map[string]dyn.Value{ + "job_0": dyn.NewValue(42, leftLocation), + }, + leftLocation, + ), + }, + leftLocation, + ), + right: dyn.NewValue( + map[string]dyn.Value{ + "jobs": dyn.NewValue( + map[string]dyn.Value{ + "job_0": dyn.NewValue(42, rightLocation), + "job_1": dyn.NewValue(1337, rightLocation), + }, + rightLocation, + ), + }, + rightLocation, + ), + expected: dyn.NewValue( + map[string]dyn.Value{ + "jobs": dyn.NewValue( + map[string]dyn.Value{ + "job_0": dyn.NewValue(42, leftLocation), + "job_1": dyn.NewValue(1337, rightLocation), + }, + leftLocation, + ), + }, + leftLocation, + ), + }, + { + name: "map - remove nested key", + state: visitorState{removed: []string{"root.jobs.job_1"}}, + left: dyn.NewValue( + map[string]dyn.Value{ + "jobs": dyn.NewValue( + map[string]dyn.Value{ + "job_0": dyn.NewValue(42, leftLocation), + "job_1": dyn.NewValue(1337, rightLocation), + }, + leftLocation, + ), + }, + leftLocation, + ), + right: dyn.NewValue( + map[string]dyn.Value{ + "jobs": dyn.NewValue( + map[string]dyn.Value{ + "job_0": dyn.NewValue(42, rightLocation), + }, + rightLocation, + ), + }, + rightLocation, + ), + expected: dyn.NewValue( + map[string]dyn.Value{ + "jobs": dyn.NewValue( + map[string]dyn.Value{ + "job_0": dyn.NewValue(42, leftLocation), + }, + leftLocation, + ), + }, + leftLocation, + ), + }, + { + name: "sequence - add", + state: visitorState{added: []string{"root[1]"}}, + left: dyn.NewValue( + []dyn.Value{ + dyn.NewValue(42, leftLocation), + }, + leftLocation, + ), + right: dyn.NewValue( + []dyn.Value{ + dyn.NewValue(42, rightLocation), + dyn.NewValue(10, rightLocation), + }, + rightLocation, + ), + expected: dyn.NewValue( + []dyn.Value{ + dyn.NewValue(42, leftLocation), + dyn.NewValue(10, rightLocation), + }, + leftLocation, + ), + }, + { + name: "sequence - remove", + state: visitorState{removed: []string{"root[1]"}}, + left: dyn.NewValue( + []dyn.Value{ + dyn.NewValue(42, leftLocation), + dyn.NewValue(10, leftLocation), + }, + leftLocation, + ), + right: dyn.NewValue( + []dyn.Value{ + dyn.NewValue(42, rightLocation), + }, + rightLocation, + ), + expected: dyn.NewValue( + []dyn.Value{ + // location hasn't changed because value hasn't changed + dyn.NewValue(42, leftLocation), + }, + leftLocation, + ), + }, + { + name: "sequence (not updated)", + state: visitorState{}, + left: dyn.NewValue( + []dyn.Value{ + dyn.NewValue(42, leftLocation), + }, + leftLocation, + ), + right: dyn.NewValue( + []dyn.Value{ + dyn.NewValue(42, rightLocation), + }, + rightLocation, + ), + expected: dyn.NewValue( + []dyn.Value{ + dyn.NewValue(42, leftLocation), + }, + leftLocation, + ), + }, + { + name: "nil (not updated)", + state: visitorState{}, + left: dyn.NilValue, + right: dyn.NilValue, + expected: dyn.NilValue, + }, + { + name: "nil (updated)", + state: visitorState{updated: []string{"root"}}, + left: dyn.NilValue, + right: dyn.NewValue(42, rightLocation), + expected: dyn.NewValue(42, rightLocation), + }, + { + name: "change kind (updated)", + state: visitorState{updated: []string{"root"}}, + left: dyn.NewValue(42.0, leftLocation), + right: dyn.NewValue(42, rightLocation), + expected: dyn.NewValue(42, rightLocation), + }, + } + + for _, tc := range modifiedTestCases { + t.Run(tc.name, func(t *testing.T) { + s, visitor := createVisitor() + out, err := override(dyn.NewPath(dyn.Key("root")), tc.left, tc.right, visitor) + + assert.NoError(t, err) + assert.Equal(t, tc.state, *s) + assert.Equal(t, tc.expected, out) + }) + } +} + +func TestOverride_PreserveMappingKeys(t *testing.T) { + leftLocation := dyn.Location{File: "left.yml", Line: 1, Column: 1} + leftKeyLocation := dyn.Location{File: "left.yml", Line: 2, Column: 1} + leftValueLocation := dyn.Location{File: "left.yml", Line: 3, Column: 1} + + rightLocation := dyn.Location{File: "right.yml", Line: 1, Column: 1} + rightKeyLocation := dyn.Location{File: "right.yml", Line: 2, Column: 1} + rightValueLocation := dyn.Location{File: "right.yml", Line: 3, Column: 1} + + left := dyn.NewMapping() + left.Set(dyn.NewValue("a", leftKeyLocation), dyn.NewValue(42, leftValueLocation)) + + right := dyn.NewMapping() + right.Set(dyn.NewValue("a", rightKeyLocation), dyn.NewValue(7, rightValueLocation)) + + state, visitor := createVisitor() + + out, err := override( + dyn.EmptyPath, + dyn.NewValue(left, leftLocation), + dyn.NewValue(right, rightLocation), + visitor, + ) + + assert.NoError(t, err) + + if err != nil { + outPairs := out.MustMap().Pairs() + + assert.Equal(t, visitorState{updated: []string{"a"}}, state) + assert.Equal(t, 1, len(outPairs)) + + // mapping was first defined in left, so it should keep its location + assert.Equal(t, leftLocation, out.Location()) + + // if there is a validation error for key value, it should point + // to where it was initially defined + assert.Equal(t, leftKeyLocation, outPairs[0].Key.Location()) + + // the value should have updated location, because it has changed + assert.Equal(t, rightValueLocation, outPairs[0].Value.Location()) + } +} + +type visitorState struct { + added []string + removed []string + updated []string +} + +func createVisitor() (*visitorState, OverrideVisitor) { + s := visitorState{} + + return &s, OverrideVisitor{ + VisitUpdate: func(valuePath dyn.Path, left dyn.Value, right dyn.Value) (dyn.Value, error) { + s.updated = append(s.updated, valuePath.String()) + + return right, nil + }, + VisitDelete: func(valuePath dyn.Path, left dyn.Value) error { + s.removed = append(s.removed, valuePath.String()) + + return nil + }, + VisitInsert: func(valuePath dyn.Path, right dyn.Value) (dyn.Value, error) { + s.added = append(s.added, valuePath.String()) + + return right, nil + }, + } +}