From 005698c261d0b6e6fe78a9480514b6b7338fd333 Mon Sep 17 00:00:00 2001 From: Tlaster Date: Mon, 27 Jan 2025 17:14:27 +0900 Subject: [PATCH] fix bluesky token not being refreshed automatically --- .../datasource/bluesky/BlueskyDataSource.kt | 16 +++++- .../data/network/bluesky/BlueskyService.kt | 49 ++++++++++++++++--- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/shared/src/commonMain/kotlin/dev/dimension/flare/data/datasource/bluesky/BlueskyDataSource.kt b/shared/src/commonMain/kotlin/dev/dimension/flare/data/datasource/bluesky/BlueskyDataSource.kt index 42557cce4..a7cb73bda 100644 --- a/shared/src/commonMain/kotlin/dev/dimension/flare/data/datasource/bluesky/BlueskyDataSource.kt +++ b/shared/src/commonMain/kotlin/dev/dimension/flare/data/datasource/bluesky/BlueskyDataSource.kt @@ -64,6 +64,7 @@ import dev.dimension.flare.common.Cacheable import dev.dimension.flare.common.FileItem import dev.dimension.flare.common.InAppNotification import dev.dimension.flare.common.MemCacheable +import dev.dimension.flare.common.encodeJson import dev.dimension.flare.data.database.app.AppDatabase import dev.dimension.flare.data.database.cache.CacheDatabase import dev.dimension.flare.data.database.cache.mapper.Bluesky @@ -151,7 +152,20 @@ internal class BlueskyDataSource( BlueskyService( baseUrl = credential.baseUrl, accountKey = accountKey, - bearerToken = credential.accessToken, + accessToken = credential.accessToken, + refreshToken = credential.refreshToken, + onTokenRefreshed = { accessToken, refreshToken -> + coroutineScope.launch { + appDatabase.accountDao().setCredential( + accountKey, + credential + .copy( + accessToken = accessToken, + refreshToken = refreshToken, + ).encodeJson(), + ) + } + }, ) } diff --git a/shared/src/commonMain/kotlin/dev/dimension/flare/data/network/bluesky/BlueskyService.kt b/shared/src/commonMain/kotlin/dev/dimension/flare/data/network/bluesky/BlueskyService.kt index 0d8ecc273..3eb125df5 100644 --- a/shared/src/commonMain/kotlin/dev/dimension/flare/data/network/bluesky/BlueskyService.kt +++ b/shared/src/commonMain/kotlin/dev/dimension/flare/data/network/bluesky/BlueskyService.kt @@ -1,19 +1,22 @@ package dev.dimension.flare.data.network.bluesky +import com.atproto.server.RefreshSessionResponse import dev.dimension.flare.common.JSON -import dev.dimension.flare.data.network.authorization.BearerAuthorization import dev.dimension.flare.data.network.ktorClient -import dev.dimension.flare.data.repository.LoginExpiredException import dev.dimension.flare.model.MicroBlogKey import io.ktor.client.HttpClient import io.ktor.client.call.HttpClientCall +import io.ktor.client.call.body import io.ktor.client.call.save import io.ktor.client.plugins.DefaultRequest import io.ktor.client.plugins.HttpClientPlugin import io.ktor.client.plugins.HttpSend import io.ktor.client.plugins.plugin import io.ktor.client.request.HttpRequestPipeline +import io.ktor.client.request.bearerAuth +import io.ktor.client.request.post import io.ktor.client.statement.bodyAsText +import io.ktor.http.HttpHeaders.Authorization import io.ktor.http.HttpStatusCode.Companion.BadRequest import io.ktor.http.Url import io.ktor.util.AttributeKey @@ -21,16 +24,19 @@ import kotlinx.serialization.json.Json import sh.christian.ozone.BlueskyApi import sh.christian.ozone.XrpcBlueskyApi import sh.christian.ozone.api.response.AtpErrorDescription +import sh.christian.ozone.api.response.StatusCode import sh.christian.ozone.unspecced.UnspeccedBlueskyApi import sh.christian.ozone.unspecced.XrpcUnspeccedBlueskyApi internal data class BlueskyService( private val baseUrl: String, - private val bearerToken: String? = null, private val accountKey: MicroBlogKey? = null, + private val accessToken: String? = null, + private val refreshToken: String? = null, + private val onTokenRefreshed: ((accessToken: String, refreshToken: String) -> Unit)? = null, ) : BlueskyApi by XrpcBlueskyApi( ktorClient( - authorization = bearerToken?.let { BearerAuthorization(it) }, +// authorization = bearerToken?.let { BearerAuthorization(it) }, ) { install(DefaultRequest) { val hostUrl = Url(baseUrl) @@ -40,6 +46,9 @@ internal data class BlueskyService( } install(XrpcAuthPlugin) { json = JSON + access = accessToken + refresh = refreshToken + tokenRefreshed = onTokenRefreshed } install(AtprotoProxyPlugin) @@ -77,9 +86,15 @@ private class AtprotoProxyPlugin { */ internal class XrpcAuthPlugin( private val json: Json, + private val accessToken: String? = null, + private val refreshToken: String? = null, + private val onTokenRefreshed: ((accessToken: String, refreshToken: String) -> Unit)? = null, ) { class Config( var json: Json = Json { ignoreUnknownKeys = true }, + var access: String? = null, + var refresh: String? = null, + var tokenRefreshed: ((accessToken: String, refreshToken: String) -> Unit)? = null, ) companion object : HttpClientPlugin { @@ -87,7 +102,12 @@ internal class XrpcAuthPlugin( override fun prepare(block: Config.() -> Unit): XrpcAuthPlugin { val config = Config().apply(block) - return XrpcAuthPlugin(config.json) + return XrpcAuthPlugin( + config.json, + config.access, + config.refresh, + config.tokenRefreshed, + ) } override fun install( @@ -95,6 +115,9 @@ internal class XrpcAuthPlugin( scope: HttpClient, ) { scope.plugin(HttpSend).intercept { context -> + if (!context.headers.contains(Authorization)) { + plugin.accessToken?.let { context.bearerAuth(it) } + } var result: HttpClientCall = execute(context) if (result.response.status != BadRequest) { return@intercept result @@ -109,7 +132,21 @@ internal class XrpcAuthPlugin( } if (response.getOrNull()?.error == "ExpiredToken") { - throw LoginExpiredException + val refreshResponse = + scope.post("/xrpc/com.atproto.server.refreshSession") { + plugin.refreshToken?.let { bearerAuth(it) } + } + if (StatusCode.fromCode(refreshResponse.status.value) == StatusCode.Okay) { + val refreshed = refreshResponse.body() + val newAccessToken = refreshed.accessJwt + val newRefreshToken = refreshed.refreshJwt + + plugin.onTokenRefreshed?.invoke(newAccessToken, newRefreshToken) + + context.headers.remove(Authorization) + context.bearerAuth(newAccessToken) + result = execute(context) + } } result