From 048c297fc70484db7ba6e424987a5fcc9ca4c824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Frosteg=C3=A5rd?= Date: Thu, 16 Jan 2025 20:13:28 +0100 Subject: [PATCH] udp: open one socket each for IPv4 and IPv6 (#220) * tmp work on udp double sockets * WIP: udp: open two sockets (one for ipv4, one for ipv6) io_uring not ported yet * udp: open one socket each for IPv4 and IPv6 Config file now has one setting for each * file transfer ci: fix udp network.address_ipv4 --- .../actions/test-file-transfers/entrypoint.sh | 2 +- CHANGELOG.md | 9 + crates/bencher/src/protocols/udp.rs | 3 +- crates/http/src/config.rs | 4 +- crates/toml_config/src/lib.rs | 4 +- crates/udp/README.md | 4 +- crates/udp/src/config.rs | 59 ++-- crates/udp/src/lib.rs | 23 +- crates/udp/src/workers/socket/mio.rs | 333 ------------------ crates/udp/src/workers/socket/mio/mod.rs | 194 ++++++++++ crates/udp/src/workers/socket/mio/socket.rs | 322 +++++++++++++++++ crates/udp/src/workers/socket/mod.rs | 57 +-- crates/udp/src/workers/socket/uring/mod.rs | 179 ++++++++-- .../src/workers/socket/uring/recv_helper.rs | 139 +++++--- .../src/workers/socket/uring/send_buffers.rs | 43 ++- crates/udp/tests/access_list.rs | 3 +- crates/udp/tests/invalid_connection_id.rs | 3 +- crates/udp/tests/requests_responses.rs | 3 +- crates/ws/src/config.rs | 4 +- 19 files changed, 863 insertions(+), 525 deletions(-) delete mode 100644 crates/udp/src/workers/socket/mio.rs create mode 100644 crates/udp/src/workers/socket/mio/mod.rs create mode 100644 crates/udp/src/workers/socket/mio/socket.rs diff --git a/.github/actions/test-file-transfers/entrypoint.sh b/.github/actions/test-file-transfers/entrypoint.sh index f585ddc3..868c88c8 100755 --- a/.github/actions/test-file-transfers/entrypoint.sh +++ b/.github/actions/test-file-transfers/entrypoint.sh @@ -59,7 +59,7 @@ echo " log_level = 'debug' [network] -address = '127.0.0.1:3000'" > udp.toml +address_ipv4 = '127.0.0.1:3000'" > udp.toml ./target/debug/aquatic udp -c udp.toml > "$HOME/udp.log" 2>&1 & # HTTP diff --git a/CHANGELOG.md b/CHANGELOG.md index d1529bb6..0ff855a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## Unreleased + +### aquatic_udp + +#### Changed + +* (Breaking) Open one socket each for IPv4 and IPv6. The config file now has + one setting for each. + ## 0.9.0 - 2024-04-03 ### General diff --git a/crates/bencher/src/protocols/udp.rs b/crates/bencher/src/protocols/udp.rs index a514d044..62058ef3 100644 --- a/crates/bencher/src/protocols/udp.rs +++ b/crates/bencher/src/protocols/udp.rs @@ -300,7 +300,8 @@ impl ProcessRunner for AquaticUdpRunner { let mut c = aquatic_udp::config::Config::default(); c.socket_workers = self.socket_workers; - c.network.address = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3000)); + c.network.address_ipv4 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3000); + c.network.use_ipv6 = false; c.network.use_io_uring = self.use_io_uring; c.protocol.max_response_peers = 30; diff --git a/crates/http/src/config.rs b/crates/http/src/config.rs index 958985be..644844bb 100644 --- a/crates/http/src/config.rs +++ b/crates/http/src/config.rs @@ -71,13 +71,13 @@ impl aquatic_common::cli::Config for Config { #[serde(default, deny_unknown_fields)] pub struct NetworkConfig { /// Bind to this address - /// + /// /// When providing an IPv4 style address, only IPv4 traffic will be /// handled. Examples: /// - "0.0.0.0:3000" binds to port 3000 on all network interfaces /// - "127.0.0.1:3000" binds to port 3000 on the loopback interface /// (localhost) - /// + /// /// When it comes to IPv6-style addresses, behaviour is more complex and /// differs between operating systems. On Linux, to accept both IPv4 and /// IPv6 traffic on any interface, use "[::]:3000". Set the "only_ipv6" diff --git a/crates/toml_config/src/lib.rs b/crates/toml_config/src/lib.rs index 42745d36..ed778030 100644 --- a/crates/toml_config/src/lib.rs +++ b/crates/toml_config/src/lib.rs @@ -75,7 +75,7 @@ pub trait TomlConfig: Default { } pub mod __private { - use std::net::SocketAddr; + use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::path::PathBuf; pub trait Private { @@ -123,4 +123,6 @@ pub mod __private { impl_trait!(PathBuf); impl_trait!(SocketAddr); + impl_trait!(SocketAddrV4); + impl_trait!(SocketAddrV6); } diff --git a/crates/udp/README.md b/crates/udp/README.md index 39776fc4..8af68b95 100644 --- a/crates/udp/README.md +++ b/crates/udp/README.md @@ -53,8 +53,8 @@ Generate the configuration file: ./target/release/aquatic_udp -p > "aquatic-udp-config.toml" ``` -Make necessary adjustments to the file. You will likely want to adjust `address` -(listening address) under the `network` section. +Make necessary adjustments to the file. You will likely want to adjust +listening addresses under the `network` section. Once done, start the application: diff --git a/crates/udp/src/config.rs b/crates/udp/src/config.rs index 9fe1c71c..0d3e6bc7 100644 --- a/crates/udp/src/config.rs +++ b/crates/udp/src/config.rs @@ -1,4 +1,7 @@ -use std::{net::SocketAddr, path::PathBuf}; +use std::{ + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + path::PathBuf, +}; use aquatic_common::{access_list::AccessListConfig, privileges::PrivilegeConfig}; use cfg_if::cfg_if; @@ -54,25 +57,24 @@ impl aquatic_common::cli::Config for Config { #[derive(Clone, Debug, PartialEq, TomlConfig, Deserialize, Serialize)] #[serde(default, deny_unknown_fields)] pub struct NetworkConfig { - /// Bind to this address - /// - /// When providing an IPv4 style address, only IPv4 traffic will be - /// handled. Examples: - /// - "0.0.0.0:3000" binds to port 3000 on all network interfaces - /// - "127.0.0.1:3000" binds to port 3000 on the loopback interface - /// (localhost) - /// - /// When it comes to IPv6-style addresses, behaviour is more complex and - /// differs between operating systems. On Linux, to accept both IPv4 and - /// IPv6 traffic on any interface, use "[::]:3000". Set the "only_ipv6" - /// flag below to limit traffic to IPv6. To bind to the loopback interface - /// and only accept IPv6 packets, use "[::1]:3000" and set the only_ipv6 - /// flag. Receiving both IPv4 and IPv6 traffic on loopback is currently - /// not supported. For other operating systems, please refer to their - /// respective documentation. - pub address: SocketAddr, - /// Only allow access over IPv6 - pub only_ipv6: bool, + /// Use IPv4 + pub use_ipv4: bool, + /// Use IPv6 + pub use_ipv6: bool, + /// IPv4 address and port + /// + /// Examples: + /// - Use 0.0.0.0:3000 to bind to all interfaces on port 3000 + /// - Use 127.0.0.1:3000 to bind to the loopback interface (localhost) on + /// port 3000 + pub address_ipv4: SocketAddrV4, + /// IPv6 address and port + /// + /// Examples: + /// - Use [::]:3000 to bind to all interfaces on port 3000 + /// - Use [::1]:3000 to bind to the loopback interface (localhost) on + /// port 3000 + pub address_ipv6: SocketAddrV6, /// Size of socket recv buffer. Use 0 for OS default. /// /// This setting can have a big impact on dropped packages. It might @@ -95,6 +97,12 @@ pub struct NetworkConfig { /// such as FreeBSD. Setting the value to zero disables resending /// functionality. pub resend_buffer_max_len: usize, + /// Set flag on IPv6 socket to only accept IPv6 traffic. + /// + /// This should typically be set to true unless your OS does not support + /// double-stack sockets (that is, sockets that receive both IPv4 and IPv6 + /// packets). + pub set_only_ipv6: bool, #[cfg(feature = "io-uring")] pub use_io_uring: bool, /// Number of ring entries (io_uring backend only) @@ -106,21 +114,24 @@ pub struct NetworkConfig { impl NetworkConfig { pub fn ipv4_active(&self) -> bool { - self.address.is_ipv4() || !self.only_ipv6 + self.use_ipv4 } pub fn ipv6_active(&self) -> bool { - self.address.is_ipv6() + self.use_ipv6 } } impl Default for NetworkConfig { fn default() -> Self { Self { - address: SocketAddr::from(([0, 0, 0, 0], 3000)), - only_ipv6: false, + use_ipv4: true, + use_ipv6: true, + address_ipv4: SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 3000), + address_ipv6: SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 3000, 0, 0), socket_recv_buffer_size: 8_000_000, poll_timeout_ms: 50, resend_buffer_max_len: 0, + set_only_ipv6: true, #[cfg(feature = "io-uring")] use_io_uring: true, #[cfg(feature = "io-uring")] diff --git a/crates/udp/src/lib.rs b/crates/udp/src/lib.rs index b1215c1f..058e6bca 100644 --- a/crates/udp/src/lib.rs +++ b/crates/udp/src/lib.rs @@ -25,14 +25,26 @@ pub const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); pub fn run(mut config: Config) -> ::anyhow::Result<()> { let mut signals = Signals::new([SIGUSR1])?; + if !(config.network.use_ipv4 || config.network.use_ipv6) { + return Result::Err(anyhow::anyhow!( + "Both use_ipv4 and use_ipv6 can not be set to false" + )); + } + if config.socket_workers == 0 { config.socket_workers = available_parallelism().map(Into::into).unwrap_or(1); }; + let num_sockets_per_worker = + if config.network.use_ipv4 { 1 } else { 0 } + if config.network.use_ipv6 { 1 } else { 0 }; + let state = State::default(); let statistics = Statistics::new(&config); let connection_validator = ConnectionValidator::new(&config)?; - let priv_dropper = PrivilegeDropper::new(config.privileges.clone(), config.socket_workers); + let priv_dropper = PrivilegeDropper::new( + config.privileges.clone(), + config.socket_workers * num_sockets_per_worker, + ); let (statistics_sender, statistics_receiver) = unbounded(); update_access_list(&config.access_list, &state.access_list)?; @@ -44,10 +56,15 @@ pub fn run(mut config: Config) -> ::anyhow::Result<()> { let state = state.clone(); let config = config.clone(); let connection_validator = connection_validator.clone(); - let priv_dropper = priv_dropper.clone(); let statistics = statistics.socket[i].clone(); let statistics_sender = statistics_sender.clone(); + let mut priv_droppers = Vec::new(); + + for _ in 0..num_sockets_per_worker { + priv_droppers.push(priv_dropper.clone()); + } + let handle = Builder::new() .name(format!("socket-{:02}", i + 1)) .spawn(move || { @@ -57,7 +74,7 @@ pub fn run(mut config: Config) -> ::anyhow::Result<()> { statistics, statistics_sender, connection_validator, - priv_dropper, + priv_droppers, ) }) .with_context(|| "spawn socket worker")?; diff --git a/crates/udp/src/workers/socket/mio.rs b/crates/udp/src/workers/socket/mio.rs deleted file mode 100644 index a73a35eb..00000000 --- a/crates/udp/src/workers/socket/mio.rs +++ /dev/null @@ -1,333 +0,0 @@ -use std::io::{Cursor, ErrorKind}; -use std::sync::atomic::Ordering; -use std::time::Duration; - -use anyhow::Context; -use aquatic_common::access_list::AccessListCache; -use crossbeam_channel::Sender; -use mio::net::UdpSocket; -use mio::{Events, Interest, Poll, Token}; - -use aquatic_common::{ - access_list::create_access_list_cache, privileges::PrivilegeDropper, CanonicalSocketAddr, - ValidUntil, -}; -use aquatic_udp_protocol::*; -use rand::rngs::SmallRng; -use rand::SeedableRng; - -use crate::common::*; -use crate::config::Config; - -use super::validator::ConnectionValidator; -use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; - -pub struct SocketWorker { - config: Config, - shared_state: State, - statistics: CachePaddedArc>, - statistics_sender: Sender, - access_list_cache: AccessListCache, - validator: ConnectionValidator, - socket: UdpSocket, - buffer: [u8; BUFFER_SIZE], - rng: SmallRng, - peer_valid_until: ValidUntil, -} - -impl SocketWorker { - pub fn run( - config: Config, - shared_state: State, - statistics: CachePaddedArc>, - statistics_sender: Sender, - validator: ConnectionValidator, - priv_dropper: PrivilegeDropper, - ) -> anyhow::Result<()> { - let socket = UdpSocket::from_std(create_socket(&config, priv_dropper)?); - let access_list_cache = create_access_list_cache(&shared_state.access_list); - let peer_valid_until = ValidUntil::new( - shared_state.server_start_instant, - config.cleaning.max_peer_age, - ); - - let mut worker = Self { - config, - shared_state, - statistics, - statistics_sender, - validator, - access_list_cache, - socket, - buffer: [0; BUFFER_SIZE], - rng: SmallRng::from_entropy(), - peer_valid_until, - }; - - worker.run_inner() - } - - pub fn run_inner(&mut self) -> anyhow::Result<()> { - let mut opt_resend_buffer = - (self.config.network.resend_buffer_max_len > 0).then_some(Vec::new()); - let mut events = Events::with_capacity(1); - let mut poll = Poll::new().context("create poll")?; - - poll.registry() - .register(&mut self.socket, Token(0), Interest::READABLE) - .context("register poll")?; - - let poll_timeout = Duration::from_millis(self.config.network.poll_timeout_ms); - - let mut iter_counter = 0u64; - - loop { - poll.poll(&mut events, Some(poll_timeout)).context("poll")?; - - for event in events.iter() { - if event.is_readable() { - self.read_and_handle_requests(&mut opt_resend_buffer); - } - } - - // If resend buffer is enabled, send any responses in it - if let Some(resend_buffer) = opt_resend_buffer.as_mut() { - for (addr, response) in resend_buffer.drain(..) { - self.send_response(&mut None, addr, response); - } - } - - if iter_counter % 256 == 0 { - self.validator.update_elapsed(); - - self.peer_valid_until = ValidUntil::new( - self.shared_state.server_start_instant, - self.config.cleaning.max_peer_age, - ); - } - - iter_counter = iter_counter.wrapping_add(1); - } - } - - fn read_and_handle_requests( - &mut self, - opt_resend_buffer: &mut Option>, - ) { - let max_scrape_torrents = self.config.protocol.max_scrape_torrents; - - loop { - match self.socket.recv_from(&mut self.buffer[..]) { - Ok((bytes_read, src)) => { - let src_port = src.port(); - let src = CanonicalSocketAddr::new(src); - - // Use canonical address for statistics - let opt_statistics = if self.config.statistics.active() { - if src.is_ipv4() { - let statistics = &self.statistics.ipv4; - - statistics - .bytes_received - .fetch_add(bytes_read + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); - - Some(statistics) - } else { - let statistics = &self.statistics.ipv6; - - statistics - .bytes_received - .fetch_add(bytes_read + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); - - Some(statistics) - } - } else { - None - }; - - if src_port == 0 { - ::log::debug!("Ignored request because source port is zero"); - - continue; - } - - match Request::parse_bytes(&self.buffer[..bytes_read], max_scrape_torrents) { - Ok(request) => { - if let Some(statistics) = opt_statistics { - statistics.requests.fetch_add(1, Ordering::Relaxed); - } - - if let Some(response) = self.handle_request(request, src) { - self.send_response(opt_resend_buffer, src, response); - } - } - Err(RequestParseError::Sendable { - connection_id, - transaction_id, - err, - }) if self.validator.connection_id_valid(src, connection_id) => { - let response = ErrorResponse { - transaction_id, - message: err.into(), - }; - - self.send_response(opt_resend_buffer, src, Response::Error(response)); - - ::log::debug!("request parse error (sent error response): {:?}", err); - } - Err(err) => { - ::log::debug!( - "request parse error (didn't send error response): {:?}", - err - ); - } - }; - } - Err(err) if err.kind() == ErrorKind::WouldBlock => { - break; - } - Err(err) => { - ::log::warn!("recv_from error: {:#}", err); - } - } - } - } - - fn handle_request(&mut self, request: Request, src: CanonicalSocketAddr) -> Option { - let access_list_mode = self.config.access_list.mode; - - match request { - Request::Connect(request) => { - return Some(Response::Connect(ConnectResponse { - connection_id: self.validator.create_connection_id(src), - transaction_id: request.transaction_id, - })); - } - Request::Announce(request) => { - if self - .validator - .connection_id_valid(src, request.connection_id) - { - if self - .access_list_cache - .load() - .allows(access_list_mode, &request.info_hash.0) - { - let response = self.shared_state.torrent_maps.announce( - &self.config, - &self.statistics_sender, - &mut self.rng, - &request, - src, - self.peer_valid_until, - ); - - return Some(response); - } else { - return Some(Response::Error(ErrorResponse { - transaction_id: request.transaction_id, - message: "Info hash not allowed".into(), - })); - } - } - } - Request::Scrape(request) => { - if self - .validator - .connection_id_valid(src, request.connection_id) - { - return Some(Response::Scrape( - self.shared_state.torrent_maps.scrape(request, src), - )); - } - } - } - - None - } - - fn send_response( - &mut self, - opt_resend_buffer: &mut Option>, - canonical_addr: CanonicalSocketAddr, - response: Response, - ) { - let mut buffer = Cursor::new(&mut self.buffer[..]); - - if let Err(err) = response.write_bytes(&mut buffer) { - ::log::error!("failed writing response to buffer: {:#}", err); - - return; - } - - let bytes_written = buffer.position() as usize; - - let addr = if self.config.network.address.is_ipv4() { - canonical_addr - .get_ipv4() - .expect("found peer ipv6 address while running bound to ipv4 address") - } else { - canonical_addr.get_ipv6_mapped() - }; - - match self - .socket - .send_to(&buffer.into_inner()[..bytes_written], addr) - { - Ok(bytes_sent) if self.config.statistics.active() => { - let stats = if canonical_addr.is_ipv4() { - let stats = &self.statistics.ipv4; - - stats - .bytes_sent - .fetch_add(bytes_sent + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); - - stats - } else { - let stats = &self.statistics.ipv6; - - stats - .bytes_sent - .fetch_add(bytes_sent + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); - - stats - }; - - match response { - Response::Connect(_) => { - stats.responses_connect.fetch_add(1, Ordering::Relaxed); - } - Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { - stats.responses_announce.fetch_add(1, Ordering::Relaxed); - } - Response::Scrape(_) => { - stats.responses_scrape.fetch_add(1, Ordering::Relaxed); - } - Response::Error(_) => { - stats.responses_error.fetch_add(1, Ordering::Relaxed); - } - } - } - Ok(_) => (), - Err(err) => match opt_resend_buffer.as_mut() { - Some(resend_buffer) - if (err.raw_os_error() == Some(libc::ENOBUFS)) - || (err.kind() == ErrorKind::WouldBlock) => - { - if resend_buffer.len() < self.config.network.resend_buffer_max_len { - ::log::debug!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err); - - resend_buffer.push((canonical_addr, response)); - } else { - ::log::warn!("Response resend buffer full, dropping response"); - } - } - _ => { - ::log::warn!("Sending response to {} failed: {:#}", addr, err); - } - }, - } - - ::log::debug!("send response fn finished"); - } -} diff --git a/crates/udp/src/workers/socket/mio/mod.rs b/crates/udp/src/workers/socket/mio/mod.rs new file mode 100644 index 00000000..3b0ebc43 --- /dev/null +++ b/crates/udp/src/workers/socket/mio/mod.rs @@ -0,0 +1,194 @@ +mod socket; + +use std::time::Duration; + +use anyhow::Context; +use aquatic_common::access_list::AccessListCache; +use crossbeam_channel::Sender; +use mio::{Events, Interest, Poll, Token}; + +use aquatic_common::{ + access_list::create_access_list_cache, privileges::PrivilegeDropper, CanonicalSocketAddr, + ValidUntil, +}; +use aquatic_udp_protocol::*; +use rand::rngs::SmallRng; +use rand::SeedableRng; + +use crate::common::*; +use crate::config::Config; + +use socket::Socket; + +use super::validator::ConnectionValidator; +use super::{EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; + +const TOKEN_V4: Token = Token(0); +const TOKEN_V6: Token = Token(1); + +pub fn run( + config: Config, + shared_state: State, + statistics: CachePaddedArc>, + statistics_sender: Sender, + validator: ConnectionValidator, + mut priv_droppers: Vec, +) -> anyhow::Result<()> { + let mut opt_socket_ipv4 = if config.network.use_ipv4 { + let priv_dropper = priv_droppers.pop().expect("not enough privilege droppers"); + + Some(Socket::::create(&config, priv_dropper)?) + } else { + None + }; + let mut opt_socket_ipv6 = if config.network.use_ipv6 { + let priv_dropper = priv_droppers.pop().expect("not enough privilege droppers"); + + Some(Socket::::create(&config, priv_dropper)?) + } else { + None + }; + + let access_list_cache = create_access_list_cache(&shared_state.access_list); + let peer_valid_until = ValidUntil::new( + shared_state.server_start_instant, + config.cleaning.max_peer_age, + ); + + let mut shared = WorkerSharedData { + config, + shared_state, + statistics, + statistics_sender, + validator, + access_list_cache, + buffer: [0; BUFFER_SIZE], + rng: SmallRng::from_entropy(), + peer_valid_until, + }; + + let mut events = Events::with_capacity(2); + let mut poll = Poll::new().context("create poll")?; + + if let Some(socket) = opt_socket_ipv4.as_mut() { + poll.registry() + .register(&mut socket.socket, TOKEN_V4, Interest::READABLE) + .context("register poll")?; + } + if let Some(socket) = opt_socket_ipv6.as_mut() { + poll.registry() + .register(&mut socket.socket, TOKEN_V6, Interest::READABLE) + .context("register poll")?; + } + + let poll_timeout = Duration::from_millis(shared.config.network.poll_timeout_ms); + + let mut iter_counter = 0u64; + + loop { + poll.poll(&mut events, Some(poll_timeout)).context("poll")?; + + for event in events.iter() { + if event.is_readable() { + match event.token() { + TOKEN_V4 => { + if let Some(socket) = opt_socket_ipv4.as_mut() { + socket.read_and_handle_requests(&mut shared); + } + } + TOKEN_V6 => { + if let Some(socket) = opt_socket_ipv6.as_mut() { + socket.read_and_handle_requests(&mut shared); + } + } + _ => (), + } + } + } + + if let Some(socket) = opt_socket_ipv4.as_mut() { + socket.resend_failed(&mut shared); + } + if let Some(socket) = opt_socket_ipv6.as_mut() { + socket.resend_failed(&mut shared); + } + + if iter_counter % 256 == 0 { + shared.validator.update_elapsed(); + + shared.peer_valid_until = ValidUntil::new( + shared.shared_state.server_start_instant, + shared.config.cleaning.max_peer_age, + ); + } + + iter_counter = iter_counter.wrapping_add(1); + } +} + +pub struct WorkerSharedData { + config: Config, + shared_state: State, + statistics: CachePaddedArc>, + statistics_sender: Sender, + access_list_cache: AccessListCache, + validator: ConnectionValidator, + buffer: [u8; BUFFER_SIZE], + rng: SmallRng, + peer_valid_until: ValidUntil, +} + +impl WorkerSharedData { + fn handle_request(&mut self, request: Request, src: CanonicalSocketAddr) -> Option { + let access_list_mode = self.config.access_list.mode; + + match request { + Request::Connect(request) => { + return Some(Response::Connect(ConnectResponse { + connection_id: self.validator.create_connection_id(src), + transaction_id: request.transaction_id, + })); + } + Request::Announce(request) => { + if self + .validator + .connection_id_valid(src, request.connection_id) + { + if self + .access_list_cache + .load() + .allows(access_list_mode, &request.info_hash.0) + { + let response = self.shared_state.torrent_maps.announce( + &self.config, + &self.statistics_sender, + &mut self.rng, + &request, + src, + self.peer_valid_until, + ); + + return Some(response); + } else { + return Some(Response::Error(ErrorResponse { + transaction_id: request.transaction_id, + message: "Info hash not allowed".into(), + })); + } + } + } + Request::Scrape(request) => { + if self + .validator + .connection_id_valid(src, request.connection_id) + { + return Some(Response::Scrape( + self.shared_state.torrent_maps.scrape(request, src), + )); + } + } + } + + None + } +} diff --git a/crates/udp/src/workers/socket/mio/socket.rs b/crates/udp/src/workers/socket/mio/socket.rs new file mode 100644 index 00000000..03227ddc --- /dev/null +++ b/crates/udp/src/workers/socket/mio/socket.rs @@ -0,0 +1,322 @@ +use std::io::{Cursor, ErrorKind}; +use std::marker::PhantomData; +use std::sync::atomic::Ordering; + +use anyhow::Context; +use mio::net::UdpSocket; +use socket2::{Domain, Protocol, Type}; + +use aquatic_common::{privileges::PrivilegeDropper, CanonicalSocketAddr}; +use aquatic_udp_protocol::*; + +use crate::config::Config; + +use super::{WorkerSharedData, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; + +pub trait IpVersion { + fn is_v4() -> bool; +} + +#[derive(Clone, Copy, Debug)] +pub struct Ipv4; + +impl IpVersion for Ipv4 { + fn is_v4() -> bool { + true + } +} + +#[derive(Clone, Copy, Debug)] +pub struct Ipv6; + +impl IpVersion for Ipv6 { + fn is_v4() -> bool { + false + } +} + +pub struct Socket { + pub socket: UdpSocket, + opt_resend_buffer: Option>, + phantom_data: PhantomData, +} + +impl Socket { + pub fn create(config: &Config, priv_dropper: PrivilegeDropper) -> anyhow::Result { + let socket = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + + socket + .set_reuse_port(true) + .with_context(|| "socket: set reuse port")?; + socket + .set_nonblocking(true) + .with_context(|| "socket: set nonblocking")?; + + let recv_buffer_size = config.network.socket_recv_buffer_size; + + if recv_buffer_size != 0 { + if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { + ::log::error!( + "socket: failed setting recv buffer to {}: {:?}", + recv_buffer_size, + err + ); + } + } + + socket + .bind(&config.network.address_ipv4.into()) + .with_context(|| format!("socket: bind to {}", config.network.address_ipv4))?; + + priv_dropper.after_socket_creation()?; + + let mut s = Self { + socket: UdpSocket::from_std(::std::net::UdpSocket::from(socket)), + opt_resend_buffer: None, + phantom_data: Default::default(), + }; + + if config.network.resend_buffer_max_len > 0 { + s.opt_resend_buffer = Some(Vec::new()); + } + + Ok(s) + } +} + +impl Socket { + pub fn create(config: &Config, priv_dropper: PrivilegeDropper) -> anyhow::Result { + let socket = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + + if config.network.set_only_ipv6 { + socket + .set_only_v6(true) + .with_context(|| "socket: set only ipv6")?; + } + socket + .set_reuse_port(true) + .with_context(|| "socket: set reuse port")?; + socket + .set_nonblocking(true) + .with_context(|| "socket: set nonblocking")?; + + let recv_buffer_size = config.network.socket_recv_buffer_size; + + if recv_buffer_size != 0 { + if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { + ::log::error!( + "socket: failed setting recv buffer to {}: {:?}", + recv_buffer_size, + err + ); + } + } + + socket + .bind(&config.network.address_ipv6.into()) + .with_context(|| format!("socket: bind to {}", config.network.address_ipv6))?; + + priv_dropper.after_socket_creation()?; + + let mut s = Self { + socket: UdpSocket::from_std(::std::net::UdpSocket::from(socket)), + opt_resend_buffer: None, + phantom_data: Default::default(), + }; + + if config.network.resend_buffer_max_len > 0 { + s.opt_resend_buffer = Some(Vec::new()); + } + + Ok(s) + } +} + +impl Socket { + pub fn read_and_handle_requests(&mut self, shared: &mut WorkerSharedData) { + let max_scrape_torrents = shared.config.protocol.max_scrape_torrents; + + loop { + match self.socket.recv_from(&mut shared.buffer[..]) { + Ok((bytes_read, src)) => { + let src_port = src.port(); + let src = CanonicalSocketAddr::new(src); + + // Use canonical address for statistics + let opt_statistics = if shared.config.statistics.active() { + if src.is_ipv4() { + let statistics = &shared.statistics.ipv4; + + statistics + .bytes_received + .fetch_add(bytes_read + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); + + Some(statistics) + } else { + let statistics = &shared.statistics.ipv6; + + statistics + .bytes_received + .fetch_add(bytes_read + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); + + Some(statistics) + } + } else { + None + }; + + if src_port == 0 { + ::log::debug!("Ignored request because source port is zero"); + + continue; + } + + match Request::parse_bytes(&shared.buffer[..bytes_read], max_scrape_torrents) { + Ok(request) => { + if let Some(statistics) = opt_statistics { + statistics.requests.fetch_add(1, Ordering::Relaxed); + } + + if let Some(response) = shared.handle_request(request, src) { + self.send_response(shared, src, response, false); + } + } + Err(RequestParseError::Sendable { + connection_id, + transaction_id, + err, + }) if shared.validator.connection_id_valid(src, connection_id) => { + let response = ErrorResponse { + transaction_id, + message: err.into(), + }; + + self.send_response(shared, src, Response::Error(response), false); + + ::log::debug!("request parse error (sent error response): {:?}", err); + } + Err(err) => { + ::log::debug!( + "request parse error (didn't send error response): {:?}", + err + ); + } + }; + } + Err(err) if err.kind() == ErrorKind::WouldBlock => { + break; + } + Err(err) => { + ::log::warn!("recv_from error: {:#}", err); + } + } + } + } + pub fn send_response( + &mut self, + shared: &mut WorkerSharedData, + canonical_addr: CanonicalSocketAddr, + response: Response, + disable_resend_buffer: bool, + ) { + let mut buffer = Cursor::new(&mut shared.buffer[..]); + + if let Err(err) = response.write_bytes(&mut buffer) { + ::log::error!("failed writing response to buffer: {:#}", err); + + return; + } + + let bytes_written = buffer.position() as usize; + + let addr = if V::is_v4() { + canonical_addr + .get_ipv4() + .expect("found peer ipv6 address while running bound to ipv4 address") + } else { + canonical_addr.get_ipv6_mapped() + }; + + match self + .socket + .send_to(&buffer.into_inner()[..bytes_written], addr) + { + Ok(bytes_sent) if shared.config.statistics.active() => { + let stats = if canonical_addr.is_ipv4() { + let stats = &shared.statistics.ipv4; + + stats + .bytes_sent + .fetch_add(bytes_sent + EXTRA_PACKET_SIZE_IPV4, Ordering::Relaxed); + + stats + } else { + let stats = &shared.statistics.ipv6; + + stats + .bytes_sent + .fetch_add(bytes_sent + EXTRA_PACKET_SIZE_IPV6, Ordering::Relaxed); + + stats + }; + + match response { + Response::Connect(_) => { + stats.responses_connect.fetch_add(1, Ordering::Relaxed); + } + Response::AnnounceIpv4(_) | Response::AnnounceIpv6(_) => { + stats.responses_announce.fetch_add(1, Ordering::Relaxed); + } + Response::Scrape(_) => { + stats.responses_scrape.fetch_add(1, Ordering::Relaxed); + } + Response::Error(_) => { + stats.responses_error.fetch_add(1, Ordering::Relaxed); + } + } + } + Ok(_) => (), + Err(err) => match self.opt_resend_buffer.as_mut() { + Some(resend_buffer) + if !disable_resend_buffer && (err.raw_os_error() == Some(libc::ENOBUFS)) + || (err.kind() == ErrorKind::WouldBlock) => + { + if resend_buffer.len() < shared.config.network.resend_buffer_max_len { + ::log::debug!("Adding response to resend queue, since sending it to {} failed with: {:#}", addr, err); + + resend_buffer.push((canonical_addr, response)); + } else { + ::log::warn!("Response resend buffer full, dropping response"); + } + } + _ => { + ::log::warn!("Sending response to {} failed: {:#}", addr, err); + } + }, + } + + ::log::debug!("send response fn finished"); + } + + /// If resend buffer is enabled, send any responses in it + pub fn resend_failed(&mut self, shared: &mut WorkerSharedData) { + if self.opt_resend_buffer.is_some() { + let mut tmp_resend_buffer = Vec::new(); + + // Do memory swap shenanigans to get around false positive in + // borrow checker regarding double mut borrowing of self + + if let Some(resend_buffer) = self.opt_resend_buffer.as_mut() { + ::std::mem::swap(resend_buffer, &mut tmp_resend_buffer); + } + + for (addr, response) in tmp_resend_buffer.drain(..) { + self.send_response(shared, addr, response, true); + } + + if let Some(resend_buffer) = self.opt_resend_buffer.as_mut() { + ::std::mem::swap(resend_buffer, &mut tmp_resend_buffer); + } + } + } +} diff --git a/crates/udp/src/workers/socket/mod.rs b/crates/udp/src/workers/socket/mod.rs index ef1adeac..282530ac 100644 --- a/crates/udp/src/workers/socket/mod.rs +++ b/crates/udp/src/workers/socket/mod.rs @@ -3,10 +3,8 @@ mod mio; mod uring; mod validator; -use anyhow::Context; use aquatic_common::privileges::PrivilegeDropper; use crossbeam_channel::Sender; -use socket2::{Domain, Protocol, Socket, Type}; use crate::{ common::{ @@ -44,10 +42,12 @@ pub fn run_socket_worker( statistics: CachePaddedArc>, statistics_sender: Sender, validator: ConnectionValidator, - priv_dropper: PrivilegeDropper, + priv_droppers: Vec, ) -> anyhow::Result<()> { #[cfg(all(target_os = "linux", feature = "io-uring"))] if config.network.use_io_uring { + use anyhow::Context; + self::uring::supported_on_current_kernel().context("check for io_uring compatibility")?; return self::uring::SocketWorker::run( @@ -56,61 +56,16 @@ pub fn run_socket_worker( statistics, statistics_sender, validator, - priv_dropper, + priv_droppers, ); } - self::mio::SocketWorker::run( + self::mio::run( config, shared_state, statistics, statistics_sender, validator, - priv_dropper, + priv_droppers, ) } - -fn create_socket( - config: &Config, - priv_dropper: PrivilegeDropper, -) -> anyhow::Result<::std::net::UdpSocket> { - let socket = if config.network.address.is_ipv4() { - Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))? - } else { - Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))? - }; - - if config.network.only_ipv6 { - socket - .set_only_v6(true) - .with_context(|| "socket: set only ipv6")?; - } - - socket - .set_reuse_port(true) - .with_context(|| "socket: set reuse port")?; - - socket - .set_nonblocking(true) - .with_context(|| "socket: set nonblocking")?; - - let recv_buffer_size = config.network.socket_recv_buffer_size; - - if recv_buffer_size != 0 { - if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { - ::log::error!( - "socket: failed setting recv buffer to {}: {:?}", - recv_buffer_size, - err - ); - } - } - - socket - .bind(&config.network.address.into()) - .with_context(|| format!("socket: bind to {}", config.network.address))?; - - priv_dropper.after_socket_creation()?; - - Ok(socket.into()) -} diff --git a/crates/udp/src/workers/socket/uring/mod.rs b/crates/udp/src/workers/socket/uring/mod.rs index ac572d4e..a115f47c 100644 --- a/crates/udp/src/workers/socket/uring/mod.rs +++ b/crates/udp/src/workers/socket/uring/mod.rs @@ -4,6 +4,7 @@ mod send_buffers; use std::cell::RefCell; use std::collections::VecDeque; +use std::net::SocketAddr; use std::net::UdpSocket; use std::ops::DerefMut; use std::os::fd::AsRawFd; @@ -15,6 +16,8 @@ use crossbeam_channel::Sender; use io_uring::opcode::Timeout; use io_uring::types::{Fixed, Timespec}; use io_uring::{IoUring, Probe}; +use recv_helper::RecvHelper; +use socket2::{Domain, Protocol, Socket, Type}; use aquatic_common::{ access_list::create_access_list_cache, privileges::PrivilegeDropper, CanonicalSocketAddr, @@ -28,11 +31,11 @@ use crate::common::*; use crate::config::Config; use self::buf_ring::BufRing; -use self::recv_helper::RecvHelper; +use self::recv_helper::{RecvHelperV4, RecvHelperV6}; use self::send_buffers::{ResponseType, SendBuffers}; use super::validator::ConnectionValidator; -use super::{create_socket, EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; +use super::{EXTRA_PACKET_SIZE_IPV4, EXTRA_PACKET_SIZE_IPV6}; /// Size of each request buffer /// @@ -48,10 +51,12 @@ const REQUEST_BUF_LEN: usize = 512; /// - scrape response for 170 info hashes const RESPONSE_BUF_LEN: usize = 2048; -const USER_DATA_RECV: u64 = u64::MAX; -const USER_DATA_PULSE_TIMEOUT: u64 = u64::MAX - 1; +const USER_DATA_RECV_V4: u64 = u64::MAX; +const USER_DATA_RECV_V6: u64 = u64::MAX - 1; +const USER_DATA_PULSE_TIMEOUT: u64 = u64::MAX - 2; -const SOCKET_IDENTIFIER: Fixed = Fixed(0); +const SOCKET_IDENTIFIER_V4: Fixed = Fixed(0); +const SOCKET_IDENTIFIER_V6: Fixed = Fixed(1); thread_local! { /// Store IoUring instance here so that it can be accessed in BufRing::drop @@ -81,13 +86,17 @@ pub struct SocketWorker { access_list_cache: AccessListCache, validator: ConnectionValidator, #[allow(dead_code)] - socket: UdpSocket, + opt_socket_ipv4: Option, + #[allow(dead_code)] + opt_socket_ipv6: Option, buf_ring: BufRing, send_buffers: SendBuffers, - recv_helper: RecvHelper, + recv_helper_v4: RecvHelperV4, + recv_helper_v6: RecvHelperV6, local_responses: VecDeque<(CanonicalSocketAddr, Response)>, resubmittable_sqe_buf: Vec, - recv_sqe: io_uring::squeue::Entry, + recv_sqe_ipv4: io_uring::squeue::Entry, + recv_sqe_ipv6: io_uring::squeue::Entry, pulse_timeout_sqe: io_uring::squeue::Entry, peer_valid_until: ValidUntil, rng: SmallRng, @@ -100,17 +109,38 @@ impl SocketWorker { statistics: CachePaddedArc>, statistics_sender: Sender, validator: ConnectionValidator, - priv_dropper: PrivilegeDropper, + mut priv_droppers: Vec, ) -> anyhow::Result<()> { let ring_entries = config.network.ring_size.next_power_of_two(); // Try to fill up the ring with send requests let send_buffer_entries = ring_entries; - let socket = create_socket(&config, priv_dropper).expect("create socket"); + let opt_socket_ipv4 = if config.network.use_ipv4 { + let priv_dropper = priv_droppers.pop().expect("not enough priv droppers"); + + Some( + create_socket(&config, priv_dropper, config.network.address_ipv4.into()) + .context("create ipv4 socket")?, + ) + } else { + None + }; + let opt_socket_ipv6 = if config.network.use_ipv6 { + let priv_dropper = priv_droppers.pop().expect("not enough priv droppers"); + + Some( + create_socket(&config, priv_dropper, config.network.address_ipv6.into()) + .context("create ipv6 socket")?, + ) + } else { + None + }; + let access_list_cache = create_access_list_cache(&shared_state.access_list); - let send_buffers = SendBuffers::new(&config, send_buffer_entries as usize); - let recv_helper = RecvHelper::new(&config); + let send_buffers = SendBuffers::new(send_buffer_entries as usize); + let recv_helper_v4 = RecvHelperV4::new(&config); + let recv_helper_v6 = RecvHelperV6::new(&config); let ring = IoUring::builder() .setup_coop_taskrun() @@ -120,7 +150,16 @@ impl SocketWorker { .unwrap(); ring.submitter() - .register_files(&[socket.as_raw_fd()]) + .register_files(&[ + opt_socket_ipv4 + .as_ref() + .map(|s| s.as_raw_fd()) + .unwrap_or(-1), + opt_socket_ipv6 + .as_ref() + .map(|s| s.as_raw_fd()) + .unwrap_or(-1), + ]) .unwrap(); // Store ring in thread local storage before creating BufRing @@ -132,8 +171,6 @@ impl SocketWorker { .build() .unwrap(); - let recv_sqe = recv_helper.create_entry(buf_ring.bgid()); - // This timeout enables regular updates of ConnectionValidator and // peer_valid_until let pulse_timeout_sqe = { @@ -144,7 +181,17 @@ impl SocketWorker { .user_data(USER_DATA_PULSE_TIMEOUT) }; - let resubmittable_sqe_buf = vec![recv_sqe.clone(), pulse_timeout_sqe.clone()]; + let mut resubmittable_sqe_buf = vec![pulse_timeout_sqe.clone()]; + + let recv_sqe_ipv4 = recv_helper_v4.create_entry(buf_ring.bgid()); + let recv_sqe_ipv6 = recv_helper_v6.create_entry(buf_ring.bgid()); + + if opt_socket_ipv4.is_some() { + resubmittable_sqe_buf.push(recv_sqe_ipv4.clone()); + } + if opt_socket_ipv6.is_some() { + resubmittable_sqe_buf.push(recv_sqe_ipv6.clone()); + } let peer_valid_until = ValidUntil::new( shared_state.server_start_instant, @@ -158,14 +205,17 @@ impl SocketWorker { statistics_sender, validator, access_list_cache, + opt_socket_ipv4, + opt_socket_ipv6, send_buffers, - recv_helper, + recv_helper_v4, + recv_helper_v6, local_responses: Default::default(), buf_ring, - recv_sqe, + recv_sqe_ipv4, + recv_sqe_ipv6, pulse_timeout_sqe, resubmittable_sqe_buf, - socket, peer_valid_until, rng: SmallRng::from_entropy(), }; @@ -192,7 +242,24 @@ impl SocketWorker { // Enqueue local responses for _ in 0..sq_space { if let Some((addr, response)) = self.local_responses.pop_front() { - match self.send_buffers.prepare_entry(response, addr) { + let send_to_ipv4_socket = if addr.is_ipv4() { + if self.opt_socket_ipv4.is_some() { + true + } else if self.opt_socket_ipv6.is_some() { + false + } else { + panic!("No socket open") + } + } else if self.opt_socket_ipv6.is_some() { + false + } else { + panic!("IPv6 response with no IPv6 socket") + }; + + match self + .send_buffers + .prepare_entry(send_to_ipv4_socket, response, addr) + { Ok(entry) => { unsafe { ring.submission().push(&entry).unwrap() }; @@ -229,13 +296,22 @@ impl SocketWorker { fn handle_cqe(&mut self, cqe: io_uring::cqueue::Entry) { match cqe.user_data() { - USER_DATA_RECV => { - if let Some((addr, response)) = self.handle_recv_cqe(&cqe) { + USER_DATA_RECV_V4 => { + if let Some((addr, response)) = self.handle_recv_cqe(&cqe, true) { + self.local_responses.push_back((addr, response)); + } + + if !io_uring::cqueue::more(cqe.flags()) { + self.resubmittable_sqe_buf.push(self.recv_sqe_ipv4.clone()); + } + } + USER_DATA_RECV_V6 => { + if let Some((addr, response)) = self.handle_recv_cqe(&cqe, false) { self.local_responses.push_back((addr, response)); } if !io_uring::cqueue::more(cqe.flags()) { - self.resubmittable_sqe_buf.push(self.recv_sqe.clone()); + self.resubmittable_sqe_buf.push(self.recv_sqe_ipv6.clone()); } } USER_DATA_PULSE_TIMEOUT => { @@ -296,6 +372,7 @@ impl SocketWorker { fn handle_recv_cqe( &mut self, cqe: &io_uring::cqueue::Entry, + received_on_ipv4_socket: bool, ) -> Option<(CanonicalSocketAddr, Response)> { let result = cqe.result(); @@ -328,7 +405,13 @@ impl SocketWorker { } }; - match self.recv_helper.parse(buffer.as_slice()) { + let recv_helper = if received_on_ipv4_socket { + &self.recv_helper_v4 as &dyn RecvHelper + } else { + &self.recv_helper_v6 as &dyn RecvHelper + }; + + match recv_helper.parse(buffer.as_slice()) { Ok((request, addr)) => { if self.config.statistics.active() { let (statistics, extra_bytes) = if addr.is_ipv4() { @@ -459,6 +542,54 @@ impl SocketWorker { } } +fn create_socket( + config: &Config, + priv_dropper: PrivilegeDropper, + address: SocketAddr, +) -> anyhow::Result<::std::net::UdpSocket> { + let socket = if address.is_ipv4() { + Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))? + } else { + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + + if config.network.set_only_ipv6 { + socket + .set_only_v6(true) + .with_context(|| "socket: set only ipv6")?; + } + + socket + }; + + socket + .set_reuse_port(true) + .with_context(|| "socket: set reuse port")?; + + socket + .set_nonblocking(true) + .with_context(|| "socket: set nonblocking")?; + + let recv_buffer_size = config.network.socket_recv_buffer_size; + + if recv_buffer_size != 0 { + if let Err(err) = socket.set_recv_buffer_size(recv_buffer_size) { + ::log::error!( + "socket: failed setting recv buffer to {}: {:?}", + recv_buffer_size, + err + ); + } + } + + socket + .bind(&address.into()) + .with_context(|| format!("socket: bind to {}", address))?; + + priv_dropper.after_socket_creation()?; + + Ok(socket.into()) +} + pub fn supported_on_current_kernel() -> anyhow::Result<()> { let opcodes = [ // We can't probe for RecvMsgMulti, so we probe for SendZc, which was diff --git a/crates/udp/src/workers/socket/uring/recv_helper.rs b/crates/udp/src/workers/socket/uring/recv_helper.rs index 7e7ca64d..2248a5bc 100644 --- a/crates/udp/src/workers/socket/uring/recv_helper.rs +++ b/crates/udp/src/workers/socket/uring/recv_helper.rs @@ -9,7 +9,7 @@ use io_uring::{opcode::RecvMsgMulti, types::RecvMsgOut}; use crate::config::Config; -use super::{SOCKET_IDENTIFIER, USER_DATA_RECV}; +use super::{SOCKET_IDENTIFIER_V4, SOCKET_IDENTIFIER_V6, USER_DATA_RECV_V4, USER_DATA_RECV_V6}; #[allow(clippy::enum_variant_names)] pub enum Error { @@ -19,18 +19,19 @@ pub enum Error { InvalidSocketAddress, } -pub struct RecvHelper { - socket_is_ipv4: bool, +pub trait RecvHelper { + fn parse(&self, buffer: &[u8]) -> Result<(Request, CanonicalSocketAddr), Error>; +} + +// For IPv4 sockets +pub struct RecvHelperV4 { max_scrape_torrents: u8, #[allow(dead_code)] name_v4: *const libc::sockaddr_in, msghdr_v4: *const libc::msghdr, - #[allow(dead_code)] - name_v6: *const libc::sockaddr_in6, - msghdr_v6: *const libc::msghdr, } -impl RecvHelper { +impl RecvHelperV4 { pub fn new(config: &Config) -> Self { let name_v4 = Box::into_raw(Box::new(libc::sockaddr_in { sin_family: 0, @@ -47,6 +48,62 @@ impl RecvHelper { Box::into_raw(Box::new(hdr)) }; + Self { + max_scrape_torrents: config.protocol.max_scrape_torrents, + name_v4, + msghdr_v4, + } + } + + pub fn create_entry(&self, buf_group: u16) -> io_uring::squeue::Entry { + RecvMsgMulti::new(SOCKET_IDENTIFIER_V4, self.msghdr_v4, buf_group) + .build() + .user_data(USER_DATA_RECV_V4) + } +} + +impl RecvHelper for RecvHelperV4 { + fn parse(&self, buffer: &[u8]) -> Result<(Request, CanonicalSocketAddr), Error> { + // Safe as long as kernel only reads from the pointer and doesn't + // write to it. I think this is the case. + let msghdr = unsafe { self.msghdr_v4.read() }; + + let msg = RecvMsgOut::parse(buffer, &msghdr).map_err(|_| Error::RecvMsgParseError)?; + + if msg.is_name_data_truncated() | msg.is_payload_truncated() { + return Err(Error::RecvMsgTruncated); + } + + let name_data = unsafe { *(msg.name_data().as_ptr() as *const libc::sockaddr_in) }; + + let addr = SocketAddr::V4(SocketAddrV4::new( + u32::from_be(name_data.sin_addr.s_addr).into(), + u16::from_be(name_data.sin_port), + )); + + if addr.port() == 0 { + return Err(Error::InvalidSocketAddress); + } + + let addr = CanonicalSocketAddr::new(addr); + + let request = Request::parse_bytes(msg.payload_data(), self.max_scrape_torrents) + .map_err(|err| Error::RequestParseError(err, addr))?; + + Ok((request, addr)) + } +} + +// For IPv6 sockets (can theoretically still receive IPv4 packets, though) +pub struct RecvHelperV6 { + max_scrape_torrents: u8, + #[allow(dead_code)] + name_v6: *const libc::sockaddr_in6, + msghdr_v6: *const libc::msghdr, +} + +impl RecvHelperV6 { + pub fn new(config: &Config) -> Self { let name_v6 = Box::into_raw(Box::new(libc::sockaddr_in6 { sin6_family: 0, sin6_port: 0, @@ -64,69 +121,39 @@ impl RecvHelper { }; Self { - socket_is_ipv4: config.network.address.is_ipv4(), max_scrape_torrents: config.protocol.max_scrape_torrents, - name_v4, - msghdr_v4, name_v6, msghdr_v6, } } pub fn create_entry(&self, buf_group: u16) -> io_uring::squeue::Entry { - let msghdr = if self.socket_is_ipv4 { - self.msghdr_v4 - } else { - self.msghdr_v6 - }; - - RecvMsgMulti::new(SOCKET_IDENTIFIER, msghdr, buf_group) + RecvMsgMulti::new(SOCKET_IDENTIFIER_V6, self.msghdr_v6, buf_group) .build() - .user_data(USER_DATA_RECV) + .user_data(USER_DATA_RECV_V6) } +} - pub fn parse(&self, buffer: &[u8]) -> Result<(Request, CanonicalSocketAddr), Error> { - let (msg, addr) = if self.socket_is_ipv4 { - // Safe as long as kernel only reads from the pointer and doesn't - // write to it. I think this is the case. - let msghdr = unsafe { self.msghdr_v4.read() }; - - let msg = RecvMsgOut::parse(buffer, &msghdr).map_err(|_| Error::RecvMsgParseError)?; - - if msg.is_name_data_truncated() | msg.is_payload_truncated() { - return Err(Error::RecvMsgTruncated); - } - - let name_data = unsafe { *(msg.name_data().as_ptr() as *const libc::sockaddr_in) }; - - let addr = SocketAddr::V4(SocketAddrV4::new( - u32::from_be(name_data.sin_addr.s_addr).into(), - u16::from_be(name_data.sin_port), - )); - - (msg, addr) - } else { - // Safe as long as kernel only reads from the pointer and doesn't - // write to it. I think this is the case. - let msghdr = unsafe { self.msghdr_v6.read() }; - - let msg = RecvMsgOut::parse(buffer, &msghdr).map_err(|_| Error::RecvMsgParseError)?; +impl RecvHelper for RecvHelperV6 { + fn parse(&self, buffer: &[u8]) -> Result<(Request, CanonicalSocketAddr), Error> { + // Safe as long as kernel only reads from the pointer and doesn't + // write to it. I think this is the case. + let msghdr = unsafe { self.msghdr_v6.read() }; - if msg.is_name_data_truncated() | msg.is_payload_truncated() { - return Err(Error::RecvMsgTruncated); - } + let msg = RecvMsgOut::parse(buffer, &msghdr).map_err(|_| Error::RecvMsgParseError)?; - let name_data = unsafe { *(msg.name_data().as_ptr() as *const libc::sockaddr_in6) }; + if msg.is_name_data_truncated() | msg.is_payload_truncated() { + return Err(Error::RecvMsgTruncated); + } - let addr = SocketAddr::V6(SocketAddrV6::new( - Ipv6Addr::from(name_data.sin6_addr.s6_addr), - u16::from_be(name_data.sin6_port), - u32::from_be(name_data.sin6_flowinfo), - u32::from_be(name_data.sin6_scope_id), - )); + let name_data = unsafe { *(msg.name_data().as_ptr() as *const libc::sockaddr_in6) }; - (msg, addr) - }; + let addr = SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::from(name_data.sin6_addr.s6_addr), + u16::from_be(name_data.sin6_port), + u32::from_be(name_data.sin6_flowinfo), + u32::from_be(name_data.sin6_scope_id), + )); if addr.port() == 0 { return Err(Error::InvalidSocketAddress); diff --git a/crates/udp/src/workers/socket/uring/send_buffers.rs b/crates/udp/src/workers/socket/uring/send_buffers.rs index b2b17724..9016f46f 100644 --- a/crates/udp/src/workers/socket/uring/send_buffers.rs +++ b/crates/udp/src/workers/socket/uring/send_buffers.rs @@ -10,9 +10,7 @@ use aquatic_common::CanonicalSocketAddr; use aquatic_udp_protocol::Response; use io_uring::opcode::SendMsg; -use crate::config::Config; - -use super::{RESPONSE_BUF_LEN, SOCKET_IDENTIFIER}; +use super::{RESPONSE_BUF_LEN, SOCKET_IDENTIFIER_V4, SOCKET_IDENTIFIER_V6}; pub enum Error { NoBuffers(Response), @@ -21,21 +19,17 @@ pub enum Error { pub struct SendBuffers { likely_next_free_index: usize, - socket_is_ipv4: bool, buffers: Vec<(SendBufferMetadata, *mut SendBuffer)>, } impl SendBuffers { - pub fn new(config: &Config, capacity: usize) -> Self { - let socket_is_ipv4 = config.network.address.is_ipv4(); - - let buffers = repeat_with(|| (Default::default(), SendBuffer::new(socket_is_ipv4))) + pub fn new(capacity: usize) -> Self { + let buffers = repeat_with(|| (Default::default(), SendBuffer::new())) .take(capacity) .collect::>(); Self { likely_next_free_index: 0, - socket_is_ipv4, buffers, } } @@ -61,6 +55,7 @@ impl SendBuffers { pub fn prepare_entry( &mut self, + send_to_ipv4_socket: bool, response: Response, addr: CanonicalSocketAddr, ) -> Result { @@ -75,7 +70,7 @@ impl SendBuffers { // Safe as long as `mark_buffer_as_free` was used correctly let buffer = unsafe { &mut *(*buffer) }; - match buffer.prepare_entry(response, addr, self.socket_is_ipv4, buffer_metadata) { + match buffer.prepare_entry(response, addr, send_to_ipv4_socket, buffer_metadata) { Ok(entry) => { buffer_metadata.free = false; @@ -116,7 +111,7 @@ struct SendBuffer { } impl SendBuffer { - fn new(socket_is_ipv4: bool) -> *mut Self { + fn new() -> *mut Self { let mut instance = Box::new(Self { name_v4: libc::sockaddr_in { sin_family: libc::AF_INET as u16, @@ -145,13 +140,9 @@ impl SendBuffer { instance.msghdr.msg_iov = addr_of_mut!(instance.iovec); instance.msghdr.msg_iovlen = 1; - if socket_is_ipv4 { - instance.msghdr.msg_name = addr_of_mut!(instance.name_v4) as *mut libc::c_void; - instance.msghdr.msg_namelen = core::mem::size_of::() as u32; - } else { - instance.msghdr.msg_name = addr_of_mut!(instance.name_v6) as *mut libc::c_void; - instance.msghdr.msg_namelen = core::mem::size_of::() as u32; - } + // Set IPv4 initially. Will be overridden with each prepare_entry call + instance.msghdr.msg_name = addr_of_mut!(instance.name_v4) as *mut libc::c_void; + instance.msghdr.msg_namelen = core::mem::size_of::() as u32; Box::into_raw(instance) } @@ -160,10 +151,10 @@ impl SendBuffer { &mut self, response: Response, addr: CanonicalSocketAddr, - socket_is_ipv4: bool, + send_to_ipv4_socket: bool, metadata: &mut SendBufferMetadata, ) -> Result { - if socket_is_ipv4 { + let entry_fd = if send_to_ipv4_socket { metadata.receiver_is_ipv4 = true; let addr = if let Some(SocketAddr::V4(addr)) = addr.get_ipv4() { @@ -174,6 +165,10 @@ impl SendBuffer { self.name_v4.sin_port = addr.port().to_be(); self.name_v4.sin_addr.s_addr = u32::from(*addr.ip()).to_be(); + self.msghdr.msg_name = addr_of_mut!(self.name_v4) as *mut libc::c_void; + self.msghdr.msg_namelen = core::mem::size_of::() as u32; + + SOCKET_IDENTIFIER_V4 } else { // Set receiver protocol type before calling addr.get_ipv6_mapped() metadata.receiver_is_ipv4 = addr.is_ipv4(); @@ -186,7 +181,11 @@ impl SendBuffer { self.name_v6.sin6_port = addr.port().to_be(); self.name_v6.sin6_addr.s6_addr = addr.ip().octets(); - } + self.msghdr.msg_name = addr_of_mut!(self.name_v6) as *mut libc::c_void; + self.msghdr.msg_namelen = core::mem::size_of::() as u32; + + SOCKET_IDENTIFIER_V6 + }; let mut cursor = Cursor::new(&mut self.bytes[..]); @@ -196,7 +195,7 @@ impl SendBuffer { metadata.response_type = ResponseType::from_response(&response); - Ok(SendMsg::new(SOCKET_IDENTIFIER, addr_of_mut!(self.msghdr)).build()) + Ok(SendMsg::new(entry_fd, addr_of_mut!(self.msghdr)).build()) } Err(err) => Err(Error::SerializationFailed(err)), } diff --git a/crates/udp/tests/access_list.rs b/crates/udp/tests/access_list.rs index bc3dacf9..dab3e0d3 100644 --- a/crates/udp/tests/access_list.rs +++ b/crates/udp/tests/access_list.rs @@ -60,7 +60,8 @@ fn test_access_list( let mut config = Config::default(); - config.network.address.set_port(tracker_port); + config.network.address_ipv4.set_port(tracker_port); + config.network.use_ipv6 = false; config.access_list.mode = mode; config.access_list.path = access_list_path; diff --git a/crates/udp/tests/invalid_connection_id.rs b/crates/udp/tests/invalid_connection_id.rs index 96ecf67e..c3a6cd60 100644 --- a/crates/udp/tests/invalid_connection_id.rs +++ b/crates/udp/tests/invalid_connection_id.rs @@ -22,7 +22,8 @@ fn test_invalid_connection_id() -> anyhow::Result<()> { let mut config = Config::default(); - config.network.address.set_port(TRACKER_PORT); + config.network.address_ipv4.set_port(TRACKER_PORT); + config.network.use_ipv6 = false; run_tracker(config); diff --git a/crates/udp/tests/requests_responses.rs b/crates/udp/tests/requests_responses.rs index dc55aa02..d00163e1 100644 --- a/crates/udp/tests/requests_responses.rs +++ b/crates/udp/tests/requests_responses.rs @@ -21,7 +21,8 @@ fn test_multiple_connect_announce_scrape() -> anyhow::Result<()> { let mut config = Config::default(); - config.network.address.set_port(TRACKER_PORT); + config.network.address_ipv4.set_port(TRACKER_PORT); + config.network.use_ipv6 = false; run_tracker(config); diff --git a/crates/ws/src/config.rs b/crates/ws/src/config.rs index 8bb28294..45026867 100644 --- a/crates/ws/src/config.rs +++ b/crates/ws/src/config.rs @@ -75,13 +75,13 @@ impl aquatic_common::cli::Config for Config { #[serde(default, deny_unknown_fields)] pub struct NetworkConfig { /// Bind to this address - /// + /// /// When providing an IPv4 style address, only IPv4 traffic will be /// handled. Examples: /// - "0.0.0.0:3000" binds to port 3000 on all network interfaces /// - "127.0.0.1:3000" binds to port 3000 on the loopback interface /// (localhost) - /// + /// /// When it comes to IPv6-style addresses, behaviour is more complex and /// differs between operating systems. On Linux, to accept both IPv4 and /// IPv6 traffic on any interface, use "[::]:3000". Set the "only_ipv6"