diff --git a/bundle/config/mutator/python/override.go b/bundle/config/mutator/python/override.go deleted file mode 100644 index a6184c5e14..0000000000 --- a/bundle/config/mutator/python/override.go +++ /dev/null @@ -1,313 +0,0 @@ -package python - -import ( - "fmt" - - "github.com/databricks/cli/libs/dyn" - "github.com/databricks/cli/libs/dyn/merge" -) - -// overrideMode restricts which changes are allowed during the override process -type overrideMode string - -const OverrideModeAppend overrideMode = "append" -const OverrideModeAppendOrModify overrideMode = "append-or-update" - -// isModified keeps track of whether a value has been modified during the override process -type isModified struct { - value bool -} - -var Modified = isModified{value: true} -var NotModified = isModified{value: false} - -type overrideFunc func(dyn.Value, dyn.Value) (dyn.Value, isModified, error) - -// OverrideByPath overrides bundle config 'leftRoot' with 'rightRoot' values in 'path', keeping -// 'location' if values haven't changed. This allows to keep track of where the values -// were last modified for error reporting. -// -// 'path' must point to a mapping or not exist, otherwise, error is returned -// -// 'mode' allows to restrict what kinds of changes are allowed in 'path': -// - append: only new values can be added to mapping -// - append-or-modify: new values can be added, or existing values can be modified -func OverrideByPath(leftRoot dyn.Value, rightRoot dyn.Value, path dyn.Path, mode overrideMode) (dyn.Value, error) { - _, leftErr := dyn.GetByPath(leftRoot, path) - _, rightErr := dyn.GetByPath(rightRoot, path) - - if leftErr != nil && rightErr != nil { - return leftRoot, nil - } - - leftRoot, err := insertMappingIfAbsent(leftRoot, path) - - if err != nil { - return dyn.InvalidValue, err - } - - rightRoot, err = insertMappingIfAbsent(rightRoot, path) - - if err != nil { - return dyn.InvalidValue, err - } - - return dyn.Map(leftRoot, path.String(), func(p dyn.Path, left dyn.Value) (dyn.Value, error) { - leftMapping := left.MustMap() - rightMapping, err := dyn.GetByPath(rightRoot, p) - - if err != nil { - return dyn.InvalidValue, err - } - - out, _, err := overrideMapping( - leftMapping, - rightMapping.MustMap(), - func(left dyn.Value, right dyn.Value) (dyn.Value, isModified, error) { - merged, modified, err := override(left, right) - - if err != nil { - return dyn.InvalidValue, NotModified, err - } - - if modified == Modified { - if mode == OverrideModeAppend && left != dyn.NilValue { - return dyn.InvalidValue, NotModified, unexpectedChangeError(left.Location()) - } - - if mode == OverrideModeAppendOrModify && right == dyn.NilValue { - return dyn.InvalidValue, NotModified, unexpectedChangeError(left.Location()) - } - } - - return merged, modified, nil - }) - - if err != nil { - return dyn.InvalidValue, err - } - - return dyn.NewValue(out, left.Location()), nil - }) -} - -func overrideMapping(leftMapping dyn.Mapping, rightMapping dyn.Mapping, overrideFunc overrideFunc) (dyn.Mapping, isModified, error) { - out := dyn.NewMapping() - var modified = false - - for _, leftPair := range leftMapping.Pairs() { - // detect if key was removed - if _, ok := rightMapping.GetPair(leftPair.Key); !ok { - out, valueModified, err := overrideFunc(leftPair.Value, dyn.NilValue) - - if err != nil { - return dyn.NewMapping(), NotModified, err - } - - if out != dyn.NilValue { - return dyn.NewMapping(), NotModified, fmt.Errorf("'overrideFunc' didn't return return Nil") - } - - modified = modified || valueModified.value - } - } - - // iterating only right mapping will remove keys not present anymore - // and insert new keys - - for _, rightPair := range rightMapping.Pairs() { - var leftValue dyn.Value - var key dyn.Value - - if leftPair, ok := leftMapping.GetPair(rightPair.Key); ok { - leftValue = leftPair.Value - - // key was there before, so keep its location - key = leftPair.Key - } else { - leftValue = dyn.NilValue - key = rightPair.Key - - // always modified because we did insert here - modified = true - } - - newValue, keyModified, err := overrideFunc(leftValue, rightPair.Value) - modified = modified || keyModified.value - - if err != nil { - return dyn.NewMapping(), NotModified, err - } - - if newValue == dyn.NilValue { - continue - } - - err = out.Set(key, newValue) - - if err != nil { - return dyn.NewMapping(), NotModified, err - } - } - - if modified { - return out, Modified, nil - } else { - return leftMapping, NotModified, nil - } -} - -func override(left dyn.Value, right dyn.Value) (dyn.Value, isModified, error) { - if left == dyn.NilValue { - if right == dyn.NilValue { - return dyn.NilValue, NotModified, nil - } else { - return right, Modified, nil - } - } - - if left.Kind() != right.Kind() { - return right, Modified, nil - } - - if left.Kind() == dyn.KindMap { - merged, modified, err := overrideMapping(left.MustMap(), right.MustMap(), override) - - if err != nil { - return dyn.InvalidValue, modified, err - } - - if modified == Modified { - return dyn.NewValue(merged, left.Location()), modified, nil - } else { - return left, NotModified, nil - } - } else if left.Kind() == dyn.KindSequence { - // some sequences are keyed, and we can detect which elements are added/removed/updated, - // but we don't have this information - - leftSeq := left.MustSequence() - rightSeq := right.MustSequence() - minLen := min(len(leftSeq), len(rightSeq)) - - var values []dyn.Value - var modified = len(leftSeq) != len(rightSeq) - - for i := 0; i < minLen; i++ { - merged, elementModified, err := override(leftSeq[i], rightSeq[i]) - - if err != nil { - return dyn.InvalidValue, NotModified, err - } - - values = append(values, merged) - modified = modified || elementModified.value - } - - for i := minLen; i < len(rightSeq); i++ { - values = append(values, rightSeq[i]) - } - - if modified { - return dyn.NewValue(values, left.Location()), Modified, nil - } else { - return left, NotModified, nil - } - } else { - // primitive values are compared directly - modified, err := overridePrimitive(left, right) - - if err != nil { - return dyn.InvalidValue, NotModified, err - } - - if modified == Modified { - return right, Modified, nil - } else { - return left, NotModified, nil - } - } -} - -func insertMappingIfAbsent(root dyn.Value, path dyn.Path) (dyn.Value, error) { - value, err := dyn.GetByPath(root, path) - - if err == nil && value != dyn.NilValue && value != dyn.InvalidValue { - return root, nil - } - - var defaultValue = dyn.NewValue(dyn.NewMapping(), dyn.Location{}) - - // create an empty object like {path[0]: {path[1]: { ... { path[n]: {}}} - for i := range path { - name := path[len(path)-i-1].Key() - - defaultValue = dyn.NewValue(map[string]dyn.Value{name: defaultValue}, dyn.Location{}) - } - - updated, err := merge.Merge(root, defaultValue) - - if err != nil { - return dyn.InvalidValue, err - } - - return updated, nil -} - -func overridePrimitive(a dyn.Value, b dyn.Value) (isModified, error) { - switch a.Kind() { - case dyn.KindInvalid: - return Modified, nil - - case dyn.KindNil: - return Modified, nil - - case dyn.KindBool: - if a.MustBool() != b.MustBool() { - return Modified, nil - } else { - return NotModified, nil - } - - case dyn.KindInt: - if a.MustInt() != b.MustInt() { - return Modified, nil - } else { - return NotModified, nil - } - - case dyn.KindFloat: - // FIXME what is the example of float field? should we use epsilon? - if a.MustFloat() != b.MustFloat() { - return Modified, nil - } else { - return NotModified, nil - } - - case dyn.KindString: - if a.MustString() != b.MustString() { - return Modified, nil - } else { - return NotModified, nil - } - - case dyn.KindTime: - if a.MustTime() != b.MustTime() { - return Modified, nil - } else { - return NotModified, nil - } - - case dyn.KindMap: - return NotModified, fmt.Errorf("unexpected kind %s, expected primitive", a.Kind()) - - case dyn.KindSequence: - return NotModified, fmt.Errorf("unexpected kind %s, expected primitive", a.Kind()) - } - - return NotModified, fmt.Errorf("unexpected kind %s, expected primitive", a.Kind()) -} - -func unexpectedChangeError(location dyn.Location) error { - return fmt.Errorf("unexpectedly changed value at '%s'", location) -} diff --git a/libs/dyn/merge/override.go b/libs/dyn/merge/override.go new file mode 100644 index 0000000000..dda98ba570 --- /dev/null +++ b/libs/dyn/merge/override.go @@ -0,0 +1,204 @@ +package merge + +import ( + "fmt" + + "github.com/databricks/cli/libs/dyn" +) + +// OverrideVisitor is visiting the changes during the override process +// and allows to control what changes are allowed, or update the effective +// value. +// +// For instance, it can disallow changes outside the specific path(s), or update +// the location of the effective value. +// +// 'VisitDelete' is called when a value is removed from mapping or sequence +// 'VisitInsert' is called when a new value is added to mapping or sequence +// 'VisitUpdate' is called when a leaf value is updated +type OverrideVisitor struct { + VisitDelete func(valuePath dyn.Path, left dyn.Value) error + VisitInsert func(valuePath dyn.Path, right dyn.Value) (dyn.Value, error) + VisitUpdate func(valuePath dyn.Path, left dyn.Value, right dyn.Value) (dyn.Value, error) +} + +// Override overrides bundle config 'leftRoot' with 'rightRoot', keeping 'location' if values +// haven't changed. Preserving 'location' is important to preserve the original source of the value +// for error reporting. +func Override(leftRoot dyn.Value, rightRoot dyn.Value, visitor OverrideVisitor) (dyn.Value, error) { + return override(dyn.EmptyPath, leftRoot, rightRoot, visitor) +} + +func override(basePath dyn.Path, left dyn.Value, right dyn.Value, visitor OverrideVisitor) (dyn.Value, error) { + if left == dyn.NilValue && right == dyn.NilValue { + return dyn.NilValue, nil + } + + if left.Kind() != right.Kind() { + newValue, err := visitor.VisitUpdate(basePath, left, right) + + if err != nil { + return dyn.InvalidValue, err + } + + return newValue, nil + } + + switch left.Kind() { + case dyn.KindMap: + merged, err := overrideMapping(basePath, left.MustMap(), right.MustMap(), visitor) + + if err != nil { + return dyn.InvalidValue, err + } + + return dyn.NewValue(merged, left.Location()), nil + case dyn.KindSequence: + // some sequences are keyed, and we can detect which elements are added/removed/updated, + // but we don't have this information + + leftSeq := left.MustSequence() + rightSeq := right.MustSequence() + + merged, err := overrideSequence(basePath, leftSeq, rightSeq, visitor) + + if err != nil { + return dyn.InvalidValue, err + } + + return dyn.NewValue(merged, left.Location()), nil + + case dyn.KindString: + if left.MustString() == right.MustString() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + + case dyn.KindFloat: + // TODO consider comparison with epsilon if normalization doesn't help, where do we use floats? + + if left.MustFloat() == right.MustFloat() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + + case dyn.KindBool: + if left.MustBool() == right.MustBool() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + + case dyn.KindTime: + if left.MustTime() == right.MustTime() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + + case dyn.KindInt: + if left.MustInt() == right.MustInt() { + return left, nil + } else { + return visitor.VisitUpdate(basePath, left, right) + } + } + + return dyn.InvalidValue, fmt.Errorf("unexpected kind %s", left.Kind()) +} + +func overrideMapping(basePath dyn.Path, leftMapping dyn.Mapping, rightMapping dyn.Mapping, visitor OverrideVisitor) (dyn.Mapping, error) { + out := dyn.NewMapping() + + for _, leftPair := range leftMapping.Pairs() { + // detect if key was removed + if _, ok := rightMapping.GetPair(leftPair.Key); !ok { + path := basePath.Append(dyn.Key(leftPair.Key.MustString())) + + err := visitor.VisitDelete(path, leftPair.Value) + + if err != nil { + return dyn.NewMapping(), err + } + } + } + + // iterating only right mapping will remove keys not present anymore + // and insert new keys + + for _, rightPair := range rightMapping.Pairs() { + if leftPair, ok := leftMapping.GetPair(rightPair.Key); ok { + path := basePath.Append(dyn.Key(rightPair.Key.MustString())) + newValue, err := override(path, leftPair.Value, rightPair.Value, visitor) + + if err != nil { + return dyn.NewMapping(), err + } + + // key was there before, so keep its location + err = out.Set(leftPair.Key, newValue) + + if err != nil { + return dyn.NewMapping(), err + } + } else { + path := basePath.Append(dyn.Key(rightPair.Key.MustString())) + + newValue, err := visitor.VisitInsert(path, rightPair.Value) + + if err != nil { + return dyn.NewMapping(), err + } + + err = out.Set(rightPair.Key, newValue) + + if err != nil { + return dyn.NewMapping(), err + } + } + } + + return out, nil +} + +func overrideSequence(basePath dyn.Path, left []dyn.Value, right []dyn.Value, visitor OverrideVisitor) ([]dyn.Value, error) { + minLen := min(len(left), len(right)) + var values []dyn.Value + + for i := 0; i < minLen; i++ { + path := basePath.Append(dyn.Index(i)) + merged, err := override(path, left[i], right[i], visitor) + + if err != nil { + return nil, err + } + + values = append(values, merged) + } + + if len(right) > len(left) { + for i := minLen; i < len(right); i++ { + path := basePath.Append(dyn.Index(i)) + newValue, err := visitor.VisitInsert(path, right[i]) + + if err != nil { + return nil, err + } + + values = append(values, newValue) + } + } else { + for i := minLen; i < len(left); i++ { + path := basePath.Append(dyn.Index(i)) + err := visitor.VisitDelete(path, left[i]) + + if err != nil { + return nil, err + } + } + } + + return values, nil +} diff --git a/bundle/config/mutator/python/override_test.go b/libs/dyn/merge/override_test.go similarity index 56% rename from bundle/config/mutator/python/override_test.go rename to libs/dyn/merge/override_test.go index c13ebeeed2..dbf249d12d 100644 --- a/bundle/config/mutator/python/override_test.go +++ b/libs/dyn/merge/override_test.go @@ -1,160 +1,18 @@ -package python +package merge import ( - "bytes" "testing" "time" "github.com/databricks/cli/libs/dyn" assert "github.com/databricks/cli/libs/dyn/dynassert" - "github.com/databricks/cli/libs/dyn/yamlloader" ) -func TestOverrideByPath(t *testing.T) { - left := loadYaml("left.yml", ` -resources: - jobs: - job_0: - name: "job_0" -`) - - right := loadYaml("right.yml", ` -resources: - jobs: - job_0: - name: job_0 - job_1: - name: job_1 -`) - - updated, err := OverrideByPath(left, right, newPath("resources", "jobs"), OverrideModeAppend) - assert.NoError(t, err) - - name0, _ := dyn.GetByPath( - updated, - newPath("resources", "jobs", "job_0", "name"), - ) - - name1, _ := dyn.GetByPath( - updated, - newPath("resources", "jobs", "job_1", "name"), - ) - - assert.Equal(t, "job_0", name0.MustString()) - assert.Equal(t, "left.yml", name0.Location().File) - - assert.Equal(t, "job_1", name1.MustString()) - assert.Equal(t, "right.yml", name1.Location().File) -} - -func TestOverrideByPath_Update(t *testing.T) { - left := loadYaml("left.yml", ` -resources: - jobs: - job_0: - name: "job_0" - description: job 0 -`) - - right := loadYaml("right.yml", ` -resources: - jobs: - job_0: - name: "job_0_updated" - description: job 0 -`) - - updated, err := OverrideByPath(left, right, newPath("resources", "jobs"), OverrideModeAppendOrModify) - - name, _ := dyn.GetByPath( - updated, - newPath("resources", "jobs", "job_0", "name"), - ) - - description, _ := dyn.GetByPath( - updated, - newPath("resources", "jobs", "job_0", "description"), - ) - - assert.NoError(t, err) - - if err != nil { - return - } - - assert.Equal(t, "job_0_updated", name.MustString()) - assert.Equal(t, "right.yml", name.Location().File) - - // description hasn't changed, so it should keep its location - assert.Equal(t, "job 0", description.MustString()) - assert.Equal(t, "left.yml", description.Location().File) -} - -func TestOverrideByPath_UpdateNotAllowed(t *testing.T) { - left := loadYaml("left.yml", ` -resources: - jobs: - job_0: - name: "job_0" -`) - - right := loadYaml("right.yml", ` -resources: - jobs: - job_0: - name: "job_0_updated" -`) - - _, err := OverrideByPath(left, right, newPath("resources", "jobs"), OverrideModeAppend) - - assert.EqualError(t, err, "unexpectedly changed value at 'left.yml:5:7'") -} - -func TestOverrideByPath_RemoveNotAllowed(t *testing.T) { - left := loadYaml("left.yml", ` -resources: - jobs: - job_0: - name: "job_0" - job_1: - name: "job_1" -`) - - right := loadYaml("right.yml", ` -resources: - jobs: - job_0: - name: "job_0" -`) - - t.Run("append or modify", func(t *testing.T) { - _, err := OverrideByPath(left, right, newPath("resources", "jobs"), OverrideModeAppendOrModify) - - assert.EqualError(t, err, "unexpectedly changed value at 'left.yml:7:7'") - }) - - t.Run("append", func(t *testing.T) { - _, err := OverrideByPath(left, right, newPath("resources", "jobs"), OverrideModeAppend) - - assert.EqualError(t, err, "unexpectedly changed value at 'left.yml:7:7'") - }) -} - -func TestOverrideByPath_Empty(t *testing.T) { - left := loadYaml("left.yml", "") - right := loadYaml("right.yml", "") - - updated, err := OverrideByPath(left, right, newPath("resources", "jobs"), OverrideModeAppend) - - assert.NoError(t, err) - assert.Equal(t, left, updated) -} - type overrideTestCase struct { name string left dyn.Value right dyn.Value - modified isModified + state visitorState expected dyn.Value } @@ -164,85 +22,88 @@ func TestOverride_Primitive(t *testing.T) { modifiedTestCases := []overrideTestCase{ { - name: "string (modified)", - modified: Modified, + name: "string (updated)", + state: visitorState{updated: []string{"root"}}, left: dyn.NewValue("a", leftLocation), right: dyn.NewValue("b", rightLocation), expected: dyn.NewValue("b", rightLocation), }, { - name: "string (not modified)", - modified: NotModified, + name: "string (not updated)", + state: visitorState{}, left: dyn.NewValue("a", leftLocation), right: dyn.NewValue("a", rightLocation), expected: dyn.NewValue("a", leftLocation), }, { - name: "bool (modified)", - modified: Modified, + name: "bool (updated)", + state: visitorState{updated: []string{"root"}}, left: dyn.NewValue(true, leftLocation), right: dyn.NewValue(false, rightLocation), expected: dyn.NewValue(false, rightLocation), }, { - name: "bool (modified)", - modified: NotModified, + name: "bool (not updated)", + state: visitorState{}, left: dyn.NewValue(true, leftLocation), right: dyn.NewValue(true, rightLocation), expected: dyn.NewValue(true, leftLocation), }, { - name: "int (modified)", - modified: Modified, + name: "int (updated)", + state: visitorState{updated: []string{"root"}}, left: dyn.NewValue(1, leftLocation), right: dyn.NewValue(2, rightLocation), expected: dyn.NewValue(2, rightLocation), }, { - name: "int (not modified)", - modified: NotModified, + name: "int (not updated)", + state: visitorState{}, left: dyn.NewValue(int32(1), leftLocation), right: dyn.NewValue(int64(1), rightLocation), expected: dyn.NewValue(int32(1), leftLocation), }, { - name: "float (modified)", - modified: Modified, + name: "float (updated)", + state: visitorState{updated: []string{"root"}}, left: dyn.NewValue(1.0, leftLocation), right: dyn.NewValue(2.0, rightLocation), expected: dyn.NewValue(2.0, rightLocation), }, { - name: "float (not modified)", - modified: NotModified, + name: "float (not updated)", + state: visitorState{}, left: dyn.NewValue(float32(1.0), leftLocation), right: dyn.NewValue(float64(1.0), rightLocation), expected: dyn.NewValue(float32(1.0), leftLocation), }, { - name: "time (modified)", - modified: Modified, + name: "time (updated)", + state: visitorState{updated: []string{"root"}}, left: dyn.NewValue(time.UnixMilli(10000), leftLocation), right: dyn.NewValue(time.UnixMilli(10001), rightLocation), expected: dyn.NewValue(time.UnixMilli(10001), rightLocation), }, { - name: "time (not modified)", - modified: NotModified, + name: "time (not updated)", + state: visitorState{}, left: dyn.NewValue(time.UnixMilli(10000), leftLocation), right: dyn.NewValue(time.UnixMilli(10000), rightLocation), expected: dyn.NewValue(time.UnixMilli(10000), leftLocation), }, { - name: "different types (modified)", - modified: Modified, + name: "different types (updated)", + state: visitorState{updated: []string{"root"}}, left: dyn.NewValue("a", leftLocation), right: dyn.NewValue(42, rightLocation), expected: dyn.NewValue(42, rightLocation), }, { - name: "map - remove 'a', update 'b'", - modified: Modified, + name: "map - remove 'a', update 'b'", + state: visitorState{ + removed: []string{"root.a"}, + updated: []string{"root.b"}, + }, left: dyn.NewValue( map[string]dyn.Value{ "a": dyn.NewValue(42, leftLocation), @@ -264,23 +125,26 @@ func TestOverride_Primitive(t *testing.T) { ), }, { - name: "map - remove 'b'", - modified: Modified, + name: "map - add 'a'", + state: visitorState{ + added: []string{"root.a"}, + }, left: dyn.NewValue( map[string]dyn.Value{ - "a": dyn.NewValue(42, leftLocation), "b": dyn.NewValue(10, leftLocation), }, leftLocation, ), right: dyn.NewValue( map[string]dyn.Value{ + "a": dyn.NewValue(42, rightLocation), "b": dyn.NewValue(10, rightLocation), }, leftLocation, ), expected: dyn.NewValue( map[string]dyn.Value{ + "a": dyn.NewValue(42, rightLocation), // location hasn't changed because value hasn't changed "b": dyn.NewValue(10, leftLocation), }, @@ -288,8 +152,10 @@ func TestOverride_Primitive(t *testing.T) { ), }, { - name: "map - remove 'b'", - modified: Modified, + name: "map - remove 'a'", + state: visitorState{ + removed: []string{"root.a"}, + }, left: dyn.NewValue( map[string]dyn.Value{ "a": dyn.NewValue(42, leftLocation), @@ -312,8 +178,10 @@ func TestOverride_Primitive(t *testing.T) { ), }, { - name: "map - add nested key", - modified: Modified, + name: "map - add 'jobs.job_1'", + state: visitorState{ + added: []string{"root.jobs.job_1"}, + }, left: dyn.NewValue( map[string]dyn.Value{ "jobs": dyn.NewValue( @@ -351,8 +219,8 @@ func TestOverride_Primitive(t *testing.T) { ), }, { - name: "map - remove nested key", - modified: Modified, + name: "map - remove nested key", + state: visitorState{removed: []string{"root.jobs.job_1"}}, left: dyn.NewValue( map[string]dyn.Value{ "jobs": dyn.NewValue( @@ -389,8 +257,8 @@ func TestOverride_Primitive(t *testing.T) { ), }, { - name: "sequence - append", - modified: Modified, + name: "sequence - add", + state: visitorState{added: []string{"root[1]"}}, left: dyn.NewValue( []dyn.Value{ dyn.NewValue(42, leftLocation), @@ -413,8 +281,8 @@ func TestOverride_Primitive(t *testing.T) { ), }, { - name: "sequence - remove", - modified: Modified, + name: "sequence - remove", + state: visitorState{removed: []string{"root[1]"}}, left: dyn.NewValue( []dyn.Value{ dyn.NewValue(42, leftLocation), @@ -437,8 +305,8 @@ func TestOverride_Primitive(t *testing.T) { ), }, { - name: "sequence (not modified)", - modified: NotModified, + name: "sequence (not updated)", + state: visitorState{}, left: dyn.NewValue( []dyn.Value{ dyn.NewValue(42, leftLocation), @@ -459,22 +327,22 @@ func TestOverride_Primitive(t *testing.T) { ), }, { - name: "nil (not modified)", - modified: NotModified, + name: "nil (not updated)", + state: visitorState{}, left: dyn.NilValue, right: dyn.NilValue, expected: dyn.NilValue, }, { - name: "nil (modified)", - modified: Modified, + name: "nil (updated)", + state: visitorState{updated: []string{"root"}}, left: dyn.NilValue, right: dyn.NewValue(42, rightLocation), expected: dyn.NewValue(42, rightLocation), }, { - name: "change kind (modified)", - modified: Modified, + name: "change kind (updated)", + state: visitorState{updated: []string{"root"}}, left: dyn.NewValue(42.0, leftLocation), right: dyn.NewValue(42, rightLocation), expected: dyn.NewValue(42, rightLocation), @@ -483,10 +351,11 @@ func TestOverride_Primitive(t *testing.T) { for _, tc := range modifiedTestCases { t.Run(tc.name, func(t *testing.T) { - out, modified, err := override(tc.left, tc.right) + s, visitor := createVisitor() + out, err := override(dyn.NewPath(dyn.Key("root")), tc.left, tc.right, visitor) assert.NoError(t, err) - assert.Equal(t, tc.modified, modified) + assert.Equal(t, tc.state, *s) assert.Equal(t, tc.expected, out) }) } @@ -507,9 +376,13 @@ func TestOverride_PreserveMappingKeys(t *testing.T) { right := dyn.NewMapping() right.Set(dyn.NewValue("a", rightKeyLocation), dyn.NewValue(7, rightValueLocation)) - out, modified, err := override( + state, visitor := createVisitor() + + out, err := override( + dyn.EmptyPath, dyn.NewValue(left, leftLocation), dyn.NewValue(right, rightLocation), + visitor, ) assert.NoError(t, err) @@ -517,7 +390,7 @@ func TestOverride_PreserveMappingKeys(t *testing.T) { if err != nil { outPairs := out.MustMap().Pairs() - assert.Equal(t, Modified, modified) + assert.Equal(t, visitorState{updated: []string{"a"}}, state) assert.Equal(t, 1, len(outPairs)) // mapping was first defined in left, so it should keep its location @@ -532,57 +405,30 @@ func TestOverride_PreserveMappingKeys(t *testing.T) { } } -func TestInsertMappingIfAbsent(t *testing.T) { - root := dyn.NewValue(map[string]dyn.Value{ - "resources": dyn.NewValue(map[string]dyn.Value{ - "pipelines": dyn.NewValue(map[string]dyn.Value{}, dyn.Location{}), - }, dyn.Location{}), - }, dyn.Location{}) - - expected := dyn.NewValue(map[string]dyn.Value{ - "resources": dyn.NewValue(map[string]dyn.Value{ - "pipelines": dyn.NewValue(map[string]dyn.Value{}, dyn.Location{}), - "jobs": dyn.NewValue(map[string]dyn.Value{}, dyn.Location{}), - }, dyn.Location{}), - }, dyn.Location{}) - - out, err := insertMappingIfAbsent(root, newPath("resources", "jobs")) - - assert.NoError(t, err) - assert.Equal(t, expected, out) +type visitorState struct { + added []string + removed []string + updated []string } -func TestInsertMappingIfAbsent_Present(t *testing.T) { - root := dyn.NewValue(map[string]dyn.Value{ - "resources": dyn.NewValue(map[string]dyn.Value{ - "pipelines": dyn.NewValue(map[string]dyn.Value{}, dyn.Location{}), - }, dyn.Location{}), - }, dyn.Location{}) - - expected := dyn.NewValue(map[string]dyn.Value{ - "resources": dyn.NewValue(map[string]dyn.Value{ - "pipelines": dyn.NewValue(map[string]dyn.Value{}, dyn.Location{}), - }, dyn.Location{}), - }, dyn.Location{}) +func createVisitor() (*visitorState, OverrideVisitor) { + s := visitorState{} - out, err := insertMappingIfAbsent(root, newPath("resources", "pipelines")) + return &s, OverrideVisitor{ + VisitUpdate: func(valuePath dyn.Path, left dyn.Value, right dyn.Value) (dyn.Value, error) { + s.updated = append(s.updated, valuePath.String()) - assert.NoError(t, err) - assert.Equal(t, expected, out) -} + return right, nil + }, + VisitDelete: func(valuePath dyn.Path, left dyn.Value) error { + s.removed = append(s.removed, valuePath.String()) -func loadYaml(name string, content string) dyn.Value { - v, err := yamlloader.LoadYAML(name, bytes.NewReader([]byte(content))) - if err != nil { - panic(err) - } - return v -} + return nil + }, + VisitInsert: func(valuePath dyn.Path, right dyn.Value) (dyn.Value, error) { + s.added = append(s.added, valuePath.String()) -func newPath(keys ...string) dyn.Path { - p := dyn.NewPath() - for _, key := range keys { - p = p.Append(dyn.Key(key)) + return right, nil + }, } - return p }