Skip to content

Commit

Permalink
Improvement python delta-sharing client: convert expires_in as string…
Browse files Browse the repository at this point in the history
… to int, if returned as string (#628) (#630)

This PR cherry picks #628 to python-branch-1.3

**TL;DR:** This PR enhances the OAuth client to support cases where the expires_in field in the token response is returned as a string instead of an integer. While the OAuth 2.0 specification mandates that expires_in should be an integer  [RFC 6749 Section 4.1.4](https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.4), some OAuth servers return it as a string, leading to potential compatibility issues.

Certain OAuth implementations deviate from the standard and return expires_in as a string, e.g.:

```
{
  "access_token": "example-token",
  "expires_in": "3600",  // Returned as a string
  "token_type": "Bearer"
}
```
This causes failures when the client expects the field to always be an integer.

Solution

This PR updates the token parsing logic to:
	1.	Check the type of the expires_in field.
	2.	Convert the value to an integer if it is provided as a string.
	3.	Maintain backward compatibility with the standard integer format.
  • Loading branch information
moderakh authored Dec 30, 2024
1 parent dad242f commit 040a6b6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
24 changes: 22 additions & 2 deletions python/delta_sharing/_internal_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,34 @@ def client_credentials(self) -> OAuthClientCredentials:
def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials:
if not response:
raise RuntimeError("Empty response from OAuth token endpoint")
# Parsing the response per oauth spec
# https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
json_node = json.loads(response)
if 'access_token' not in json_node or not isinstance(json_node['access_token'], str):
raise RuntimeError("Missing 'access_token' field in OAuth token response")
if 'expires_in' not in json_node or not isinstance(json_node['expires_in'], int):
if 'expires_in' not in json_node:
raise RuntimeError("Missing 'expires_in' field in OAuth token response")
try:
# OAuth spec requires 'expires_in' to be an integer, e.g., 3600.
# See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
# But some token endpoints return `expires_in` as a string e.g., "3600".
# This ensures that we support both integer and string values for 'expires_in' field.
# Example request resulting in 'expires_in' as a string:
# curl -X POST \
# https://login.windows.net/$TENANT_ID/oauth2/token \
# -H "Content-Type: application/x-www-form-urlencoded" \
# -d "grant_type=client_credentials" \
# -d "client_id=$CLIENT_ID" \
# -d "client_secret=$CLIENT_SECRET" \
# -d "scope=https://graph.microsoft.com/.default"
expires_in = int(json_node['expires_in']) # Convert to int if it's a string
except ValueError:
raise RuntimeError(
"'expires_in' field must be an integer or a string convertible to integer"
)
return OAuthClientCredentials(
json_node['access_token'],
json_node['expires_in'],
expires_in,
int(datetime.now().timestamp())
)

Expand Down
28 changes: 24 additions & 4 deletions python/delta_sharing/tests/test_oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,30 @@ def mock_server():
yield server


def test_oauth_client_should_parse_token_response_correctly(mock_server):
@pytest.mark.parametrize("response_data, expected_expires_in, expected_access_token", [
# OAuth spec requires 'expires_in' to be an integer, e.g., 3600.
# See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
# But some token endpoints return `expires_in` as a string e.g., "3600".
# This test ensures the client can handle such cases.
# The test case ensures that we support both integer and string values for 'expires_in' field.
(
'{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}',
3600,
"test-access-token"
),
(
'{"access_token": "test-access-token", "expires_in": "3600", "token_type": "bearer"}',
3600,
"test-access-token"
)
])
def test_oauth_client_should_parse_token_response_correctly(mock_server,
response_data,
expected_expires_in,
expected_access_token):
mock_server.add_response(
200,
'{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}')
response_data)

with patch('requests.post') as mock_post:
mock_post.side_effect = lambda *args, **kwargs: mock_server.get_response()
Expand All @@ -59,8 +79,8 @@ def test_oauth_client_should_parse_token_response_correctly(mock_server):
token = oauth_client.client_credentials()
end = datetime.now().timestamp()

assert token.access_token == "test-access-token"
assert token.expires_in == 3600
assert token.access_token == expected_access_token
assert token.expires_in == expected_expires_in
assert int(start) <= token.creation_timestamp
assert token.creation_timestamp <= int(end)

Expand Down

0 comments on commit 040a6b6

Please sign in to comment.