From 5161a54ec5b19be29c5b248e9d32429d28c90c13 Mon Sep 17 00:00:00 2001 From: Daan De Deckere Date: Wed, 5 Jul 2023 13:53:52 +0200 Subject: [PATCH] Use tokio streams instead of crossbeam, crossbeam causes deadlocks --- Cargo.toml | 4 +--- src/client.rs | 51 +++++++++++++++++++++++++++------------------------ src/worker.rs | 15 +++++++++------ 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 57210c7..503ee34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,15 +17,13 @@ include = [ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -async-trait = "0.1.61" base64 = "0.21.0" bytes = { version = "1.3.0", features = ["serde"] } chrono = "0.4.23" -crossbeam = "0.8.2" log = "0.4.17" rmp-serde = "1.1.1" serde = { version = "1.0.152", features = ["derive"] } -tokio = { version = "1.24.2", features = ["net", "time", "io-util", "rt"] } +tokio = { version = "1.24.2", features = ["net", "time", "io-util", "rt", "sync"] } uuid = { version = "1.2.2", features = ["v4"] } [dev-dependencies] diff --git a/src/client.rs b/src/client.rs index 39e72dd..ac6d279 100644 --- a/src/client.rs +++ b/src/client.rs @@ -24,10 +24,12 @@ use std::net::SocketAddr; use std::time::Duration; -use async_trait::async_trait; use base64::{engine::general_purpose, Engine}; -use crossbeam::channel::{self, Sender}; -use tokio::{net::TcpStream, time::timeout}; +use tokio::{ + net::TcpStream, + sync::broadcast::{channel, Sender}, + time::timeout, +}; use uuid::Uuid; use crate::record::Map; @@ -79,10 +81,9 @@ impl Default for Config { } } -#[async_trait] pub trait FluentClient: Send + Sync { fn send(&self, tag: &'static str, record: Map) -> Result<(), SendError>; - async fn stop(self) -> Result<(), SendError>; + fn stop(self) -> Result<(), SendError>; } #[derive(Debug, Clone)] @@ -95,7 +96,7 @@ impl Client { /// Connect to the fluentd server and create a worker with tokio::spawn. pub async fn new(config: &Config) -> tokio::io::Result { let stream = timeout(config.timeout, TcpStream::connect(config.addr)).await??; - let (sender, receiver) = channel::unbounded(); + let (sender, receiver) = channel(1024); let config = config.clone(); let _ = tokio::spawn(async move { @@ -132,11 +133,11 @@ impl Client { .send(Message::Record(record)) .map_err(|e| SendError { source: e.to_string(), - }) + })?; + Ok(()) } } -#[async_trait] impl FluentClient for Client { /// Send a fluent record to the fluentd server. /// @@ -149,10 +150,13 @@ impl FluentClient for Client { } /// Stop the worker. - async fn stop(self) -> Result<(), SendError> { - self.sender.send(Message::Terminate).map_err(|e| SendError { - source: e.to_string(), - }) + fn stop(self) -> Result<(), SendError> { + self.sender + .send(Message::Terminate) + .map_err(|e| SendError { + source: e.to_string(), + })?; + Ok(()) } } @@ -167,13 +171,12 @@ impl Drop for Client { /// NopClient does nothing. pub struct NopClient; -#[async_trait] impl FluentClient for NopClient { - fn send(&self, _tag: &'static str, _record: Map) -> Result<(), SendError> { + fn send(&self, _tag: &str, _record: Map) -> Result<(), SendError> { Ok(()) } - async fn stop(self) -> Result<(), SendError> { + fn stop(self) -> Result<(), SendError> { Ok(()) } } @@ -191,7 +194,7 @@ mod tests { use crate::record::Value; use crate::record_map; - let (sender, receiver) = channel::unbounded(); + let (sender, mut receiver) = channel(1024); let client = Client { sender }; let timestamp = chrono::Utc.timestamp_opt(1234567, 0).unwrap().timestamp(); @@ -201,7 +204,7 @@ mod tests { "failed to send with time" ); - let got = receiver.recv().expect("failed to receive"); + let got = receiver.try_recv().expect("failed to receive"); match got { Message::Record(r) => { assert_eq!(r.tag, "test"); @@ -212,13 +215,13 @@ mod tests { } } - #[tokio::test] - async fn test_stop() { - let (sender, receiver) = channel::unbounded(); + #[test] + fn test_stop() { + let (sender, mut receiver) = channel(1024); let client = Client { sender }; - assert!(client.stop().await.is_ok(), "faled to stop"); + assert!(client.stop().is_ok(), "faled to stop"); - let got = receiver.recv().expect("failed to receive"); + let got = receiver.try_recv().expect("failed to receive"); match got { Message::Record(_) => unreachable!("got record message"), Message::Terminate => {} @@ -227,11 +230,11 @@ mod tests { #[test] fn test_client_drop_sends_terminate() { - let (sender, receiver) = channel::unbounded(); + let (sender, mut receiver) = channel(1024); { Client { sender }; } - let got = receiver.recv().expect("failed to receive"); + let got = receiver.try_recv().expect("failed to receive"); match got { Message::Record(_) => unreachable!("got record message"), Message::Terminate => {} diff --git a/src/worker.rs b/src/worker.rs index 1ad2193..4d563f5 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,11 +1,11 @@ use bytes::{Buf, BufMut}; -use crossbeam::channel::{self, Receiver}; use log::warn; use rmp_serde::Serializer; use serde::{ser::SerializeMap, Deserialize, Serialize}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, + sync::broadcast::{error::RecvError, Receiver}, time::Duration, }; @@ -37,7 +37,7 @@ impl std::fmt::Display for Error { } } -#[derive(Debug, Serialize)] +#[derive(Clone, Debug, Serialize)] pub struct Record { pub tag: &'static str, pub timestamp: i64, @@ -45,7 +45,7 @@ pub struct Record { pub options: Options, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Options { pub chunk: String, } @@ -61,6 +61,7 @@ impl Serialize for Options { } } +#[derive(Clone)] pub enum Message { Record(Record), Terminate, @@ -100,7 +101,7 @@ impl Worker { pub async fn run(&mut self) { loop { - match self.receiver.try_recv() { + match self.receiver.recv().await { Ok(Message::Record(record)) => { let record = match self.encode(record) { Ok(record) => record, @@ -115,8 +116,10 @@ impl Worker { Err(_) => continue, }; } - Err(channel::TryRecvError::Empty) => continue, - Ok(Message::Terminate) | Err(channel::TryRecvError::Disconnected) => break, + Err(RecvError::Closed) | Ok(Message::Terminate) => { + break; + } + Err(RecvError::Lagged(_)) => continue, } } }