Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Authorization Refresh #3

Merged
merged 1 commit into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Returns the latitude and longitude of the address you specify.
import AppleMapsKit
import AsyncHTTPClient

let client = try await AppleMapsClient(
let client = AppleMapsClient(
httpClient: HTTPClient(...),
teamID: "DEF123GHIJ",
keyID: "ABC123DEFG",
Expand All @@ -53,7 +53,7 @@ Returns an array of addresses present at the coordinates you provide.
import AppleMapsKit
import AsyncHTTPClient

let client = try await AppleMapsClient(
let client = AppleMapsClient(
httpClient: HTTPClient(...),
teamID: "DEF123GHIJ",
keyID: "ABC123DEFG",
Expand All @@ -75,7 +75,7 @@ Find places by name or by specific search criteria.
import AppleMapsKit
import AsyncHTTPClient

let client = try await AppleMapsClient(
let client = AppleMapsClient(
httpClient: HTTPClient(...),
teamID: "DEF123GHIJ",
keyID: "ABC123DEFG",
Expand All @@ -97,7 +97,7 @@ Find results that you can use to autocomplete searches.
import AppleMapsKit
import AsyncHTTPClient

let client = try await AppleMapsClient(
let client = AppleMapsClient(
httpClient: HTTPClient(...),
teamID: "DEF123GHIJ",
keyID: "ABC123DEFG",
Expand All @@ -119,7 +119,7 @@ Find directions by specific criteria.
import AppleMapsKit
import AsyncHTTPClient

let client = try await AppleMapsClient(
let client = AppleMapsClient(
httpClient: HTTPClient(...),
teamID: "DEF123GHIJ",
keyID: "ABC123DEFG",
Expand All @@ -144,7 +144,7 @@ Returns the estimated time of arrival (ETA) and distance between starting and en
import AppleMapsKit
import AsyncHTTPClient

let client = try await AppleMapsClient(
let client = AppleMapsClient(
httpClient: HTTPClient(...),
teamID: "DEF123GHIJ",
keyID: "ABC123DEFG",
Expand Down Expand Up @@ -177,7 +177,7 @@ Obtain a set of ``Place`` objects for a given set of Place IDs or get a list of
import AppleMapsKit
import AsyncHTTPClient

let client = try await AppleMapsClient(
let client = AppleMapsClient(
httpClient: HTTPClient(...),
teamID: "DEF123GHIJ",
keyID: "ABC123DEFG",
Expand Down
79 changes: 8 additions & 71 deletions Sources/AppleMapsKit/AppleMapsClient.swift
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import AsyncHTTPClient
import Foundation
import JWTKit
import NIOCore
import NIOFoundationCompat
import NIOHTTP1

/// Methods to make calls to APIs such as geocode, search, and so on.
public struct AppleMapsClient: Sendable {
static let apiServer = "https://maps-api.apple.com"
private let httpClient: HTTPClient
private let accessToken: String
private let authorizationProvider: AuthorizationProvider

private let decoder = JSONDecoder()

Expand All @@ -22,11 +20,14 @@ public struct AppleMapsClient: Sendable {
/// - teamID: A 10-character Team ID obtained from your Apple Developer account.
/// - keyID: A 10-character key identifier that provides the ID of the private key that you obtain from your Apple Developer account.
/// - key: A MapKit JS private key.
public init(httpClient: HTTPClient, teamID: String, keyID: String, key: String) async throws {
public init(httpClient: HTTPClient, teamID: String, keyID: String, key: String) {
self.httpClient = httpClient
self.accessToken = try await Self.getAccessToken(
self.authorizationProvider = AuthorizationProvider(
httpClient: httpClient,
authToken: Self.createJWT(teamID: teamID, keyID: keyID, key: key)
apiServer: Self.apiServer,
teamID: teamID,
keyID: keyID,
key: key
)
}

Expand Down Expand Up @@ -518,6 +519,7 @@ public struct AppleMapsClient: Sendable {
/// - Throws: Error response object.
private func httpGet(url: URL) async throws -> ByteBuffer {
var headers = HTTPHeaders()
let accessToken = try await authorizationProvider.validToken().accessToken
headers.add(name: "Authorization", value: "Bearer \(accessToken)")

var request = HTTPClientRequest(url: url.absoluteString)
Expand Down Expand Up @@ -550,68 +552,3 @@ public struct AppleMapsClient: Sendable {
return (latitude, longitude)
}
}

// MARK: - auth/c & auth/z
extension AppleMapsClient {
/// Creates a JWT token, which is auth token in this context.
///
/// - Parameters:
/// - teamID: A 10-character Team ID obtained from your Apple Developer account.
/// - keyID: A 10-character key identifier that provides the ID of the private key that you obtain from your Apple Developer account.
/// - key: A MapKit JS private key.
///
/// - Returns: A JWT token represented as `String`.
private static func createJWT(teamID: String, keyID: String, key: String) async throws -> String {
let keys = try await JWTKeyCollection().add(ecdsa: ES256PrivateKey(pem: key))

var header = JWTHeader()
header.alg = "ES256"
header.kid = keyID
header.typ = "JWT"

struct Payload: JWTPayload {
let iss: IssuerClaim
let iat: IssuedAtClaim
let exp: ExpirationClaim

func verify(using key: some JWTAlgorithm) throws {
try self.exp.verifyNotExpired()
}
}

let payload = Payload(
iss: IssuerClaim(value: teamID),
iat: IssuedAtClaim(value: Date()),
exp: .init(value: Date().addingTimeInterval(30 * 60))
)

return try await keys.sign(payload, header: header)
}

/// Makes an HTTP request to exchange Auth token for Access token.
///
/// - Parameters:
/// - httpClient: The HTTP client to use.
/// - authToken: The authorization token.
///
/// - Throws: Error response object.
///
/// - Returns: An access token.
private static func getAccessToken(httpClient: HTTPClient, authToken: String) async throws -> String {
var headers = HTTPHeaders()
headers.add(name: "Authorization", value: "Bearer \(authToken)")

var request = HTTPClientRequest(url: "\(apiServer)/v1/token")
request.headers = headers

let response = try await httpClient.execute(request, timeout: .seconds(30))

if response.status == .ok {
return try await JSONDecoder()
.decode(TokenResponse.self, from: response.body.collect(upTo: 1024 * 1024))
.accessToken
} else {
throw try await JSONDecoder().decode(ErrorResponse.self, from: response.body.collect(upTo: 1024 * 1024))
}
}
}
133 changes: 133 additions & 0 deletions Sources/AppleMapsKit/Authorization/AuthorizationProvider.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//
// AuthorizationProvider.swift
// apple-maps-kit
//
// Created by FarouK on 11/10/2024.
//

import AsyncHTTPClient
import Foundation
import JWTKit
import NIOHTTP1

// MARK: - auth/c & auth/z
internal actor AuthorizationProvider {

private let httpClient: HTTPClient
private let apiServer: String
private let teamID: String
private let keyID: String
private let key: String

private var currentToken: TokenResponse?
private var refreshTask: Task<TokenResponse, any Error>?

internal init(httpClient: HTTPClient, apiServer: String, teamID: String, keyID: String, key: String) {
self.httpClient = httpClient
self.apiServer = apiServer
self.teamID = teamID
self.keyID = keyID
self.key = key
}

func validToken() async throws -> TokenResponse {
// If we're currently refreshing a token, await the value for our refresh task to make sure we return the refreshed token.
if let handle = refreshTask {
return try await handle.value
}

// If we don't have a current token, we request a new one.
guard let token = currentToken else {
return try await refreshToken()
}

if token.isValid {
return token
}

// None of the above applies so we'll need to refresh the token.
return try await refreshToken()
}

private func refreshToken() async throws -> TokenResponse {
if let refreshTask = refreshTask {
return try await refreshTask.value
}

let task = Task { () throws -> TokenResponse in
defer { refreshTask = nil }
let authToken = try await createJWT(teamID: teamID, keyID: keyID, key: key)
let newToken = try await getAccessToken(authToken: authToken)
currentToken = newToken
return newToken
}

self.refreshTask = task
return try await task.value
}
}

// MARK: - HELPERS
extension AuthorizationProvider {

/// Makes an HTTP request to exchange Auth token for Access token.
///
/// - Parameters:
/// - httpClient: The HTTP client to use.
/// - authToken: The authorization token.
///
/// - Throws: Error response object.
///
/// - Returns: An access token.
fileprivate func getAccessToken(authToken: String) async throws -> TokenResponse {
var headers = HTTPHeaders()
headers.add(name: "Authorization", value: "Bearer \(authToken)")

var request = HTTPClientRequest(url: "\(apiServer)/v1/token")
request.headers = headers

let response = try await httpClient.execute(request, timeout: .seconds(30))

if response.status == .ok {
return try await JSONDecoder()
.decode(TokenResponse.self, from: response.body.collect(upTo: 1024 * 1024))
} else {
throw try await JSONDecoder().decode(ErrorResponse.self, from: response.body.collect(upTo: 1024 * 1024))
}
}

/// Creates a JWT token, which is auth token in this context.
///
/// - Parameters:
/// - teamID: A 10-character Team ID obtained from your Apple Developer account.
/// - keyID: A 10-character key identifier that provides the ID of the private key that you obtain from your Apple Developer account.
/// - key: A MapKit JS private key.
///
/// - Returns: A JWT token represented as `String`.
fileprivate func createJWT(teamID: String, keyID: String, key: String) async throws -> String {
let keys = try await JWTKeyCollection().add(ecdsa: ES256PrivateKey(pem: key))

var header = JWTHeader()
header.alg = "ES256"
header.kid = keyID
header.typ = "JWT"

struct Payload: JWTPayload {
let iss: IssuerClaim
let iat: IssuedAtClaim
let exp: ExpirationClaim

func verify(using key: some JWTAlgorithm) throws {
try self.exp.verifyNotExpired()
}
}

let payload = Payload(
iss: IssuerClaim(value: teamID),
iat: IssuedAtClaim(value: Date()),
exp: ExpirationClaim(value: Date().addingTimeInterval(30 * 60))
)

return try await keys.sign(payload, header: header)
}
}
38 changes: 38 additions & 0 deletions Sources/AppleMapsKit/Authorization/TokenResponse.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import Foundation

/// An object that contains an access token and an expiration time in seconds.
internal struct TokenResponse: Codable {
/// A string that represents the access token.
let accessToken: String

/// An integer that indicates the time, in seconds from now until the token expires.
let expiresInSeconds: Int

/// A date that indicates when then token will expire.
let expirationDate: Date

internal init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.accessToken = try container.decode(String.self, forKey: .accessToken)
self.expiresInSeconds = try container.decode(Int.self, forKey: .expiresInSeconds)
self.expirationDate = Date.now.addingTimeInterval(TimeInterval(expiresInSeconds))
}

internal init(accessToken: String, expiresInSeconds: Int) {
self.accessToken = accessToken
self.expiresInSeconds = expiresInSeconds
self.expirationDate = Date.now.addingTimeInterval(TimeInterval(expiresInSeconds))
}

}

extension TokenResponse {

/// A boolean indicates whether to token is valid 10 seconds before it's actual expiry time.
var isValid: Bool {
Copy link
Owner

@fpseverino fpseverino Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't ever become false, as the Date.now is new every time and expiresInSeconds doesn't get updated since the response is received.

Maybe we can add a private expirationDate property to TokenResponse to which we assign Date.now.addingTimeInterval(TimeInterval(expiresInSeconds)) when the TokenResponse gets initialized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's a very good point, sorry for that.
I will push a fix and write a test for it.

let currentDate = Date.now
// we consider a token invalid 10 seconds before it actual expiry time, so we have some time to refresh it.
let expirationBuffer: TimeInterval = 10
return currentDate < (expirationDate - expirationBuffer)
}
}
8 changes: 0 additions & 8 deletions Sources/AppleMapsKit/DTOs/TokenResponse.swift

This file was deleted.

2 changes: 1 addition & 1 deletion Tests/AppleMapsKitTests/AppleMapsKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct AppleMapsKitTests {

init() async throws {
// TODO: Replace the following values with valid ones.
client = try await AppleMapsClient(
client = AppleMapsClient(
httpClient: HTTPClient.shared,
teamID: "DEF123GHIJ",
keyID: "ABC123DEFG",
Expand Down
28 changes: 28 additions & 0 deletions Tests/AppleMapsKitTests/AuthorizationProviderTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//
// AuthorizationProviderTests.swift
// apple-maps-kit
//
// Created by FarouK on 12/10/2024.
//

import Testing

@testable import AppleMapsKit

struct AuthorizationProviderTests {

struct TokenValidityTests {
// It's 1 second actually due to the expiration buffer on the token.
let token = TokenResponse(accessToken: "some token", expiresInSeconds: 11)

@Test func tokenInvalidCheck() async {
try? await Task.sleep(for: .seconds(2))
#expect(token.isValid == false)
}

@Test func tokenValidCheck() async {
#expect(token.isValid)
}
}

}