Skip to content

Commit

Permalink
Refactor and support Android's flow statistic report
Browse files Browse the repository at this point in the history
  • Loading branch information
zonyitoo committed Mar 1, 2020
1 parent 3700338 commit a67154f
Show file tree
Hide file tree
Showing 12 changed files with 809 additions and 990 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ strum = "0.17"
strum_macros = "0.17"
iprange = "0.6"
ipnet = "2.2"
async-trait = "0.1"

[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["mswsock", "winsock2"] }
Expand Down
19 changes: 18 additions & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use trust_dns_resolver::TokioAsyncResolver;
use crate::relay::dns_resolver::create_resolver;
use crate::{
config::{Config, ConfigType, ServerConfig},
relay::{dns_resolver::resolve, socks5::Address},
relay::{dns_resolver::resolve, flow::ServerFlowStatistic, socks5::Address},
};

// Entries for server's bloom filter
Expand Down Expand Up @@ -144,9 +144,20 @@ pub type SharedServerState = Arc<ServerState>;
/// Shared basic configuration for the whole server
pub struct Context {
config: Config,

// Shared variables for all servers
server_state: SharedServerState,

// Server's running indicator
// For killing all background jobs
server_running: AtomicBool,

// Check for duplicated IV/Nonce, for prevent replay attack
// https://github.com/shadowsocks/shadowsocks-org/issues/44
nonce_ppbloom: Mutex<PingPongBloom>,

// For Android's flow stat report
local_flow_statistic: ServerFlowStatistic,
}

/// Unique context thw whole server
Expand All @@ -162,6 +173,7 @@ impl Context {
server_state,
server_running: AtomicBool::new(true),
nonce_ppbloom,
local_flow_statistic: ServerFlowStatistic::new(),
}
}

Expand Down Expand Up @@ -264,4 +276,9 @@ impl Context {
Some(ref a) => a.check_target_bypassed(self, target).await,
}
}

/// Get client flow statistics
pub fn local_flow_statistic(&self) -> &ServerFlowStatistic {
&self.local_flow_statistic
}
}
65 changes: 64 additions & 1 deletion src/relay/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,73 @@
use std::io::{self, ErrorKind};

use cfg_if::cfg_if;
use futures::{future::select_all, FutureExt};
use log::{debug, error, trace, warn};
use tokio::runtime::Handle;

use crate::{
config::{Config, ConfigType},
context::{Context, ServerState},
context::{Context, ServerState, SharedContext},
plugin::{PluginMode, Plugins},
relay::{tcprelay::local::run as run_tcp, udprelay::local::run as run_udp, utils::set_nofile},
};

cfg_if! {
if #[cfg(target_os = "android")] {
async fn flow_report_task(context: SharedContext) -> io::Result<()> {
use std::{ptr, time::Duration};

use tokio::{io::AsyncWriteExt, net::UnixStream, time};

// Android's flow statistic report RPC

let path = context.config().stat_path.as_ref().expect("stat_path must be provided");
let timeout = Duration::from_secs(1);

while context.server_running() {
// keep it as libev's default, 0.5 seconds
time::delay_for(Duration::from_millis(500)).await;

let stream = match time::timeout(timeout, UnixStream::connect(path)).await {
Ok(Ok(s)) => s,
Ok(Err(err)) => {
error!("send client flow statistic error: {}", err);
continue;
}
Err(..) => {
error!("send client flow statistic error: timeout");
continue;
}
};

let flow_stat = context.local_flow_statistic();
let tx = flow_stat.tcp().tx() + flow_stat.udp().tx();
let rx = flow_stat.tcp().rx() + flow_stat.udp().rx();

let buf: [u64; 2] = [tx, rx];

let buf = unsafe { &*(ptr::slice_from_raw_parts(buf.as_ptr() as *const _, 16)) };
match time::timeout(timeout, stream.write_all(buf)).await {
Ok(Ok(..)) => {}
Ok(Err(err)) => {
error!("send client flow statistic error: {}", err);
}
Err(..) => {
error!("send client flow statistic error: timeout");
}
}
}

Ok(())
}
} else {
async fn flow_report_task(_context: SharedContext) -> io::Result<()> {
unimplemented!("only for android")
}
}
}

