diff --git a/rai/client_test.go b/rai/client_test.go index 116e5b1..033829e 100644 --- a/rai/client_test.go +++ b/rai/client_test.go @@ -18,11 +18,14 @@ import ( "context" "fmt" "net/http" + "os" + "path/filepath" "strings" "testing" "time" "github.com/google/uuid" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -63,6 +66,32 @@ func findModel(models []Model, name string) *Model { return nil } +// deleteTokenCacheDir deletes the token file cache and the "~/.rai" directory if empty +func deleteTokenCacheDir(t *testing.T) { + fname, err := cachePath() + if err != nil { + t.Error("Failed to get token cache file name") + } + err = os.Remove(fname) + if err != nil { + t.Errorf("Failed to delete token cache file %s\n", fname) + } + // if the directory is not empty then the deletion request will fail, but we want the test to continue + _ = os.Remove(filepath.Dir(fname)) +} + +// assertTokenCacheFileCreated asserts that the token file has been created +func assertTokenCacheFileCreated(t *testing.T) { + fpath, err := cachePath() + if err != nil { + t.Error("Failed to get token cache file name") + } + + if _, err := os.Stat(fpath); err != nil { + t.Error(errors.Wrapf(err, "Failed to stat token cache file %s", fpath)) + } +} + func TestNewClient(t *testing.T) { var testClient *Client var cfg Config @@ -95,6 +124,30 @@ func TestNewClient(t *testing.T) { assert.NotNil(t, err) } +// Test token cache file creation +func TestTokenCacheFile(t *testing.T) { + deleteTokenCacheDir(t) + + var testClient *Client + var cfg Config + + err := getConfig(&cfg) + assert.Nil(t, err) + + opts := ClientOptions{Config: cfg} + testClient = NewClient(context.Background(), &opts) + + token, err := testClient.accessTokenHandler.GetAccessToken() + assert.Nil(t, err) + assert.NotNil(t, token) + + tokenCached, _ := testClient.accessTokenHandler.GetAccessToken() + + assert.Equal(t, token, tokenCached) + + assertTokenCacheFileCreated(t) +} + // Test database management APIs. func TestDatabase(t *testing.T) { client := test.client diff --git a/rai/handlers.go b/rai/handlers.go index c2f86b1..48ee8fe 100644 --- a/rai/handlers.go +++ b/rai/handlers.go @@ -22,6 +22,7 @@ import ( "os" "os/user" "path" + "path/filepath" "github.com/pkg/errors" ) @@ -56,8 +57,8 @@ func NewClientCredentialsHandler( return &ClientCredentialsHandler{client: c, creds: creds} } -// Returns the name of the token cache file. -func cacheName() (string, error) { +// Returns the path of the token cache file. +func cachePath() (string, error) { usr, err := user.Current() if err != nil { return "", err @@ -79,7 +80,7 @@ func readAccessToken(creds *ClientCredentials) (*AccessToken, error) { } func readTokenCache() (map[string]*AccessToken, error) { - fname, err := cacheName() + fname, err := cachePath() if err != nil { return nil, err } @@ -107,12 +108,20 @@ func writeAccessToken(clientID string, token *AccessToken) { } func writeTokenCache(cache map[string]*AccessToken) { - fname, err := cacheName() + fname, err := cachePath() if err != nil { return } + + dirName := filepath.Dir(fname) + err = os.MkdirAll(dirName, 0775) + if err != nil { + fmt.Println(errors.Wrapf(err, "failed to create token directory")) + } + f, err := os.OpenFile(fname, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { + fmt.Println(errors.Wrapf(err, "failed to open token file")) return } if err := json.NewEncoder(f).Encode(cache); err != nil {