Skip to content

Commit

Permalink
RAI-24760 Handle token directory creation (#103)
Browse files Browse the repository at this point in the history
* Handle token directory creation

* Add test for token file creation
  • Loading branch information
ginal authored Jul 12, 2024
1 parent 4614e6d commit 136ae26
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
53 changes: 53 additions & 0 deletions rai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ import (
"context"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -63,6 +66,32 @@ func findModel(models []Model, name string) *Model {
return nil
}

// deleteTokenCacheDir deletes the token file cache and the "~/.rai" directory if empty
func deleteTokenCacheDir(t *testing.T) {
fname, err := cachePath()
if err != nil {
t.Error("Failed to get token cache file name")
}
err = os.Remove(fname)
if err != nil {
t.Errorf("Failed to delete token cache file %s\n", fname)
}
// if the directory is not empty then the deletion request will fail, but we want the test to continue
_ = os.Remove(filepath.Dir(fname))
}

// assertTokenCacheFileCreated asserts that the token file has been created
func assertTokenCacheFileCreated(t *testing.T) {
fpath, err := cachePath()
if err != nil {
t.Error("Failed to get token cache file name")
}

if _, err := os.Stat(fpath); err != nil {
t.Error(errors.Wrapf(err, "Failed to stat token cache file %s", fpath))
}
}

func TestNewClient(t *testing.T) {
var testClient *Client
var cfg Config
Expand Down Expand Up @@ -95,6 +124,30 @@ func TestNewClient(t *testing.T) {
assert.NotNil(t, err)
}

// Test token cache file creation
func TestTokenCacheFile(t *testing.T) {
deleteTokenCacheDir(t)

var testClient *Client
var cfg Config

err := getConfig(&cfg)
assert.Nil(t, err)

opts := ClientOptions{Config: cfg}
testClient = NewClient(context.Background(), &opts)

token, err := testClient.accessTokenHandler.GetAccessToken()
assert.Nil(t, err)
assert.NotNil(t, token)

tokenCached, _ := testClient.accessTokenHandler.GetAccessToken()

assert.Equal(t, token, tokenCached)

assertTokenCacheFileCreated(t)
}

// Test database management APIs.
func TestDatabase(t *testing.T) {
client := test.client
Expand Down
17 changes: 13 additions & 4 deletions rai/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"os"
"os/user"
"path"
"path/filepath"

"github.com/pkg/errors"
)
Expand Down Expand Up @@ -56,8 +57,8 @@ func NewClientCredentialsHandler(
return &ClientCredentialsHandler{client: c, creds: creds}
}

// Returns the name of the token cache file.
func cacheName() (string, error) {
// Returns the path of the token cache file.
func cachePath() (string, error) {
usr, err := user.Current()
if err != nil {
return "", err
Expand All @@ -79,7 +80,7 @@ func readAccessToken(creds *ClientCredentials) (*AccessToken, error) {
}

func readTokenCache() (map[string]*AccessToken, error) {
fname, err := cacheName()
fname, err := cachePath()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -107,12 +108,20 @@ func writeAccessToken(clientID string, token *AccessToken) {
}

func writeTokenCache(cache map[string]*AccessToken) {
fname, err := cacheName()
fname, err := cachePath()
if err != nil {
return
}

dirName := filepath.Dir(fname)
err = os.MkdirAll(dirName, 0775)
if err != nil {
fmt.Println(errors.Wrapf(err, "failed to create token directory"))
}

f, err := os.OpenFile(fname, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
fmt.Println(errors.Wrapf(err, "failed to open token file"))
return
}
if err := json.NewEncoder(f).Encode(cache); err != nil {
Expand Down

0 comments on commit 136ae26

Please sign in to comment.