diff --git a/Cargo.toml b/Cargo.toml index fe304a1d9f03..f4271bd50a13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -97,6 +97,7 @@ strum = "0.17" strum_macros = "0.17" iprange = "0.6" ipnet = "2.2" +async-trait = "0.1" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["mswsock", "winsock2"] } diff --git a/src/context.rs b/src/context.rs index 291c193247ed..895dcb0d3c56 100644 --- a/src/context.rs +++ b/src/context.rs @@ -19,7 +19,7 @@ use trust_dns_resolver::TokioAsyncResolver; use crate::relay::dns_resolver::create_resolver; use crate::{ config::{Config, ConfigType, ServerConfig}, - relay::{dns_resolver::resolve, socks5::Address}, + relay::{dns_resolver::resolve, flow::ServerFlowStatistic, socks5::Address}, }; // Entries for server's bloom filter @@ -144,9 +144,20 @@ pub type SharedServerState = Arc; /// Shared basic configuration for the whole server pub struct Context { config: Config, + + // Shared variables for all servers server_state: SharedServerState, + + // Server's running indicator + // For killing all background jobs server_running: AtomicBool, + + // Check for duplicated IV/Nonce, for prevent replay attack + // https://github.com/shadowsocks/shadowsocks-org/issues/44 nonce_ppbloom: Mutex, + + // For Android's flow stat report + local_flow_statistic: ServerFlowStatistic, } /// Unique context thw whole server @@ -162,6 +173,7 @@ impl Context { server_state, server_running: AtomicBool::new(true), nonce_ppbloom, + local_flow_statistic: ServerFlowStatistic::new(), } } @@ -264,4 +276,9 @@ impl Context { Some(ref a) => a.check_target_bypassed(self, target).await, } } + + /// Get client flow statistics + pub fn local_flow_statistic(&self) -> &ServerFlowStatistic { + &self.local_flow_statistic + } } diff --git a/src/relay/local.rs b/src/relay/local.rs index 12eaab797ea9..d8e6d11902d1 100644 --- a/src/relay/local.rs +++ b/src/relay/local.rs @@ -2,17 +2,70 @@ use std::io::{self, ErrorKind}; +use cfg_if::cfg_if; use futures::{future::select_all, FutureExt}; use log::{debug, error, trace, warn}; use tokio::runtime::Handle; use crate::{ config::{Config, ConfigType}, - context::{Context, ServerState}, + context::{Context, ServerState, SharedContext}, plugin::{PluginMode, Plugins}, relay::{tcprelay::local::run as run_tcp, udprelay::local::run as run_udp, utils::set_nofile}, }; +cfg_if! { + if #[cfg(target_os = "android")] { + async fn flow_report_task(context: SharedContext) -> io::Result<()> { + use std::{slice, time::Duration}; + + use tokio::{io::AsyncWriteExt, net::UnixStream, time}; + + // Android's flow statistic report RPC + let path = context.config().stat_path.as_ref().expect("stat_path must be provided"); + let timeout = Duration::from_secs(1); + + while context.server_running() { + // keep it as libev's default, 0.5 seconds + time::delay_for(Duration::from_millis(500)).await; + let mut stream = match time::timeout(timeout, UnixStream::connect(path)).await { + Ok(Ok(s)) => s, + Ok(Err(err)) => { + error!("send client flow statistic error: {}", err); + continue; + } + Err(..) => { + error!("send client flow statistic error: timeout"); + continue; + } + }; + + let flow_stat = context.local_flow_statistic(); + let tx = flow_stat.tcp().tx() + flow_stat.udp().tx(); + let rx = flow_stat.tcp().rx() + flow_stat.udp().rx(); + + let buf: [u64; 2] = [tx, rx]; + let buf = unsafe { slice::from_raw_parts(buf.as_ptr() as *const _, 16) }; + + match time::timeout(timeout, stream.write_all(buf)).await { + Ok(Ok(..)) => {} + Ok(Err(err)) => { + error!("send client flow statistic error: {}", err); + } + Err(..) => { + error!("send client flow statistic error: timeout"); + } + } + } + Ok(()) + } + } else { + async fn flow_report_task(_context: SharedContext) -> io::Result<()> { + unimplemented!("only for android") + } + } +} + /// Relay server running under local environment. pub async fn run(mut config: Config, rt: Handle) -> io::Result<()> { trace!("initializing local server with {:?}", config); @@ -96,6 +149,13 @@ pub async fn run(mut config: Config, rt: Handle) -> io::Result<()> { vf.push(udp_fut.boxed()); } + if cfg!(target_os = "android") && context.config().stat_path.is_some() { + // For Android's flow statistic + + let report_fut = flow_report_task(context.clone()); + vf.push(report_fut.boxed()); + } + let (res, ..) = select_all(vf.into_iter()).await; error!("one of servers exited unexpectly, result: {:?}", res); diff --git a/src/relay/tcprelay/http_local.rs b/src/relay/tcprelay/http_local.rs index 69911b860529..091620cc1ef0 100644 --- a/src/relay/tcprelay/http_local.rs +++ b/src/relay/tcprelay/http_local.rs @@ -158,7 +158,7 @@ impl tower::Service for DirectConnector { let err = Error::new(ErrorKind::Other, "URI must be a valid Address"); Err(err) } - Some(addr) => ProxyStream::connect_direct(&*context, &addr).await, + Some(addr) => ProxyStream::connect_direct(context, &addr).await, } } .boxed(), diff --git a/src/relay/tcprelay/proxy_stream.rs b/src/relay/tcprelay/proxy_stream.rs index c055bd2eeb2e..ee8d1ba52cc4 100644 --- a/src/relay/tcprelay/proxy_stream.rs +++ b/src/relay/tcprelay/proxy_stream.rs @@ -22,20 +22,17 @@ use crate::{ use super::{connection::Connection, CryptoStream, STcpStream}; -macro_rules! forward_call { - ($self:expr, $method:ident $(, $param:expr)*) => { - match *$self { - ProxyStream::Direct(ref mut s) => Pin::new(s).$method($($param),*), - ProxyStream::Proxied(ref mut s) => Pin::new(s).$method($($param),*), - } - }; -} - /// Stream wrapper for both direct connections and proxied connections #[allow(clippy::large_enum_variant)] pub enum ProxyStream { - Direct(STcpStream), - Proxied(CryptoStream), + Direct { + stream: STcpStream, + context: SharedContext, + }, + Proxied { + stream: CryptoStream, + context: SharedContext, + }, } #[derive(Debug)] @@ -80,7 +77,7 @@ impl ProxyStream { addr: &Address, ) -> Result { if context.check_target_bypassed(addr).await { - ProxyStream::connect_direct_wrapped(&*context, addr).await + ProxyStream::connect_direct_wrapped(context, addr).await } else { ProxyStream::connect_proxied_wrapped(context, svr_cfg, addr).await } @@ -89,7 +86,7 @@ impl ProxyStream { /// Connect to remote directly (without proxy) /// /// This is used for hosts that matches ACL bypassed rules - pub async fn connect_direct(context: &Context, addr: &Address) -> io::Result { + pub async fn connect_direct(context: SharedContext, addr: &Address) -> io::Result { debug!("connect to {} directly (bypassed)", addr); // NOTE: Direct connection's timeout is controlled by the global key @@ -105,10 +102,13 @@ impl ProxyStream { } }; - Ok(ProxyStream::Direct(Connection::new(stream, timeout))) + Ok(ProxyStream::Direct { + stream: Connection::new(stream, timeout), + context, + }) } - async fn connect_direct_wrapped(context: &Context, addr: &Address) -> Result { + async fn connect_direct_wrapped(context: SharedContext, addr: &Address) -> Result { match ProxyStream::connect_direct(context, addr).await { Ok(s) => Ok(s), Err(err) => Err(ProxyStreamError::new(err, true)), @@ -130,10 +130,13 @@ impl ProxyStream { svr_cfg.external_addr() ); - let server_stream = connect_proxy_server(&*context, svr_cfg).await?; - let proxy_stream = proxy_server_handshake(context, server_stream, svr_cfg, addr).await?; + let server_stream = connect_proxy_server(&context, svr_cfg).await?; + let proxy_stream = proxy_server_handshake(context.clone(), server_stream, svr_cfg, addr).await?; - Ok(ProxyStream::Proxied(proxy_stream)) + Ok(ProxyStream::Proxied { + stream: proxy_stream, + context, + }) } async fn connect_proxied_wrapped( @@ -156,31 +159,66 @@ impl ProxyStream { /// Returns the local socket address of this stream socket pub fn local_addr(&self) -> io::Result { match *self { - ProxyStream::Direct(ref s) => s.get_ref().local_addr(), - ProxyStream::Proxied(ref s) => s.get_ref().get_ref().local_addr(), + ProxyStream::Direct { ref stream, .. } => stream.get_ref().local_addr(), + ProxyStream::Proxied { ref stream, .. } => stream.get_ref().get_ref().local_addr(), } } /// Check if the underlying connection is proxied pub fn is_proxied(&self) -> bool { match *self { - ProxyStream::Proxied(..) => true, + ProxyStream::Proxied { .. } => true, _ => false, } } + + /// Get reference to context + pub fn context(&self) -> &Context { + match *self { + ProxyStream::Direct { ref context, .. } => &context, + ProxyStream::Proxied { ref context, .. } => &context, + } + } } impl Unpin for ProxyStream {} +macro_rules! forward_call { + ($self:expr, $method:ident $(, $param:expr)*) => { + match *$self { + ProxyStream::Direct { ref mut stream, .. } => Pin::new(stream).$method($($param),*), + ProxyStream::Proxied { ref mut stream, .. } => Pin::new(stream).$method($($param),*), + } + }; +} + impl AsyncRead for ProxyStream { fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut [u8]) -> Poll> { - forward_call!(self, poll_read, cx, buf) + let p = forward_call!(self, poll_read, cx, buf); + + // Flow statistic for Android client + if cfg!(target_os = "android") && self.is_proxied() { + if let Poll::Ready(Ok(n)) = p { + self.context().local_flow_statistic().tcp().incr_tx(n as u64); + } + } + + p } } impl AsyncWrite for ProxyStream { fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll> { - forward_call!(self, poll_write, cx, buf) + let p = forward_call!(self, poll_write, cx, buf); + + // Flow statistic for Android client + if cfg!(target_os = "android") && self.is_proxied() { + if let Poll::Ready(Ok(n)) = p { + self.context().local_flow_statistic().tcp().incr_rx(n as u64); + } + } + + p } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { diff --git a/src/relay/udprelay/association.rs b/src/relay/udprelay/association.rs new file mode 100644 index 000000000000..fafeadab8f25 --- /dev/null +++ b/src/relay/udprelay/association.rs @@ -0,0 +1,429 @@ +//! UDP Association +//! +//! Working like a NAT proxy + +#![allow(dead_code)] + +use std::{ + io::{self, Cursor, Read}, + net::{IpAddr, Ipv4Addr, SocketAddr}, +}; + +use async_trait::async_trait; +use bytes::BytesMut; +use futures::future; +use log::{debug, error, warn}; +use tokio::{ + self, + net::udp::{RecvHalf, SendHalf}, + sync::{mpsc, oneshot}, +}; + +use crate::{ + config::{ServerAddr, ServerConfig}, + context::Context, + relay::{ + loadbalancing::server::{ServerData, SharedServerStatistic}, + socks5::Address, + sys::create_udp_socket_with_context, + }, +}; + +use super::{ + crypto_io::{decrypt_payload, encrypt_payload}, + MAXIMUM_UDP_PAYLOAD_SIZE, +}; + +#[async_trait] +pub trait ProxySend { + async fn send_packet(&mut self, data: Vec) -> io::Result<()>; +} + +pub struct ProxyAssociation { + tx: mpsc::Sender<(Address, Vec)>, + watchers: Vec>, +} + +impl ProxyAssociation { + pub async fn associate_proxied( + src_addr: SocketAddr, + server: SharedServerStatistic, + sender: H, + ) -> io::Result + where + S: ServerData + Send + 'static, + H: ProxySend + Send + 'static, + { + // Create a socket for receiving packets + let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); + + let remote_udp = create_udp_socket_with_context(&local_addr, server.context()).await?; + let remote_bind_addr = remote_udp.local_addr().expect("determine port bound to"); + + debug!("created UDP association {} <-> {}", src_addr, remote_bind_addr); + + // Create a channel for sending packets to remote + // FIXME: Channel size 1024? + let (tx, rx) = mpsc::channel::<(Address, Vec)>(1024); + + // Splits socket into sender and receiver + let (remote_receiver, remote_sender) = remote_udp.split(); + + // LOCAL -> REMOTE task + // All packets will be sent directly to proxy + tokio::spawn(Self::l2r_packet_proxied(src_addr, server.clone(), rx, remote_sender)); + + // REMOTE <- LOCAL task + let (remote_watcher_tx, remote_watcher_rx) = oneshot::channel::<()>(); + tokio::spawn(Self::r2l_packet( + src_addr, + server, + sender, + remote_receiver, + remote_watcher_rx, + )); + + let watchers = vec![remote_watcher_tx]; + + Ok(ProxyAssociation { tx, watchers }) + } + + pub async fn associate_bypassed( + src_addr: SocketAddr, + server: SharedServerStatistic, + sender: H, + ) -> io::Result + where + S: ServerData + Send + 'static, + H: ProxySend + Send + 'static, + { + // Create a socket for receiving packets + let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); + + let remote_udp = create_udp_socket_with_context(&local_addr, server.context()).await?; + let remote_bind_addr = remote_udp.local_addr().expect("determine port bound to"); + + debug!("created UDP association {} <-> {}", src_addr, remote_bind_addr); + + // Create a channel for sending packets to remote + // FIXME: Channel size 1024? + let (tx, rx) = mpsc::channel::<(Address, Vec)>(1024); + + // Splits socket into sender and receiver + let (remote_receiver, remote_sender) = remote_udp.split(); + + // LOCAL -> REMOTE task + // All packets will be sent directly to proxy + tokio::spawn(Self::l2r_packet_bypassed(src_addr, server.clone(), rx, remote_sender)); + + // REMOTE <- LOCAL task + let (remote_watcher_tx, remote_watcher_rx) = oneshot::channel::<()>(); + tokio::spawn(Self::r2l_packet( + src_addr, + server, + sender, + remote_receiver, + remote_watcher_rx, + )); + + let watchers = vec![remote_watcher_tx]; + + Ok(ProxyAssociation { tx, watchers }) + } + + pub async fn associate_with_acl( + src_addr: SocketAddr, + server: SharedServerStatistic, + sender: H, + ) -> io::Result + where + S: ServerData + Send + 'static, + H: ProxySend + Clone + Send + 'static, + { + // Create a socket for receiving packets + let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); + + let remote_udp = create_udp_socket_with_context(&local_addr, server.context()).await?; + let remote_bind_addr = remote_udp.local_addr().expect("determine port bound to"); + + // A socket for bypassed + let bypass_udp = create_udp_socket_with_context(&local_addr, server.context()).await?; + let bypass_bind_addr = bypass_udp.local_addr().expect("determine port bound to"); + + debug!( + "created UDP association {} <-> {}, {}", + src_addr, remote_bind_addr, bypass_bind_addr + ); + + // Create a channel for sending packets to remote + // FIXME: Channel size 1024? + let (tx, rx) = mpsc::channel::<(Address, Vec)>(1024); + + // Splits socket into sender and receiver + let (remote_receiver, remote_sender) = remote_udp.split(); + let (bypass_receiver, bypass_sender) = bypass_udp.split(); + + // LOCAL -> REMOTE task + // Packets may be sent via proxy decided by acl rules + + tokio::spawn(Self::l2r_packet_acl( + src_addr, + server.clone(), + rx, + bypass_sender, + remote_sender, + )); + + let (bypass_watcher_tx, bypass_watcher_rx) = oneshot::channel::<()>(); + tokio::spawn(Self::r2l_packet( + src_addr, + server.clone(), + sender.clone(), + bypass_receiver, + bypass_watcher_rx, + )); + + // REMOTE <- LOCAL task + let (remote_watcher_tx, remote_watcher_rx) = oneshot::channel::<()>(); + tokio::spawn(Self::r2l_packet( + src_addr, + server, + sender, + remote_receiver, + remote_watcher_rx, + )); + + let watchers = vec![bypass_watcher_tx, remote_watcher_tx]; + + Ok(ProxyAssociation { tx, watchers }) + } + + pub async fn send(&mut self, target: Address, payload: Vec) { + if let Err(..) = self.tx.send((target, payload)).await { + // SHOULDn't HAPPEN + unreachable!("UDP association local -> remote queue closed unexpectly"); + } + } + + async fn l2r_packet_acl( + src_addr: SocketAddr, + server: SharedServerStatistic, + mut rx: mpsc::Receiver<(Address, Vec)>, + mut bypass_sender: SendHalf, + mut remote_sender: SendHalf, + ) where + S: ServerData + Send + 'static, + { + let context = server.context(); + let svr_cfg = server.server_config(); + + while let Some((addr, payload)) = rx.recv().await { + // Check if addr should be bypassed + let is_bypassed = context.check_target_bypassed(&addr).await; + + let res = if is_bypassed { + Self::send_packet_bypassed(src_addr, context, &addr, &payload, &mut bypass_sender).await + } else { + Self::send_packet_proxied(src_addr, context, svr_cfg, &addr, &payload, &mut remote_sender).await + }; + + if let Err(err) = res { + error!( + "failed to send packet {} -> {}, bypassed? {}, error: {}", + src_addr, addr, is_bypassed, err + ); + } + } + + debug!("UDP association {} -> .. task is closing", src_addr); + } + + async fn l2r_packet_proxied( + src_addr: SocketAddr, + server: SharedServerStatistic, + mut rx: mpsc::Receiver<(Address, Vec)>, + mut remote_sender: SendHalf, + ) where + S: ServerData + Send + 'static, + { + let context = server.context(); + let svr_cfg = server.server_config(); + + while let Some((addr, payload)) = rx.recv().await { + let res = Self::send_packet_proxied(src_addr, context, svr_cfg, &addr, &payload, &mut remote_sender).await; + + if let Err(err) = res { + error!("UDP association send packet {} -> {}, error: {}", src_addr, addr, err); + } + } + + debug!("UDP association {} -> .. task is closing", src_addr); + } + + async fn l2r_packet_bypassed( + src_addr: SocketAddr, + server: SharedServerStatistic, + mut rx: mpsc::Receiver<(Address, Vec)>, + mut remote_sender: SendHalf, + ) where + S: ServerData + Send + 'static, + { + let context = server.context(); + + while let Some((addr, payload)) = rx.recv().await { + let res = Self::send_packet_bypassed(src_addr, context, &addr, &payload, &mut remote_sender).await; + + if let Err(err) = res { + error!("UDP association send packet {} -> {}, error: {}", src_addr, addr, err); + } + } + + debug!("UDP association {} -> .. task is closing", src_addr); + } + + async fn send_packet_proxied( + src_addr: SocketAddr, + context: &Context, + svr_cfg: &ServerConfig, + target: &Address, + payload: &[u8], + socket: &mut SendHalf, + ) -> io::Result<()> { + // CLIENT -> SERVER protocol: ADDRESS + PAYLOAD + let mut send_buf = Vec::new(); + target.write_to_buf(&mut send_buf); + send_buf.extend_from_slice(payload); + + let mut encrypt_buf = BytesMut::new(); + encrypt_payload(context, svr_cfg.method(), svr_cfg.key(), &send_buf, &mut encrypt_buf)?; + + let send_len = match svr_cfg.addr() { + ServerAddr::SocketAddr(ref remote_addr) => socket.send_to(&encrypt_buf[..], remote_addr).await?, + ServerAddr::DomainName(ref dname, port) => { + lookup_then!(context, dname, *port, |addr| { + socket.send_to(&encrypt_buf[..], &addr).await + })? + .1 + } + }; + + if encrypt_buf.len() != send_len { + warn!( + "UDP association {} -> {} via proxy {} payload truncated, expected {} bytes, but sent {} bytes", + src_addr, + target, + svr_cfg.addr(), + encrypt_buf.len(), + send_len + ); + } + + if cfg!(target_os = "android") { + context.local_flow_statistic().udp().incr_tx(send_len as u64); + } + + Ok(()) + } + + async fn send_packet_bypassed( + src_addr: SocketAddr, + context: &Context, + target: &Address, + payload: &[u8], + socket: &mut SendHalf, + ) -> io::Result<()> { + // BYPASSED, so just send it directly without any modifications + + let send_len = match *target { + Address::SocketAddress(ref saddr) => socket.send_to(payload, saddr).await?, + Address::DomainNameAddress(ref host, port) => { + lookup_then!(context, host, port, |saddr| { socket.send_to(payload, &saddr).await })?.1 + } + }; + + if payload.len() != send_len { + warn!( + "UDP association {} -> {} payload truncated, expected {} bytes, but sent {} bytes", + src_addr, + target, + payload.len(), + send_len + ); + } + + Ok(()) + } + + async fn r2l_packet( + src_addr: SocketAddr, + server: SharedServerStatistic, + mut sender: H, + mut socket: RecvHalf, + watcher_rx: oneshot::Receiver<()>, + ) -> io::Result<()> + where + S: ServerData + Send + 'static, + H: ProxySend + Send + 'static, + { + let recv_fut = async move { + let context = server.context(); + let svr_cfg = server.server_config(); + + loop { + match Self::recv_packet_proxied(context, svr_cfg, &mut socket).await { + Ok(data) => { + if let Err(err) = sender.send_packet(data).await { + error!("UDP association send {} <- .., error: {}", src_addr, err); + } + } + Err(err) => { + error!("UDP association recv {} <- .., error: {}", src_addr, err); + } + } + } + }; + + tokio::pin!(recv_fut); + + // Resolve if watcher_rx resolves + let _ = future::select(recv_fut, watcher_rx).await; + + debug!("UDP association {} <- .. task is closing", src_addr); + + Ok(()) + } + + async fn recv_packet_proxied( + context: &Context, + svr_cfg: &ServerConfig, + socket: &mut RecvHalf, + ) -> io::Result> { + // Waiting for response from server SERVER -> CLIENT + // Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded. + let mut recv_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE]; + + let (recv_n, _) = socket.recv_from(&mut recv_buf).await?; + + let decrypt_buf = match decrypt_payload(context, svr_cfg.method(), svr_cfg.key(), &recv_buf[..recv_n])? { + None => { + error!("UDP packet too short, received length {}", recv_n); + let err = io::Error::new(io::ErrorKind::InvalidData, "packet too short"); + return Err(err); + } + Some(b) => b, + }; + // SERVER -> CLIENT protocol: ADDRESS + PAYLOAD + let mut cur = Cursor::new(decrypt_buf); + // FIXME: Address is ignored. Maybe useful in the future if we uses one common UdpSocket for communicate with remote server + let _ = Address::read_from(&mut cur).await?; + + let mut payload = Vec::new(); + cur.read_to_end(&mut payload)?; + + if cfg!(target_os = "android") { + context.local_flow_statistic().udp().incr_rx(recv_n as u64); + } + + Ok(payload) + } +} diff --git a/src/relay/udprelay/mod.rs b/src/relay/udprelay/mod.rs index 06c8f73efa40..b4fa91c2edb9 100644 --- a/src/relay/udprelay/mod.rs +++ b/src/relay/udprelay/mod.rs @@ -49,7 +49,9 @@ use std::time::Duration; +mod association; pub mod client; +mod crypto_io; pub mod local; mod redir_local; pub mod server; @@ -59,8 +61,6 @@ mod tproxy_socket; mod tunnel_local; mod utils; -mod crypto_io; - /// The maximum UDP payload size (defined in the original shadowsocks Python) /// /// *I cannot find any references about why clowwindy used this value as the maximum diff --git a/src/relay/udprelay/redir_local.rs b/src/relay/udprelay/redir_local.rs index 9206fabbbfab..caa95eb9e28c 100644 --- a/src/relay/udprelay/redir_local.rs +++ b/src/relay/udprelay/redir_local.rs @@ -1,334 +1,74 @@ //! UDP relay local server -use std::{ - io::{self, Cursor, Read}, - net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::Arc, - time::Duration, -}; +use std::{io, net::SocketAddr, sync::Arc}; -use bytes::BytesMut; -use futures::{future, FutureExt}; -use log::{debug, error, info, trace}; +use async_trait::async_trait; +use log::{error, info, trace}; use lru_time_cache::{Entry, LruCache}; -use tokio::{ - self, - net::udp::{RecvHalf, SendHalf}, - sync::{mpsc, oneshot, Mutex}, - time, -}; +use tokio::{self, sync::Mutex, time}; use crate::{ - config::{ServerAddr, ServerConfig}, - context::{Context, SharedContext}, + context::SharedContext, relay::{ - loadbalancing::server::{PlainPingBalancer, ServerType, SharedPlainServerStatistic}, + loadbalancing::server::{PlainPingBalancer, ServerType}, socks5::Address, - sys::create_udp_socket, - utils::try_timeout, }, }; use super::{ - crypto_io::{decrypt_payload, encrypt_payload}, + association::{ProxyAssociation, ProxySend}, tproxy_socket::TProxyUdpSocket, DEFAULT_TIMEOUT, MAXIMUM_UDP_PAYLOAD_SIZE, }; -fn cache_key(src: &SocketAddr, dst: &SocketAddr) -> String { - format!("{}-{}", src, dst) -} - -// Drop the oneshot::Sender<()> will trigger local <- remote task to finish -struct UdpAssociationWatcher(oneshot::Sender<()>); - -// Represent a UDP association -#[derive(Clone)] -struct UdpAssociation { - // local -> remote Queue - // Drops tx, will close local -> remote task - tx: mpsc::Sender>, +type AssocMap = LruCache; +type SharedAssocMap = Arc>; - // local <- remote task life watcher - watcher: Arc, +struct ProxyHandler { + src_addr: SocketAddr, + local_udp: TProxyUdpSocket, + cache_key: String, + assoc_map: SharedAssocMap, } -impl UdpAssociation { - /// Create an association with addr - async fn associate( - server: SharedPlainServerStatistic, +impl ProxyHandler { + pub fn new( src_addr: SocketAddr, dst_addr: SocketAddr, + cache_key: String, assoc_map: SharedAssocMap, - ) -> io::Result { - // Create a socket for receiving packets - let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); - - let remote_udp = create_udp_socket(&local_addr).await?; - let remote_bind_addr = remote_udp.local_addr().expect("determine port bound to"); - - debug!("created UDP association for {} from {}", src_addr, remote_bind_addr); - - // Create a channel for sending packets to remote - // FIXME: Channel size 1024? - let (tx, mut rx) = mpsc::channel::>(1024); - - // Create a watcher for local <- remote task - let (watcher_tx, watcher_rx) = oneshot::channel::<()>(); - - let close_flag = Arc::new(UdpAssociationWatcher(watcher_tx)); - - // Splits socket into sender and receiver - let (mut remote_receiver, mut remote_sender) = remote_udp.split(); - - // Create a socket for sending packets back - let mut local_udp = TProxyUdpSocket::bind(&dst_addr)?; - - let timeout = server.config().udp_timeout.unwrap_or(DEFAULT_TIMEOUT); - - let dst_saddr = Address::from(dst_addr); - let is_bypassed = server.context().check_target_bypassed(&dst_saddr).await; - - { - // local -> remote - - let server = server.clone(); - tokio::spawn(async move { - let svr_cfg = server.server_config(); - let context = server.context(); - - while let Some(pkt) = rx.recv().await { - // pkt is already a raw packet, so just send it - debug!( - "UDP REDIR {} -> {}, payload length {} bytes", - src_addr, - dst_addr, - pkt.len() - ); - - let res = if is_bypassed { - Self::relay_l2r_bypassed(context, &mut remote_sender, &dst_saddr, &pkt, timeout).await - } else { - UdpAssociation::relay_l2r_proxied( - context, - &mut remote_sender, - &dst_saddr, - &pkt, - timeout, - svr_cfg, - ) - .await - }; - - if let Err(err) = res { - error!("failed to send packet {} -> {}, error: {}", src_addr, dst_addr, err); - - // FIXME: Ignore? Or how to deal with it? - } - } - - debug!("UDP REDIR {} -> {} finished", src_addr, dst_addr); - }); - } - - // local <- remote - tokio::spawn(async move { - let transfer_fut = async move { - let svr_cfg = server.server_config(); - let context = server.context(); - - loop { - // Read and send back to source - let res = if is_bypassed { - UdpAssociation::relay_r2l_bypassed(&src_addr, &mut remote_receiver, &mut local_udp).await - } else { - UdpAssociation::relay_r2l_proxied( - context, - &src_addr, - &mut remote_receiver, - &mut local_udp, - svr_cfg, - ) - .await - }; - - if let Err(err) = res { - error!("failed to receive packet, {} <- {}, error: {}", src_addr, dst_addr, err); - - // FIXME: Don't break, or if you can find a way to drop the UdpAssociation - // break; - } - - let cache_key = cache_key(&src_addr, &dst_addr); - { - let mut amap = assoc_map.lock().await; - - // Check or update expire time - let _ = amap.get(&cache_key); - } - } - }; - - // Resolved only if watcher_rx resolved - let _ = future::select(transfer_fut.boxed(), watcher_rx.boxed()).await; - debug!("UDP REDIR {} <- {} finished", src_addr, dst_addr); - }); - - Ok(UdpAssociation { - tx, - watcher: close_flag, - }) - } - - /// Relay packets from local to remote - async fn relay_l2r_proxied( - context: &Context, - remote_udp: &mut SendHalf, - addr: &Address, - payload: &[u8], - timeout: Duration, - svr_cfg: &ServerConfig, - ) -> io::Result<()> { - // CLIENT -> SERVER protocol: ADDRESS + PAYLOAD - let mut send_buf = Vec::new(); - addr.write_to_buf(&mut send_buf); - send_buf.extend_from_slice(&payload); - - let mut encrypt_buf = BytesMut::new(); - encrypt_payload(context, svr_cfg.method(), svr_cfg.key(), &send_buf, &mut encrypt_buf)?; - - let send_len = match svr_cfg.addr() { - ServerAddr::SocketAddr(ref remote_addr) => { - try_timeout(remote_udp.send_to(&encrypt_buf[..], remote_addr), Some(timeout)).await? - } - ServerAddr::DomainName(ref dname, port) => { - lookup_then!(context, dname, *port, |addr| { - try_timeout(remote_udp.send_to(&encrypt_buf[..], &addr), Some(timeout)).await - })? - .1 - } - }; + ) -> io::Result { + // Create a socket binds to destination addr + // This only works for systems that supports binding to non-local addresses + let local_udp = TProxyUdpSocket::bind(&dst_addr)?; - assert_eq!(encrypt_buf.len(), send_len); - - Ok(()) - } - - async fn relay_l2r_bypassed( - context: &Context, - bypass_udp: &mut SendHalf, - addr: &Address, - payload: &[u8], - timeout: Duration, - ) -> io::Result<()> { - let send_len = match *addr { - Address::SocketAddress(ref saddr) => try_timeout(bypass_udp.send_to(payload, saddr), Some(timeout)).await?, - Address::DomainNameAddress(ref host, port) => { - lookup_then!(context, host, port, |saddr| { - try_timeout(bypass_udp.send_to(payload, &saddr), Some(timeout)).await - })? - .1 - } - }; - - assert_eq!(payload.len(), send_len); - - Ok(()) - } - - /// Relay packets from remote to local (proxied) - async fn relay_r2l_proxied( - context: &Context, - src_addr: &SocketAddr, - remote_udp: &mut RecvHalf, - local_udp: &mut TProxyUdpSocket, - svr_cfg: &ServerConfig, - ) -> io::Result<()> { - // Waiting for response from server SERVER -> CLIENT - // Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded. - let mut recv_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE]; - - let (recv_n, remote_addr) = remote_udp.recv_from(&mut recv_buf).await?; - - let decrypt_buf = match decrypt_payload(context, svr_cfg.method(), svr_cfg.key(), &recv_buf[..recv_n])? { - None => { - error!("UDP packet too short, received length {}", recv_n); - let err = io::Error::new(io::ErrorKind::InvalidData, "packet too short"); - return Err(err); - } - Some(b) => b, - }; - // SERVER -> CLIENT protocol: ADDRESS + PAYLOAD - let mut cur = Cursor::new(decrypt_buf); - // FIXME: Address is ignored. Maybe useful in the future if we uses one common UdpSocket for communicate with remote server - let _ = Address::read_from(&mut cur).await?; - - let mut payload = Vec::new(); - cur.read_to_end(&mut payload)?; - - debug!( - "UDP REDIR {} <- {}, payload length {} bytes", + Ok(ProxyHandler { src_addr, - remote_addr, - payload.len() - ); - - // Send back to src_addr - if let Err(err) = local_udp.send_to(&payload, src_addr).await { - error!("failed to send packet back to {}, error: {}", src_addr, err); - - // FIXME: What to do? Ignore? - } - - Ok(()) + local_udp, + cache_key, + assoc_map, + }) } +} - /// Relay packets from remote to local (bypassed) - async fn relay_r2l_bypassed( - src_addr: &SocketAddr, - bypass_udp: &mut RecvHalf, - local_udp: &mut TProxyUdpSocket, - ) -> io::Result<()> { - // Waiting for response from server SERVER -> CLIENT - // Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded. - let mut recv_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE]; - - let (recv_n, remote_addr) = bypass_udp.recv_from(&mut recv_buf).await?; - - let payload = recv_buf[..recv_n].to_vec(); - - debug!( - "UDP REDIR {} <- {}, payload length {} bytes", - src_addr, - remote_addr, - payload.len() - ); +#[async_trait] +impl ProxySend for ProxyHandler { + async fn send_packet(&mut self, data: Vec) -> io::Result<()> { + self.local_udp.send_to(&data, &self.src_addr).await?; - // Send back to src_addr - if let Err(err) = local_udp.send_to(&payload, src_addr).await { - error!("failed to send packet back to {}, error: {}", src_addr, err); + // Update LRU + { + let mut amap = self.assoc_map.lock().await; - // FIXME: What to do? Ignore? + // Check or update expire time + let _ = amap.get(&self.cache_key); } Ok(()) } - - // Send packet to remote - // - // Return `Err` if receiver have been closed - async fn send(&mut self, pkt: Vec) { - if let Err(..) = self.tx.send(pkt).await { - // SHOULDn't HAPPEN - unreachable!("UDP association local -> remote Queue closed unexpectly"); - } - } } -type AssocMap = LruCache; -type SharedAssocMap = Arc>; - /// Starts a UDP local server pub async fn run(context: SharedContext) -> io::Result<()> { if let Err(err) = super::sys::check_support_tproxy() { @@ -387,32 +127,46 @@ pub async fn run(context: SharedContext) -> io::Result<()> { continue; } + // Check destination should be proxied or not + let target = Address::SocketAddress(dst); + let is_bypassed = context.check_target_bypassed(&target).await; + // Check or (re)create an association - let mut assoc = { + { // Locks the whole association map let mut ref_assoc_map = assoc_map.lock().await; + let cache_key = format!("{}-{}", src, dst); + // Get or create an association - let assoc = match ref_assoc_map.entry(cache_key(&src, &dst)) { + let assoc = match ref_assoc_map.entry(cache_key.clone()) { Entry::Occupied(oc) => oc.into_mut(), Entry::Vacant(vc) => { // Pick a server let server = balancer.pick_server(); - vc.insert( - UdpAssociation::associate(server, src, dst, assoc_map.clone()) - .await - .expect("create udp association"), - ) + let sender = match ProxyHandler::new(src, dst, cache_key, assoc_map.clone()) { + Ok(s) => s, + Err(err) => { + error!("create UDP association for {} <-> {}, error: {}", src, dst, err); + continue; + } + }; + + let assoc = if is_bypassed { + ProxyAssociation::associate_bypassed(src, server, sender).await + } else { + ProxyAssociation::associate_proxied(src, server, sender).await + } + .expect("create UDP association"); + + vc.insert(assoc) } }; - // Clone the handle and release the lock. - // Make sure we keep the critical section small - assoc.clone() - }; - - // Send to local -> remote task - assoc.send(pkt.to_vec()).await; + // FIXME: Lock is still kept for a mutable reference + // Send to local -> remote task + assoc.send(target, pkt.to_vec()).await; + } } } diff --git a/src/relay/udprelay/socks5_local.rs b/src/relay/udprelay/socks5_local.rs index 026a856b8754..ff1416ecf2b1 100644 --- a/src/relay/udprelay/socks5_local.rs +++ b/src/relay/udprelay/socks5_local.rs @@ -2,39 +2,50 @@ use std::{ io::{self, Cursor, ErrorKind, Read}, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::SocketAddr, sync::Arc, - time::Duration, }; -use bytes::BytesMut; -use futures::{future, FutureExt}; +use async_trait::async_trait; use log::{debug, error, info, trace}; use lru_time_cache::{Entry, LruCache}; use tokio::{ self, - net::udp::{RecvHalf, SendHalf}, - sync::{mpsc, oneshot, Mutex}, + sync::{mpsc, Mutex}, time, }; use crate::{ - config::{ServerAddr, ServerConfig}, - context::{Context, SharedContext}, + context::SharedContext, relay::{ - loadbalancing::server::{PlainPingBalancer, ServerType, SharedPlainServerStatistic}, + loadbalancing::server::{PlainPingBalancer, ServerType}, socks5::{Address, UdpAssociateHeader}, - sys::{create_udp_socket, create_udp_socket_with_context}, - utils::try_timeout, + sys::create_udp_socket, }, }; use super::{ - crypto_io::{decrypt_payload, encrypt_payload}, + association::{ProxyAssociation, ProxySend}, DEFAULT_TIMEOUT, MAXIMUM_UDP_PAYLOAD_SIZE, }; +#[derive(Clone)] +struct ProxyHandler { + src_addr: SocketAddr, + response_tx: mpsc::Sender<(SocketAddr, Vec)>, +} + +#[async_trait] +impl ProxySend for ProxyHandler { + async fn send_packet(&mut self, data: Vec) -> io::Result<()> { + if let Err(err) = self.response_tx.send((self.src_addr, data)).await { + error!("UDP associate response channel error: {}", err); + } + Ok(()) + } +} + async fn parse_packet(pkt: &[u8]) -> io::Result<(Address, Vec)> { // PKT = UdpAssociateHeader + PAYLOAD let mut cur = Cursor::new(pkt); @@ -56,323 +67,6 @@ async fn parse_packet(pkt: &[u8]) -> io::Result<(Address, Vec)> { Ok((addr, payload)) } -struct UdpAssociationWatcher(oneshot::Sender<()>, oneshot::Sender<()>); - -// Represent a UDP association -#[derive(Clone)] -struct UdpAssociation { - // local -> remote Queue - // Drops tx, will close local -> remote task - tx: mpsc::Sender>, - - // local <- remote task life watcher - watcher: Arc, -} - -impl UdpAssociation { - /// Create an association with addr - async fn associate( - server: SharedPlainServerStatistic, - src_addr: SocketAddr, - mut response_tx: mpsc::Sender<(SocketAddr, Vec)>, - ) -> io::Result { - // Create a socket for receiving packets - let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); - - let remote_udp = create_udp_socket_with_context(&local_addr, &server.context()).await?; - let remote_bind_addr = remote_udp.local_addr().expect("determine port bound to"); - - let bypass_udp = create_udp_socket_with_context(&local_addr, &server.context()).await?; - let bypass_bind_addr = bypass_udp.local_addr().expect("determine port bound to"); - - debug!( - "created UDP association for {} from {} and {}", - src_addr, remote_bind_addr, bypass_bind_addr - ); - - // Create a channel for sending packets to remote - // FIXME: Channel size 1024? - let (tx, mut rx) = mpsc::channel::>(1024); - - // Create a watcher for local <- remote task - let (remote_watcher_tx, remote_watcher_rx) = oneshot::channel::<()>(); - let (bypass_watcher_tx, bypass_watcher_rx) = oneshot::channel::<()>(); - - let close_flag = Arc::new(UdpAssociationWatcher(remote_watcher_tx, bypass_watcher_tx)); - - // Splits socket into sender and receiver - let (mut remote_receiver, mut remote_sender) = remote_udp.split(); - let (mut bypass_receiver, mut bypass_sender) = bypass_udp.split(); - - let timeout = server.config().udp_timeout.unwrap_or(DEFAULT_TIMEOUT); - - { - // local -> remote - - let server = server.clone(); - tokio::spawn(async move { - let svr_cfg = server.server_config(); - let context = server.context(); - - while let Some(pkt) = rx.recv().await { - // pkt is already a raw packet, so just send it - let res = Self::relay_l2r( - context, - &src_addr, - &mut remote_sender, - &mut bypass_sender, - &pkt[..], - timeout, - svr_cfg, - ) - .await; - - if let Err(err) = res { - error!("failed to send packet {} -> ..., error: {}", src_addr, err); - - // FIXME: Ignore? Or how to deal with it? - } - } - - debug!("UDP ASSOCIATE {} -> .. finished", src_addr); - }); - } - - { - // local <- remote (proxied) - - let mut response_tx = response_tx.clone(); - tokio::spawn(async move { - let transfer_fut = async move { - let svr_cfg = server.server_config(); - let context = server.context(); - - loop { - // Read and send back to source - let res = - Self::relay_r2l_proxied(context, src_addr, &mut remote_receiver, &mut response_tx, svr_cfg) - .await; - - match res { - Ok(..) => {} - Err(err) => { - error!("failed to receive packet, {} <- .., error: {}", src_addr, err); - - // FIXME: Don't break, or if you can find a way to drop the UdpAssociation - // break; - } - } - } - }; - - // Resolved only if watcher_rx resolved - let _ = future::select(transfer_fut.boxed(), remote_watcher_rx.boxed()).await; - - debug!("UDP ASSOCIATE {} <- .. finished", src_addr); - }); - } - - // local <- remote (bypassed) - tokio::spawn(async move { - let transfer_fut = async move { - loop { - // Read and send back to source - match Self::relay_r2l_bypassed(src_addr, &mut bypass_receiver, &mut response_tx).await { - Ok(..) => {} - Err(err) => { - error!("failed to receive packet, {} <- .., error: {}", src_addr, err); - - // FIXME: Don't break, or if you can find a way to drop the UdpAssociation - // break; - } - } - } - }; - - // Resolved only if watcher_rx resolved - let _ = future::select(transfer_fut.boxed(), bypass_watcher_rx.boxed()).await; - - debug!("UDP ASSOCIATE {} <- .. finished", src_addr); - }); - - Ok(UdpAssociation { - tx, - watcher: close_flag, - }) - } - - /// Relay packets from local to remote - async fn relay_l2r( - context: &Context, - src: &SocketAddr, - remote_udp: &mut SendHalf, - bypass_udp: &mut SendHalf, - pkt: &[u8], - timeout: Duration, - svr_cfg: &ServerConfig, - ) -> io::Result<()> { - let (addr, payload) = parse_packet(&pkt).await?; - - debug!( - "UDP ASSOCIATE {} -> {}, payload length {} bytes", - src, - addr, - payload.len() - ); - - if context.check_target_bypassed(&addr).await { - Self::relay_l2r_bypassed(context, bypass_udp, &addr, &payload, timeout).await - } else { - Self::relay_l2r_proxied(context, remote_udp, &addr, &payload, timeout, svr_cfg).await - } - } - - async fn relay_l2r_proxied( - context: &Context, - remote_udp: &mut SendHalf, - addr: &Address, - payload: &[u8], - timeout: Duration, - svr_cfg: &ServerConfig, - ) -> io::Result<()> { - // CLIENT -> SERVER protocol: ADDRESS + PAYLOAD - let mut send_buf = Vec::new(); - addr.write_to_buf(&mut send_buf); - send_buf.extend_from_slice(&payload); - - let mut encrypt_buf = BytesMut::new(); - encrypt_payload(context, svr_cfg.method(), svr_cfg.key(), &send_buf, &mut encrypt_buf)?; - - let send_len = match svr_cfg.addr() { - ServerAddr::SocketAddr(ref remote_addr) => { - try_timeout(remote_udp.send_to(&encrypt_buf[..], remote_addr), Some(timeout)).await? - } - ServerAddr::DomainName(ref dname, port) => { - lookup_then!(context, dname, *port, |addr| { - try_timeout(remote_udp.send_to(&encrypt_buf[..], &addr), Some(timeout)).await - })? - .1 - } - }; - - assert_eq!(encrypt_buf.len(), send_len); - - Ok(()) - } - - async fn relay_l2r_bypassed( - context: &Context, - bypass_udp: &mut SendHalf, - addr: &Address, - payload: &[u8], - timeout: Duration, - ) -> io::Result<()> { - let send_len = match *addr { - Address::SocketAddress(ref saddr) => try_timeout(bypass_udp.send_to(payload, saddr), Some(timeout)).await?, - Address::DomainNameAddress(ref host, port) => { - lookup_then!(context, host, port, |saddr| { - try_timeout(bypass_udp.send_to(payload, &saddr), Some(timeout)).await - })? - .1 - } - }; - - assert_eq!(payload.len(), send_len); - - Ok(()) - } - - /// Relay packets from remote to local - async fn relay_r2l_proxied( - context: &Context, - src_addr: SocketAddr, - remote_udp: &mut RecvHalf, - response_tx: &mut mpsc::Sender<(SocketAddr, Vec)>, - svr_cfg: &ServerConfig, - ) -> io::Result<()> { - // Waiting for response from server SERVER -> CLIENT - // Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded. - let mut recv_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE]; - let (recv_n, remote_addr) = remote_udp.recv_from(&mut recv_buf).await?; - - let decrypt_buf = match decrypt_payload(context, svr_cfg.method(), svr_cfg.key(), &recv_buf[..recv_n])? { - None => { - error!("UDP packet too short, received length {}", recv_n); - let err = io::Error::new(io::ErrorKind::InvalidData, "packet too short"); - return Err(err); - } - Some(b) => b, - }; - // SERVER -> CLIENT protocol: ADDRESS + PAYLOAD - let mut cur = Cursor::new(decrypt_buf); - // FIXME: Address is ignored. Maybe useful in the future if we uses one common UdpSocket for communicate with remote server - let _ = Address::read_from(&mut cur).await?; - - let mut payload = Vec::new(); - - let header = UdpAssociateHeader::new(0, Address::SocketAddress(src_addr)); - - header.write_to_buf(&mut payload); - cur.read_to_end(&mut payload)?; - - debug!( - "UDP ASSOCIATE {} <- {}, payload length {} bytes", - src_addr, - remote_addr, - payload.len() - ); - - // Send back to src_addr - if let Err(err) = response_tx.send((src_addr, payload)).await { - error!("failed to send packet into response channel, error: {}", err); - - // FIXME: What to do? Ignore? - } - - Ok(()) - } - - /// Relay packets from remote to local - async fn relay_r2l_bypassed( - src_addr: SocketAddr, - remote_udp: &mut RecvHalf, - response_tx: &mut mpsc::Sender<(SocketAddr, Vec)>, - ) -> io::Result<()> { - // Waiting for response from server SERVER -> CLIENT - // Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded. - let mut recv_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE]; - let (recv_n, remote_addr) = remote_udp.recv_from(&mut recv_buf).await?; - - let payload = recv_buf[..recv_n].to_vec(); - - debug!( - "UDP ASSOCIATE {} <- {}, payload length {} bytes", - src_addr, - remote_addr, - payload.len() - ); - - // Send back to src_addr - if let Err(err) = response_tx.send((src_addr, payload)).await { - error!("failed to send packet into response channel, error: {}", err); - - // FIXME: What to do? Ignore? - } - - Ok(()) - } - - // Send packet to remote - // - // Return `Err` if receiver have been closed - async fn send(&mut self, pkt: Vec) { - if let Err(..) = self.tx.send(pkt).await { - // SHOULDn't HAPPEN - unreachable!("UDP Association local -> remote Queue closed unexpectly"); - } - } -} - /// Starts a UDP local server pub async fn run(context: SharedContext) -> io::Result<()> { let local_addr = context.config().local.as_ref().expect("local config"); @@ -454,8 +148,20 @@ pub async fn run(context: SharedContext) -> io::Result<()> { continue; } + // Parse it for validating + let (target, payload) = match parse_packet(pkt).await { + Ok(t) => t, + Err(err) => { + error!( + "received unrecognized UDP packet from {}, length {} bytes, error: {}", + src, recv_len, err + ); + continue; + } + }; + // Check or (re)create an association - let mut assoc = { + { // Locks the whole association map let mut assoc_map = assoc_map.lock().await; @@ -466,20 +172,22 @@ pub async fn run(context: SharedContext) -> io::Result<()> { // Pick a server let server = balancer.pick_server(); + let sender = ProxyHandler { + src_addr: src, + response_tx: tx.clone(), + }; + vc.insert( - UdpAssociation::associate(server, src, tx.clone()) + ProxyAssociation::associate_with_acl(src, server, sender) .await - .expect("create udp association"), + .expect("create UDP association"), ) } }; - // Clone the handle and release the lock. - // Make sure we keep the critical section small - assoc.clone() - }; - - // Send to local -> remote task - assoc.send(pkt.to_vec()).await; + // FIXME: Lock is still kept for a mutable reference + // Send to local -> remote task + assoc.send(target, payload).await; + } } } diff --git a/src/relay/udprelay/sys/mod.rs b/src/relay/udprelay/sys/mod.rs index 3d4584fde97b..eeb4d597b247 100644 --- a/src/relay/udprelay/sys/mod.rs +++ b/src/relay/udprelay/sys/mod.rs @@ -1,13 +1,4 @@ -use std::{ - io::{self, ErrorKind}, - net::SocketAddr, - task::{Context, Poll}, -}; - use cfg_if::cfg_if; -use futures::{future::poll_fn, ready}; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; -use tokio::io::PollEvented; cfg_if! { if #[cfg(unix)] { @@ -20,85 +11,3 @@ cfg_if! { compile_error!("UDP Relay is not supported in current platform"); } } - -/// A socket interface for transparent proxy -/// -/// It has basically the same APIs like `tokio::net::UdpSocket`, -/// but `recv_from` will return destination address of UDP packet -pub struct TProxyUdpSocket { - io: PollEvented, -} - -impl TProxyUdpSocket { - /// Create a new UDP socket binded to `addr` - /// - /// This will allow binding to `addr` that is not in local host - pub fn bind(addr: &SocketAddr) -> io::Result { - // Check if current plaform supports TPROXY (UDP) - // This is a runtime error. - check_support_tproxy()?; - - let domain = match *addr { - SocketAddr::V4(..) => Domain::ipv4(), - SocketAddr::V6(..) => Domain::ipv6(), - }; - let socket = Socket::new(domain, Type::dgram(), Some(Protocol::udp()))?; - set_socket_before_bind(addr, &socket)?; - - socket.set_nonblocking(true)?; - socket.set_reuse_address(true)?; - - socket.bind(&SockAddr::from(*addr))?; - - let msock = mio::net::UdpSocket::from_socket(socket.into_udp_socket())?; - let io = PollEvented::new(msock)?; - Ok(TProxyUdpSocket { io }) - } - - /// Send data to the socket to the given target address - pub async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result { - poll_fn(|cx| self.poll_send_to(cx, buf, target)).await - } - - fn poll_send_to(&self, cx: &mut Context<'_>, buf: &[u8], target: &SocketAddr) -> Poll> { - ready!(self.io.poll_write_ready(cx))?; - - match self.io.get_ref().send_to(buf, target) { - Err(ref e) if e.kind() == ErrorKind::WouldBlock => { - self.io.clear_write_ready(cx)?; - Poll::Pending - } - x => Poll::Ready(x), - } - } - - /// Returns the local address that this socket is bound to. - pub fn local_addr(&self) -> io::Result { - self.io.get_ref().local_addr() - } - - /// Receive a single datagram from the socket. - /// - /// On success, the future resolves to the number of bytes read and the origin, target address - /// - /// `(bytes read, origin address, target address)` - pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, SocketAddr)> { - poll_fn(|cx| self.poll_recv_from(cx, buf)).await - } - - fn poll_recv_from( - &self, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - - match recv_from_with_destination(self.io.get_ref(), buf) { - Err(ref e) if e.kind() == ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, mio::Ready::readable())?; - Poll::Pending - } - x => Poll::Ready(x), - } - } -} diff --git a/src/relay/udprelay/tproxy_socket.rs b/src/relay/udprelay/tproxy_socket.rs index 8fe13bf16598..71269201d8a4 100644 --- a/src/relay/udprelay/tproxy_socket.rs +++ b/src/relay/udprelay/tproxy_socket.rs @@ -1,3 +1,95 @@ //! Socket for supporting TPROXY -pub use super::sys::TProxyUdpSocket; +use std::{ + io::{self, ErrorKind}, + net::SocketAddr, + task::{Context, Poll}, +}; + +use futures::{future::poll_fn, ready}; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; +use tokio::io::PollEvented; + +use super::sys::{check_support_tproxy, recv_from_with_destination, set_socket_before_bind}; + +/// A socket interface for transparent proxy +/// +/// It has basically the same APIs like `tokio::net::UdpSocket`, +/// but `recv_from` will return destination address of UDP packet +pub struct TProxyUdpSocket { + io: PollEvented, +} + +impl TProxyUdpSocket { + /// Create a new UDP socket binded to `addr` + /// + /// This will allow binding to `addr` that is not in local host + pub fn bind(addr: &SocketAddr) -> io::Result { + // Check if current plaform supports TPROXY (UDP) + // This is a runtime error. + check_support_tproxy()?; + + let domain = match *addr { + SocketAddr::V4(..) => Domain::ipv4(), + SocketAddr::V6(..) => Domain::ipv6(), + }; + let socket = Socket::new(domain, Type::dgram(), Some(Protocol::udp()))?; + set_socket_before_bind(addr, &socket)?; + + socket.set_nonblocking(true)?; + socket.set_reuse_address(true)?; + + socket.bind(&SockAddr::from(*addr))?; + + let msock = mio::net::UdpSocket::from_socket(socket.into_udp_socket())?; + let io = PollEvented::new(msock)?; + Ok(TProxyUdpSocket { io }) + } + + /// Send data to the socket to the given target address + pub async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result { + poll_fn(|cx| self.poll_send_to(cx, buf, target)).await + } + + fn poll_send_to(&self, cx: &mut Context<'_>, buf: &[u8], target: &SocketAddr) -> Poll> { + ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().send_to(buf, target) { + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + self.io.clear_write_ready(cx)?; + Poll::Pending + } + x => Poll::Ready(x), + } + } + + /// Returns the local address that this socket is bound to. + pub fn local_addr(&self) -> io::Result { + self.io.get_ref().local_addr() + } + + /// Receive a single datagram from the socket. + /// + /// On success, the future resolves to the number of bytes read and the origin, target address + /// + /// `(bytes read, origin address, target address)` + pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, SocketAddr)> { + poll_fn(|cx| self.poll_recv_from(cx, buf)).await + } + + fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; + + match recv_from_with_destination(self.io.get_ref(), buf) { + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + self.io.clear_read_ready(cx, mio::Ready::readable())?; + Poll::Pending + } + x => Poll::Ready(x), + } + } +} diff --git a/src/relay/udprelay/tunnel_local.rs b/src/relay/udprelay/tunnel_local.rs index 7dd96fd21d0f..7af6f8967197 100644 --- a/src/relay/udprelay/tunnel_local.rs +++ b/src/relay/udprelay/tunnel_local.rs @@ -1,239 +1,44 @@ //! UDP relay local server -use std::{ - io::{self, Cursor, Read}, - net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::Arc, - time::Duration, -}; +use std::{io, net::SocketAddr, sync::Arc}; -use bytes::BytesMut; -use futures::{future, FutureExt}; +use async_trait::async_trait; use log::{debug, error, info, trace}; use lru_time_cache::{Entry, LruCache}; use tokio::{ self, - net::udp::{RecvHalf, SendHalf}, - sync::{mpsc, oneshot, Mutex}, + sync::{mpsc, Mutex}, time, }; use crate::{ - config::{ServerAddr, ServerConfig}, - context::{Context, SharedContext}, + context::SharedContext, relay::{ - loadbalancing::server::{PlainPingBalancer, ServerType, SharedPlainServerStatistic}, - socks5::Address, + loadbalancing::server::{PlainPingBalancer, ServerType}, sys::create_udp_socket, - utils::try_timeout, }, }; use super::{ - crypto_io::{decrypt_payload, encrypt_payload}, + association::{ProxyAssociation, ProxySend}, DEFAULT_TIMEOUT, MAXIMUM_UDP_PAYLOAD_SIZE, }; -// Drop the oneshot::Sender<()> will trigger local <- remote task to finish -struct UdpAssociationWatcher(oneshot::Sender<()>); - -// Represent a UDP association #[derive(Clone)] -struct UdpAssociation { - // local -> remote Queue - // Drops tx, will close local -> remote task - tx: mpsc::Sender>, - - // local <- remote task life watcher - watcher: Arc, +struct ProxyHandler { + src_addr: SocketAddr, + response_tx: mpsc::Sender<(SocketAddr, Vec)>, } -impl UdpAssociation { - /// Create an association with addr - async fn associate( - server: SharedPlainServerStatistic, - src_addr: SocketAddr, - mut response_tx: mpsc::Sender<(SocketAddr, Vec)>, - ) -> io::Result { - // Create a socket for receiving packets - let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); - - let remote_udp = create_udp_socket(&local_addr).await?; - let remote_bind_addr = remote_udp.local_addr().expect("determine port bound to"); - - debug!("created UDP association for {} from {}", src_addr, remote_bind_addr); - - // Create a channel for sending packets to remote - // FIXME: Channel size 1024? - let (tx, mut rx) = mpsc::channel::>(1024); - - // Create a watcher for local <- remote task - let (watcher_tx, watcher_rx) = oneshot::channel::<()>(); - - let close_flag = Arc::new(UdpAssociationWatcher(watcher_tx)); - - // Splits socket into sender and receiver - let (mut remote_receiver, mut remote_sender) = remote_udp.split(); - - let timeout = server.config().udp_timeout.unwrap_or(DEFAULT_TIMEOUT); - - { - // local -> remote - - let server = server.clone(); - tokio::spawn(async move { - let svr_cfg = server.server_config(); - let context = server.context(); - let dst_addr = server.config().forward.as_ref().expect("forward address"); - - while let Some(pkt) = rx.recv().await { - // pkt is already a raw packet, so just send it - debug!( - "UDP TUNNEL {} -> {}, payload length {} bytes", - src_addr, - dst_addr, - pkt.len() - ); - - let res = - Self::relay_l2r_proxied(context, &mut remote_sender, dst_addr, &pkt, timeout, svr_cfg).await; - - if let Err(err) = res { - error!("failed to send packet {} -> {}, error: {}", src_addr, dst_addr, err); - - // FIXME: Ignore? Or how to deal with it? - } - } - - debug!("UDP TUNNEL {} -> {} finished", src_addr, dst_addr); - }); +#[async_trait] +impl ProxySend for ProxyHandler { + async fn send_packet(&mut self, data: Vec) -> io::Result<()> { + if let Err(err) = self.response_tx.send((self.src_addr, data)).await { + error!("UDP associate response channel error: {}", err); } - - // local <- remote - tokio::spawn(async move { - let dst_addr = server.config().forward.as_ref().expect("forward address"); - - let transfer_fut = async { - let svr_cfg = server.server_config(); - let context = server.context(); - - loop { - // Read and send back to source - let res = - Self::relay_r2l_proxied(context, &src_addr, &mut remote_receiver, &mut response_tx, svr_cfg) - .await; - - if let Err(err) = res { - error!("failed to receive packet, {} <- {}, error: {}", src_addr, dst_addr, err); - - // FIXME: Don't break, or if you can find a way to drop the UdpAssociation - // break; - } - } - }; - - // Resolved only if watcher_rx resolved - let _ = future::select(transfer_fut.boxed(), watcher_rx.boxed()).await; - debug!("UDP TUNNEL {} <- {} finished", src_addr, dst_addr); - }); - - Ok(UdpAssociation { - tx, - watcher: close_flag, - }) - } - - /// Relay packets from local to remote - async fn relay_l2r_proxied( - context: &Context, - remote_udp: &mut SendHalf, - addr: &Address, - payload: &[u8], - timeout: Duration, - svr_cfg: &ServerConfig, - ) -> io::Result<()> { - // CLIENT -> SERVER protocol: ADDRESS + PAYLOAD - let mut send_buf = Vec::new(); - addr.write_to_buf(&mut send_buf); - send_buf.extend_from_slice(&payload); - - let mut encrypt_buf = BytesMut::new(); - encrypt_payload(context, svr_cfg.method(), svr_cfg.key(), &send_buf, &mut encrypt_buf)?; - - let send_len = match svr_cfg.addr() { - ServerAddr::SocketAddr(ref remote_addr) => { - try_timeout(remote_udp.send_to(&encrypt_buf[..], remote_addr), Some(timeout)).await? - } - ServerAddr::DomainName(ref dname, port) => { - lookup_then!(context, dname, *port, |addr| { - try_timeout(remote_udp.send_to(&encrypt_buf[..], &addr), Some(timeout)).await - })? - .1 - } - }; - - assert_eq!(encrypt_buf.len(), send_len); - Ok(()) } - - /// Relay packets from remote to local (proxied) - async fn relay_r2l_proxied( - context: &Context, - src_addr: &SocketAddr, - remote_udp: &mut RecvHalf, - response_tx: &mut mpsc::Sender<(SocketAddr, Vec)>, - svr_cfg: &ServerConfig, - ) -> io::Result<()> { - // Waiting for response from server SERVER -> CLIENT - // Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded. - let mut recv_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE]; - - let (recv_n, remote_addr) = remote_udp.recv_from(&mut recv_buf).await?; - - let decrypt_buf = match decrypt_payload(context, svr_cfg.method(), svr_cfg.key(), &recv_buf[..recv_n])? { - None => { - error!("UDP packet too short, received length {}", recv_n); - let err = io::Error::new(io::ErrorKind::InvalidData, "packet too short"); - return Err(err); - } - Some(b) => b, - }; - // SERVER -> CLIENT protocol: ADDRESS + PAYLOAD - let mut cur = Cursor::new(decrypt_buf); - // FIXME: Address is ignored. Maybe useful in the future if we uses one common UdpSocket for communicate with remote server - let _ = Address::read_from(&mut cur).await?; - - let mut payload = Vec::new(); - cur.read_to_end(&mut payload)?; - - debug!( - "UDP TUNNEL {} <- {}, payload length {} bytes", - src_addr, - remote_addr, - payload.len() - ); - - // Send back to src_addr - if let Err(err) = response_tx.send((*src_addr, payload)).await { - error!("failed to send packet into response channel, error: {}", err); - - // FIXME: What to do? Ignore? - } - - Ok(()) - } - - // Send packet to remote - // - // Return `Err` if receiver have been closed - async fn send(&mut self, pkt: Vec) { - if let Err(..) = self.tx.send(pkt).await { - // SHOULDn't HAPPEN - unreachable!("UDP association local -> remote Queue closed unexpectly"); - } - } } /// Starts a UDP local server @@ -248,10 +53,11 @@ pub async fn run(context: SharedContext) -> io::Result<()> { let (mut r, mut w) = l.split(); + let forward_target = context.config().forward.clone().expect("`forward` address in config"); + info!( "shadowsocks UDP tunnel listening on {}, forward to {}", - local_addr, - context.config().forward.as_ref().expect("`forward` address in config") + local_addr, forward_target ); // NOTE: Associations are only eliminated by expire time @@ -322,7 +128,7 @@ pub async fn run(context: SharedContext) -> io::Result<()> { } // Check or (re)create an association - let mut assoc = { + { // Locks the whole association map let mut assoc_map = assoc_map.lock().await; @@ -333,20 +139,22 @@ pub async fn run(context: SharedContext) -> io::Result<()> { // Pick a server let server = balancer.pick_server(); + let sender = ProxyHandler { + src_addr: src, + response_tx: tx.clone(), + }; + vc.insert( - UdpAssociation::associate(server, src, tx.clone()) + ProxyAssociation::associate_with_acl(src, server, sender) .await - .expect("create udp association"), + .expect("create UDP association"), ) } }; - // Clone the handle and release the lock. - // Make sure we keep the critical section small - assoc.clone() - }; - - // Send to local -> remote task - assoc.send(pkt.to_vec()).await; + // FIXME: Lock is still kept for a mutable reference + // Send to local -> remote task + assoc.send(forward_target.clone(), pkt.to_vec()).await; + } } }