From 2de463e572af46632fba7747881486ba6bfa771f Mon Sep 17 00:00:00 2001 From: Pushkar Mishra Date: Sat, 12 Oct 2024 00:48:59 +0530 Subject: [PATCH] Add RateLimiter for WebSocket (#1994) --- nautilus_core/network/src/python/websocket.rs | 34 ++++++++++++++++--- nautilus_core/network/src/websocket.rs | 9 +++++ nautilus_trader/core/nautilus_pyo3.pyi | 4 ++- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/nautilus_core/network/src/python/websocket.rs b/nautilus_core/network/src/python/websocket.rs index 598342025aff..4810d5de878f 100644 --- a/nautilus_core/network/src/python/websocket.rs +++ b/nautilus_core/network/src/python/websocket.rs @@ -74,6 +74,8 @@ impl WebSocketClient { post_connection: Option, post_reconnection: Option, post_disconnection: Option, + keyed_quotas: Option>, + default_quota: Option, py: Python<'_>, ) -> PyResult> { pyo3_asyncio_0_21::tokio::future_into_py(py, async move { @@ -82,6 +84,8 @@ impl WebSocketClient { post_connection, post_reconnection, post_disconnection, + keyed_quotas, + default_quota, ) .await .map_err(to_websocket_pyerr) @@ -143,21 +147,41 @@ impl WebSocketClient { }) } - /// Send UTF-8 encoded bytes as text data to the server. + /// Send UTF-8 encoded bytes as text data to the server, respecting rate limits. + /// + /// `data`: The byte data to be sent, which will be converted to a UTF-8 string. + /// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data. /// /// # Errors + /// - Raises `PyRuntimeError` if unable to send the data. /// - /// - Raises PyRuntimeError if not able to send data. + /// # Example + /// + /// When a request is made the URL should be split into all relevant keys within it. + /// + /// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting. #[pyo3(name = "send_text")] fn py_send_text<'py>( slf: PyRef<'_, Self>, data: Vec, py: Python<'py>, + keys: Option>, ) -> PyResult> { let data = String::from_utf8(data).map_err(to_pyvalue_err)?; - tracing::trace!("Sending text: {data}"); + let keys = keys.unwrap_or_default(); let writer = slf.writer.clone(); + let rate_limiter = slf.rate_limiter.clone(); pyo3_asyncio_0_21::tokio::future_into_py(py, async move { + let tasks = keys.iter().map(|key| rate_limiter.until_key_ready(key)); + stream::iter(tasks) + .for_each(|key| async move { + key.await; + }) + .await; + + // Log after passing rate limit checks + tracing::trace!("Sending text: {data}"); + let mut guard = writer.lock().await; guard .send(Message::Text(data)) @@ -337,7 +361,7 @@ counter = Counter()", None, None, ); - let client = WebSocketClient::connect(config, None, None, None) + let client = WebSocketClient::connect(config, None, None, None, None, None) .await .unwrap(); @@ -441,7 +465,7 @@ checker = Checker()", Some("heartbeat message".to_string()), None, ); - let client = WebSocketClient::connect(config, None, None, None) + let client = WebSocketClient::connect(config, None, None, None, None, None) .await .unwrap(); diff --git a/nautilus_core/network/src/websocket.rs b/nautilus_core/network/src/websocket.rs index 6f29dde7ff2b..be5ec3d84d92 100644 --- a/nautilus_core/network/src/websocket.rs +++ b/nautilus_core/network/src/websocket.rs @@ -37,6 +37,7 @@ use tokio_tungstenite::{ MaybeTlsStream, WebSocketStream, }; +use crate::ratelimiter::{clock::MonotonicClock, quota::Quota, RateLimiter}; type MessageWriter = SplitSink>, Message>; type SharedMessageWriter = Arc>, Message>>>; @@ -320,6 +321,7 @@ impl Drop for WebSocketClientInner { pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network") )] pub struct WebSocketClient { + pub(crate) rate_limiter: Arc>, pub(crate) writer: SharedMessageWriter, pub(crate) controller_task: task::JoinHandle<()>, pub(crate) disconnect_mode: Arc, @@ -335,6 +337,8 @@ impl WebSocketClient { post_connection: Option, post_reconnection: Option, post_disconnection: Option, + keyed_quotas: Option>, + default_quota: Option, ) -> Result { tracing::debug!("Connecting"); let inner = WebSocketClientInner::connect_url(config).await?; @@ -347,6 +351,10 @@ impl WebSocketClient { post_reconnection, post_disconnection, ); + let rate_limiter = Arc::new(RateLimiter::new_with_quota( + default_quota, + keyed_quotas.unwrap_or_default(), + )); if let Some(handler) = post_connection { Python::with_gil(|py| match handler.call0(py) { @@ -356,6 +364,7 @@ impl WebSocketClient { }; Ok(Self { + rate_limiter, writer, controller_task, disconnect_mode, diff --git a/nautilus_trader/core/nautilus_pyo3.pyi b/nautilus_trader/core/nautilus_pyo3.pyi index 6ed435779bdd..fc3c0b274096 100644 --- a/nautilus_trader/core/nautilus_pyo3.pyi +++ b/nautilus_trader/core/nautilus_pyo3.pyi @@ -2646,11 +2646,13 @@ class WebSocketClient: post_connection: Callable[..., None] | None = None, post_reconnection: Callable[..., None] | None = None, post_disconnection: Callable[..., None] | None = None, + keyed_quotas: list[tuple[str, Quota]] = [], + default_quota: Quota | None = None, ) -> Awaitable[WebSocketClient]: ... def disconnect(self) -> Awaitable[None]: ... def is_alive(self) -> bool: ... def send(self, data: bytes) -> Awaitable[None]: ... - def send_text(self, data: bytes) -> Awaitable[None]: ... + def send_text(self, data: bytes, keys: list[str] | None = None,) -> Awaitable[None]: ... def send_pong(self, data: bytes) -> Awaitable[None]: ... class SocketClient: