From 49636d89642e9203a6355b162a1b3a7fc73e8339 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Fri, 23 Aug 2024 16:30:15 -0700 Subject: [PATCH] added flag push --- build.gradle.kts | 1 + buildSrc/src/main/kotlin/Versions.kt | 1 + src/main/kotlin/LocalEvaluationClient.kt | 15 +- src/main/kotlin/LocalEvaluationConfig.kt | 2 + .../kotlin/deployment/DeploymentRunner.kt | 100 ++------ src/main/kotlin/flag/FlagConfigStreamApi.kt | 134 ++++++++++ src/main/kotlin/flag/FlagConfigUpdater.kt | 233 ++++++++++++++++++ src/main/kotlin/util/Metrics.kt | 10 + src/main/kotlin/util/SdkStream.kt | 122 +++++++++ .../kotlin/deployment/DeploymentRunnerTest.kt | 19 +- 10 files changed, 551 insertions(+), 86 deletions(-) create mode 100644 src/main/kotlin/flag/FlagConfigStreamApi.kt create mode 100644 src/main/kotlin/flag/FlagConfigUpdater.kt create mode 100644 src/main/kotlin/util/SdkStream.kt diff --git a/build.gradle.kts b/build.gradle.kts index f8d832e..863afa6 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -24,6 +24,7 @@ dependencies { testImplementation("io.mockk:mockk:${Versions.mockk}") implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:${Versions.serializationRuntime}") implementation("com.squareup.okhttp3:okhttp:${Versions.okhttp}") + implementation("com.squareup.okhttp3:okhttp-sse:${Versions.okhttpSse}") implementation("com.amplitude:evaluation-core:${Versions.evaluationCore}") implementation("com.amplitude:java-sdk:${Versions.amplitudeAnalytics}") implementation("org.json:json:${Versions.json}") diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index cfc233d..b1ed2c0 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -6,6 +6,7 @@ object Versions { const val serializationRuntime = "1.4.1" const val json = "20231013" const val okhttp = "4.12.0" + const val okhttpSse = "4.12.0" // Update this alongside okhttp. Note this library isn't stable and may contain breaking changes. const val evaluationCore = "2.0.0-beta.2" const val amplitudeAnalytics = "1.12.0" const val mockk = "1.13.9" diff --git a/src/main/kotlin/LocalEvaluationClient.kt b/src/main/kotlin/LocalEvaluationClient.kt index 57dc4ae..a25e143 100644 --- a/src/main/kotlin/LocalEvaluationClient.kt +++ b/src/main/kotlin/LocalEvaluationClient.kt @@ -19,6 +19,8 @@ import com.amplitude.experiment.evaluation.EvaluationEngineImpl import com.amplitude.experiment.evaluation.EvaluationFlag import com.amplitude.experiment.evaluation.topologicalSort import com.amplitude.experiment.flag.DynamicFlagConfigApi +import com.amplitude.experiment.flag.FlagConfigPoller +import com.amplitude.experiment.flag.FlagConfigStreamApi import com.amplitude.experiment.flag.InMemoryFlagConfigStorage import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper import com.amplitude.experiment.util.Logger @@ -43,7 +45,10 @@ class LocalEvaluationClient internal constructor( private val serverUrl: HttpUrl = getServerUrl(config) private val evaluation: EvaluationEngine = EvaluationEngineImpl() private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(config.metrics) - private val flagConfigApi = DynamicFlagConfigApi(apiKey, serverUrl, getProxyUrl(config), httpClient) + private val flagConfigApi = DynamicFlagConfigApi(apiKey, serverUrl, null, httpClient) + private val proxyUrl: HttpUrl? = getProxyUrl(config) + private val flagConfigProxyApi = if (proxyUrl == null) null else DynamicFlagConfigApi(apiKey, proxyUrl, null, httpClient) + private val flagConfigStreamApi = FlagConfigStreamApi(apiKey, "https://stream.lab.amplitude.com", httpClient) private val flagConfigStorage = InMemoryFlagConfigStorage() private val cohortStorage = if (config.cohortSyncConfig == null) { null @@ -60,6 +65,8 @@ class LocalEvaluationClient internal constructor( private val deploymentRunner = DeploymentRunner( config = config, flagConfigApi = flagConfigApi, + flagConfigProxyApi = flagConfigProxyApi, + flagConfigStreamApi = flagConfigStreamApi, flagConfigStorage = flagConfigStorage, cohortApi = cohortApi, cohortStorage = cohortStorage, @@ -214,3 +221,9 @@ private fun getEventServerUrl( assignmentConfiguration.serverUrl } } + +fun main() { + val client = LocalEvaluationClient("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz") + client.start() + println(client.evaluateV2(ExperimentUser("1"))) +} \ No newline at end of file diff --git a/src/main/kotlin/LocalEvaluationConfig.kt b/src/main/kotlin/LocalEvaluationConfig.kt index 67a5a72..e271610 100644 --- a/src/main/kotlin/LocalEvaluationConfig.kt +++ b/src/main/kotlin/LocalEvaluationConfig.kt @@ -207,6 +207,8 @@ interface LocalEvaluationMetrics { fun onFlagConfigFetch() fun onFlagConfigFetchFailure(exception: Exception) fun onFlagConfigFetchOriginFallback(exception: Exception) + fun onFlagConfigStream() + fun onFlagConfigStreamFailure(exception: Exception?) fun onCohortDownload() fun onCohortDownloadFailure(exception: Exception) fun onCohortDownloadOriginFallback(exception: Exception) diff --git a/src/main/kotlin/deployment/DeploymentRunner.kt b/src/main/kotlin/deployment/DeploymentRunner.kt index 0001278..6c38411 100644 --- a/src/main/kotlin/deployment/DeploymentRunner.kt +++ b/src/main/kotlin/deployment/DeploymentRunner.kt @@ -2,22 +2,19 @@ package com.amplitude.experiment.deployment -import com.amplitude.experiment.ExperimentalApi -import com.amplitude.experiment.LocalEvaluationConfig -import com.amplitude.experiment.LocalEvaluationMetrics +import com.amplitude.experiment.* import com.amplitude.experiment.cohort.CohortApi import com.amplitude.experiment.cohort.CohortLoader import com.amplitude.experiment.cohort.CohortStorage +import com.amplitude.experiment.flag.* import com.amplitude.experiment.flag.FlagConfigApi +import com.amplitude.experiment.flag.FlagConfigPoller import com.amplitude.experiment.flag.FlagConfigStorage import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper import com.amplitude.experiment.util.Logger import com.amplitude.experiment.util.Once import com.amplitude.experiment.util.daemonFactory import com.amplitude.experiment.util.getAllCohortIds -import com.amplitude.experiment.util.wrapMetrics -import java.util.concurrent.CompletableFuture -import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.Executors import java.util.concurrent.TimeUnit @@ -26,6 +23,8 @@ private const val MIN_COHORT_POLLING_INTERVAL = 60000L internal class DeploymentRunner( private val config: LocalEvaluationConfig, private val flagConfigApi: FlagConfigApi, + private val flagConfigProxyApi: FlagConfigApi? = null, + private val flagConfigStreamApi: FlagConfigStreamApi? = null, private val flagConfigStorage: FlagConfigStorage, cohortApi: CohortApi?, private val cohortStorage: CohortStorage?, @@ -39,21 +38,26 @@ internal class DeploymentRunner( null } private val cohortPollingInterval: Long = getCohortPollingInterval() + // Fallback in this order: proxy, stream, poll. + private val amplitudeFlagConfigPoller = FlagConfigPoller(flagConfigApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics) + private val amplitudeFlagConfigUpdater = + if (flagConfigStreamApi != null) + FlagConfigFallbackRetryWrapper( + FlagConfigStreamer(flagConfigStreamApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), + amplitudeFlagConfigPoller, + ) + else amplitudeFlagConfigPoller + private val flagConfigUpdater = + if (flagConfigProxyApi != null) + FlagConfigFallbackRetryWrapper( + FlagConfigPoller(flagConfigProxyApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics), + amplitudeFlagConfigPoller + ) + else + amplitudeFlagConfigUpdater fun start() = lock.once { - refresh() - poller.scheduleWithFixedDelay( - { - try { - refresh() - } catch (t: Throwable) { - Logger.e("Refresh flag configs failed.", t) - } - }, - config.flagConfigPollerIntervalMillis, - config.flagConfigPollerIntervalMillis, - TimeUnit.MILLISECONDS - ) + flagConfigUpdater.start() if (cohortLoader != null) { poller.scheduleWithFixedDelay( { @@ -74,63 +78,7 @@ internal class DeploymentRunner( fun stop() { poller.shutdown() - } - - fun refresh() { - Logger.d("Refreshing flag configs.") - // Get updated flags from the network. - val flagConfigs = wrapMetrics( - metric = metrics::onFlagConfigFetch, - failure = metrics::onFlagConfigFetchFailure, - ) { - flagConfigApi.getFlagConfigs() - } - - // Remove flags that no longer exist. - val flagKeys = flagConfigs.map { it.key }.toSet() - flagConfigStorage.removeIf { !flagKeys.contains(it.key) } - - // Get all flags from storage - val storageFlags = flagConfigStorage.getFlagConfigs() - - // Load cohorts for each flag if applicable and put the flag in storage. - val futures = ConcurrentHashMap>() - for (flagConfig in flagConfigs) { - if (cohortLoader == null) { - flagConfigStorage.putFlagConfig(flagConfig) - continue - } - val cohortIds = flagConfig.getAllCohortIds() - val storageCohortIds = storageFlags[flagConfig.key]?.getAllCohortIds() ?: emptySet() - val cohortsToLoad = cohortIds - storageCohortIds - if (cohortsToLoad.isEmpty()) { - flagConfigStorage.putFlagConfig(flagConfig) - continue - } - for (cohortId in cohortsToLoad) { - futures.putIfAbsent( - cohortId, - cohortLoader.loadCohort(cohortId).handle { _, exception -> - if (exception != null) { - Logger.e("Failed to load cohort $cohortId", exception) - } - flagConfigStorage.putFlagConfig(flagConfig) - } - ) - } - } - futures.values.forEach { it.join() } - - // Delete unused cohorts - if (cohortStorage != null) { - val flagCohortIds = flagConfigStorage.getFlagConfigs().values.toList().getAllCohortIds() - val storageCohortIds = cohortStorage.getCohorts().keys - val deletedCohortIds = storageCohortIds - flagCohortIds - for (deletedCohortId in deletedCohortIds) { - cohortStorage.deleteCohort(deletedCohortId) - } - } - Logger.d("Refreshed ${flagConfigs.size} flag configs.") + flagConfigUpdater.shutdown() } private fun getCohortPollingInterval(): Long { diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt new file mode 100644 index 0000000..3d61313 --- /dev/null +++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt @@ -0,0 +1,134 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.util.* +import com.amplitude.experiment.util.SdkStream +import kotlinx.serialization.decodeFromString +import okhttp3.OkHttpClient +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ExecutionException +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicBoolean + +internal open class FlagConfigStreamApiError(message: String?, cause: Throwable?): Exception(message, cause) { + constructor(message: String?) : this(message, null) + constructor(cause: Throwable?) : this(cause?.toString(), cause) +} +internal class FlagConfigStreamApiConnTimeoutError: FlagConfigStreamApiError("Initial connection timed out") +internal class FlagConfigStreamApiDataCorruptError: FlagConfigStreamApiError("Stream data corrupted") +internal class FlagConfigStreamApiStreamError(cause: Throwable?): FlagConfigStreamApiError("Stream error", cause) + +private const val CONNECTION_TIMEOUT_MILLIS_DEFAULT = 2000L +private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 17000L +private const val RECONN_INTERVAL_MILLIS_DEFAULT = 15 * 60 * 1000L +internal class FlagConfigStreamApi ( + deploymentKey: String, + serverUrl: String, + httpClient: OkHttpClient = OkHttpClient(), + connectionTimeoutMillis: Long = CONNECTION_TIMEOUT_MILLIS_DEFAULT, + keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT, + reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT +) { + var onInitUpdate: ((List) -> Unit)? = null + var onUpdate: ((List) -> Unit)? = null + var onError: ((Exception?) -> Unit)? = null + private val stream: SdkStream = SdkStream( + "Api-Key $deploymentKey", + "$serverUrl/sdk/stream/v1/flags", + httpClient, + connectionTimeoutMillis, + keepaliveTimeoutMillis, + reconnIntervalMillis) + + fun connect() { + val isInit = AtomicBoolean(true) + val connectTimeoutFuture = CompletableFuture() + val updateTimeoutFuture = CompletableFuture() + stream.onUpdate = { data -> + if (isInit.getAndSet(false)) { + // Stream is establishing. First data received. + // Resolve timeout. + connectTimeoutFuture.complete(Unit) + + // Make sure valid data. + try { + val flags = getFlagsFromData(data) + + try { + if (onInitUpdate != null) { + onInitUpdate?.let { it(flags) } + } else { + onUpdate?.let { it(flags) } + } + updateTimeoutFuture.complete(Unit) + } catch (e: Throwable) { + updateTimeoutFuture.completeExceptionally(e) + } + } catch (_: Throwable) { + updateTimeoutFuture.completeExceptionally(FlagConfigStreamApiDataCorruptError()) + } + + } else { + // Stream has already established. + // Make sure valid data. + try { + val flags = getFlagsFromData(data) + + try { + onUpdate?.let { it(flags) } + } catch (_: Throwable) { + // Don't care about application error. + } + } catch (_: Throwable) { + // Stream corrupted. Reconnect. + handleError(FlagConfigStreamApiDataCorruptError()) + } + + } + } + stream.onError = { t -> + if (isInit.getAndSet(false)) { + connectTimeoutFuture.completeExceptionally(t) + updateTimeoutFuture.completeExceptionally(t) + } else { + handleError(FlagConfigStreamApiStreamError(t)) + } + } + stream.connect() + + val t: Throwable + try { + connectTimeoutFuture.get(2000, TimeUnit.MILLISECONDS) + updateTimeoutFuture.get() + return + } catch (e: TimeoutException) { + // Timeouts should retry + t = FlagConfigStreamApiConnTimeoutError() + } catch (e: ExecutionException) { + val cause = e.cause + t = if (cause is StreamException) { + FlagConfigStreamApiStreamError(cause) + } else { + FlagConfigStreamApiError(e) + } + } catch (e: Throwable) { + t = FlagConfigStreamApiError(e) + } + close() + throw t + } + + fun close() { + stream.cancel() + } + + private fun getFlagsFromData(data: String): List { + return json.decodeFromString>(data) + } + + private fun handleError(e: Exception?) { + close() + onError?.let { it(e) } + } +} diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt new file mode 100644 index 0000000..f882ab5 --- /dev/null +++ b/src/main/kotlin/flag/FlagConfigUpdater.kt @@ -0,0 +1,233 @@ +package com.amplitude.experiment.flag + +import com.amplitude.experiment.LocalEvaluationConfig +import com.amplitude.experiment.LocalEvaluationMetrics +import com.amplitude.experiment.cohort.CohortLoader +import com.amplitude.experiment.cohort.CohortStorage +import com.amplitude.experiment.evaluation.EvaluationFlag +import com.amplitude.experiment.util.* +import com.amplitude.experiment.util.Logger +import com.amplitude.experiment.util.daemonFactory +import com.amplitude.experiment.util.wrapMetrics +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Executors +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit + +internal interface FlagConfigUpdater { + // Start the updater. There can be multiple calls. + // If start fails, it should throw exception. The caller should handle fallback. + // If some other error happened while updating (already started successfully), it should call fallback. + fun start(fallback: (() -> Unit)? = null) + // Stop should stop updater temporarily. There may be another start in the future. + // To stop completely, with intention to never start again, use shutdown() instead. + fun stop() + // Destroy should stop the updater forever in preparation for server shutdown. + fun shutdown() +} + +internal abstract class FlagConfigUpdaterBase( + private val flagConfigStorage: FlagConfigStorage, + private val cohortLoader: CohortLoader?, + private val cohortStorage: CohortStorage?, +): FlagConfigUpdater { + fun update(flagConfigs: List) { + println("update") + // Remove flags that no longer exist. + val flagKeys = flagConfigs.map { it.key }.toSet() + flagConfigStorage.removeIf { !flagKeys.contains(it.key) } + + // Get all flags from storage + val storageFlags = flagConfigStorage.getFlagConfigs() + + // Load cohorts for each flag if applicable and put the flag in storage. + val futures = ConcurrentHashMap>() + for (flagConfig in flagConfigs) { + if (cohortLoader == null) { + flagConfigStorage.putFlagConfig(flagConfig) + continue + } + val cohortIds = flagConfig.getAllCohortIds() + val storageCohortIds = storageFlags[flagConfig.key]?.getAllCohortIds() ?: emptySet() + val cohortsToLoad = cohortIds - storageCohortIds + if (cohortsToLoad.isEmpty()) { + flagConfigStorage.putFlagConfig(flagConfig) + continue + } + for (cohortId in cohortsToLoad) { + futures.putIfAbsent( + cohortId, + cohortLoader.loadCohort(cohortId).handle { _, exception -> + if (exception != null) { + Logger.e("Failed to load cohort $cohortId", exception) + } + flagConfigStorage.putFlagConfig(flagConfig) + } + ) + } + } + futures.values.forEach { it.join() } + + // Delete unused cohorts + if (cohortStorage != null) { + val flagCohortIds = flagConfigStorage.getFlagConfigs().values.toList().getAllCohortIds() + val storageCohortIds = cohortStorage.getCohorts().keys + val deletedCohortIds = storageCohortIds - flagCohortIds + for (deletedCohortId in deletedCohortIds) { + cohortStorage.deleteCohort(deletedCohortId) + } + } + Logger.d("Refreshed ${flagConfigs.size} flag configs.") + } +} + +internal class FlagConfigPoller( + private val flagConfigApi: FlagConfigApi, + private val storage: FlagConfigStorage, + private val cohortLoader: CohortLoader?, + private val cohortStorage: CohortStorage?, + private val config: LocalEvaluationConfig, + private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper() +): FlagConfigUpdaterBase( + storage, cohortLoader, cohortStorage +) { + private val poller = Executors.newScheduledThreadPool(1, daemonFactory) + private var scheduledFuture: ScheduledFuture<*>? = null + override fun start(fallback: (() -> Unit)?) { + // Perform updates + refresh() + scheduledFuture = poller.scheduleWithFixedDelay( + { + try { + refresh() + } catch (t: Throwable) { + Logger.e("Refresh flag configs failed.", t) + stop() + fallback?.invoke() + } + }, + config.flagConfigPollerIntervalMillis, + config.flagConfigPollerIntervalMillis, + TimeUnit.MILLISECONDS + ) + } + + override fun stop() { + // Pause only stop the task scheduled. It doesn't stop the executor. + scheduledFuture?.cancel(true) + scheduledFuture = null + } + + override fun shutdown() { + // Stop the executor. + poller.shutdown() + } + + fun refresh() { + Logger.d("Refreshing flag configs.") + println("flag poller refreshing") + // Get updated flags from the network. + val flagConfigs = wrapMetrics( + metric = metrics::onFlagConfigFetch, + failure = metrics::onFlagConfigFetchFailure, + ) { + flagConfigApi.getFlagConfigs() + } + + update(flagConfigs) + println("flag poller refreshed") + } +} + +internal class FlagConfigStreamer( + private val flagConfigStreamApi: FlagConfigStreamApi, + private val storage: FlagConfigStorage, + private val cohortLoader: CohortLoader?, + private val cohortStorage: CohortStorage?, + private val config: LocalEvaluationConfig, + private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper() +): FlagConfigUpdaterBase( + storage, cohortLoader, cohortStorage +) { + override fun start(fallback: (() -> Unit)?) { + flagConfigStreamApi.onUpdate = {flags -> + println("flag streamer received") + update(flags) + } + flagConfigStreamApi.onError = {e -> + Logger.e("Stream flag configs streaming failed.", e) + metrics.onFlagConfigStreamFailure(e) + fallback?.invoke() + } + wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) { + flagConfigStreamApi.connect() + } + } + + override fun stop() { + flagConfigStreamApi.close() + } + + override fun shutdown() = stop() +} + +private const val RETRY_DELAY_MILLIS_DEFAULT = 15 * 1000L +private const val MAX_JITTER_MILLIS_DEFAULT = 2000L +internal class FlagConfigFallbackRetryWrapper( + private val mainUpdater: FlagConfigUpdater, + private val fallbackUpdater: FlagConfigUpdater, + private val retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT, + private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT +): FlagConfigUpdater { + private val reconnIntervalRange = (retryDelayMillis - maxJitterMillis)..(retryDelayMillis + maxJitterMillis) + private val executor = Executors.newScheduledThreadPool(1, daemonFactory) + private var retryTask: ScheduledFuture<*>? = null + + override fun start(fallback: (() -> Unit)?) { + try { + mainUpdater.start { + startRetry(fallback) // Don't care if poller start error or not, always retry. + try { + fallbackUpdater.start(fallback) + } catch (_: Throwable) { + fallback?.invoke() + } + } + } catch (t: Throwable) { + Logger.e("Update flag configs start failed.", t) + fallbackUpdater.start(fallback) // If fallback failed, don't retry. + startRetry(fallback) + } + } + + override fun stop() { + mainUpdater.stop() + fallbackUpdater.stop() + retryTask?.cancel(true) + } + + override fun shutdown() { + mainUpdater.shutdown() + fallbackUpdater.shutdown() + retryTask?.cancel(true) + } + + private fun startRetry(fallback: (() -> Unit)?) { + retryTask = executor.schedule({ + try { + mainUpdater.start { + startRetry(fallback) // Don't care if poller start error or not, always retry stream. + try { + fallbackUpdater.start(fallback) + } catch (_: Throwable) { + fallback?.invoke() + } + } + fallbackUpdater.stop() + } catch (_: Throwable) { + startRetry(fallback) + } + }, reconnIntervalRange.random(), TimeUnit.MILLISECONDS) + } +} \ No newline at end of file diff --git a/src/main/kotlin/util/Metrics.kt b/src/main/kotlin/util/Metrics.kt index 2748652..99345c5 100644 --- a/src/main/kotlin/util/Metrics.kt +++ b/src/main/kotlin/util/Metrics.kt @@ -66,6 +66,16 @@ internal class LocalEvaluationMetricsWrapper( executor?.execute { metrics.onFlagConfigFetchFailure(exception) } } + override fun onFlagConfigStream() { + val metrics = metrics ?: return + executor?.execute { metrics.onFlagConfigStream() } + } + + override fun onFlagConfigStreamFailure(exception: Exception?) { + val metrics = metrics ?: return + executor?.execute { metrics.onFlagConfigStreamFailure(exception) } + } + override fun onFlagConfigFetchOriginFallback(exception: Exception) { val metrics = metrics ?: return executor?.execute { metrics.onFlagConfigFetchOriginFallback(exception) } diff --git a/src/main/kotlin/util/SdkStream.kt b/src/main/kotlin/util/SdkStream.kt new file mode 100644 index 0000000..ec2292f --- /dev/null +++ b/src/main/kotlin/util/SdkStream.kt @@ -0,0 +1,122 @@ +package com.amplitude.experiment.util + +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.Response +import okhttp3.internal.http2.ErrorCode +import okhttp3.internal.http2.StreamResetException +import okhttp3.sse.EventSource +import okhttp3.sse.EventSourceListener +import okhttp3.sse.EventSources +import java.util.* +import java.util.concurrent.TimeUnit +import kotlin.concurrent.schedule + +internal class StreamException(error: String): Throwable(error) + +private const val RECONN_INTERVAL_MILLIS_DEFAULT = 30 * 60 * 1000L +private const val MAX_JITTER_MILLIS_DEFAULT = 5000L +internal class SdkStream ( + private val authToken: String, + private val serverUrl: String, + private val httpClient: OkHttpClient = OkHttpClient(), + private val connectionTimeoutMillis: Long, + private val keepaliveTimeoutMillis: Long, + private val reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT, + private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT +) { + private val reconnIntervalRange = (reconnIntervalMillis - maxJitterMillis)..(reconnIntervalMillis + maxJitterMillis) + private val eventSourceListener = object : EventSourceListener() { + override fun onOpen(eventSource: EventSource, response: Response) { + // No action needed. + } + + override fun onClosed(eventSource: EventSource) { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } + // Server closed the connection, just reconnect. + cancel() + connect() + } + + override fun onEvent( + eventSource: EventSource, + id: String?, + type: String?, + data: String + ) { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } + // Keep alive data + if (" " == data) { + return + } + onUpdate?.let { it(data) } + } + + override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { + if ((eventSource != es)) { + // Not the current event source using right now, should cancel. + eventSource.cancel() + return + } + if (t is StreamResetException && t.errorCode == ErrorCode.CANCEL) { + // TODO: relying on okhttp3.internal to differentiate cancel case. + return + } + cancel() + var err = t + if (t == null) { + err = if (response != null) { + StreamException(response.toString()) + } else { + StreamException("Unknown stream failure") + } + } + onError?.let { it(err) } + } + } + + private val request = Request.Builder() + .url(serverUrl) + .header("Authorization", authToken) + .addHeader("Accept", "text/event-stream") + .build() + + private val client = httpClient.newBuilder() // client.newBuilder reuses the connection pool in the same client with new configs. + .connectTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Connection timeout for establishing SSE. + .callTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Call timeout for establishing SSE. + .readTimeout(keepaliveTimeoutMillis, TimeUnit.MILLISECONDS) // Timeout between messages, keepalive in this case. + .writeTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) + .retryOnConnectionFailure(false) + .build() + + private var es: EventSource? = null + private var reconnectTimerTask: TimerTask? = null + var onUpdate: ((String) -> Unit)? = null + var onError: ((Throwable?) -> Unit)? = null + + fun connect() { + cancel() // Clear any existing event sources. + es = EventSources.createFactory(client).newEventSource(request = request, listener = eventSourceListener) + reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source. + // This forces client side reconnection after interval. + this@SdkStream.cancel() + connect() + } + } + + fun cancel() { + reconnectTimerTask?.cancel() + + // There can be cases where an event source is being cancelled by these calls, but take a long time and made a callback to onFailure callback. + es?.cancel() + es = null + } +} \ No newline at end of file diff --git a/src/test/kotlin/deployment/DeploymentRunnerTest.kt b/src/test/kotlin/deployment/DeploymentRunnerTest.kt index 8e58a79..9b74742 100644 --- a/src/test/kotlin/deployment/DeploymentRunnerTest.kt +++ b/src/test/kotlin/deployment/DeploymentRunnerTest.kt @@ -42,11 +42,11 @@ class DeploymentRunnerTest { val flagConfigStorage = Mockito.mock(FlagConfigStorage::class.java) val cohortStorage = Mockito.mock(CohortStorage::class.java) val runner = DeploymentRunner( - LocalEvaluationConfig(), - flagApi, - flagConfigStorage, - cohortApi, - cohortStorage, + config = LocalEvaluationConfig(), + flagConfigApi = flagApi, + flagConfigStorage = flagConfigStorage, + cohortApi = cohortApi, + cohortStorage = cohortStorage, ) Mockito.`when`(flagApi.getFlagConfigs()).thenThrow(RuntimeException("test")) try { @@ -71,10 +71,11 @@ class DeploymentRunnerTest { val flagConfigStorage = Mockito.mock(FlagConfigStorage::class.java) val cohortStorage = Mockito.mock(CohortStorage::class.java) val runner = DeploymentRunner( - LocalEvaluationConfig(), - flagApi, flagConfigStorage, - cohortApi, - cohortStorage, + config = LocalEvaluationConfig(), + flagConfigApi = flagApi, + flagConfigStorage = flagConfigStorage, + cohortApi = cohortApi, + cohortStorage = cohortStorage, ) Mockito.`when`(flagApi.getFlagConfigs()).thenReturn(listOf(flag)) Mockito.`when`(cohortApi.getCohort(COHORT_ID, null)).thenThrow(RuntimeException("test"))