From 027196ff7b2770e8187c3db4c721fe6a54e3d53d Mon Sep 17 00:00:00 2001
From: Peter Zhu <peter.zhu@amplitude.com>
Date: Mon, 26 Aug 2024 15:01:07 -0700
Subject: [PATCH] added config, minor renames

---
 buildSrc/src/main/kotlin/Versions.kt          |  2 +-
 src/main/kotlin/LocalEvaluationClient.kt      | 16 ++++++-
 src/main/kotlin/LocalEvaluationConfig.kt      | 30 +++++++++++++
 src/main/kotlin/ServerZone.kt                 |  2 +
 src/main/kotlin/flag/FlagConfigStreamApi.kt   | 14 ++++---
 src/main/kotlin/flag/FlagConfigUpdater.kt     | 42 +++++++++----------
 src/main/kotlin/util/Request.kt               |  2 +-
 .../util/{SdkStream.kt => SseStream.kt}       | 20 ++++-----
 8 files changed, 86 insertions(+), 42 deletions(-)
 rename src/main/kotlin/util/{SdkStream.kt => SseStream.kt} (88%)

diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt
index b1ed2c0..098ce65 100644
--- a/buildSrc/src/main/kotlin/Versions.kt
+++ b/buildSrc/src/main/kotlin/Versions.kt
@@ -6,7 +6,7 @@ object Versions {
     const val serializationRuntime = "1.4.1"
     const val json = "20231013"
     const val okhttp = "4.12.0"
-    const val okhttpSse = "4.12.0" // Update this alongside okhttp. Note this library isn't stable and may contain breaking changes.
+    const val okhttpSse = "4.12.0" // Update this alongside okhttp. Note this library isn't stable and may contain breaking changes. Search uses of okhttp3.internal classes before updating.
     const val evaluationCore = "2.0.0-beta.2"
     const val amplitudeAnalytics = "1.12.0"
     const val mockk = "1.13.9"
diff --git a/src/main/kotlin/LocalEvaluationClient.kt b/src/main/kotlin/LocalEvaluationClient.kt
index a25e143..a9b9e5d 100644
--- a/src/main/kotlin/LocalEvaluationClient.kt
+++ b/src/main/kotlin/LocalEvaluationClient.kt
@@ -43,12 +43,13 @@ class LocalEvaluationClient internal constructor(
 ) {
     private val assignmentService: AssignmentService? = createAssignmentService(apiKey)
     private val serverUrl: HttpUrl = getServerUrl(config)
+    private val streamServerUrl: HttpUrl = getStreamServerUrl(config)
     private val evaluation: EvaluationEngine = EvaluationEngineImpl()
     private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(config.metrics)
     private val flagConfigApi = DynamicFlagConfigApi(apiKey, serverUrl, null, httpClient)
     private val proxyUrl: HttpUrl? = getProxyUrl(config)
     private val flagConfigProxyApi = if (proxyUrl == null) null else DynamicFlagConfigApi(apiKey, proxyUrl, null, httpClient)
-    private val flagConfigStreamApi = FlagConfigStreamApi(apiKey, "https://stream.lab.amplitude.com", httpClient)
+    private val flagConfigStreamApi = if (config.streamUpdates) FlagConfigStreamApi(apiKey, streamServerUrl, httpClient, config.streamFlagConnTimeoutMillis) else null
     private val flagConfigStorage = InMemoryFlagConfigStorage()
     private val cohortStorage = if (config.cohortSyncConfig == null) {
         null
@@ -192,6 +193,17 @@ private fun getServerUrl(config: LocalEvaluationConfig): HttpUrl {
     }
 }
 
+private fun getStreamServerUrl(config: LocalEvaluationConfig): HttpUrl {
+    return if (config.streamServerUrl == LocalEvaluationConfig.Defaults.STREAM_SERVER_URL) {
+        when (config.serverZone) {
+            ServerZone.US -> US_STREAM_SERVER_URL.toHttpUrl()
+            ServerZone.EU -> EU_STREAM_SERVER_URL.toHttpUrl()
+        }
+    } else {
+        config.streamServerUrl.toHttpUrl()
+    }
+}
+
 private fun getProxyUrl(config: LocalEvaluationConfig): HttpUrl? {
     return config.evaluationProxyConfig?.proxyUrl?.toHttpUrl()
 }
@@ -223,7 +235,7 @@ private fun getEventServerUrl(
 }
 
 fun main() {
-    val client = LocalEvaluationClient("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz")
+    val client = LocalEvaluationClient("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", LocalEvaluationConfig(streamUpdates = true))
     client.start()
     println(client.evaluateV2(ExperimentUser("1")))
 }
\ No newline at end of file
diff --git a/src/main/kotlin/LocalEvaluationConfig.kt b/src/main/kotlin/LocalEvaluationConfig.kt
index e271610..eeab7cf 100644
--- a/src/main/kotlin/LocalEvaluationConfig.kt
+++ b/src/main/kotlin/LocalEvaluationConfig.kt
@@ -22,6 +22,12 @@ class LocalEvaluationConfig internal constructor(
     @JvmField
     val flagConfigPollerRequestTimeoutMillis: Long = Defaults.FLAG_CONFIG_POLLER_REQUEST_TIMEOUT_MILLIS,
     @JvmField
+    val streamUpdates: Boolean = Defaults.STREAM_UPDATES,
+    @JvmField
+    val streamServerUrl: String = Defaults.STREAM_SERVER_URL,
+    @JvmField
+    val streamFlagConnTimeoutMillis: Long = Defaults.STREAM_FLAG_CONN_TIMEOUT_MILLIS,
+    @JvmField
     val assignmentConfiguration: AssignmentConfiguration? = Defaults.ASSIGNMENT_CONFIGURATION,
     @JvmField
     val cohortSyncConfig: CohortSyncConfig? = Defaults.COHORT_SYNC_CONFIGURATION,
@@ -76,6 +82,12 @@ class LocalEvaluationConfig internal constructor(
          */
         const val FLAG_CONFIG_POLLER_REQUEST_TIMEOUT_MILLIS = 10_000L
 
+        const val STREAM_UPDATES = false
+
+        const val STREAM_SERVER_URL = US_STREAM_SERVER_URL
+
+        const val STREAM_FLAG_CONN_TIMEOUT_MILLIS = 1_500L
+
         /**
          * null
          */
@@ -111,6 +123,9 @@ class LocalEvaluationConfig internal constructor(
         private var serverUrl = Defaults.SERVER_URL
         private var flagConfigPollerIntervalMillis = Defaults.FLAG_CONFIG_POLLER_INTERVAL_MILLIS
         private var flagConfigPollerRequestTimeoutMillis = Defaults.FLAG_CONFIG_POLLER_REQUEST_TIMEOUT_MILLIS
+        private var streamUpdates = Defaults.STREAM_UPDATES
+        private var streamServerUrl = Defaults.STREAM_SERVER_URL
+        private var streamFlagConnTimeoutMillis = Defaults.STREAM_FLAG_CONN_TIMEOUT_MILLIS
         private var assignmentConfiguration = Defaults.ASSIGNMENT_CONFIGURATION
         private var cohortSyncConfiguration = Defaults.COHORT_SYNC_CONFIGURATION
         private var evaluationProxyConfiguration = Defaults.EVALUATION_PROXY_CONFIGURATION
@@ -136,6 +151,18 @@ class LocalEvaluationConfig internal constructor(
             this.flagConfigPollerRequestTimeoutMillis = flagConfigPollerRequestTimeoutMillis
         }
 
+        fun streamUpdates(streamUpdates: Boolean) = apply {
+            this.streamUpdates = streamUpdates
+        }
+
+        fun streamServerUrl(streamServerUrl: String) = apply {
+            this.streamServerUrl = streamServerUrl
+        }
+
+        fun streamFlagConnTimeoutMillis(streamFlagConnTimeoutMillis: Long) = apply {
+            this.streamFlagConnTimeoutMillis = streamFlagConnTimeoutMillis
+        }
+
         fun enableAssignmentTracking(assignmentConfiguration: AssignmentConfiguration) = apply {
             this.assignmentConfiguration = assignmentConfiguration
         }
@@ -161,6 +188,9 @@ class LocalEvaluationConfig internal constructor(
                 serverZone = serverZone,
                 flagConfigPollerIntervalMillis = flagConfigPollerIntervalMillis,
                 flagConfigPollerRequestTimeoutMillis = flagConfigPollerRequestTimeoutMillis,
+                streamUpdates = streamUpdates,
+                streamServerUrl = streamServerUrl,
+                streamFlagConnTimeoutMillis = streamFlagConnTimeoutMillis,
                 assignmentConfiguration = assignmentConfiguration,
                 cohortSyncConfig = cohortSyncConfiguration,
                 evaluationProxyConfig = evaluationProxyConfiguration,
diff --git a/src/main/kotlin/ServerZone.kt b/src/main/kotlin/ServerZone.kt
index c5d5dc3..c658ded 100644
--- a/src/main/kotlin/ServerZone.kt
+++ b/src/main/kotlin/ServerZone.kt
@@ -2,6 +2,8 @@ package com.amplitude.experiment
 
 internal const val US_SERVER_URL = "https://api.lab.amplitude.com"
 internal const val EU_SERVER_URL = "https://api.lab.eu.amplitude.com"
+internal const val US_STREAM_SERVER_URL = "https://stream.lab.amplitude.com"
+internal const val EU_STREAM_SERVER_URL = "https://stream.lab.eu.amplitude.com"
 internal const val US_COHORT_SERVER_URL = "https://cohort-v2.lab.amplitude.com"
 internal const val EU_COHORT_SERVER_URL = "https://cohort-v2.lab.eu.amplitude.com"
 internal const val US_EVENT_SERVER_URL = "https://api2.amplitude.com/2/httpapi"
diff --git a/src/main/kotlin/flag/FlagConfigStreamApi.kt b/src/main/kotlin/flag/FlagConfigStreamApi.kt
index 3d61313..919e936 100644
--- a/src/main/kotlin/flag/FlagConfigStreamApi.kt
+++ b/src/main/kotlin/flag/FlagConfigStreamApi.kt
@@ -2,8 +2,9 @@ package com.amplitude.experiment.flag
 
 import com.amplitude.experiment.evaluation.EvaluationFlag
 import com.amplitude.experiment.util.*
-import com.amplitude.experiment.util.SdkStream
+import com.amplitude.experiment.util.SseStream
 import kotlinx.serialization.decodeFromString
+import okhttp3.HttpUrl
 import okhttp3.OkHttpClient
 import java.util.concurrent.CompletableFuture
 import java.util.concurrent.ExecutionException
@@ -19,12 +20,12 @@ internal class FlagConfigStreamApiConnTimeoutError: FlagConfigStreamApiError("In
 internal class FlagConfigStreamApiDataCorruptError: FlagConfigStreamApiError("Stream data corrupted")
 internal class FlagConfigStreamApiStreamError(cause: Throwable?): FlagConfigStreamApiError("Stream error", cause)
 
-private const val CONNECTION_TIMEOUT_MILLIS_DEFAULT = 2000L
-private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 17000L
+private const val CONNECTION_TIMEOUT_MILLIS_DEFAULT = 1500L
+private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 17000L // keep alive sends at 15s interval. 2s grace period
 private const val RECONN_INTERVAL_MILLIS_DEFAULT = 15 * 60 * 1000L
 internal class FlagConfigStreamApi (
     deploymentKey: String,
-    serverUrl: String,
+    serverUrl: HttpUrl,
     httpClient: OkHttpClient = OkHttpClient(),
     connectionTimeoutMillis: Long = CONNECTION_TIMEOUT_MILLIS_DEFAULT,
     keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT,
@@ -33,9 +34,10 @@ internal class FlagConfigStreamApi (
     var onInitUpdate: ((List<EvaluationFlag>) -> Unit)? = null
     var onUpdate: ((List<EvaluationFlag>) -> Unit)? = null
     var onError: ((Exception?) -> Unit)? = null
-    private val stream: SdkStream = SdkStream(
+    val url = serverUrl.newBuilder().addPathSegments("sdk/stream/v1/flags").build()
+    private val stream: SseStream = SseStream(
         "Api-Key $deploymentKey",
-        "$serverUrl/sdk/stream/v1/flags",
+        url,
         httpClient,
         connectionTimeoutMillis,
         keepaliveTimeoutMillis,
diff --git a/src/main/kotlin/flag/FlagConfigUpdater.kt b/src/main/kotlin/flag/FlagConfigUpdater.kt
index f882ab5..0cf3616 100644
--- a/src/main/kotlin/flag/FlagConfigUpdater.kt
+++ b/src/main/kotlin/flag/FlagConfigUpdater.kt
@@ -17,9 +17,9 @@ import java.util.concurrent.TimeUnit
 
 internal interface FlagConfigUpdater {
     // Start the updater. There can be multiple calls.
-    // If start fails, it should throw exception. The caller should handle fallback.
-    // If some other error happened while updating (already started successfully), it should call fallback.
-    fun start(fallback: (() -> Unit)? = null)
+    // If start fails, it should throw exception. The caller should handle error.
+    // If some other error happened while updating (already started successfully), it should call onError.
+    fun start(onError: (() -> Unit)? = null)
     // Stop should stop updater temporarily. There may be another start in the future.
     // To stop completely, with intention to never start again, use shutdown() instead.
     fun stop()
@@ -94,8 +94,7 @@ internal class FlagConfigPoller(
 ) {
     private val poller = Executors.newScheduledThreadPool(1, daemonFactory)
     private var scheduledFuture: ScheduledFuture<*>? = null
-    override fun start(fallback: (() -> Unit)?) {
-        // Perform updates
+    override fun start(onError: (() -> Unit)?) {
         refresh()
         scheduledFuture = poller.scheduleWithFixedDelay(
             {
@@ -104,7 +103,7 @@ internal class FlagConfigPoller(
                 } catch (t: Throwable) {
                     Logger.e("Refresh flag configs failed.", t)
                     stop()
-                    fallback?.invoke()
+                    onError?.invoke()
                 }
             },
             config.flagConfigPollerIntervalMillis,
@@ -124,7 +123,7 @@ internal class FlagConfigPoller(
         poller.shutdown()
     }
 
-    fun refresh() {
+    private fun refresh() {
         Logger.d("Refreshing flag configs.")
         println("flag poller refreshing")
         // Get updated flags from the network.
@@ -150,15 +149,14 @@ internal class FlagConfigStreamer(
 ): FlagConfigUpdaterBase(
     storage, cohortLoader, cohortStorage
 ) {
-    override fun start(fallback: (() -> Unit)?) {
+    override fun start(onError: (() -> Unit)?) {
         flagConfigStreamApi.onUpdate = {flags ->
-            println("flag streamer received")
             update(flags)
         }
         flagConfigStreamApi.onError = {e ->
             Logger.e("Stream flag configs streaming failed.", e)
             metrics.onFlagConfigStreamFailure(e)
-            fallback?.invoke()
+            onError?.invoke()
         }
         wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) {
             flagConfigStreamApi.connect()
@@ -184,20 +182,20 @@ internal class FlagConfigFallbackRetryWrapper(
     private val executor = Executors.newScheduledThreadPool(1, daemonFactory)
     private var retryTask: ScheduledFuture<*>? = null
 
-    override fun start(fallback: (() -> Unit)?) {
+    override fun start(onError: (() -> Unit)?) {
         try {
             mainUpdater.start {
-                startRetry(fallback) // Don't care if poller start error or not, always retry.
+                scheduleRetry(onError) // Don't care if poller start error or not, always retry.
                 try {
-                    fallbackUpdater.start(fallback)
+                    fallbackUpdater.start(onError)
                 } catch (_: Throwable) {
-                    fallback?.invoke()
+                    onError?.invoke()
                 }
             }
         } catch (t: Throwable) {
-            Logger.e("Update flag configs start failed.", t)
-            fallbackUpdater.start(fallback) // If fallback failed, don't retry.
-            startRetry(fallback)
+            Logger.e("Primary flag configs start failed, start fallback. Error: ", t)
+            fallbackUpdater.start(onError) // If fallback failed, don't retry.
+            scheduleRetry(onError)
         }
     }
 
@@ -213,20 +211,20 @@ internal class FlagConfigFallbackRetryWrapper(
         retryTask?.cancel(true)
     }
 
-    private fun startRetry(fallback: (() -> Unit)?) {
+    private fun scheduleRetry(onError: (() -> Unit)?) {
         retryTask = executor.schedule({
             try {
                 mainUpdater.start {
-                    startRetry(fallback) // Don't care if poller start error or not, always retry stream.
+                    scheduleRetry(onError) // Don't care if poller start error or not, always retry stream.
                     try {
-                        fallbackUpdater.start(fallback)
+                        fallbackUpdater.start(onError)
                     } catch (_: Throwable) {
-                        fallback?.invoke()
+                        onError?.invoke()
                     }
                 }
                 fallbackUpdater.stop()
             } catch (_: Throwable) {
-                startRetry(fallback)
+                scheduleRetry(onError)
             }
         }, reconnIntervalRange.random(), TimeUnit.MILLISECONDS)
     }
diff --git a/src/main/kotlin/util/Request.kt b/src/main/kotlin/util/Request.kt
index 2bd09fb..1a826df 100644
--- a/src/main/kotlin/util/Request.kt
+++ b/src/main/kotlin/util/Request.kt
@@ -60,7 +60,7 @@ private fun OkHttpClient.submit(
     return future
 }
 
-private fun newGet(
+internal fun newGet(
     serverUrl: HttpUrl,
     path: String? = null,
     headers: Map<String, String>? = null,
diff --git a/src/main/kotlin/util/SdkStream.kt b/src/main/kotlin/util/SseStream.kt
similarity index 88%
rename from src/main/kotlin/util/SdkStream.kt
rename to src/main/kotlin/util/SseStream.kt
index ec2292f..2776fd4 100644
--- a/src/main/kotlin/util/SdkStream.kt
+++ b/src/main/kotlin/util/SseStream.kt
@@ -1,5 +1,7 @@
 package com.amplitude.experiment.util
 
+import com.amplitude.experiment.LIBRARY_VERSION
+import okhttp3.HttpUrl
 import okhttp3.OkHttpClient
 import okhttp3.Request
 import okhttp3.Response
@@ -14,14 +16,15 @@ import kotlin.concurrent.schedule
 
 internal class StreamException(error: String): Throwable(error)
 
+private const val KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT = 0L // no timeout
 private const val RECONN_INTERVAL_MILLIS_DEFAULT = 30 * 60 * 1000L
 private const val MAX_JITTER_MILLIS_DEFAULT = 5000L
-internal class SdkStream (
+internal class SseStream (
     private val authToken: String,
-    private val serverUrl: String,
+    private val url: HttpUrl,
     private val httpClient: OkHttpClient = OkHttpClient(),
     private val connectionTimeoutMillis: Long,
-    private val keepaliveTimeoutMillis: Long,
+    private val keepaliveTimeoutMillis: Long = KEEP_ALIVE_TIMEOUT_MILLIS_DEFAULT,
     private val reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT,
     private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT
 ) {
@@ -67,7 +70,8 @@ internal class SdkStream (
                 return
             }
             if (t is StreamResetException && t.errorCode == ErrorCode.CANCEL) {
-                // TODO: relying on okhttp3.internal to differentiate cancel case.
+                // Relying on okhttp3.internal to differentiate cancel case.
+                // Can be a pitfall later on.
                 return
             }
             cancel()
@@ -83,11 +87,7 @@ internal class SdkStream (
         }
     }
 
-    private val request = Request.Builder()
-        .url(serverUrl)
-        .header("Authorization", authToken)
-        .addHeader("Accept", "text/event-stream")
-        .build()
+    private val request = newGet(url, null, mapOf("Authorization" to authToken, "Accept" to "text/event-stream"))
 
     private val client = httpClient.newBuilder() // client.newBuilder reuses the connection pool in the same client with new configs.
         .connectTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS) // Connection timeout for establishing SSE.
@@ -107,7 +107,7 @@ internal class SdkStream (
         es = EventSources.createFactory(client).newEventSource(request = request, listener = eventSourceListener)
         reconnectTimerTask = Timer().schedule(reconnIntervalRange.random()) {// Timer for a new event source.
             // This forces client side reconnection after interval.
-            this@SdkStream.cancel()
+            this@SseStream.cancel()
             connect()
         }
     }