/// Relay server running under local environment.
pub async fn run(mut config: Config, rt: Handle) -> io::Result<()> {
trace!("initializing local server with {:?}", config);
Expand Down Expand Up @@ -96,6 +152,13 @@ pub async fn run(mut config: Config, rt: Handle) -> io::Result<()> {
vf.push(udp_fut.boxed());
}

if cfg!(target_os = "android") && context.config().stat_path.is_some() {
// For Android's flow statistic

let report_fut = flow_report_task(context.clone());
vf.push(report_fut.boxed());
}

let (res, ..) = select_all(vf.into_iter()).await;
error!("one of servers exited unexpectly, result: {:?}", res);

Expand Down
2 changes: 1 addition & 1 deletion src/relay/tcprelay/http_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl tower::Service<Uri> for DirectConnector {
let err = Error::new(ErrorKind::Other, "URI must be a valid Address");
Err(err)
}
Some(addr) => ProxyStream::connect_direct(&*context, &addr).await,
Some(addr) => ProxyStream::connect_direct(context, &addr).await,
}
}
.boxed(),
Expand Down
84 changes: 61 additions & 23 deletions src/relay/tcprelay/proxy_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,17 @@ use crate::{

use super::{connection::Connection, CryptoStream, STcpStream};

macro_rules! forward_call {
($self:expr, $method:ident $(, $param:expr)*) => {
match *$self {
ProxyStream::Direct(ref mut s) => Pin::new(s).$method($($param),*),
ProxyStream::Proxied(ref mut s) => Pin::new(s).$method($($param),*),
}
};
}

/// Stream wrapper for both direct connections and proxied connections
#[allow(clippy::large_enum_variant)]
pub enum ProxyStream {
Direct(STcpStream),
Proxied(CryptoStream<STcpStream>),
Direct {
stream: STcpStream,
context: SharedContext,
},
Proxied {
stream: CryptoStream<STcpStream>,
context: SharedContext,
},
}

#[derive(Debug)]
Expand Down Expand Up @@ -80,7 +77,7 @@ impl ProxyStream {
addr: &Address,
) -> Result<ProxyStream, ProxyStreamError> {
if context.check_target_bypassed(addr).await {
ProxyStream::connect_direct_wrapped(&*context, addr).await
ProxyStream::connect_direct_wrapped(context, addr).await
} else {
ProxyStream::connect_proxied_wrapped(context, svr_cfg, addr).await
}
Expand All @@ -89,7 +86,7 @@ impl ProxyStream {
/// Connect to remote directly (without proxy)
///
/// This is used for hosts that matches ACL bypassed rules
pub async fn connect_direct(context: &Context, addr: &Address) -> io::Result<ProxyStream> {
pub async fn connect_direct(context: SharedContext, addr: &Address) -> io::Result<ProxyStream> {
debug!("connect to {} directly (bypassed)", addr);

// NOTE: Direct connection's timeout is controlled by the global key
Expand All @@ -105,10 +102,13 @@ impl ProxyStream {
}
};

Ok(ProxyStream::Direct(Connection::new(stream, timeout)))
Ok(ProxyStream::Direct {
stream: Connection::new(stream, timeout),
context,
})
}

async fn connect_direct_wrapped(context: &Context, addr: &Address) -> Result<ProxyStream, ProxyStreamError> {
async fn connect_direct_wrapped(context: SharedContext, addr: &Address) -> Result<ProxyStream, ProxyStreamError> {
match ProxyStream::connect_direct(context, addr).await {
Ok(s) => Ok(s),
Err(err) => Err(ProxyStreamError::new(err, true)),
Expand All @@ -130,10 +130,13 @@ impl ProxyStream {
svr_cfg.external_addr()
);

let server_stream = connect_proxy_server(&*context, svr_cfg).await?;
let proxy_stream = proxy_server_handshake(context, server_stream, svr_cfg, addr).await?;
let server_stream = connect_proxy_server(&context, svr_cfg).await?;
let proxy_stream = proxy_server_handshake(context.clone(), server_stream, svr_cfg, addr).await?;

Ok(ProxyStream::Proxied(proxy_stream))
Ok(ProxyStream::Proxied {
stream: proxy_stream,
context,
})
}

async fn connect_proxied_wrapped(
Expand All @@ -156,31 +159,66 @@ impl ProxyStream {
/// Returns the local socket address of this stream socket
pub fn local_addr(&self) -> io::Result<SocketAddr> {
match *self {
ProxyStream::Direct(ref s) => s.get_ref().local_addr(),
ProxyStream::Proxied(ref s) => s.get_ref().get_ref().local_addr(),
ProxyStream::Direct { ref stream, .. } => stream.get_ref().local_addr(),
ProxyStream::Proxied { ref stream, .. } => stream.get_ref().get_ref().local_addr(),
}
}

/// Check if the underlying connection is proxied
pub fn is_proxied(&self) -> bool {
match *self {
ProxyStream::Proxied(..) => true,
ProxyStream::Proxied { .. } => true,
_ => false,
}
}

/// Get reference to context
pub fn context(&self) -> &Context {
match *self {
ProxyStream::Direct { ref context, .. } => &context,
ProxyStream::Proxied { ref context, .. } => &context,
}
}
}

impl Unpin for ProxyStream {}

macro_rules! forward_call {
($self:expr, $method:ident $(, $param:expr)*) => {
match *$self {
ProxyStream::Direct { ref mut stream, .. } => Pin::new(stream).$method($($param),*),
ProxyStream::Proxied { ref mut stream, .. } => Pin::new(stream).$method($($param),*),
}
};
}

impl AsyncRead for ProxyStream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
forward_call!(self, poll_read, cx, buf)
let p = forward_call!(self, poll_read, cx, buf);

// Flow statistic for Android client
if cfg!(target_os = "android") && self.is_proxied() {
if let Poll::Ready(Ok(n)) = p {
self.context().local_flow_statistic().tcp().incr_tx(n as u64);
}
}

p
}
}

impl AsyncWrite for ProxyStream {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
forward_call!(self, poll_write, cx, buf)
let p = forward_call!(self, poll_write, cx, buf);

// Flow statistic for Android client
if cfg!(target_os = "android") && self.is_proxied() {
if let Poll::Ready(Ok(n)) = p {
self.context().local_flow_statistic().tcp().incr_rx(n as u64);
}
}

p
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<io::Result<()>> {
Expand Down
Loading

0 comments on commit a67154f

Please sign in to comment.