Skip to content

Commit

Permalink
Add RateLimiter for WebSocket (#1994)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pushkarm029 authored Oct 11, 2024
1 parent 43d2135 commit 2de463e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
34 changes: 29 additions & 5 deletions nautilus_core/network/src/python/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ impl WebSocketClient {
post_connection: Option<PyObject>,
post_reconnection: Option<PyObject>,
post_disconnection: Option<PyObject>,
keyed_quotas: Option<Vec<(String, Quota)>>,
default_quota: Option<Quota>,
py: Python<'_>,
) -> PyResult<Bound<PyAny>> {
pyo3_asyncio_0_21::tokio::future_into_py(py, async move {
Expand All @@ -82,6 +84,8 @@ impl WebSocketClient {
post_connection,
post_reconnection,
post_disconnection,
keyed_quotas,
default_quota,
)
.await
.map_err(to_websocket_pyerr)
Expand Down Expand Up @@ -143,21 +147,41 @@ impl WebSocketClient {
})
}

/// Send UTF-8 encoded bytes as text data to the server.
/// Send UTF-8 encoded bytes as text data to the server, respecting rate limits.
///
/// `data`: The byte data to be sent, which will be converted to a UTF-8 string.
/// `keys`: Optional list of rate limit keys. If provided, the function will wait for rate limits to be met for each key before sending the data.
///
/// # Errors
/// - Raises `PyRuntimeError` if unable to send the data.
///
/// - Raises PyRuntimeError if not able to send data.
/// # Example
///
/// When a request is made the URL should be split into all relevant keys within it.
///
/// For request /foo/bar, should pass keys ["foo/bar", "foo"] for rate limiting.
#[pyo3(name = "send_text")]
fn py_send_text<'py>(
slf: PyRef<'_, Self>,
data: Vec<u8>,
py: Python<'py>,
keys: Option<Vec<String>>,
) -> PyResult<Bound<'py, PyAny>> {
let data = String::from_utf8(data).map_err(to_pyvalue_err)?;
tracing::trace!("Sending text: {data}");
let keys = keys.unwrap_or_default();
let writer = slf.writer.clone();
let rate_limiter = slf.rate_limiter.clone();
pyo3_asyncio_0_21::tokio::future_into_py(py, async move {
let tasks = keys.iter().map(|key| rate_limiter.until_key_ready(key));
stream::iter(tasks)
.for_each(|key| async move {
key.await;
})
.await;

// Log after passing rate limit checks
tracing::trace!("Sending text: {data}");

let mut guard = writer.lock().await;
guard
.send(Message::Text(data))
Expand Down Expand Up @@ -337,7 +361,7 @@ counter = Counter()",
None,
None,
);
let client = WebSocketClient::connect(config, None, None, None)
let client = WebSocketClient::connect(config, None, None, None, None, None)
.await
.unwrap();

Expand Down Expand Up @@ -441,7 +465,7 @@ checker = Checker()",
Some("heartbeat message".to_string()),
None,
);
let client = WebSocketClient::connect(config, None, None, None)
let client = WebSocketClient::connect(config, None, None, None, None, None)
.await
.unwrap();

Expand Down
9 changes: 9 additions & 0 deletions nautilus_core/network/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream,
};

use crate::ratelimiter::{clock::MonotonicClock, quota::Quota, RateLimiter};
type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
type SharedMessageWriter =
Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>;
Expand Down Expand Up @@ -320,6 +321,7 @@ impl Drop for WebSocketClientInner {
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
)]
pub struct WebSocketClient {
pub(crate) rate_limiter: Arc<RateLimiter<String, MonotonicClock>>,
pub(crate) writer: SharedMessageWriter,
pub(crate) controller_task: task::JoinHandle<()>,
pub(crate) disconnect_mode: Arc<AtomicBool>,
Expand All @@ -335,6 +337,8 @@ impl WebSocketClient {
post_connection: Option<PyObject>,
post_reconnection: Option<PyObject>,
post_disconnection: Option<PyObject>,
keyed_quotas: Option<Vec<(String, Quota)>>,
default_quota: Option<Quota>,
) -> Result<Self, Error> {
tracing::debug!("Connecting");
let inner = WebSocketClientInner::connect_url(config).await?;
Expand All @@ -347,6 +351,10 @@ impl WebSocketClient {
post_reconnection,
post_disconnection,
);
let rate_limiter = Arc::new(RateLimiter::new_with_quota(
default_quota,
keyed_quotas.unwrap_or_default(),
));

if let Some(handler) = post_connection {
Python::with_gil(|py| match handler.call0(py) {
Expand All @@ -356,6 +364,7 @@ impl WebSocketClient {
};

Ok(Self {
rate_limiter,
writer,
controller_task,
disconnect_mode,
Expand Down
4 changes: 3 additions & 1 deletion nautilus_trader/core/nautilus_pyo3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2646,11 +2646,13 @@ class WebSocketClient:
post_connection: Callable[..., None] | None = None,
post_reconnection: Callable[..., None] | None = None,
post_disconnection: Callable[..., None] | None = None,
keyed_quotas: list[tuple[str, Quota]] = [],
default_quota: Quota | None = None,
) -> Awaitable[WebSocketClient]: ...
def disconnect(self) -> Awaitable[None]: ...
def is_alive(self) -> bool: ...
def send(self, data: bytes) -> Awaitable[None]: ...
def send_text(self, data: bytes) -> Awaitable[None]: ...
def send_text(self, data: bytes, keys: list[str] | None = None,) -> Awaitable[None]: ...
def send_pong(self, data: bytes) -> Awaitable[None]: ...

class SocketClient:
Expand Down

0 comments on commit 2de463e

Please sign in to comment.