From 29eef08c889c5531c403ed901e9557a4088c1bcc Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 26 Jul 2024 19:07:51 +0530 Subject: [PATCH] fix: jwks cache and test fixes --- .gitignore | 3 ++ .../emailpassword/network_interceptor_test.go | 16 +++---- recipe/emailpassword/querier_test.go | 28 ++++++------- .../passwordless/passwordless_email_test.go | 3 ++ recipe/session/constants.go | 1 - recipe/session/recipeImplementation.go | 8 +++- recipe/session/sessionFunctions.go | 2 +- recipe/session/session_test.go | 42 +++++++++---------- recipe/session/sessmodels/models.go | 2 + recipe/session/utils.go | 6 +++ 10 files changed, 65 insertions(+), 46 deletions(-) diff --git a/.gitignore b/.gitignore index 138047d5..1e0644cf 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,6 @@ releasePassword .vscode/ .idea/ /test_report + +build-errors.log +main diff --git a/recipe/emailpassword/network_interceptor_test.go b/recipe/emailpassword/network_interceptor_test.go index fd719c04..b04b5085 100644 --- a/recipe/emailpassword/network_interceptor_test.go +++ b/recipe/emailpassword/network_interceptor_test.go @@ -19,9 +19,9 @@ func TestNetworkInterceptorDuringSignIn(t *testing.T) { configValue := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(request *http.Request, context supertokens.UserContext) *http.Request { + NetworkInterceptor: func(request *http.Request, context supertokens.UserContext) (*http.Request, error) { isNetworkIntercepted = true - return request + return request, nil }, }, AppInfo: supertokens.AppInfo{ @@ -105,11 +105,11 @@ func TestNetworkInterceptorIncorrectCoreURL(t *testing.T) { configValue := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(request *http.Request, context supertokens.UserContext) *http.Request { + NetworkInterceptor: func(request *http.Request, context supertokens.UserContext) (*http.Request, error) { isNetworkIntercepted = true newRequest := request newRequest.URL.Path = "/public/recipe/incorrect/path" - return newRequest + return newRequest, nil }, }, AppInfo: supertokens.AppInfo{ @@ -149,12 +149,12 @@ func TestNetworkInterceptorIncorrectQueryParams(t *testing.T) { configValue := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(r *http.Request, context supertokens.UserContext) *http.Request { + NetworkInterceptor: func(r *http.Request, context supertokens.UserContext) (*http.Request, error) { isNetworkIntercepted = true newRequest := r q := url.Values{} newRequest.URL.RawQuery = q.Encode() - return newRequest + return newRequest, nil }, }, AppInfo: supertokens.AppInfo{ @@ -191,12 +191,12 @@ func TestNetworkInterceptorRequestBody(t *testing.T) { configValue := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(r *http.Request, context supertokens.UserContext) *http.Request { + NetworkInterceptor: func(r *http.Request, context supertokens.UserContext) (*http.Request, error) { isNetworkIntercepted = true newBody := bytes.NewReader([]byte(`{"newKey": "newValue"}`)) req, _ := http.NewRequest(r.Method, r.URL.String(), newBody) req.Header = r.Header - return req + return req, nil }, }, AppInfo: supertokens.AppInfo{ diff --git a/recipe/emailpassword/querier_test.go b/recipe/emailpassword/querier_test.go index f78ca68f..a8d9eb47 100644 --- a/recipe/emailpassword/querier_test.go +++ b/recipe/emailpassword/querier_test.go @@ -17,9 +17,9 @@ func TestCachingWorks(t *testing.T) { config := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request { + NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) { calledCore = true - return r + return r, nil }, }, AppInfo: supertokens.AppInfo{ @@ -79,9 +79,9 @@ func TestNoCachingIfDisabledByUser(t *testing.T) { config := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request { + NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) { calledCore = true - return r + return r, nil }, DisableCoreCallCache: true, }, @@ -124,9 +124,9 @@ func TestNoCachingIfHeadersAreDifferent(t *testing.T) { config := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request { + NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) { calledCore = true - return r + return r, nil }, }, AppInfo: supertokens.AppInfo{ @@ -176,9 +176,9 @@ func TestCachingGetsClearWhenQueryWithoutUserContext(t *testing.T) { config := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request { + NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) { calledCore = true - return r + return r, nil }, }, AppInfo: supertokens.AppInfo{ @@ -223,9 +223,9 @@ func TestCachingDoesNotGetClearWithNonGetIfKeepAlive(t *testing.T) { config := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request { + NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) { calledCore = true - return r + return r, nil }, }, AppInfo: supertokens.AppInfo{ @@ -295,9 +295,9 @@ func TestCachingGetsClearWithNonGetIfKeepAliveIsFalse(t *testing.T) { config := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request { + NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) { calledCore = true - return r + return r, nil }, }, AppInfo: supertokens.AppInfo{ @@ -367,9 +367,9 @@ func TestCachingGetsClearWithNonGetIfKeepAliveIsNotSet(t *testing.T) { config := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ ConnectionURI: "http://localhost:8080", - NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request { + NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) { calledCore = true - return r + return r, nil }, }, AppInfo: supertokens.AppInfo{ diff --git a/recipe/passwordless/passwordless_email_test.go b/recipe/passwordless/passwordless_email_test.go index f886c06e..c2f0e592 100644 --- a/recipe/passwordless/passwordless_email_test.go +++ b/recipe/passwordless/passwordless_email_test.go @@ -970,6 +970,9 @@ func TestThatMagicLinkUsesRightValueFromOriginFunction(t *testing.T) { APIDomain: "api.supertokens.io", AppName: "SuperTokens", GetOrigin: func(request *http.Request, userContext supertokens.UserContext) (string, error) { + if request == nil { + return "https://supertokens.com", nil + } // read request body decoder := json.NewDecoder(request.Body) var requestBody map[string]interface{} diff --git a/recipe/session/constants.go b/recipe/session/constants.go index 1db241f1..d2393829 100644 --- a/recipe/session/constants.go +++ b/recipe/session/constants.go @@ -32,7 +32,6 @@ const ( CookieSameSite_STRICT = "strict" ) -var JWKCacheMaxAgeInMs int64 = 60000 var JWKRefreshRateLimit = 500 var protectedProps = []string{ "sub", diff --git a/recipe/session/recipeImplementation.go b/recipe/session/recipeImplementation.go index c6b2a464..2a81024c 100644 --- a/recipe/session/recipeImplementation.go +++ b/recipe/session/recipeImplementation.go @@ -38,6 +38,12 @@ var mutex sync.RWMutex func getJWKSFromCacheIfPresent() *sessmodels.GetJWKSResult { mutex.RLock() defer mutex.RUnlock() + + sessionInstance, err := getRecipeInstanceOrThrowError() + if err != nil { + return nil + } + if jwksCache != nil { // This means that we have valid JWKs for the given core path // We check if we need to refresh before returning @@ -48,7 +54,7 @@ func getJWKSFromCacheIfPresent() *sessmodels.GetJWKSResult { // Note that this also means that the SDK will not try to query any other Core (if there are multiple) // if it has a valid cache entry from one of the core URLs. It will only attempt to fetch // from the cores again after the entry in the cache is expired - if (currentTime - jwksCache.LastFetched) < JWKCacheMaxAgeInMs { + if (currentTime - jwksCache.LastFetched) < int64(sessionInstance.Config.JWKSRefreshIntervalSec*1000) { if supertokens.IsRunningInTestMode() { if len(returnedFromCache) == cap(returnedFromCache) { // need to clear the channel if full because it's not being consumed in the test close(returnedFromCache) diff --git a/recipe/session/sessionFunctions.go b/recipe/session/sessionFunctions.go index fb7be614..c1de5dd2 100644 --- a/recipe/session/sessionFunctions.go +++ b/recipe/session/sessionFunctions.go @@ -95,7 +95,7 @@ func getSessionHelper(config sessmodels.TypeNormalisedInput, querier supertokens // We check if the token was created since the last time we refreshed the keys from the core // Since we do not know the exact timing of the last refresh, we check against the max age - if timeCreated <= (GetCurrTimeInMS() - uint64(JWKCacheMaxAgeInMs)) { + if timeCreated <= (GetCurrTimeInMS() - config.JWKSRefreshIntervalSec*1000) { return sessmodels.GetSessionResponse{}, err } } else { diff --git a/recipe/session/session_test.go b/recipe/session/session_test.go index 16ad207e..eb1b7fc5 100644 --- a/recipe/session/session_test.go +++ b/recipe/session/session_test.go @@ -1496,10 +1496,9 @@ This test verifies that the SDK calls the well known API properly in the normal */ func TestThatJWKSIsFetchedAsExpected(t *testing.T) { originalRefreshlimit := JWKRefreshRateLimit - originalCacheAge := JWKCacheMaxAgeInMs JWKRefreshRateLimit = 100 - JWKCacheMaxAgeInMs = 2000 + var JWKCacheMaxAgeInSec uint64 = 2 lastLineBeforeTest := unittesting.GetInfoLogData(t, "").LastLine @@ -1513,7 +1512,9 @@ func TestThatJWKSIsFetchedAsExpected(t *testing.T) { APIDomain: "api.supertokens.io", }, RecipeList: []supertokens.Recipe{ - Init(nil), + Init(&sessmodels.TypeInput{ + JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec, + }), }, } BeforeEach() @@ -1548,7 +1549,7 @@ func TestThatJWKSIsFetchedAsExpected(t *testing.T) { t.Error(err.Error()) } - time.Sleep(time.Duration(JWKCacheMaxAgeInMs) * time.Millisecond) + time.Sleep(time.Duration(JWKCacheMaxAgeInSec) * time.Second) logInfoAfterWaiting := unittesting.GetInfoLogData(t, lastLineBeforeTest) wellKnownCallLogs = []string{} @@ -1562,7 +1563,6 @@ func TestThatJWKSIsFetchedAsExpected(t *testing.T) { assert.Equal(t, len(wellKnownCallLogs), 1) JWKRefreshRateLimit = originalRefreshlimit - JWKCacheMaxAgeInMs = originalCacheAge } /* @@ -1578,10 +1578,9 @@ cache expired and the keys need to be refetched. */ func TestThatJWKSResultIsRefreshedProperly(t *testing.T) { originalRefreshlimit := JWKRefreshRateLimit - originalCacheAge := JWKCacheMaxAgeInMs JWKRefreshRateLimit = 100 - JWKCacheMaxAgeInMs = 2000 + JWKCacheMaxAgeInSec := uint64(2) configValue := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ @@ -1593,7 +1592,9 @@ func TestThatJWKSResultIsRefreshedProperly(t *testing.T) { APIDomain: "api.supertokens.io", }, RecipeList: []supertokens.Recipe{ - Init(nil), + Init(&sessmodels.TypeInput{ + JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec, + }), }, } BeforeEach() @@ -1633,7 +1634,6 @@ func TestThatJWKSResultIsRefreshedProperly(t *testing.T) { assert.True(t, len(newKeys) != 0) JWKRefreshRateLimit = originalRefreshlimit - JWKCacheMaxAgeInMs = originalCacheAge } /* @@ -1794,10 +1794,9 @@ This test verifies the behaviour of the JWKS cache maintained by the SDK */ func TestJWKSCacheLogic(t *testing.T) { originalRefreshlimit := JWKRefreshRateLimit - originalCacheAge := JWKCacheMaxAgeInMs JWKRefreshRateLimit = 100 - JWKCacheMaxAgeInMs = 2000 + var JWKCacheMaxAgeInSec uint64 = 2 configValue := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ @@ -1809,7 +1808,9 @@ func TestJWKSCacheLogic(t *testing.T) { APIDomain: "api.supertokens.io", }, RecipeList: []supertokens.Recipe{ - Init(nil), + Init(&sessmodels.TypeInput{ + JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec, + }), }, } BeforeEach() @@ -1849,7 +1850,6 @@ func TestJWKSCacheLogic(t *testing.T) { assert.NotNil(t, jwksCache) JWKRefreshRateLimit = originalRefreshlimit - JWKCacheMaxAgeInMs = originalCacheAge } /* @@ -1940,10 +1940,9 @@ This test ensures that the SDK's caching logic for fetching JWKs works fine */ func TestThatJWKSReturnsFromCacheCorrectly(t *testing.T) { originalRefreshlimit := JWKRefreshRateLimit - originalCacheAge := JWKCacheMaxAgeInMs JWKRefreshRateLimit = 100 - JWKCacheMaxAgeInMs = 2000 + var JWKCacheMaxAgeInSec uint64 = 2 configValue := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ @@ -1955,7 +1954,9 @@ func TestThatJWKSReturnsFromCacheCorrectly(t *testing.T) { APIDomain: "api.supertokens.io", }, RecipeList: []supertokens.Recipe{ - Init(nil), + Init(&sessmodels.TypeInput{ + JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec, + }), }, } BeforeEach() @@ -2002,7 +2003,6 @@ func TestThatJWKSReturnsFromCacheCorrectly(t *testing.T) { assert.Equal(t, <-returnedFromCache, false) JWKRefreshRateLimit = originalRefreshlimit - JWKCacheMaxAgeInMs = originalCacheAge } /* @@ -2205,10 +2205,9 @@ func TestSessionVerificationOfJWTBasedOnSessionPayloadWithCheckDatabase(t *testi func TestThatLockingForJWKSCacheWorksFine(t *testing.T) { originalRefreshlimit := JWKRefreshRateLimit - originalCacheAge := JWKCacheMaxAgeInMs JWKRefreshRateLimit = 100 - JWKCacheMaxAgeInMs = 2000 + var JWKCacheMaxAgeInSec uint64 = 2 configValue := supertokens.TypeInput{ Supertokens: &supertokens.ConnectionInfo{ @@ -2220,7 +2219,9 @@ func TestThatLockingForJWKSCacheWorksFine(t *testing.T) { WebsiteDomain: "supertokens.io", }, RecipeList: []supertokens.Recipe{ - Init(nil), + Init(&sessmodels.TypeInput{ + JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec, + }), }, } BeforeEach() @@ -2295,7 +2296,6 @@ func TestThatLockingForJWKSCacheWorksFine(t *testing.T) { assert.Equal(t, notReturnFromCacheCount, 5) JWKRefreshRateLimit = originalRefreshlimit - JWKCacheMaxAgeInMs = originalCacheAge } func TestThatGetSessionThrowsWIthDynamicKeysIfSessionWasCreatedWithStaticKeys(t *testing.T) { diff --git a/recipe/session/sessmodels/models.go b/recipe/session/sessmodels/models.go index 02ecba4c..0d029e0a 100644 --- a/recipe/session/sessmodels/models.go +++ b/recipe/session/sessmodels/models.go @@ -111,6 +111,7 @@ type TypeInput struct { GetTokenTransferMethod func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) TokenTransferMethod ExposeAccessTokenToFrontendInCookieBasedAuth bool UseDynamicAccessTokenSigningKey *bool + JWKSRefreshIntervalSec *uint64 } type OverrideStruct struct { @@ -141,6 +142,7 @@ type TypeNormalisedInput struct { GetTokenTransferMethod func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) TokenTransferMethod ExposeAccessTokenToFrontendInCookieBasedAuth bool UseDynamicAccessTokenSigningKey bool + JWKSRefreshIntervalSec uint64 } type AntiCsrfFunctionOrString struct { diff --git a/recipe/session/utils.go b/recipe/session/utils.go index d7ba8155..d65201e7 100644 --- a/recipe/session/utils.go +++ b/recipe/session/utils.go @@ -222,6 +222,11 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config useDynamicSigningKey = *config.UseDynamicAccessTokenSigningKey } + var jwksRefreshIntervalSec uint64 = 4 * 3600 // 4 hours + if config != nil && config.JWKSRefreshIntervalSec != nil { + jwksRefreshIntervalSec = *config.JWKSRefreshIntervalSec + } + typeNormalisedInput := sessmodels.TypeNormalisedInput{ RefreshTokenPath: appInfo.APIBasePath.AppendPath(refreshAPIPath), CookieDomain: cookieDomain, @@ -233,6 +238,7 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config AntiCsrfFunctionOrString: antiCsrfFunctionOrString, ExposeAccessTokenToFrontendInCookieBasedAuth: config.ExposeAccessTokenToFrontendInCookieBasedAuth, UseDynamicAccessTokenSigningKey: useDynamicSigningKey, + JWKSRefreshIntervalSec: jwksRefreshIntervalSec, ErrorHandlers: errorHandlers, GetTokenTransferMethod: config.GetTokenTransferMethod, Override: sessmodels.OverrideStruct{