diff --git a/cmd/root.go b/cmd/root.go index 57e127b..b683d0e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -148,7 +148,7 @@ func initConfig(cmd *cobra.Command) error { sess.App.Version = version sess.App.SemVer = semver - if err := session.CreateDefaultConfigIfMissing(configRoot); err != nil { + if _, err := session.CreateDefaultConfigIfMissing(configRoot); err != nil { fmt.Printf("can't create default session: %v\n", err) os.Exit(1) diff --git a/process/process.go b/process/process.go index 8be8c2c..e6fa75a 100644 --- a/process/process.go +++ b/process/process.go @@ -423,9 +423,9 @@ func generateJSON(results *findHostsResults) (json.RawMessage, error) { return out, nil } -func New(config *session.Session) (Processor, error) { +func New(sess *session.Session) (Processor, error) { p := Processor{ - Session: config, + Session: sess, } return p, nil diff --git a/process/process_test.go b/process/process_test.go index f2ef9d2..436c89a 100644 --- a/process/process_test.go +++ b/process/process_test.go @@ -1 +1,16 @@ package process + +import ( + "github.com/jonhadfield/ipscout/session" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNew(t *testing.T) { + t.Run("New", func(t *testing.T) { + n, err := New(session.New()) + require.NoError(t, err) + require.NotNil(t, n) + }) + +} diff --git a/session/session.go b/session/session.go index 94b4ede..46421bb 100644 --- a/session/session.go +++ b/session/session.go @@ -211,40 +211,39 @@ func unmarshalConfig(data []byte) (*Session, error) { return &conf, nil } -func CreateDefaultConfigIfMissing(path string) error { +// CreateDefaultConfigIfMissing creates a default session configuration file if it does not exist +// and returns true if it was created, or false if it already exists +func CreateDefaultConfigIfMissing(path string) (bool, error) { if path == "" { - return fmt.Errorf("session path not specified") + return false, fmt.Errorf("session path not specified") } - var err error - // check if session already exists - _, err = os.Stat(filepath.Join(path, DefaultConfigFileName)) - + _, err := os.Stat(filepath.Join(path, DefaultConfigFileName)) switch { case err == nil: - return nil + return false, nil case os.IsNotExist(err): // check default session is valid if _, err = unmarshalConfig([]byte(defaultConfig)); err != nil { - return fmt.Errorf("default session invalid: %w", err) + return false, fmt.Errorf("default session invalid: %w", err) } // create dir specified in path argument if missing if _, err = os.Stat(path); os.IsNotExist(err) { if err = os.MkdirAll(path, 0o700); err != nil { - return fmt.Errorf("failed to create session directory: %w", err) + return false, fmt.Errorf("failed to create session directory: %w", err) } } if err = os.WriteFile(filepath.Join(path, DefaultConfigFileName), []byte(defaultConfig), 0o600); err != nil { - return fmt.Errorf("failed to write default session: %w", err) + return false, fmt.Errorf("failed to write default session: %w", err) } case err != nil: - return fmt.Errorf("failed to stat session directory: %w", err) + return false, fmt.Errorf("failed to stat session directory: %w", err) } - return nil + return true, nil } // CreateConfigPathStructure creates all the necessary paths under session root if they don't exist diff --git a/session/session_test.go b/session/session_test.go index 5eee1d3..98ec583 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -39,22 +39,25 @@ func TestUnmarshalConfig(t *testing.T) { func TestCreateDefaultConfig(t *testing.T) { t.Run("PathExists", func(t *testing.T) { path := "/tmp" - err := CreateDefaultConfigIfMissing(path) + created, err := CreateDefaultConfigIfMissing(path) require.NoError(t, err) + require.False(t, created) }) t.Run("PathDoesNotExist", func(t *testing.T) { - path := "/tmp/nonexistent" - err := CreateDefaultConfigIfMissing(path) + path := t.TempDir() + created, err := CreateDefaultConfigIfMissing(path) require.NoError(t, err) + require.True(t, created) _, err = os.Stat(path) require.NoError(t, err) }) t.Run("InvalidPath", func(t *testing.T) { path := "" - err := CreateDefaultConfigIfMissing(path) + created, err := CreateDefaultConfigIfMissing(path) require.Error(t, err) + require.False(t, created) }) } @@ -65,10 +68,12 @@ func TestCreateCachePathIfNotExist(t *testing.T) { configRoot := GetConfigRoot(tempDir, AppName) // create session root (required for cache path) - require.NoError(t, CreateDefaultConfigIfMissing(configRoot)) + created, err := CreateDefaultConfigIfMissing(configRoot) + require.NoError(t, err) + require.True(t, created) // check session root exists - _, err := os.Stat(configRoot) + _, err = os.Stat(configRoot) require.NoError(t, err) // check cache path does not exist @@ -89,9 +94,11 @@ func TestCreateCachePathIfNotExist(t *testing.T) { configRoot := GetConfigRoot(tempDir, AppName) // create session root (required for cache path) - require.NoError(t, CreateDefaultConfigIfMissing(configRoot)) + created, err := CreateDefaultConfigIfMissing(configRoot) + require.NoError(t, err) + require.True(t, created) - err := CreateConfigPathStructure(configRoot) + err = CreateConfigPathStructure(configRoot) require.NoError(t, err) for _, dir := range []string{"backups", "cache"} {