diff --git a/auth.go b/auth.go index d1565cd16..4126355e6 100644 --- a/auth.go +++ b/auth.go @@ -215,7 +215,7 @@ func postAuth( client *http.Client, params *url.Values, headers map[string]string, - body []byte, + bodyCreator bodyCreatorType, timeout time.Duration) ( data *authResponse, err error) { params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) @@ -223,7 +223,7 @@ func postAuth( fullURL := sr.getFullURL(loginRequestPath, params) logger.Infof("full URL: %v", fullURL) - resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, body, timeout, true) + resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, true) if err != nil { return nil, err } @@ -287,6 +287,23 @@ func authenticate( samlResponse []byte, proofKey []byte, ) (resp *authResponseMain, err error) { + if sc.cfg.Authenticator == AuthTypeTokenAccessor { + logger.Info("Bypass authentication using existing token from token accessor") + sessionInfo := authResponseSessionInfo{ + DatabaseName: sc.cfg.Database, + SchemaName: sc.cfg.Schema, + WarehouseName: sc.cfg.Warehouse, + RoleName: sc.cfg.Role, + } + token, masterToken, sessionID := sc.cfg.TokenAccessor.GetTokens() + return &authResponseMain{ + Token: token, + MasterToken: masterToken, + SessionID: sessionID, + SessionInfo: sessionInfo, + }, nil + } + headers := getHeaders() clientEnvironment := authRequestClientEnvironment{ Application: sc.cfg.Application, @@ -310,6 +327,67 @@ func authenticate( if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { sessionParameters[clientStoreTemporaryCredential] = true } + bodyCreator := func() ([]byte, error) { + return createRequestBody(sc, sessionParameters, clientEnvironment, proofKey, samlResponse) + } + + params := &url.Values{} + if sc.cfg.Database != "" { + params.Add("databaseName", sc.cfg.Database) + } + if sc.cfg.Schema != "" { + params.Add("schemaName", sc.cfg.Schema) + } + if sc.cfg.Warehouse != "" { + params.Add("warehouse", sc.cfg.Warehouse) + } + if sc.cfg.Role != "" { + params.Add("roleName", sc.cfg.Role) + } + + logger.WithContext(sc.ctx).Infof("PARAMS for Auth: %v, %v, %v, %v, %v, %v", + params, sc.rest.Protocol, sc.rest.Host, sc.rest.Port, sc.rest.LoginTimeout, sc.cfg.Authenticator.String()) + + respd, err := sc.rest.FuncPostAuth(ctx, sc.rest, sc.rest.getClientFor(sc.cfg.Authenticator), params, headers, bodyCreator, sc.rest.LoginTimeout) + if err != nil { + return nil, err + } + if !respd.Success { + logger.Errorln("Authentication FAILED") + sc.rest.TokenAccessor.SetTokens("", "", -1) + if sessionParameters[clientRequestMfaToken] == true { + deleteCredential(sc, mfaToken) + } + if sessionParameters[clientStoreTemporaryCredential] == true { + deleteCredential(sc, idToken) + } + code, err := strconv.Atoi(respd.Code) + if err != nil { + code = -1 + return nil, err + } + return nil, (&SnowflakeError{ + Number: code, + SQLState: SQLStateConnectionRejected, + Message: respd.Message, + }).exceptionTelemetry(sc) + } + logger.Info("Authentication SUCCESS") + sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID) + if sessionParameters[clientRequestMfaToken] == true { + token := respd.Data.MfaToken + setCredential(sc, mfaToken, token) + } + if sessionParameters[clientStoreTemporaryCredential] == true { + token := respd.Data.IDToken + setCredential(sc, idToken, token) + } + return &respd.Data, nil +} + +func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface{}, + clientEnvironment authRequestClientEnvironment, proofKey []byte, samlResponse []byte, +) ([]byte, error) { requestMain := authRequestData{ ClientAppID: clientType, ClientAppVersion: SnowflakeGoDriverVersion, @@ -362,83 +440,16 @@ func authenticate( if sc.cfg.MfaToken != "" { requestMain.Token = sc.cfg.MfaToken } - case AuthTypeTokenAccessor: - logger.Info("Bypass authentication using existing token from token accessor") - sessionInfo := authResponseSessionInfo{ - DatabaseName: sc.cfg.Database, - SchemaName: sc.cfg.Schema, - WarehouseName: sc.cfg.Warehouse, - RoleName: sc.cfg.Role, - } - token, masterToken, sessionID := sc.cfg.TokenAccessor.GetTokens() - return &authResponseMain{ - Token: token, - MasterToken: masterToken, - SessionID: sessionID, - SessionInfo: sessionInfo, - }, nil } authRequest := authRequest{ Data: requestMain, } - params := &url.Values{} - if sc.cfg.Database != "" { - params.Add("databaseName", sc.cfg.Database) - } - if sc.cfg.Schema != "" { - params.Add("schemaName", sc.cfg.Schema) - } - if sc.cfg.Warehouse != "" { - params.Add("warehouse", sc.cfg.Warehouse) - } - if sc.cfg.Role != "" { - params.Add("roleName", sc.cfg.Role) - } - jsonBody, err := json.Marshal(authRequest) - if err != nil { - return - } - - logger.WithContext(sc.ctx).Infof("PARAMS for Auth: %v, %v, %v, %v, %v, %v", - params, sc.rest.Protocol, sc.rest.Host, sc.rest.Port, sc.rest.LoginTimeout, sc.cfg.Authenticator.String()) - - respd, err := sc.rest.FuncPostAuth(ctx, sc.rest, sc.rest.getClientFor(sc.cfg.Authenticator), params, headers, jsonBody, sc.rest.LoginTimeout) if err != nil { return nil, err } - if !respd.Success { - logger.Errorln("Authentication FAILED") - sc.rest.TokenAccessor.SetTokens("", "", -1) - if sessionParameters[clientRequestMfaToken] == true { - deleteCredential(sc, mfaToken) - } - if sessionParameters[clientStoreTemporaryCredential] == true { - deleteCredential(sc, idToken) - } - code, err := strconv.Atoi(respd.Code) - if err != nil { - code = -1 - return nil, err - } - return nil, (&SnowflakeError{ - Number: code, - SQLState: SQLStateConnectionRejected, - Message: respd.Message, - }).exceptionTelemetry(sc) - } - logger.Info("Authentication SUCCESS") - sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID) - if sessionParameters[clientRequestMfaToken] == true { - token := respd.Data.MfaToken - setCredential(sc, mfaToken, token) - } - if sessionParameters[clientStoreTemporaryCredential] == true { - token := respd.Data.IDToken - setCredential(sc, idToken, token) - } - return &respd.Data, nil + return jsonBody, nil } // Generate a JWT token in string given the configuration diff --git a/auth_test.go b/auth_test.go index 190bd39c8..35b80e327 100644 --- a/auth_test.go +++ b/auth_test.go @@ -26,51 +26,54 @@ func TestUnitPostAuth(t *testing.T) { FuncAuthPost: postAuthTestAfterRenew, } var err error - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0) + bodyCreator := func() ([]byte, error) { + return []byte{0x12, 0x34}, nil + } + _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err != nil { t.Fatalf("err: %v", err) } sr.FuncAuthPost = postAuthTestError - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0) + _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } sr.FuncAuthPost = postAuthTestAppBadGatewayError - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0) + _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } sr.FuncAuthPost = postAuthTestAppForbiddenError - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0) + _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } sr.FuncAuthPost = postAuthTestAppUnexpectedError - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0) + _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } } -func postAuthFailServiceIssue(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { +func postAuthFailServiceIssue(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return nil, &SnowflakeError{ Number: ErrCodeServiceUnavailable, } } -func postAuthFailWrongAccount(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { +func postAuthFailWrongAccount(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return nil, &SnowflakeError{ Number: ErrCodeFailedToConnect, } } -func postAuthFailUnknown(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { +func postAuthFailUnknown(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return nil, &SnowflakeError{ Number: ErrFailedToAuth, } } -func postAuthSuccessWithErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { +func postAuthSuccessWithErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: false, Code: "98765", @@ -78,7 +81,7 @@ func postAuthSuccessWithErrorCode(_ context.Context, _ *snowflakeRestful, _ *htt }, nil } -func postAuthSuccessWithInvalidErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { +func postAuthSuccessWithInvalidErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: false, Code: "abcdef", @@ -86,7 +89,7 @@ func postAuthSuccessWithInvalidErrorCode(_ context.Context, _ *snowflakeRestful, }, nil } -func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { +func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, Data: authResponseMain{ @@ -99,8 +102,9 @@ func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ * }, nil } -func postAuthCheckSAMLResponse(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckSAMLResponse(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -126,9 +130,10 @@ func postAuthCheckOAuth( _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, - jsonBody []byte, + bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -153,8 +158,9 @@ func postAuthCheckOAuth( }, nil } -func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -173,8 +179,9 @@ func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Clien }, nil } -func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -195,8 +202,9 @@ func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ * // JWT token validate callback function to check the JWT token // It uses the public key paired with the testPrivKey -func postAuthCheckJWTToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckJWTToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -231,8 +239,9 @@ func postAuthCheckJWTToken(_ context.Context, _ *snowflakeRestful, _ *http.Clien }, nil } -func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -253,8 +262,9 @@ func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ }, nil } -func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -275,8 +285,9 @@ func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestfu }, nil } -func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -292,8 +303,9 @@ func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestf }, nil } -func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -314,8 +326,9 @@ func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *htt }, nil } -func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } @@ -336,8 +349,9 @@ func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ }, nil } -func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) { +func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) { var ar authRequest + jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { return nil, err } diff --git a/ci/scripts/hang_webserver.py b/ci/scripts/hang_webserver.py index c08d6c85b..dbcee5c16 100755 --- a/ci/scripts/hang_webserver.py +++ b/ci/scripts/hang_webserver.py @@ -4,18 +4,35 @@ from socketserver import ThreadingMixIn import threading import time +import json class HTTPRequestHandler(BaseHTTPRequestHandler): + invocations = 0 + def do_POST(self): - if self.path.startswith('/403'): + if self.path.startswith('/reset'): + print("Resetting HTTP mocks") + HTTPRequestHandler.invocations = 0 + self.__respond(200) + elif self.path.startswith('/invocations'): + self.__respond(200, body=str(HTTPRequestHandler.invocations)) + elif self.path.startswith('/ocsp'): + print("ocsp") + self.ocspMocks() + elif self.path.startswith('/session/v1/login-request'): + self.authMocks() + + def ocspMocks(self): + if self.path.startswith('/ocsp/403'): self.send_response(403) self.send_header('Content-Type', 'text/plain') self.end_headers() - elif self.path.startswith('/404'): + elif self.path.startswith('/ocsp/404'): self.send_response(404) self.send_header('Content-Type', 'text/plain') self.end_headers() - elif self.path.startswith('/hang'): + elif self.path.startswith('/ocsp/hang'): + print("Hanging") time.sleep(300) self.send_response(200, 'OK') self.send_header('Content-Type', 'text/plain') @@ -24,6 +41,36 @@ def do_POST(self): self.send_response(200, 'OK') self.send_header('Content-Type', 'text/plain') self.end_headers() + + def authMocks(self): + content_length = int(self.headers.get('content-length', 0)) + body = self.rfile.read(content_length) + jsonBody = json.loads(body) + if jsonBody['data']['ACCOUNT_NAME'] == "jwtAuthTokenTimeout": + HTTPRequestHandler.invocations += 1 + if HTTPRequestHandler.invocations >= 3: + self.__respond(200, body='''{ + "data": { + "token": "someToken" + }, + "success": true + }''') + else: + time.sleep(2000) + self.send_response(200) + else: + print("Unknown auth request") + self.send_response(500) + + def __respond(self, http_code, content_type='application/json', body=None): + print("responding:", body) + self.send_response(http_code) + self.send_header('Content-Type', content_type) + self.end_headers() + if body != None: + responseBody = bytes(body, "utf-8") + self.wfile.write(responseBody) + do_GET = do_POST class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): diff --git a/driver_ocsp_test.go b/driver_ocsp_test.go index f76fc0f60..9070c4c88 100644 --- a/driver_ocsp_test.go +++ b/driver_ocsp_test.go @@ -387,7 +387,7 @@ func TestOCSPFailOpenCacheServerTimeout(t *testing.T) { cleanup() defer cleanup() - setenv(cacheServerURLEnv, "http://localhost:12345/hang") + setenv(cacheServerURLEnv, "http://localhost:12345/ocsp/hang") setenv(ocspTestResponseCacheServerTimeoutEnv, "1000") config := &Config{ @@ -426,7 +426,7 @@ func TestOCSPFailClosedCacheServerTimeout(t *testing.T) { cleanup() defer cleanup() - setenv(cacheServerURLEnv, "http://localhost:12345/hang") + setenv(cacheServerURLEnv, "http://localhost:12345/ocsp/hang") setenv(ocspTestResponseCacheServerTimeoutEnv, "1000") config := &Config{ @@ -482,7 +482,7 @@ func TestOCSPFailOpenResponderTimeout(t *testing.T) { defer cleanup() setenv(cacheServerEnabledEnv, "false") - setenv(ocspTestResponderURLEnv, "http://localhost:12345/hang") + setenv(ocspTestResponderURLEnv, "http://localhost:12345/ocsp/hang") setenv(ocspTestResponderTimeoutEnv, "1000") config := &Config{ @@ -522,7 +522,7 @@ func TestOCSPFailClosedResponderTimeout(t *testing.T) { defer cleanup() setenv(cacheServerEnabledEnv, "false") - setenv(ocspTestResponderURLEnv, "http://localhost:12345/hang") + setenv(ocspTestResponderURLEnv, "http://localhost:12345/ocsp/hang") setenv(ocspTestResponderTimeoutEnv, "1000") config := &Config{ @@ -566,7 +566,7 @@ func TestOCSPFailOpenResponder404(t *testing.T) { defer cleanup() setenv(cacheServerEnabledEnv, "false") - setenv(ocspTestResponderURLEnv, "http://localhost:12345/404") + setenv(ocspTestResponderURLEnv, "http://localhost:12345/ocsp/404") config := &Config{ Account: "fakeaccount10", @@ -605,7 +605,7 @@ func TestOCSPFailClosedResponder404(t *testing.T) { defer cleanup() setenv(cacheServerEnabledEnv, "false") - setenv(ocspTestResponderURLEnv, "http://localhost:12345/404") + setenv(ocspTestResponderURLEnv, "http://localhost:12345/ocsp/404") config := &Config{ Account: "fakeaccount11", diff --git a/priv_key_test.go b/priv_key_test.go index 3f7888b5c..c3f2646f6 100644 --- a/priv_key_test.go +++ b/priv_key_test.go @@ -7,6 +7,7 @@ package gosnowflake import ( "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -115,3 +116,26 @@ func TestJWTAuthentication(t *testing.T) { db.Close() } + +func TestJWTTokenTimeout(t *testing.T) { + resetHTTPMocks(t) + + dsn := "user:pass@localhost:12345/db/schema?account=jwtAuthTokenTimeout&protocol=http&jwtClientTimeout=1" + dsn = appendPrivateKeyString(&dsn, testPrivKey) + db, err := sql.Open("snowflake", dsn) + if err != nil { + t.Fatalf(err.Error()) + } + defer db.Close() + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatalf(err.Error()) + } + defer conn.Close() + + invocations := getMocksInvocations(t) + if invocations != 3 { + t.Errorf("Unexpected number of invocations, expected 3, got %v", invocations) + } +} diff --git a/restful.go b/restful.go index 2acbbd758..327871767 100644 --- a/restful.go +++ b/restful.go @@ -44,9 +44,14 @@ const ( type ( funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error) funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool) (*http.Response, error) - funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, []byte, time.Duration, bool) (*http.Response, error) + funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, bool) (*http.Response, error) + bodyCreatorType func() ([]byte, error) ) +var emptyBodyCreator = func() ([]byte, error) { + return []byte{}, nil +} + type snowflakeRestful struct { Host string Port int @@ -70,7 +75,7 @@ type snowflakeRestful struct { FuncCloseSession func(context.Context, *snowflakeRestful, time.Duration) error FuncCancelQuery func(context.Context, *snowflakeRestful, UUID, time.Duration) error - FuncPostAuth func(context.Context, *snowflakeRestful, *http.Client, *url.Values, map[string]string, []byte, time.Duration) (*authResponse, error) + FuncPostAuth func(context.Context, *snowflakeRestful, *http.Client, *url.Values, map[string]string, bodyCreatorType, time.Duration) (*authResponse, error) FuncPostAuthSAML func(context.Context, *snowflakeRestful, map[string]string, []byte, time.Duration) (*authResponse, error) FuncPostAuthOKTA func(context.Context, *snowflakeRestful, map[string]string, []byte, string, time.Duration) (*authOKTAResponse, error) FuncGetSSO func(context.Context, *snowflakeRestful, *url.Values, map[string]string, string, time.Duration) ([]byte, error) @@ -159,8 +164,11 @@ func postRestful( timeout time.Duration, raise4XX bool) ( *http.Response, error) { - return newRetryHTTP( - ctx, sr.Client, http.NewRequest, fullURL, headers, timeout).doPost().setBody(body).doRaise4XX(raise4XX).execute() + return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout). + doPost(). + setBody(body). + doRaise4XX(raise4XX). + execute() } func getRestful( @@ -170,8 +178,7 @@ func getRestful( headers map[string]string, timeout time.Duration) ( *http.Response, error) { - return newRetryHTTP( - ctx, sr.Client, http.NewRequest, fullURL, headers, timeout).execute() + return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout).execute() } func postAuthRestful( @@ -179,12 +186,15 @@ func postAuthRestful( client *http.Client, fullURL *url.URL, headers map[string]string, - body []byte, + bodyCreator bodyCreatorType, timeout time.Duration, raise4XX bool) ( *http.Response, error) { - return newRetryHTTP( - ctx, client, http.NewRequest, fullURL, headers, timeout).doPost().setBody(body).doRaise4XX(raise4XX).execute() + return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout). + doPost(). + setBodyCreator(bodyCreator). + doRaise4XX(raise4XX). + execute() } func postRestfulQuery( diff --git a/restful_test.go b/restful_test.go index a49b2a3e7..ec2721387 100644 --- a/restful_test.go +++ b/restful_test.go @@ -22,7 +22,7 @@ func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str }, errors.New("failed to run post method") } -func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, @@ -43,7 +43,7 @@ func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.U }, nil } -func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, @@ -57,14 +57,14 @@ func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.UR }, nil } -func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusInsufficientStorage, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, @@ -110,7 +110,7 @@ func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str }, nil } -func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, diff --git a/retry.go b/retry.go index 07a7709fb..95efbc771 100644 --- a/retry.go +++ b/retry.go @@ -161,15 +161,15 @@ type clientInterface interface { } type retryHTTP struct { - ctx context.Context - client clientInterface - req requestFunc - method string - fullURL *url.URL - headers map[string]string - body []byte - timeout time.Duration - raise4XX bool + ctx context.Context + client clientInterface + req requestFunc + method string + fullURL *url.URL + headers map[string]string + bodyCreator bodyCreatorType + timeout time.Duration + raise4XX bool } func newRetryHTTP(ctx context.Context, @@ -185,8 +185,8 @@ func newRetryHTTP(ctx context.Context, instance.method = "GET" instance.fullURL = fullURL instance.headers = headers - instance.body = nil instance.timeout = timeout + instance.bodyCreator = emptyBodyCreator instance.raise4XX = false return &instance } @@ -202,7 +202,14 @@ func (r *retryHTTP) doPost() *retryHTTP { } func (r *retryHTTP) setBody(body []byte) *retryHTTP { - r.body = body + r.bodyCreator = func() ([]byte, error) { + return body, nil + } + return r +} + +func (r *retryHTTP) setBodyCreator(bodyCreator bodyCreatorType) *retryHTTP { + r.bodyCreator = bodyCreator return r } @@ -217,7 +224,11 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { for { logger.Debugf("retry count: %v", retryCounter) - req, err := r.req(r.method, r.fullURL.String(), bytes.NewReader(r.body)) + body, err := r.bodyCreator() + if err != nil { + return nil, err + } + req, err := r.req(r.method, r.fullURL.String(), bytes.NewReader(body)) if err != nil { return nil, err } diff --git a/retry_test.go b/retry_test.go index 57d019ebb..5acb38039 100644 --- a/retry_test.go +++ b/retry_test.go @@ -3,7 +3,9 @@ package gosnowflake import ( + "bytes" "context" + "fmt" "io" "net/http" "net/url" @@ -50,10 +52,17 @@ type fakeHTTPClient struct { success bool // return success after retry in cnt times timeout bool // timeout body []byte // return body + reqBody []byte // last request body statusCode int // status code } func (c *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { + if req != nil { + buf := new(bytes.Buffer) + buf.ReadFrom(req.Body) + c.reqBody = buf.Bytes() + } + c.cnt-- if c.cnt < 0 { c.cnt = 0 @@ -269,6 +278,33 @@ func TestRetryLoginRequest(t *testing.T) { } } +func TestRetryAuthLoginRequest(t *testing.T) { + logger.Info("Retry N times always with newer body") + client := &fakeHTTPClient{ + cnt: 3, + success: true, + timeout: true, + } + urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_id=testid") + if err != nil { + t.Fatal("failed to parse the test URL") + } + execID := 0 + bodyCreator := func() ([]byte, error) { + execID++ + return []byte(fmt.Sprintf("execID: %d", execID)), nil + } + _, err = newRetryHTTP(context.TODO(), + client, + http.NewRequest, urlPtr, make(map[string]string), 60*time.Second).doPost().setBodyCreator(bodyCreator).execute() + if err != nil { + t.Fatal("failed to run retry") + } + if lastReqBody := string(client.reqBody); lastReqBody != "execID: 3" { + t.Fatalf("body should be updated on each request, expected: execID: 3, last body: %v", lastReqBody) + } +} + func TestLoginRetry429(t *testing.T) { client := &fakeHTTPClient{ cnt: 3, diff --git a/test_util.go b/test_util.go new file mode 100644 index 000000000..f9f7a9b25 --- /dev/null +++ b/test_util.go @@ -0,0 +1,33 @@ +// Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. + +package gosnowflake + +import ( + "io" + "net/http" + "strconv" + "testing" +) + +func resetHTTPMocks(t *testing.T) { + _, err := http.Post("http://localhost:12345/reset", "text/plain", nil) + if err != nil { + t.Fatalf("Cannot reset HTTP mocks") + } +} + +func getMocksInvocations(t *testing.T) int { + resp, err := http.Get("http://localhost:12345/invocations") + if err != nil { + t.Fatalf(err.Error()) + } + bytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf(err.Error()) + } + ret, err := strconv.Atoi(string(bytes)) + if err != nil { + t.Fatalf(err.Error()) + } + return ret +}