diff --git a/providers/file/file.go b/providers/file/file.go index 20c3630b..ed61051e 100644 --- a/providers/file/file.go +++ b/providers/file/file.go @@ -95,18 +95,6 @@ func (f *File) Watch(cb func(event interface{}, err error)) error { evFile := filepath.Clean(event.Name) - // Since the event is triggered on a directory, is this - // one on the file being watched? - if evFile != realPath && evFile != f.path { - continue - } - - // The file was removed. - if event.Op&fsnotify.Remove != 0 { - cb(nil, fmt.Errorf("file %s was removed", event.Name)) - break loop - } - // Resolve symlink to get the real path, in case the symlink's // target has changed. curPath, err := filepath.EvalSymlinks(f.path) @@ -114,15 +102,26 @@ func (f *File) Watch(cb func(event interface{}, err error)) error { cb(nil, err) break loop } - realPath = filepath.Clean(curPath) + curPath = filepath.Clean(curPath) - // Finally, we only care about create and write. - if event.Op&(fsnotify.Write|fsnotify.Create) == 0 { - continue - } + onWatchedFile := evFile == realPath || evFile == f.path - // Trigger event. - cb(nil, nil) + // Since the event is triggered on a directory, is this + // a create or write on the file being watched? + // + // Or has the real path of the file being watched changed? + // + // If either of the above are true, trigger the callback. + if event.Has(fsnotify.Create|fsnotify.Write) && (onWatchedFile || + (curPath != "" && curPath != realPath)) { + realPath = curPath + + // Trigger event. + cb(nil, nil) + } else if onWatchedFile && event.Has(fsnotify.Remove) { + cb(nil, fmt.Errorf("file %s was removed", event.Name)) + break loop + } // There's an error. case err, ok := <-f.w.Errors: diff --git a/tests/koanf_test.go b/tests/koanf_test.go index dc549c3f..ffb951d4 100644 --- a/tests/koanf_test.go +++ b/tests/koanf_test.go @@ -469,11 +469,13 @@ func TestWatchFile(t *testing.T) { // Wait a second and change the file. time.Sleep(1 * time.Second) os.WriteFile(tmpFile, []byte(`{"parent": {"name": "name2"}}`), 0600) - wg.Wait() - - assert.Condition(func() bool { - return strings.Compare(<-changedC, "name2") == 0 - }, "file watch reload didn't change config") + if waitTimeout(&wg, time.Second*10) { + assert.Fail("timeout waiting for file watch trigger") + } else { + assert.Condition(func() bool { + return strings.Compare(<-changedC, "name2") == 0 + }, "file watch reload didn't change config") + } } func TestWatchFileSymlink(t *testing.T) { @@ -526,11 +528,96 @@ func TestWatchFileSymlink(t *testing.T) { // symlink. We do this to avoid removing the symlink and triggering a REMOVE event. time.Sleep(1 * time.Second) assert.NoError(os.Rename(symPath2, symPath), "error swaping symlink to another file type") - wg.Wait() + if waitTimeout(&wg, time.Second*10) { + assert.Fail("timeout waiting for file watch trigger") + } else { + assert.Condition(func() bool { + return strings.Compare(<-changedC, "yml") == 0 + }, "symlink watch reload didn't change config") + } +} + +func TestWatchFileDirectorySymlink(t *testing.T) { + var ( + assert = assert.New(t) + k = koanf.New(delim) + ) + tmpDir := t.TempDir() + + baseDir := filepath.Join(tmpDir, "base_dir") + baseDir2 := filepath.Join(tmpDir, "base_dir2") + + err := os.Mkdir(baseDir, 0700) + assert.NoError(err, "error creating base dir") + + err = os.Mkdir(baseDir2, 0700) + assert.NoError(err, "error creating base dir 2") + + wd, err := os.Getwd() + assert.NoError(err, "error getting working dir") + + jsonFile := filepath.Join(wd, mockJSON) + yamlFile := filepath.Join(wd, mockYAML) + + jsonData, err := os.ReadFile(jsonFile) + assert.NoError(err, "error reading JSON file") + + err = os.WriteFile(filepath.Join(baseDir, "config"), jsonData, 0600) + assert.NoError(err, "error writing JSON file to base dir") + + yamlData, err := os.ReadFile(yamlFile) + assert.NoError(err, "error reading YAML file") + + err = os.WriteFile(filepath.Join(baseDir2, "config"), yamlData, 0600) + assert.NoError(err, "error writing YAML file to base dir 2") + + // Create a symlink. + symDir := filepath.Join(tmpDir, "koanf_test_symlink") + symDir2 := filepath.Join(tmpDir, "koanf_test_symlink2") + symPath := filepath.Join(tmpDir, "config") + + // Create a symlink to the JSON file which will be swapped out later. + assert.NoError(os.Symlink(baseDir, symDir), "error creating symlink dir") + assert.NoError(os.Symlink(baseDir2, symDir2), "error creating symlink dir2") + assert.NoError(os.Symlink(filepath.Join(symDir, "config"), symPath), "error creating symlink") + + // Load the symlink (to the JSON) file. + f := file.Provider(symPath) + k.Load(f, json.Parser()) + + // Watch for changes. + changedC := make(chan string, 1) + var wg sync.WaitGroup + wg.Add(1) // our assurance that cb is called max once + f.Watch(func(event interface{}, err error) { + // The File watcher always returns a nil `event`, which can + // be ignored. + if err != nil { + // TODO: make use of Error Wrapping-Scheme and assert.ErrorIs() checks as of go v1.13 + assert.Condition(func() bool { + return strings.Contains(err.Error(), "no such file or directory") + }, "received unexpected error. err: %s", err) + return + } + // Reload the config. + k.Load(f, yaml.Parser()) + changedC <- k.String("type") + wg.Done() + }) - assert.Condition(func() bool { - return strings.Compare(<-changedC, "yml") == 0 - }, "symlink watch reload didn't change config") + // Wait a second and swap the symlink target from the JSON file to the YAML file. + // Create a temp symlink to the YAML file and rename the old symlink to the new + // symlink. We do this to avoid removing the symlink and triggering a REMOVE event. + time.Sleep(1 * time.Second) + assert.NoError(os.Rename(symDir2, symDir), "error swapping symlink dir to another symlink dir") + + if waitTimeout(&wg, time.Second*10) { + assert.Fail("timeout waiting for file watch trigger") + } else { + assert.Condition(func() bool { + return strings.Compare(<-changedC, "yml") == 0 + }, "symlink watch reload didn't change config") + } } func TestUnwatchFile(t *testing.T) { @@ -716,7 +803,7 @@ func TestFlags(t *testing.T) { bf.String("parent1.child1.type", "flag", "") bf.String("parent2.child2.name", "override-default", "") bf.Set("parent1.child1.type", "basicflag") - assert.Nil(k.Load(basicflag.ProviderWithValue(bf, ".",nil), nil), "error loading basicflag") + assert.Nil(k.Load(basicflag.ProviderWithValue(bf, ".", nil), nil), "error loading basicflag") assert.Equal("basicflag", k.String("parent1.child1.type"), "types don't match") assert.Equal("override-default", k.String("parent2.child2.name"), "basicflag default value override failed") } @@ -728,7 +815,7 @@ func TestFlags(t *testing.T) { bf.String("parent1.child1.name", "override-default", "") bf.String("parent2.child2.name", "override-default", "") bf.Set("parent2.child2.name", "custom") - assert.Nil(k.Load(basicflag.ProviderWithValue(bf, ".",nil, def),nil), "error loading basicflag") + assert.Nil(k.Load(basicflag.ProviderWithValue(bf, ".", nil, def), nil), "error loading basicflag") assert.Equal("child1", k.String("parent1.child1.name"), "basicflag default overwrote") assert.Equal("custom", k.String("parent2.child2.name"), "basicflag set failed") } @@ -1402,3 +1489,19 @@ func TestGetStringsMap(t *testing.T) { assert.Equal(map[string][]string{"k2": {"value"}}, k.StringsMap("ifaces2"), "types don't match") assert.Equal(map[string][]string{"k2": {"value"}}, k.StringsMap("ifaces3"), "types don't match") } + +// waitTimeout waits for the waitgroup for the specified max timeout. +// Returns true if waiting timed out. +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: + return false // completed normally + case <-time.After(timeout): + return true // timed out + } +}