Skip to content

Commit

Permalink
Improvement spark delta-sharing client: convert expires_in as string …
Browse files Browse the repository at this point in the history
…to int, if returned as string (#631)

TL;DR This PR improves oauth-client in spark-client to support parsing expires_in as string similar to the changes for the python client: #628


Detail:
This PR enhances the OAuth client to support cases where the expires_in field in the token response is returned as a string instead of an integer. While the OAuth 2.0 specification mandates that expires_in should be an integer [RFC 6749 Section 4.1.4](https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.4), some OAuth servers return it as a string, leading to potential compatibility issues.

Certain OAuth implementations deviate from the standard and return expires_in as a string, e.g.:

```
{
  "access_token": "example-token",
  "expires_in": "3600",  // Returned as a string
  "token_type": "Bearer"
}
```
This causes failures when the client expects the field to always be an integer.

Solution

This PR updates the token parsing logic to:
1. Check the type of the expires_in field.
2. Convert the value to an integer if it is provided as a string.
3. Maintain backward compatibility with the standard integer format.
  • Loading branch information
moderakh authored Dec 31, 2024
1 parent b72872d commit 0866c30
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,47 @@ private[client] class OAuthClient(httpClient:
}

private def parseOAuthTokenResponse(response: String): OAuthClientCredentials = {
// Parsing the response per oauth spec
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
if (response == null || response.isEmpty) {
throw new RuntimeException("Empty response from OAuth token endpoint")
}
val jsonNode = JsonUtils.readTree(response)
if (!jsonNode.has("access_token") || !jsonNode.get("access_token").isTextual) {
throw new RuntimeException("Missing 'access_token' field in OAuth token response")
}
if (!jsonNode.has("expires_in") || !jsonNode.get("expires_in").isNumber) {
if (!jsonNode.has("expires_in")) {
throw new RuntimeException("Missing 'expires_in' field in OAuth token response")
}

// OAuth spec requires 'expires_in' to be an integer, e.g., 3600.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
// But some token endpoints return `expires_in` as a string e.g., "3600".
// This ensures that we support both integer and string values for 'expires_in' field.
// Example request resulting in 'expires_in' as a string:
// curl -X POST \
// https://login.windows.net/$TENANT_ID/oauth2/token \
// -H "Content-Type: application/x-www-form-urlencoded" \
// -d "grant_type=client_credentials" \
// -d "client_id=$CLIENT_ID" \
// -d "client_secret=$CLIENT_SECRET" \
// -d "scope=https://graph.microsoft.com/.default"
val expiresIn : Long = jsonNode.get("expires_in") match {
case n if n.isNumber => n.asLong()
case n if n.isTextual =>
try {
n.asText().toLong
} catch {
case _: NumberFormatException =>
throw new RuntimeException("Invalid 'expires_in' field in OAuth token response")
}
case _ =>
throw new RuntimeException("Invalid 'expires_in' field in OAuth token response")
}

OAuthClientCredentials(
jsonNode.get("access_token").asText(),
jsonNode.get("expires_in").asLong(),
expiresIn,
System.currentTimeMillis()
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.http.impl.bootstrap.{HttpServer, ServerBootstrap}
import org.apache.http.impl.client.{CloseableHttpClient, HttpClients}
import org.apache.http.protocol.{HttpContext, HttpRequestHandler}
import org.apache.spark.SparkFunSuite
import org.scalatest.prop.TableDrivenPropertyChecks

class OAuthClientSuite extends SparkFunSuite {
class OAuthClientSuite extends SparkFunSuite with TableDrivenPropertyChecks {
var server: HttpServer = _

def startServer(handler: HttpRequestHandler): Unit = {
Expand Down Expand Up @@ -58,40 +59,66 @@ class OAuthClientSuite extends SparkFunSuite {
throw new RuntimeException(s"Port $port is not released after $timeoutMillis milliseconds")
}

test("OAuthClient should parse token response correctly") {
val handler = new HttpRequestHandler {
@throws[HttpException]
@throws[IOException]
override def handle(request: HttpRequest,
response: HttpResponse,
context: HttpContext): Unit = {
val responseBody =
"""{
| "access_token": "test-access-token",
| "expires_in": 3600,
| "token_type": "bearer"
|}""".stripMargin
response.setEntity(new StringEntity(responseBody, ContentType.APPLICATION_JSON))
response.setStatusCode(200)
case class TokenExchangeSuccessScenario(responseBody: String,
expectedAccessToken: String,
expectedExpiresIn: Long)

// OAuth spec requires 'expires_in' to be an integer, e.g., 3600.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
// But some token endpoints return `expires_in` as a string e.g., "3600".
// This test ensures the client can handle such cases.
// The test case ensures that we support both integer and string values for 'expires_in' field.
private val tokenExchangeSuccessScenarios = Table(
"testScenario",
TokenExchangeSuccessScenario(
responseBody = """{
| "access_token": "test-access-token",
| "expires_in": 3600,
| "token_type": "bearer"
|}""".stripMargin,
expectedAccessToken = "test-access-token",
expectedExpiresIn = 3600
),
TokenExchangeSuccessScenario(
responseBody = """{
| "access_token": "test-access-token",
| "expires_in": "3600",
| "token_type": "bearer"
|}""".stripMargin,
expectedAccessToken = "test-access-token",
expectedExpiresIn = 3600
)
)

forAll(tokenExchangeSuccessScenarios) { testScenario =>
test(s"OAuthClient should parse token response correctly for ${testScenario.responseBody}") {
val handler = new HttpRequestHandler {
@throws[HttpException]
@throws[IOException]
override def handle(request: HttpRequest,
response: HttpResponse,
context: HttpContext): Unit = {
response.setEntity(
new StringEntity(testScenario.responseBody, ContentType.APPLICATION_JSON))
response.setStatusCode(200)
}
}
}
startServer(handler)
startServer(handler)

val httpClient: CloseableHttpClient = HttpClients.createDefault()
val oauthClient = new OAuthClient(httpClient, AuthConfig(),
"http://localhost:1080/token", "client-id", "client-secret")

val start = System.currentTimeMillis()
val httpClient: CloseableHttpClient = HttpClients.createDefault()
val oauthClient = new OAuthClient(httpClient, AuthConfig(),
"http://localhost:1080/token", "client-id", "client-secret")

val token = oauthClient.clientCredentials()
val start = System.currentTimeMillis()
val token = oauthClient.clientCredentials()
val end = System.currentTimeMillis()

val end = System.currentTimeMillis()
assert(token.accessToken == testScenario.expectedAccessToken)
assert(token.expiresIn == testScenario.expectedExpiresIn)
assert(token.creationTimestamp >= start && token.creationTimestamp <= end)

assert(token.accessToken == "test-access-token")
assert(token.expiresIn == 3600)
assert(token.creationTimestamp >= start && token.creationTimestamp <= end)

stopServer()
stopServer()
}
}

test("OAuthClient should handle 401 Unauthorized response") {
Expand Down

0 comments on commit 0866c30

Please sign in to comment.