diff --git a/globus_sdk/authorizers/refresh_token.py b/globus_sdk/authorizers/refresh_token.py index 87015b3ad..cbff2518e 100644 --- a/globus_sdk/authorizers/refresh_token.py +++ b/globus_sdk/authorizers/refresh_token.py @@ -87,6 +87,9 @@ def _extract_token_data(self, res): """ Get the tokens .by_resource_server, Ensure that only one token was gotten, and return that token. + + If the token_data includes a "refresh_token" field, update self.refresh_token to + that value. """ token_data = res.by_resource_server.values() if len(token_data) != 1: @@ -95,4 +98,11 @@ def _extract_token_data(self, res): "didn't return exactly one token. Possible service error." ) - return next(iter(token_data)) + token_data = next(iter(token_data)) + + # handle refresh_token being present + # mandated by OAuth2: https://tools.ietf.org/html/rfc6749#section-6 + if "refresh_token" in token_data: + self.refresh_token = token_data["refresh_token"] + + return token_data diff --git a/globus_sdk/authorizers/renewing.py b/globus_sdk/authorizers/renewing.py index dc2dfb49b..b4025bfd8 100644 --- a/globus_sdk/authorizers/renewing.py +++ b/globus_sdk/authorizers/renewing.py @@ -129,6 +129,9 @@ def _get_new_access_token(self): Given token data from _get_token_response and _extract_token_data, set the access token and expiration time, calculate the new token hash, and call on_refresh + + If a "refresh_token" is sent back as part of token_data, update + self.refresh_token as well. """ # get the first (and only) token res = self._get_token_response() diff --git a/tests/unit/authorizers/test_refresh_token_authorizer.py b/tests/unit/authorizers/test_refresh_token_authorizer.py index ab7cef1e9..2363151b9 100644 --- a/tests/unit/authorizers/test_refresh_token_authorizer.py +++ b/tests/unit/authorizers/test_refresh_token_authorizer.py @@ -13,12 +13,19 @@ EXPIRES_AT = -1 -@pytest.fixture -def response(): +@pytest.fixture(params=["simple", "with_new_refresh_token"]) +def response(request): r = mock.Mock() r.by_resource_server = { - "rs1": {"expires_at_seconds": -1, "access_token": "access_token_2"} - } + "simple": {"rs1": {"expires_at_seconds": -1, "access_token": "access_token_2"}}, + "with_new_refresh_token": { + "rs1": { + "expires_at_seconds": -1, + "access_token": "access_token_2", + "refresh_token": "refresh_token_2", + } + }, + }[request.param] return r @@ -50,9 +57,8 @@ def test_get_token_response(authorizer, client, response): def test_multiple_resource_servers(authorizer, response): """ - Sets the mock ConfidentialAppAuthClient to return multiple resource - servers. Confirms GlobusError is raised when _extract_token_data is - called. + Sets the mock client to return multiple resource servers. + Confirms GlobusError is raised when _extract_token_data is called. """ response.by_resource_server["rs2"] = { "expires_at_seconds": -1, @@ -61,3 +67,19 @@ def test_multiple_resource_servers(authorizer, response): with pytest.raises(ValueError) as excinfo: authorizer._extract_token_data(response) assert "didn't return exactly one token" in str(excinfo.value) + + +def test_conditional_refresh_token_update(authorizer, response): + """ + Call check_expiration_time (triggering a refresh) + Confirm that the authorizer always udpates its access token and only updates + refresh_token if one was present in the response + """ + authorizer.check_expiration_time() # trigger refresh + token_data = response.by_resource_server["rs1"] + if "refresh_token" in token_data: # if present, confirm refresh token was updated + assert authorizer.access_token == "access_token_2" + assert authorizer.refresh_token == "refresh_token_2" + else: # otherwise, confirm no change + assert authorizer.access_token == "access_token_2" + assert authorizer.refresh_token == "refresh_token_1"