diff --git a/cmd/download.go b/cmd/download.go index 4a60b61e9..ca8bda2d3 100644 --- a/cmd/download.go +++ b/cmd/download.go @@ -54,16 +54,6 @@ func runDownload(cfg config.Config, flags *pflag.FlagSet, args []string) error { return err } - track, err := flags.GetString("track") - if err != nil { - return err - } - - team, err := flags.GetString("team") - if err != nil { - return err - } - identifier, err := downloadIdentifier(flags) if err != nil { return err @@ -80,25 +70,10 @@ func runDownload(cfg config.Config, flags *pflag.FlagSet, args []string) error { return err } - uuid, err := flags.GetString("uuid") + req, err = addQueryToDownloadRequest(flags, req) if err != nil { return err } - slug, err := flags.GetString("exercise") - if err != nil { - return err - } - if uuid == "" { - q := req.URL.Query() - q.Add("exercise_id", slug) - if track != "" { - q.Add("track_id", track) - } - if team != "" { - q.Add("team_id", team) - } - req.URL.RawQuery = q.Encode() - } res, err := client.Do(req) if err != nil { @@ -269,6 +244,39 @@ func downloadIdentifier(flags *pflag.FlagSet) (string, error) { return identifier, nil } +func addQueryToDownloadRequest(flags *pflag.FlagSet, req *http.Request) (*http.Request, error) { + uuid, err := flags.GetString("uuid") + if err != nil { + return req, err + } + slug, err := flags.GetString("exercise") + if err != nil { + return req, err + } + track, err := flags.GetString("track") + if err != nil { + return req, err + } + + team, err := flags.GetString("team") + if err != nil { + return req, err + } + + if uuid == "" { + q := req.URL.Query() + q.Add("exercise_id", slug) + if track != "" { + q.Add("track_id", track) + } + if team != "" { + q.Add("team_id", team) + } + req.URL.RawQuery = q.Encode() + } + return req, nil +} + func setupDownloadFlags(flags *pflag.FlagSet) { flags.StringP("uuid", "u", "", "the solution UUID") flags.StringP("track", "t", "", "the track ID")