diff --git a/auth.go b/auth.go index f473471ef..a95b968ba 100644 --- a/auth.go +++ b/auth.go @@ -365,10 +365,10 @@ func authenticate( logger.WithContext(ctx).Errorln("Authentication FAILED") sc.rest.TokenAccessor.SetTokens("", "", -1) if sessionParameters[clientRequestMfaToken] == true { - deleteCredential(sc, mfaToken) + credentialsStorage.deleteCredential(sc, mfaToken) } if sessionParameters[clientStoreTemporaryCredential] == true { - deleteCredential(sc, idToken) + credentialsStorage.deleteCredential(sc, idToken) } code, err := strconv.Atoi(respd.Code) if err != nil { @@ -384,11 +384,11 @@ func authenticate( sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID) if sessionParameters[clientRequestMfaToken] == true { token := respd.Data.MfaToken - setCredential(sc, mfaToken, token) + credentialsStorage.setCredential(sc, mfaToken, token) } if sessionParameters[clientStoreTemporaryCredential] == true { token := respd.Data.IDToken - setCredential(sc, idToken, token) + credentialsStorage.setCredential(sc, idToken, token) } return &respd.Data, nil } @@ -575,9 +575,9 @@ func authenticateWithConfig(sc *snowflakeConn) error { } func fillCachedIDToken(sc *snowflakeConn) { - getCredential(sc, idToken) + credentialsStorage.getCredential(sc, idToken) } func fillCachedMfaToken(sc *snowflakeConn) { - getCredential(sc, mfaToken) + credentialsStorage.getCredential(sc, mfaToken) } diff --git a/secure_storage_manager.go b/secure_storage_manager.go index da8610e36..fc392662b 100644 --- a/secure_storage_manager.go +++ b/secure_storage_manager.go @@ -4,6 +4,7 @@ package gosnowflake import ( "encoding/json" + "fmt" "os" "path/filepath" "runtime" @@ -20,68 +21,231 @@ const ( credCacheFileName = "temporary_credential.json" ) -var ( - credCacheDir = "" - credCache = "" - localCredCache = map[string]string{} -) +type secureStorageManager interface { + setCredential(sc *snowflakeConn, credType, token string) + getCredential(sc *snowflakeConn, credType string) + deleteCredential(sc *snowflakeConn, credType string) +} -var ( - credCacheLock sync.RWMutex -) +var credentialsStorage = newSecureStorageManager() + +func newSecureStorageManager() secureStorageManager { + switch runtime.GOOS { + case "linux": + return newFileBasedSecureStorageManager() + case "darwin", "windows": + return newKeyringBasedSecureStorageManager() + default: + return newNoopSecureStorageManager() + } +} + +type fileBasedSecureStorageManager struct { + credCacheFilePath string + localCredCache map[string]string + credCacheLock sync.RWMutex +} + +func newFileBasedSecureStorageManager() secureStorageManager { + ssm := &fileBasedSecureStorageManager{ + localCredCache: map[string]string{}, + credCacheLock: sync.RWMutex{}, + } + credCacheDir := ssm.buildCredCacheDirPath() + if err := ssm.createCacheDir(credCacheDir); err != nil { + logger.Debugf("failed to create credentials cache dir. %v", err) + return newNoopSecureStorageManager() + } + credCacheFilePath := filepath.Join(credCacheDir, credCacheFileName) + logger.Infof("Credentials cache path: %v", credCacheFilePath) + ssm.credCacheFilePath = credCacheFilePath + return ssm +} + +func (ssm *fileBasedSecureStorageManager) createCacheDir(credCacheDir string) error { + _, err := os.Stat(credCacheDir) + if os.IsNotExist(err) { + if err = os.MkdirAll(credCacheDir, os.ModePerm); err != nil { + return fmt.Errorf("failed to create cache directory. %v, err: %v", credCacheDir, err) + } + return nil + } + return err +} + +func (ssm *fileBasedSecureStorageManager) buildCredCacheDirPath() string { + credCacheDir := os.Getenv(credCacheDirEnv) + if credCacheDir != "" { + return credCacheDir + } + home := os.Getenv("HOME") + if home == "" { + logger.Info("HOME is blank") + return "" + } + credCacheDir = filepath.Join(home, ".cache", "snowflake") + return credCacheDir +} + +func (ssm *fileBasedSecureStorageManager) setCredential(sc *snowflakeConn, credType, token string) { + if token == "" { + logger.Debug("no token provided") + } else { + credentialsKey := buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType) + ssm.credCacheLock.Lock() + defer ssm.credCacheLock.Unlock() + ssm.localCredCache[credentialsKey] = token + + j, err := json.Marshal(ssm.localCredCache) + if err != nil { + logger.Warnf("failed to convert credential to JSON.") + return + } + + logger.Debugf("writing credential cache file. %v\n", ssm.credCacheFilePath) + credCacheLockFileName := ssm.credCacheFilePath + ".lck" + logger.Debugf("Creating lock file. %v", credCacheLockFileName) + err = os.Mkdir(credCacheLockFileName, 0600) -func createCredentialCacheDir() { - credCacheDir = os.Getenv(credCacheDirEnv) - if credCacheDir == "" { - switch runtime.GOOS { - case "windows": - credCacheDir = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local", "Snowflake", "Caches") - case "darwin": - home := os.Getenv("HOME") - if home == "" { - logger.Info("HOME is blank.") + switch { + case os.IsExist(err): + statinfo, err := os.Stat(credCacheLockFileName) + if err != nil { + logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", ssm.credCacheFilePath, err) + return + } + if time.Since(statinfo.ModTime()) < 15*time.Minute { + logger.Debugf("other process locks the cache file. %v. ignored.\n", ssm.credCacheFilePath) + return + } + if err = os.Remove(credCacheLockFileName); err != nil { + logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) + return } - credCacheDir = filepath.Join(home, "Library", "Caches", "Snowflake") - default: - home := os.Getenv("HOME") - if home == "" { - logger.Info("HOME is blank") + if err = os.Mkdir(credCacheLockFileName, 0600); err != nil { + logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) + return } - credCacheDir = filepath.Join(home, ".cache", "snowflake") + } + defer os.RemoveAll(credCacheLockFileName) + + if err = os.WriteFile(ssm.credCacheFilePath, j, 0644); err != nil { + logger.Debugf("Failed to write the cache file. File: %v err: %v.", ssm.credCacheFilePath, err) } } +} - if _, err := os.Stat(credCacheDir); os.IsNotExist(err) { - if err = os.MkdirAll(credCacheDir, os.ModePerm); err != nil { - logger.Debugf("Failed to create cache directory. %v, err: %v. ignored\n", credCacheDir, err) +func (ssm *fileBasedSecureStorageManager) getCredential(sc *snowflakeConn, credType string) { + credentialsKey := buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType) + ssm.credCacheLock.Lock() + defer ssm.credCacheLock.Unlock() + localCredCache := ssm.readTemporaryCacheFile() + cred := localCredCache[credentialsKey] + if cred != "" { + logger.Debug("Successfully read token. Returning as string") + } else { + logger.Debug("Returned credential is empty") + } + + if credType == idToken { + sc.cfg.IDToken = cred + } else if credType == mfaToken { + sc.cfg.MfaToken = cred + } else { + logger.Debugf("Unrecognized type %v for local cached credential", credType) + } +} + +func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]string { + jsonData, err := os.ReadFile(ssm.credCacheFilePath) + if err != nil { + logger.Debugf("Failed to read credential file: %v", err) + return nil + } + err = json.Unmarshal([]byte(jsonData), &ssm.localCredCache) + if err != nil { + logger.Debugf("failed to read JSON. Err: %v", err) + return nil + } + + return ssm.localCredCache +} + +func (ssm *fileBasedSecureStorageManager) deleteCredential(sc *snowflakeConn, credType string) { + ssm.credCacheLock.Lock() + defer ssm.credCacheLock.Unlock() + credentialsKey := buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType) + delete(ssm.localCredCache, credentialsKey) + j, err := json.Marshal(ssm.localCredCache) + if err != nil { + logger.Warnf("failed to convert credential to JSON.") + return + } + ssm.writeTemporaryCacheFile(j) +} + +func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(input []byte) { + logger.Debugf("writing credential cache file. %v\n", ssm.credCacheFilePath) + credCacheLockFileName := ssm.credCacheFilePath + ".lck" + err := os.Mkdir(credCacheLockFileName, 0600) + logger.Debugf("Creating lock file. %v", credCacheLockFileName) + + switch { + case os.IsExist(err): + statinfo, err := os.Stat(credCacheLockFileName) + if err != nil { + logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", ssm.credCacheFilePath, err) + return + } + if time.Since(statinfo.ModTime()) < 15*time.Minute { + logger.Debugf("other process locks the cache file. %v. ignored.\n", ssm.credCacheFilePath) + return } + if err = os.Remove(credCacheLockFileName); err != nil { + logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) + return + } + if err = os.Mkdir(credCacheLockFileName, 0600); err != nil { + logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) + return + } + } + defer os.RemoveAll(credCacheLockFileName) + + if err = os.WriteFile(ssm.credCacheFilePath, input, 0644); err != nil { + logger.Debugf("Failed to write the cache file. File: %v err: %v.", ssm.credCacheFilePath, err) } - credCache = filepath.Join(credCacheDir, credCacheFileName) - logger.Infof("Cache directory: %v", credCache) } -func setCredential(sc *snowflakeConn, credType, token string) { +type keyringSecureStorageManager struct { +} + +func newKeyringBasedSecureStorageManager() secureStorageManager { + return &keyringSecureStorageManager{} +} + +func (ssm *keyringSecureStorageManager) setCredential(sc *snowflakeConn, credType, token string) { if token == "" { logger.Debug("no token provided") } else { - var target string + var credentialsKey string if runtime.GOOS == "windows" { - target = driverName + ":" + credType + credentialsKey = driverName + ":" + credType ring, _ := keyring.Open(keyring.Config{ WinCredPrefix: strings.ToUpper(sc.cfg.Host), ServiceName: strings.ToUpper(sc.cfg.User), }) item := keyring.Item{ - Key: target, + Key: credentialsKey, Data: []byte(token), } if err := ring.Set(item); err != nil { logger.Debugf("Failed to write to Windows credential manager. Err: %v", err) } } else if runtime.GOOS == "darwin" { - target = convertTarget(sc.cfg.Host, sc.cfg.User, credType) + credentialsKey = buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType) ring, _ := keyring.Open(keyring.Config{ - ServiceName: target, + ServiceName: credentialsKey, }) account := strings.ToUpper(sc.cfg.User) item := keyring.Item{ @@ -91,33 +255,28 @@ func setCredential(sc *snowflakeConn, credType, token string) { if err := ring.Set(item); err != nil { logger.Debugf("Failed to write to keychain. Err: %v", err) } - } else if runtime.GOOS == "linux" { - createCredentialCacheDir() - writeTemporaryCredential(sc, credType, token) - } else { - logger.Debug("OS not supported for Local Secure Storage") } } } -func getCredential(sc *snowflakeConn, credType string) { - var target string +func (ssm *keyringSecureStorageManager) getCredential(sc *snowflakeConn, credType string) { + var credentialsKey string cred := "" if runtime.GOOS == "windows" { - target = driverName + ":" + credType + credentialsKey = driverName + ":" + credType ring, _ := keyring.Open(keyring.Config{ WinCredPrefix: strings.ToUpper(sc.cfg.Host), ServiceName: strings.ToUpper(sc.cfg.User), }) - i, err := ring.Get(target) + i, err := ring.Get(credentialsKey) if err != nil { - logger.Debugf("Failed to read target or could not find it in Windows Credential Manager. Error: %v", err) + logger.Debugf("Failed to read credentialsKey or could not find it in Windows Credential Manager. Error: %v", err) } cred = string(i.Data) } else if runtime.GOOS == "darwin" { - target = convertTarget(sc.cfg.Host, sc.cfg.User, credType) + credentialsKey = buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType) ring, _ := keyring.Open(keyring.Config{ - ServiceName: target, + ServiceName: credentialsKey, }) account := strings.ToUpper(sc.cfg.User) i, err := ring.Get(account) @@ -130,11 +289,6 @@ func getCredential(sc *snowflakeConn, credType string) { } else { logger.Debug("Successfully read token. Returning as string") } - } else if runtime.GOOS == "linux" { - createCredentialCacheDir() - cred = readTemporaryCredential(sc, credType) - } else { - logger.Debug("OS not supported for Local Secure Storage") } if credType == idToken { @@ -146,139 +300,49 @@ func getCredential(sc *snowflakeConn, credType string) { } } -func deleteCredential(sc *snowflakeConn, credType string) { - target := driverName + ":" + credType +func (ssm *keyringSecureStorageManager) deleteCredential(sc *snowflakeConn, credType string) { + credentialsKey := driverName + ":" + credType if runtime.GOOS == "windows" { ring, _ := keyring.Open(keyring.Config{ WinCredPrefix: strings.ToUpper(sc.cfg.Host), ServiceName: strings.ToUpper(sc.cfg.User), }) - err := ring.Remove(target) + err := ring.Remove(credentialsKey) if err != nil { - logger.Debugf("Failed to delete target in Windows Credential Manager. Error: %v", err) + logger.Debugf("Failed to delete credentialsKey in Windows Credential Manager. Error: %v", err) } } else if runtime.GOOS == "darwin" { - target = convertTarget(sc.cfg.Host, sc.cfg.User, credType) + credentialsKey = buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType) ring, _ := keyring.Open(keyring.Config{ - ServiceName: target, + ServiceName: credentialsKey, }) account := strings.ToUpper(sc.cfg.User) err := ring.Remove(account) if err != nil { - logger.Debugf("Failed to delete target in keychain. Error: %v", err) + logger.Debugf("Failed to delete credentialsKey in keychain. Error: %v", err) } - } else if runtime.GOOS == "linux" { - deleteTemporaryCredential(sc, credType) } } -// Reads temporary credential file when OS is Linux. -func readTemporaryCredential(sc *snowflakeConn, credType string) string { - target := convertTarget(sc.cfg.Host, sc.cfg.User, credType) - credCacheLock.Lock() - defer credCacheLock.Unlock() - localCredCache := readTemporaryCacheFile() - cred := localCredCache[target] - if cred != "" { - logger.Debug("Successfully read token. Returning as string") - } else { - logger.Debug("Returned credential is empty") - } - return cred +func buildCredentialsKey(host, user, credType string) string { + host = strings.ToUpper(host) + user = strings.ToUpper(user) + credType = strings.ToUpper(credType) + return host + ":" + user + ":" + driverName + ":" + credType } -// Writes to temporary credential file when OS is Linux. -func writeTemporaryCredential(sc *snowflakeConn, credType, token string) { - target := convertTarget(sc.cfg.Host, sc.cfg.User, credType) - credCacheLock.Lock() - defer credCacheLock.Unlock() - localCredCache[target] = token - - j, err := json.Marshal(localCredCache) - if err != nil { - logger.Warnf("failed to convert credential to JSON.") - return - } - writeTemporaryCacheFile(j) +type noopSecureStorageManager struct { } -func deleteTemporaryCredential(sc *snowflakeConn, credType string) { - if credCacheDir == "" { - logger.Debug("Cache file doesn't exist. Skipping deleting credential file.") - } else { - credCacheLock.Lock() - defer credCacheLock.Unlock() - target := convertTarget(sc.cfg.Host, sc.cfg.User, credType) - delete(localCredCache, target) - j, err := json.Marshal(localCredCache) - if err != nil { - logger.Warnf("failed to convert credential to JSON.") - return - } - writeTemporaryCacheFile(j) - } +func newNoopSecureStorageManager() secureStorageManager { + return &noopSecureStorageManager{} } -func readTemporaryCacheFile() map[string]string { - if credCache == "" { - logger.Debug("Cache file doesn't exist. Skipping reading credential file.") - return nil - } - jsonData, err := os.ReadFile(credCache) - if err != nil { - logger.Debugf("Failed to read credential file: %v", err) - return nil - } - err = json.Unmarshal([]byte(jsonData), &localCredCache) - if err != nil { - logger.Debugf("failed to read JSON. Err: %v", err) - return nil - } - - return localCredCache +func (ssm *noopSecureStorageManager) setCredential(sc *snowflakeConn, credType, token string) { } -func writeTemporaryCacheFile(input []byte) { - if credCache == "" { - logger.Debug("Cache file doesn't exist. Skipping writing temporary credential file.") - } else { - logger.Debugf("writing credential cache file. %v\n", credCache) - credCacheLockFileName := credCache + ".lck" - err := os.Mkdir(credCacheLockFileName, 0600) - logger.Debugf("Creating lock file. %v", credCacheLockFileName) - - switch { - case os.IsExist(err): - statinfo, err := os.Stat(credCacheLockFileName) - if err != nil { - logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", credCache, err) - return - } - if time.Since(statinfo.ModTime()) < 15*time.Minute { - logger.Debugf("other process locks the cache file. %v. ignored.\n", credCache) - return - } - if err = os.Remove(credCacheLockFileName); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return - } - if err = os.Mkdir(credCacheLockFileName, 0600); err != nil { - logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err) - return - } - } - defer os.RemoveAll(credCacheLockFileName) - - if err = os.WriteFile(credCache, input, 0644); err != nil { - logger.Debugf("Failed to write the cache file. File: %v err: %v.", credCache, err) - } - } +func (ssm *noopSecureStorageManager) getCredential(sc *snowflakeConn, credType string) { } -func convertTarget(host, user, credType string) string { - host = strings.ToUpper(host) - user = strings.ToUpper(user) - credType = strings.ToUpper(credType) - target := host + ":" + user + ":" + driverName + ":" + credType - return target +func (ssm *noopSecureStorageManager) deleteCredential(sc *snowflakeConn, credType string) { //TODO implement me } diff --git a/secure_storage_manager_test.go b/secure_storage_manager_test.go index 2f1ceac6d..3b3c2325a 100644 --- a/secure_storage_manager_test.go +++ b/secure_storage_manager_test.go @@ -3,10 +3,6 @@ package gosnowflake import ( - "errors" - "io" - "os" - "runtime" "testing" ) @@ -23,114 +19,44 @@ type tcCredentials struct { } func TestSetAndGetCredentialMfa(t *testing.T) { - if runtime.GOOS == "darwin" { - t.Skip("MacOS requires keychain password to be manually entered.") - } else { - fakeMfaToken := "fakeMfaToken" - expectedMfaToken := "fakeMfaToken" - sc := getDefaultSnowflakeConn() - sc.cfg.Host = "testhost" - setCredential(sc, mfaToken, fakeMfaToken) - getCredential(sc, mfaToken) - - if sc.cfg.MfaToken != expectedMfaToken { - t.Fatalf("Expected mfa token %v but got %v", expectedMfaToken, sc.cfg.MfaToken) - } + skipOnMac(t, "keyring asks for password") + fakeMfaToken := "fakeMfaToken" + expectedMfaToken := "fakeMfaToken" + sc := getDefaultSnowflakeConn() + sc.cfg.Host = "testhost" + credentialsStorage.setCredential(sc, mfaToken, fakeMfaToken) + credentialsStorage.getCredential(sc, mfaToken) - // delete credential and check it no longer exists - deleteCredential(sc, mfaToken) - getCredential(sc, mfaToken) - if sc.cfg.MfaToken != "" { - t.Fatalf("Expected mfa token to be empty but got %v", sc.cfg.MfaToken) - } + if sc.cfg.MfaToken != expectedMfaToken { + t.Fatalf("Expected mfa token %v but got %v", expectedMfaToken, sc.cfg.MfaToken) } -} - -func TestSetAndGetCredentialIdToken(t *testing.T) { - if runtime.GOOS == "darwin" { - t.Skip("MacOS requires keychain password to be manually entered.") - } else { - fakeIDToken := "fakeIDToken" - expectedIDToken := "fakeIDToken" - sc := getDefaultSnowflakeConn() - sc.cfg.Host = "testhost" - setCredential(sc, idToken, fakeIDToken) - getCredential(sc, idToken) - if sc.cfg.IDToken != expectedIDToken { - t.Fatalf("Expected id token %v but got %v", expectedIDToken, sc.cfg.IDToken) - } - - // delete credential and check it no longer exists - deleteCredential(sc, idToken) - getCredential(sc, idToken) - if sc.cfg.IDToken != "" { - t.Fatalf("Expected id token to be empty but got %v", sc.cfg.IDToken) - } + // delete credential and check it no longer exists + credentialsStorage.deleteCredential(sc, mfaToken) + credentialsStorage.getCredential(sc, mfaToken) + if sc.cfg.MfaToken != "" { + t.Fatalf("Expected mfa token to be empty but got %v", sc.cfg.MfaToken) } } -func TestCreateCredentialCache(t *testing.T) { - skipOnJenkins(t, "cannot write to file system") - if runningOnGithubAction() { - t.Skip("cannot write to github file system") - } - dirName, err := os.UserHomeDir() - if err != nil { - t.Error(err) - } - srcFileName := dirName + "/.cache/snowflake/temporary_credential.json" - tmpFileName := srcFileName + "_tmp" - dst, err := os.Create(tmpFileName) - if err != nil { - t.Error(err) - } - defer dst.Close() - var src *os.File - if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) { - // file does not exist - if err = os.MkdirAll(dirName+"/.cache/snowflake/", os.ModePerm); err != nil { - t.Error(err) - } - if _, err = os.Create(srcFileName); err != nil { - t.Error(err) - } - } else if err != nil { - t.Error(err) - } else { - // file exists - src, err = os.Open(srcFileName) - if err != nil { - t.Error(err) - } - defer src.Close() - // copy original contents to temporary file - if _, err = io.Copy(dst, src); err != nil { - t.Error(err) - } - if err = os.Remove(srcFileName); err != nil { - t.Error(err) - } - } +func TestSetAndGetCredentialIdToken(t *testing.T) { + skipOnMac(t, "keyring asks for password") + fakeIDToken := "fakeIDToken" + expectedIDToken := "fakeIDToken" + sc := getDefaultSnowflakeConn() + sc.cfg.Host = "testhost" + credentialsStorage.setCredential(sc, idToken, fakeIDToken) + credentialsStorage.getCredential(sc, idToken) - createCredentialCacheDir() - if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) { - t.Error(err) - } else if err != nil { - t.Error(err) + if sc.cfg.IDToken != expectedIDToken { + t.Fatalf("Expected id token %v but got %v", expectedIDToken, sc.cfg.IDToken) } - // cleanup - src, _ = os.Open(tmpFileName) - defer src.Close() - dst, _ = os.OpenFile(srcFileName, os.O_WRONLY, readWriteFileMode) - defer dst.Close() - // copy temporary file contents back to original file - if _, err = io.Copy(dst, src); err != nil { - t.Fatal(err) - } - if err = os.Remove(tmpFileName); err != nil { - t.Error(err) + // delete credential and check it no longer exists + credentialsStorage.deleteCredential(sc, idToken) + credentialsStorage.getCredential(sc, idToken) + if sc.cfg.IDToken != "" { + t.Fatalf("Expected id token to be empty but got %v", sc.cfg.IDToken) } } @@ -140,39 +66,42 @@ func TestStoreTemporaryCredental(t *testing.T) { } testcases := []tcCredentials{ - {"mfaToken", "598ghFnjfh8BBgmf45mmhgkfRR45mgkt5"}, - {"IdToken", "090Arftf54Jk3gh57ggrVvf09lJa3DD"}, - } - createCredentialCacheDir() - if credCache == "" { - t.Fatalf("failed to create credential cache") + {mfaToken, "598ghFnjfh8BBgmf45mmhgkfRR45mgkt5"}, + {idToken, "090Arftf54Jk3gh57ggrVvf09lJa3DD"}, } + + ssm := newFileBasedSecureStorageManager() + _, ok := ssm.(*fileBasedSecureStorageManager) + assertTrueF(t, ok) + sc := getDefaultSnowflakeConn() for _, test := range testcases { t.Run(test.token, func(t *testing.T) { - writeTemporaryCredential(sc, test.credType, test.token) - target := convertTarget(sc.cfg.Host, sc.cfg.User, test.credType) - _, ok := localCredCache[target] - if !ok { - t.Fatalf("failed to write credential to local cache") + ssm.setCredential(sc, test.credType, test.token) + ssm.getCredential(sc, test.credType) + if test.credType == mfaToken { + assertEqualE(t, sc.cfg.MfaToken, test.token) + } else { + assertEqualE(t, sc.cfg.IDToken, test.token) } - tmpCred := readTemporaryCredential(sc, test.credType) - if tmpCred == "" { - t.Fatalf("failed to read credential from temporary cache") + ssm.deleteCredential(sc, test.credType) + ssm.getCredential(sc, test.credType) + if test.credType == mfaToken { + assertEqualE(t, sc.cfg.MfaToken, "") } else { - deleteTemporaryCredential(sc, test.credType) + assertEqualE(t, sc.cfg.IDToken, "") } }) } } -func TestConvertTarget(t *testing.T) { +func TestBuildCredentialsKey(t *testing.T) { testcases := []tcTargets{ {"testaccount.snowflakecomputing.com", "testuser", "mfaToken", "TESTACCOUNT.SNOWFLAKECOMPUTING.COM:TESTUSER:SNOWFLAKE-GO-DRIVER:MFATOKEN"}, {"testaccount.snowflakecomputing.com", "testuser", "IdToken", "TESTACCOUNT.SNOWFLAKECOMPUTING.COM:TESTUSER:SNOWFLAKE-GO-DRIVER:IDTOKEN"}, } for _, test := range testcases { - target := convertTarget(test.host, test.user, test.credType) + target := buildCredentialsKey(test.host, test.user, test.credType) if target != test.out { t.Fatalf("failed to convert target. expected: %v, but got: %v", test.out, target) } diff --git a/util_test.go b/util_test.go index 807123214..f67912dbb 100644 --- a/util_test.go +++ b/util_test.go @@ -8,6 +8,7 @@ import ( "fmt" "math/rand" "os" + "runtime" "strconv" "sync" "testing" @@ -397,6 +398,12 @@ func runOnlyOnDockerContainer(t *testing.T, message string) { } } +func skipOnMac(t *testing.T, reason string) { + if runtime.GOOS == "darwin" && runningOnGithubAction() { + t.Skip("skipped on Mac: " + reason) + } +} + func randomString(n int) string { r := rand.New(rand.NewSource(time.Now().UnixNano())) alpha := []rune("abcdefghijklmnopqrstuvwxyz")