diff --git a/cmd/datasetRetriever/main.go b/cmd/datasetRetriever/main.go index 2f8adaf..77f52a4 100644 --- a/cmd/datasetRetriever/main.go +++ b/cmd/datasetRetriever/main.go @@ -21,10 +21,11 @@ import ( "net/http" "os" "os/exec" - "github.com/paulscherrerinstitute/scicat/datasetUtils" "strings" "time" + "github.com/paulscherrerinstitute/scicat/datasetUtils" + "github.com/fatih/color" ) @@ -69,11 +70,26 @@ func main() { flag.Parse() + // param test only + if datasetUtils.TestFlags != nil { + datasetUtils.TestFlags(map[string]interface{}{ + "retrieve": *retrieveFlag, + "user": *userpass, + "token": *token, + "nochksum": *nochksumFlag, + "dataset": *datasetId, + "ownergroup": *ownerGroup, + "testenv": *testenvFlag, + "devenv": *devenvFlag, + "version": *showVersion}) + return + } + if *showVersion { fmt.Printf("%s\n", VERSION) return } - + datasetUtils.CheckForNewVersion(client, APP, VERSION) var env string diff --git a/cmd/datasetRetriever/main_test.go b/cmd/datasetRetriever/main_test.go new file mode 100644 index 0000000..ab823de --- /dev/null +++ b/cmd/datasetRetriever/main_test.go @@ -0,0 +1,86 @@ +package main + +import ( + "flag" + "os" + "testing" + + "github.com/paulscherrerinstitute/scicat/datasetUtils" +) + +func TestMainFlags(t *testing.T) { + // test cases + tests := []struct { + name string + flags map[string]interface{} + args []string + }{ + { + name: "Test without flags", + flags: map[string]interface{}{ + "retrieve": false, + "nochksum": false, + "testenv": false, + "devenv": false, + "version": false, + "user": "", + "token": "", + "dataset": "", + "ownergroup": "", + }, + args: []string{"test"}, + }, + { + name: "Set all flags", + flags: map[string]interface{}{ + "retrieve": true, + "nochksum": true, + "testenv": true, + "devenv": true, + "version": true, + "user": "usertest:passtest", + "token": "token", + "dataset": "some dataset", + "ownergroup": "some owners", + }, + args: []string{ + "test", + "--retrieve", + "--nochksum", + "--testenv", + "--devenv", + "--user", + "usertest:passtest", + "--token", + "token", + "--dataset", + "some dataset", + "--ownergroup", + "some owners", + "--version", + }, + }, + } + + // running test cases + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + flag.CommandLine = flag.NewFlagSet(test.name, flag.ExitOnError) + datasetUtils.TestFlags = func(flags map[string]interface{}) { + passing := true + for flag := range test.flags { + if flags[flag] != test.flags[flag] { + t.Logf("%s's value should be \"%v\" but it's \"%v\", or non-matching type", flag, test.flags[flag], flags[flag]) + passing = false + } + } + if !passing { + t.Fail() + } + } + + os.Args = test.args + main() + }) + } +}