Skip to content

Commit

Permalink
enhancement(vector source): implement client connection limits for gr…
Browse files Browse the repository at this point in the history
…pc server

Related: #19457
Related: #10728
  • Loading branch information
fpytloun committed Aug 14, 2024
1 parent 7f206cd commit 14066df
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/components/validation/runner/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ pub fn spawn_grpc_server<S>(

let server = run_grpc_server(
listen_addr.as_socket_addr(),
None,
None,
tls_settings,
service,
shutdown_signal,
Expand Down
103 changes: 103 additions & 0 deletions src/sources/util/grpc/connectionlimit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use std::{

Check failure on line 1 in src/sources/util/grpc/connectionlimit.rs

View workflow job for this annotation

GitHub Actions / Check Spelling

`connectionlimit` is not a recognized word. (check-file-path)
sync::{Arc, Mutex},
time::{Duration, Instant},
task::{Context, Poll},
};

use http::{Request, Response};
use hyper::Body;
use std::future::Future;
use std::pin::Pin;
use tonic::{body::BoxBody, Status};
use tower::{Layer, Service};
use futures_util::FutureExt;

/// A service that tracks the number of requests and elapsed time,
/// shutting down the connection gracefully if the configured limits are reached.
#[derive(Clone)]
pub struct ConnectionLimit<S> {
inner: S,
request_count: Arc<Mutex<usize>>,
max_requests: Option<usize>,
max_duration: Duration,
start_time: Instant,
}

impl<S> ConnectionLimit<S> {
pub fn new(inner: S, max_requests: Option<usize>, max_duration: Duration) -> Self {
Self {
inner,
request_count: Arc::new(Mutex::new(0)),
max_requests: max_requests.unwrap_or(usize::MAX), // Default to no limit if not set,
max_duration: max_duration.unwrap_or(Duration::from_secs(u64::MAX)), // Default to very long duration if not set
start_time: Instant::now(),
}
}
}

impl<S> Service<Request<Body>> for ConnectionLimit<S>
where
S: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: std::fmt::Display,
{
type Response = Response<BoxBody>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<Body>) -> Self::Future {
let max_requests = self.max_requests;
let max_duration = self.max_duration;
let request_count = Arc::clone(&self.request_count);
let start_time = self.start_time;

let elapsed_time = start_time.elapsed();

let future = self.inner.call(req);

Box::pin(async move {
let response = future.await?;

// After processing the request, increment the request count and check the limits.
let mut count = request_count.lock().unwrap();
*count += 1;

if *count > max_requests || elapsed_time > max_duration {
// If the limit is reached, return a ResourceExhausted error to close the connection.
return Err(Status::resource_exhausted(
"Connection closed after reaching the limit.",
));
}

Ok(response)
})
}
}

/// A layer that adds the ConnectionLimit functionality to a service.
#[derive(Clone, Default)]
pub struct ConnectionLimitLayer {
max_requests: Option<usize>,
max_duration: Duration,
}

impl ConnectionLimitLayer {
pub fn new(max_requests: Option<usize>, max_duration: Duration) -> Self {
Self {
max_requests,
max_duration,
}
}
}

impl<S> Layer<S> for ConnectionLimitLayer {
type Service = ConnectionLimit<S>;

fn layer(&self, inner: S) -> Self::Service {
ConnectionLimit::new(inner, self.max_requests, self.max_duration)
}
}
12 changes: 12 additions & 0 deletions src/sources/util/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ use tracing::Span;
mod decompression;
pub use self::decompression::{DecompressionAndMetrics, DecompressionAndMetricsLayer};

mod connectionlimit;

Check failure on line 25 in src/sources/util/grpc/mod.rs

View workflow job for this annotation

GitHub Actions / Check Spelling

`connectionlimit` is not a recognized word. (unrecognized-spelling)
pub use self::connectionlimit::{ConnectionLimit, ConnectionLimitLayer};

Check failure on line 26 in src/sources/util/grpc/mod.rs

View workflow job for this annotation

GitHub Actions / Check Spelling

`connectionlimit` is not a recognized word. (unrecognized-spelling)

pub async fn run_grpc_server<S>(
address: SocketAddr,
max_requests: Option<usize>,
max_duration: Option<Duration>,
tls_settings: MaybeTlsSettings,
service: S,
shutdown: ShutdownSignal,
Expand All @@ -43,6 +48,13 @@ where

info!(%address, "Building gRPC server.");

// Conditionally apply the ConnectionLimitLayer if any limits are set
let service = if max_requests.is_some() || max_duration.is_some() {
ConnectionLimitLayer::new(max_requests, max_duration).layer(service)
} else {
service
};

Server::builder()
.layer(build_grpc_trace_layer(span.clone()))
// This layer explicitly decompresses payloads, if compressed, and reports the number of message bytes we've
Expand Down
22 changes: 20 additions & 2 deletions src/sources/vector/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//! The `vector` source. See [VectorConfig].
use std::net::SocketAddr;
use std::{
net::SocketAddr,
time::Duration,
};

use chrono::Utc;
use futures::TryFutureExt;
Expand All @@ -26,6 +29,8 @@ use crate::{
SourceSender,
};

use serde_with::serde_as;

/// Marker type for version two of the configuration for the `vector` source.
#[configurable_component]
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -114,6 +119,7 @@ async fn handle_batch_status(receiver: Option<BatchStatusReceiver>) -> Result<()
}

/// Configuration for the `vector` source.
#[serde_as]
#[configurable_component(source("vector", "Collect observability data from a Vector instance."))]
#[derive(Clone, Debug)]
#[serde(deny_unknown_fields)]
Expand All @@ -134,6 +140,16 @@ pub struct VectorConfig {
#[serde(default, deserialize_with = "bool_or_struct")]
acknowledgements: SourceAcknowledgementsConfig,

/// Maximum duration of client connection before it is closed
#[serde_as(as = "serde_with::DurationSeconds<u64>")]
#[configurable(metadata(docs::human_name = "Max client connection duration"))]
max_duration: Option<Duration>,

/// Maximum number of client requests before connection is closed
#[configurable(metadata(docs::type_unit = "requests"))]
#[configurable(metadata(docs::human_name = "Max client requests before connection is closed"))]
max_requests: Option<usize>,

/// The namespace to use for logs. This overrides the global setting.
#[serde(default)]
#[configurable(metadata(docs::hidden))]
Expand All @@ -157,6 +173,8 @@ impl Default for VectorConfig {
address: "0.0.0.0:6000".parse().unwrap(),
tls: None,
acknowledgements: Default::default(),
max_requests: None,
max_duration: None,
log_namespace: None,
}
}
Expand Down Expand Up @@ -186,7 +204,7 @@ impl SourceConfig for VectorConfig {
.max_decoding_message_size(usize::MAX);

let source =
run_grpc_server(self.address, tls_settings, service, cx.shutdown).map_err(|error| {
run_grpc_server(self.address, self.max_requests, self.max_duration, tls_settings, service, cx.shutdown).map_err(|error| {
error!(message = "Source future failed.", %error);
});

Expand Down

0 comments on commit 14066df

Please sign in to comment.