diff --git a/issue230_test.go b/issue230_test.go new file mode 100644 index 0000000..129f037 --- /dev/null +++ b/issue230_test.go @@ -0,0 +1,85 @@ +package mergo_test + +import ( + "testing" + + "github.com/imdario/mergo" +) + +var testDataM = []struct { + M1 mapTest + M2 mapTest + WithOverrideEmptyValue bool + ExpectedMap map[int]int +}{ + { + M1: mapTest{ + M: map[int]int{1: 1, 3: 3}, + }, + M2: mapTest{ + M: map[int]int{1: 2, 2: 2}, + }, + WithOverrideEmptyValue: true, + ExpectedMap: map[int]int{1: 1, 3: 3}, + }, + { + M1: mapTest{ + M: map[int]int{1: 1, 3: 3}, + }, + M2: mapTest{ + M: map[int]int{1: 2, 2: 2}, + }, + WithOverrideEmptyValue: false, + ExpectedMap: map[int]int{1: 1, 2: 2, 3: 3}, + }, + { + M1: mapTest{ + M: map[int]int{}, + }, + M2: mapTest{ + M: map[int]int{1: 2, 2: 2}, + }, + WithOverrideEmptyValue: true, + ExpectedMap: map[int]int{}, + }, + { + M1: mapTest{ + M: map[int]int{}, + }, + M2: mapTest{ + M: map[int]int{1: 2, 2: 2}, + }, + WithOverrideEmptyValue: false, + ExpectedMap: map[int]int{1: 2, 2: 2}, + }, +} + +func withOverrideEmptyValue(enable bool) func(*mergo.Config) { + if enable { + return mergo.WithOverwriteWithEmptyValue + } + + return mergo.WithOverride +} + +func TestMergeMapWithOverride(t *testing.T) { + t.Parallel() + + for _, data := range testDataM { + err := mergo.Merge(&data.M2, data.M1, withOverrideEmptyValue(data.WithOverrideEmptyValue)) + if err != nil { + t.Errorf("Error while merging %s", err) + } + + if len(data.M2.M) != len(data.ExpectedMap) { + t.Errorf("Got %d elements in map, but expected %d", len(data.M2.M), len(data.ExpectedMap)) + return + } + + for i, val := range data.M2.M { + if val != data.ExpectedMap[i] { + t.Errorf("Expected value: %d, but got %d while merging map", data.ExpectedMap[i], val) + } + } + } +} diff --git a/issue89_test.go b/issue89_test.go index e47a0df..1e138a9 100644 --- a/issue89_test.go +++ b/issue89_test.go @@ -36,10 +36,6 @@ func TestIssue89MergeWithEmptyValue(t *testing.T) { expected interface{} key string }{ - { - 3, - "A", - }, { "", "B", diff --git a/merge.go b/merge.go index 4b47d0b..1bf26cf 100644 --- a/merge.go +++ b/merge.go @@ -205,6 +205,16 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co dst.SetMapIndex(key, srcElement) } } + + // Ensure that all keys in dst are deleted if they are not in src. + if overwriteWithEmptySrc { + for _, key := range dst.MapKeys() { + srcElement := src.MapIndex(key) + if !srcElement.IsValid() { + dst.SetMapIndex(key, reflect.Value{}) + } + } + } case reflect.Slice: if !dst.CanSet() { break