diff --git a/cmd/download.go b/cmd/download.go index e4c19377b..7bf278bff 100644 --- a/cmd/download.go +++ b/cmd/download.go @@ -64,6 +64,10 @@ func runDownload(cfg config.Config, flags *pflag.FlagSet, args []string) error { metadata := download.payload.metadata() dir := metadata.Exercise(usrCfg.GetString("workspace")).MetadataDir() + if _, err = os.Stat(dir); !download.forceoverwrite && err == nil { + return fmt.Errorf("directory '%s' already exists, use --force to overwrite", dir) + } + if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { return err } @@ -103,7 +107,6 @@ func runDownload(cfg config.Config, flags *pflag.FlagSet, args []string) error { continue } - // TODO: handle collisions path := sf.relativePath() dir := filepath.Join(metadata.Dir, filepath.Dir(path)) if err = os.MkdirAll(dir, os.FileMode(0755)); err != nil { @@ -133,7 +136,8 @@ type download struct { token, apibaseurl, workspace string // optional - track, team string + track, team string + forceoverwrite bool payload *downloadPayload } @@ -158,6 +162,11 @@ func newDownload(flags *pflag.FlagSet, usrCfg *viper.Viper) (*download, error) { return nil, err } + d.forceoverwrite, err = flags.GetBool("force") + if err != nil { + return nil, err + } + d.token = usrCfg.GetString("token") d.apibaseurl = usrCfg.GetString("apibaseurl") d.workspace = usrCfg.GetString("workspace") @@ -354,6 +363,7 @@ func setupDownloadFlags(flags *pflag.FlagSet) { flags.StringP("track", "t", "", "the track ID") flags.StringP("exercise", "e", "", "the exercise slug") flags.StringP("team", "T", "", "the team slug") + flags.BoolP("force", "F", false, "overwrite existing exercise directory") } func init() { diff --git a/cmd/download_test.go b/cmd/download_test.go index 2eb0418df..676d90518 100644 --- a/cmd/download_test.go +++ b/cmd/download_test.go @@ -209,6 +209,108 @@ func TestDownload(t *testing.T) { } } +func TestDownloadToExistingDirectory(t *testing.T) { + co := newCapturedOutput() + co.override() + defer co.reset() + + testCases := []struct { + exerciseDir string + flags map[string]string + }{ + { + exerciseDir: filepath.Join("bogus-track", "bogus-exercise"), + flags: map[string]string{"exercise": "bogus-exercise", "track": "bogus-track"}, + }, + { + exerciseDir: filepath.Join("teams", "bogus-team", "bogus-track", "bogus-exercise"), + flags: map[string]string{"exercise": "bogus-exercise", "track": "bogus-track", "team": "bogus-team"}, + }, + } + + for _, tc := range testCases { + tmpDir, err := ioutil.TempDir("", "download-cmd") + defer os.RemoveAll(tmpDir) + assert.NoError(t, err) + + err = os.MkdirAll(filepath.Join(tmpDir, tc.exerciseDir), os.FileMode(0755)) + assert.NoError(t, err) + + ts := fakeDownloadServer("true", "") + defer ts.Close() + + v := viper.New() + v.Set("workspace", tmpDir) + v.Set("apibaseurl", ts.URL) + v.Set("token", "abc123") + + cfg := config.Config{ + UserViperConfig: v, + } + flags := pflag.NewFlagSet("fake", pflag.PanicOnError) + setupDownloadFlags(flags) + for name, value := range tc.flags { + flags.Set(name, value) + } + + err = runDownload(cfg, flags, []string{}) + + if assert.Error(t, err) { + assert.Regexp(t, "directory '.+' already exists", err.Error()) + } + } +} + +func TestDownloadToExistingDirectoryWithForce(t *testing.T) { + co := newCapturedOutput() + co.override() + defer co.reset() + + testCases := []struct { + exerciseDir string + flags map[string]string + }{ + { + exerciseDir: filepath.Join("bogus-track", "bogus-exercise"), + flags: map[string]string{"exercise": "bogus-exercise", "track": "bogus-track"}, + }, + { + exerciseDir: filepath.Join("teams", "bogus-team", "bogus-track", "bogus-exercise"), + flags: map[string]string{"exercise": "bogus-exercise", "track": "bogus-track", "team": "bogus-team"}, + }, + } + + for _, tc := range testCases { + tmpDir, err := ioutil.TempDir("", "download-cmd") + defer os.RemoveAll(tmpDir) + assert.NoError(t, err) + + err = os.MkdirAll(filepath.Join(tmpDir, tc.exerciseDir), os.FileMode(0755)) + assert.NoError(t, err) + + ts := fakeDownloadServer("true", "") + defer ts.Close() + + v := viper.New() + v.Set("workspace", tmpDir) + v.Set("apibaseurl", ts.URL) + v.Set("token", "abc123") + + cfg := config.Config{ + UserViperConfig: v, + } + flags := pflag.NewFlagSet("fake", pflag.PanicOnError) + setupDownloadFlags(flags) + for name, value := range tc.flags { + flags.Set(name, value) + } + flags.Set("force", "true") + + err = runDownload(cfg, flags, []string{}) + assert.NoError(t, err) + } +} + func fakeDownloadServer(requestor, teamSlug string) *httptest.Server { mux := http.NewServeMux() server := httptest.NewServer(mux)