diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 8cb398ec0..40d4d2ff5 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -530,6 +530,7 @@ jobs: env: LLVM_VERSION: "16" SCCACHE_GHA_ENABLED: "on" + SCCACHE_SERVER_UDS: "\\x00sccache.socket" steps: - uses: actions/checkout@v4 @@ -651,6 +652,7 @@ jobs: env: SCCACHE_GHA_ENABLED: "on" + SCCACHE_SERVER_UDS: "/home/runner/sccache.socket" steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index 8fa5a6e6c..107e6d25f 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,13 @@ If you don't [specify otherwise](#storage-options), sccache will use a local dis sccache works using a client-server model, where the server runs locally on the same machine as the client. The client-server model allows the server to be more efficient by keeping some state in memory. The sccache command will spawn a server process if one is not already running, or you can run `sccache --start-server` to start the background server process without performing any compilation. +By default sccache server will listen on `127.0.0.1:4226`, you can specify environment variable `SCCACHE_SERVER_PORT` to use a different port or `SCCACHE_SERVER_UDS` to listen on unix domain socket. Abstract unix socket is also supported as long as the path is escaped following the [format](https://doc.rust-lang.org/std/ascii/fn.escape_default.html). For example: + +``` +% env SCCACHE_SERVER_UDS=$HOME/sccache.sock sccache --start-server # unix socket +% env SCCACHE_SERVER_UDS=\\x00sccache.sock sccache --start-server # abstract unix socket +``` + You can run `sccache --stop-server` to terminate the server. It will also terminate after (by default) 10 minutes of inactivity. Running `sccache --show-stats` will print a summary of cache statistics. diff --git a/src/client.rs b/src/client.rs index 51a6a1fe3..f13ec2be7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,28 +13,28 @@ // limitations under the License. use crate::errors::*; +use crate::net::Connection; use crate::protocol::{Request, Response}; use crate::util; use byteorder::{BigEndian, ByteOrder}; use retry::{delay::Fixed, retry}; use std::io::{self, BufReader, BufWriter, Read}; -use std::net::TcpStream; /// A connection to an sccache server. pub struct ServerConnection { /// A reader for the socket connected to the server. - reader: BufReader, + reader: BufReader>, /// A writer for the socket connected to the server. - writer: BufWriter, + writer: BufWriter>, } impl ServerConnection { /// Create a new connection using `stream`. - pub fn new(stream: TcpStream) -> io::Result { - let writer = stream.try_clone()?; + pub fn new(conn: Box) -> io::Result { + let write_conn = conn.try_clone()?; Ok(ServerConnection { - reader: BufReader::new(stream), - writer: BufWriter::new(writer), + reader: BufReader::new(conn), + writer: BufWriter::new(write_conn), }) } @@ -62,24 +62,24 @@ impl ServerConnection { } } -/// Establish a TCP connection to an sccache server listening on `port`. -pub fn connect_to_server(port: u16) -> io::Result { - trace!("connect_to_server({})", port); - let stream = TcpStream::connect(("127.0.0.1", port))?; - ServerConnection::new(stream) +/// Establish a TCP connection to an sccache server listening on `addr`. +pub fn connect_to_server(addr: &crate::net::SocketAddr) -> io::Result { + trace!("connect_to_server({addr})"); + let conn = crate::net::connect(addr)?; + ServerConnection::new(conn) } -/// Attempt to establish a TCP connection to an sccache server listening on `port`. +/// Attempt to establish a TCP connection to an sccache server listening on `addr`. /// /// If the connection fails, retry a few times. -pub fn connect_with_retry(port: u16) -> io::Result { - trace!("connect_with_retry({})", port); +pub fn connect_with_retry(addr: &crate::net::SocketAddr) -> io::Result { + trace!("connect_with_retry({addr})"); // TODOs: // * Pass the server Child in here, so we can stop retrying // if the process exited. // * Send a pipe handle to the server process so it can notify // us once it starts the server instead of us polling. - match retry(Fixed::from_millis(500).take(10), || connect_to_server(port)) { + match retry(Fixed::from_millis(500).take(10), || connect_to_server(addr)) { Ok(conn) => Ok(conn), Err(e) => Err(io::Error::new( io::ErrorKind::TimedOut, diff --git a/src/commands.rs b/src/commands.rs index 2b68df1e0..3f6857f62 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -50,11 +50,18 @@ pub const DEFAULT_PORT: u16 = 4226; const SERVER_STARTUP_TIMEOUT: Duration = Duration::from_millis(10000); /// Get the port on which the server should listen. -fn get_port() -> u16 { - env::var("SCCACHE_SERVER_PORT") +fn get_addr() -> crate::net::SocketAddr { + #[cfg(unix)] + if let Ok(addr) = env::var("SCCACHE_SERVER_UDS") { + if let Ok(uds) = crate::net::SocketAddr::parse_uds(&addr) { + return uds; + } + } + let port = env::var("SCCACHE_SERVER_PORT") .ok() .and_then(|s| s.parse().ok()) - .unwrap_or(DEFAULT_PORT) + .unwrap_or(DEFAULT_PORT); + crate::net::SocketAddr::with_port(port) } /// Check if ignoring all response errors @@ -293,28 +300,27 @@ fn run_server_process(startup_timeout: Option) -> Result, ) -> Result { - trace!("connect_or_start_server({})", port); - match connect_to_server(port) { + trace!("connect_or_start_server({addr})"); + match connect_to_server(addr) { Ok(server) => Ok(server), Err(ref e) - if e.kind() == io::ErrorKind::ConnectionRefused - || e.kind() == io::ErrorKind::TimedOut => + if (e.kind() == io::ErrorKind::ConnectionRefused + || e.kind() == io::ErrorKind::TimedOut) + || (e.kind() == io::ErrorKind::NotFound && addr.is_unix_path()) => { // If the connection was refused we probably need to start // the server. match run_server_process(startup_timeout)? { - ServerStartup::Ok { port: actualport } => { - if port != actualport { + ServerStartup::Ok { addr: actual_addr } => { + if addr.to_string() != actual_addr { // bail as the next connect_with_retry will fail bail!( - "sccache: Listening on port {} instead of {}", - actualport, - port + "sccache: Listening on address {actual_addr} instead of {addr}" ); } } @@ -324,7 +330,7 @@ fn connect_or_start_server( ServerStartup::TimedOut => bail!("Timed out waiting for server startup. Maybe the remote service is unreachable?\nRun with SCCACHE_LOG=debug SCCACHE_NO_DAEMON=1 to get more information"), ServerStartup::Err { reason } => bail!("Server startup failed: {}\nRun with SCCACHE_LOG=debug SCCACHE_NO_DAEMON=1 to get more information", reason), } - let server = connect_with_retry(port)?; + let server = connect_with_retry(addr)?; Ok(server) } Err(e) => Err(e.into()), @@ -614,7 +620,7 @@ pub fn run_command(cmd: Command) -> Result { match cmd { Command::ShowStats(fmt, advanced) => { trace!("Command::ShowStats({:?})", fmt); - let stats = match connect_to_server(get_port()) { + let stats = match connect_to_server(&get_addr()) { Ok(srv) => request_stats(srv).context("failed to get stats from server")?, // If there is no server, spawning a new server would start with zero stats // anyways, so we can just return (mostly) empty stats directly. @@ -658,7 +664,7 @@ pub fn run_command(cmd: Command) -> Result { // We aren't asking for a log file daemonize()?; } - server::start_server(config, get_port())?; + server::start_server(config, &get_addr())?; } Command::StartServer => { trace!("Command::StartServer"); @@ -666,10 +672,8 @@ pub fn run_command(cmd: Command) -> Result { let startup = run_server_process(startup_timeout).context("failed to start server process")?; match startup { - ServerStartup::Ok { port } => { - if port != DEFAULT_PORT { - println!("sccache: Listening on port {}", port); - } + ServerStartup::Ok { addr } => { + println!("sccache: Listening on address {addr}"); } ServerStartup::TimedOut => bail!("Timed out waiting for server startup"), ServerStartup::AddrInUse => bail!("Server startup failed: Address in use"), @@ -679,13 +683,13 @@ pub fn run_command(cmd: Command) -> Result { Command::StopServer => { trace!("Command::StopServer"); println!("Stopping sccache server..."); - let server = connect_to_server(get_port()).context("couldn't connect to server")?; + let server = connect_to_server(&get_addr()).context("couldn't connect to server")?; let stats = request_shutdown(server)?; stats.print(false); } Command::ZeroStats => { trace!("Command::ZeroStats"); - let conn = connect_or_start_server(get_port(), startup_timeout)?; + let conn = connect_or_start_server(&get_addr(), startup_timeout)?; request_zero_stats(conn).context("couldn't zero stats on server")?; eprintln!("Statistics zeroed."); } @@ -747,7 +751,7 @@ pub fn run_command(cmd: Command) -> Result { ), Command::DistStatus => { trace!("Command::DistStatus"); - let srv = connect_or_start_server(get_port(), startup_timeout)?; + let srv = connect_or_start_server(&get_addr(), startup_timeout)?; let status = request_dist_status(srv).context("failed to get dist-status from server")?; serde_json::to_writer(&mut io::stdout(), &status)?; @@ -785,7 +789,7 @@ pub fn run_command(cmd: Command) -> Result { } => { trace!("Command::Compile {{ {:?}, {:?}, {:?} }}", exe, cmdline, cwd); let jobserver = unsafe { Client::new() }; - let conn = connect_or_start_server(get_port(), startup_timeout)?; + let conn = connect_or_start_server(&get_addr(), startup_timeout)?; let mut runtime = Runtime::new()?; let res = do_compile( ProcessCommandCreator::new(&jobserver), diff --git a/src/lib.rs b/src/lib.rs index a893a155f..8d3001c1c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,7 @@ pub mod dist; mod jobserver; pub mod lru_disk_cache; mod mock_command; +mod net; mod protocol; pub mod server; #[doc(hidden)] diff --git a/src/net.rs b/src/net.rs new file mode 100644 index 000000000..79d350f68 --- /dev/null +++ b/src/net.rs @@ -0,0 +1,182 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! The module is used to provide abstraction over TCP socket and UDS. + +use std::fmt; +#[cfg(any(target_os = "linux", target_os = "android"))] +use std::os::linux::net::SocketAddrExt; + +use futures::{Future, TryFutureExt}; +use tokio::io::{AsyncRead, AsyncWrite}; + +// A unify version of `std::net::SocketAddr` and Unix domain socket. +#[derive(Debug)] +pub enum SocketAddr { + Net(std::net::SocketAddr), + // This could work on Windows in the future. See also rust-lang/rust#56533. + #[cfg(unix)] + Unix(std::path::PathBuf), + #[cfg(any(target_os = "linux", target_os = "android"))] + UnixAbstract(Vec), +} + +impl fmt::Display for SocketAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SocketAddr::Net(addr) => write!(f, "{}", addr), + #[cfg(unix)] + SocketAddr::Unix(p) => write!(f, "{}", p.display()), + #[cfg(any(target_os = "linux", target_os = "android"))] + SocketAddr::UnixAbstract(p) => write!(f, "\\x00{}", p.escape_ascii()), + } + } +} + +impl SocketAddr { + /// Get a Net address that with IP part set to "127.0.0.1". + #[inline] + pub fn with_port(port: u16) -> Self { + SocketAddr::Net(std::net::SocketAddr::from(([127, 0, 0, 1], port))) + } + + #[inline] + pub fn as_net(&self) -> Option<&std::net::SocketAddr> { + match self { + SocketAddr::Net(addr) => Some(addr), + #[cfg(unix)] + _ => None, + } + } + + /// Parse a string as a unix domain socket. + /// + /// The string should follow the format of `self.to_string()`. + #[cfg(unix)] + pub fn parse_uds(s: &str) -> std::io::Result { + // Parse abstract socket address first as it can contain any chars. + #[cfg(any(target_os = "linux", target_os = "android"))] + { + if s.starts_with("\\x00") { + // Rust abstract path expects no prepand '\x00'. + let data = crate::util::ascii_unescape_default(&s.as_bytes()[4..])?; + return Ok(SocketAddr::UnixAbstract(data)); + } + } + let path = std::path::PathBuf::from(s); + Ok(SocketAddr::Unix(path)) + } + + #[cfg(unix)] + pub fn is_unix_path(&self) -> bool { + matches!(self, SocketAddr::Unix(_)) + } + + #[cfg(not(unix))] + pub fn is_unix_path(&self) -> bool { + false + } +} + +// A helper trait to unify the behavior of TCP and UDS listener. +pub trait Acceptor { + type Socket: AsyncRead + AsyncWrite + Unpin + Send; + + fn accept(&self) -> impl Future> + Send; + fn local_addr(&self) -> tokio::io::Result>; +} + +impl Acceptor for tokio::net::TcpListener { + type Socket = tokio::net::TcpStream; + + #[inline] + fn accept(&self) -> impl Future> + Send { + tokio::net::TcpListener::accept(self).and_then(|(s, _)| futures::future::ok(s)) + } + + #[inline] + fn local_addr(&self) -> tokio::io::Result> { + tokio::net::TcpListener::local_addr(self).map(|a| Some(SocketAddr::Net(a))) + } +} + +// A helper trait to unify the behavior of TCP and UDS stream. +pub trait Connection: std::io::Read + std::io::Write { + fn try_clone(&self) -> std::io::Result>; +} + +impl Connection for std::net::TcpStream { + #[inline] + fn try_clone(&self) -> std::io::Result> { + let stream = std::net::TcpStream::try_clone(self)?; + Ok(Box::new(stream)) + } +} + +// Helper function to create a stream. Uses dynamic dispatch to make code more +// readable. +pub fn connect(addr: &SocketAddr) -> std::io::Result> { + match addr { + SocketAddr::Net(addr) => { + std::net::TcpStream::connect(addr).map(|s| Box::new(s) as Box) + } + #[cfg(unix)] + SocketAddr::Unix(p) => { + std::os::unix::net::UnixStream::connect(p).map(|s| Box::new(s) as Box) + } + #[cfg(any(target_os = "linux", target_os = "android"))] + SocketAddr::UnixAbstract(p) => { + let sock = std::os::unix::net::SocketAddr::from_abstract_name(p)?; + std::os::unix::net::UnixStream::connect_addr(&sock) + .map(|s| Box::new(s) as Box) + } + } +} + +#[cfg(unix)] +mod unix_imp { + use futures::TryFutureExt; + + use super::*; + + impl Acceptor for tokio::net::UnixListener { + type Socket = tokio::net::UnixStream; + + #[inline] + fn accept(&self) -> impl Future> + Send { + tokio::net::UnixListener::accept(self).and_then(|(s, _)| futures::future::ok(s)) + } + + #[inline] + fn local_addr(&self) -> tokio::io::Result> { + let addr = tokio::net::UnixListener::local_addr(self)?; + if let Some(p) = addr.as_pathname() { + return Ok(Some(SocketAddr::Unix(p.to_path_buf()))); + } + // TODO: support get addr from abstract socket. + // tokio::net::SocketAddr needs to support `as_abstract_name`. + // #[cfg(any(target_os = "linux", target_os = "android"))] + // if let Some(p) = addr.0.as_abstract_name() { + // return Ok(SocketAddr::UnixAbstract(p.to_vec())); + // } + Ok(None) + } + } + + impl Connection for std::os::unix::net::UnixStream { + #[inline] + fn try_clone(&self) -> std::io::Result> { + let stream = std::os::unix::net::UnixStream::try_clone(self)?; + Ok(Box::new(stream)) + } + } +} diff --git a/src/server.rs b/src/server.rs index 0128a4adf..03c2b6a63 100644 --- a/src/server.rs +++ b/src/server.rs @@ -46,7 +46,8 @@ use std::io::{self, Write}; use std::marker::Unpin; #[cfg(feature = "dist-client")] use std::mem; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +#[cfg(any(target_os = "linux", target_os = "android"))] +use std::os::linux::net::SocketAddrExt; use std::path::PathBuf; use std::pin::Pin; use std::process::{ExitStatus, Output}; @@ -60,7 +61,6 @@ use tokio::sync::Mutex; use tokio::sync::RwLock; use tokio::{ io::{AsyncRead, AsyncWrite}, - net::TcpListener, runtime::Runtime, time::{self, sleep, Sleep}, }; @@ -81,8 +81,8 @@ const DIST_CLIENT_RECREATE_TIMEOUT: Duration = Duration::from_secs(30); /// Result of background server startup. #[derive(Debug, Serialize, Deserialize)] pub enum ServerStartup { - /// Server started successfully on `port`. - Ok { port: u16 }, + /// Server started successfully on `addr`. + Ok { addr: String }, /// Server Addr already in suse AddrInUse, /// Timed out waiting for server startup. @@ -401,12 +401,12 @@ thread_local! { static PANIC_LOCATION: Cell> = const { Cell::new(None) }; } -/// Start an sccache server, listening on `port`. +/// Start an sccache server, listening on `addr`. /// /// Spins an event loop handling client connections until a client /// requests a shutdown. -pub fn start_server(config: &Config, port: u16) -> Result<()> { - info!("start_server: port: {}", port); +pub fn start_server(config: &Config, addr: &crate::net::SocketAddr) -> Result<()> { + info!("start_server: {addr}"); let panic_hook = std::panic::take_hook(); std::panic::set_hook(Box::new(move |info| { PANIC_LOCATION.with(|l| { @@ -467,59 +467,105 @@ pub fn start_server(config: &Config, port: u16) -> Result<()> { _ => raw_storage, }; - let res = - SccacheServer::::new(port, runtime, client, dist_client, storage); + let res = (|| -> io::Result<_> { + match addr { + crate::net::SocketAddr::Net(addr) => { + trace!("binding TCP {addr}"); + let l = runtime.block_on(tokio::net::TcpListener::bind(addr))?; + let srv = + SccacheServer::<_>::with_listener(l, runtime, client, dist_client, storage); + Ok(( + srv.local_addr().unwrap(), + Box::new(move |f| srv.run(f)) as Box _>, + )) + } + #[cfg(unix)] + crate::net::SocketAddr::Unix(path) => { + trace!("binding unix socket {}", path.display()); + // Unix socket will report addr in use on any unlink file. + let _ = std::fs::remove_file(path); + let l = { + let _guard = runtime.enter(); + tokio::net::UnixListener::bind(path)? + }; + let srv = + SccacheServer::<_>::with_listener(l, runtime, client, dist_client, storage); + Ok(( + srv.local_addr().unwrap(), + Box::new(move |f| srv.run(f)) as Box _>, + )) + } + #[cfg(any(target_os = "linux", target_os = "android"))] + crate::net::SocketAddr::UnixAbstract(p) => { + trace!("binding abstract unix socket {}", p.escape_ascii()); + let abstract_addr = std::os::unix::net::SocketAddr::from_abstract_name(p)?; + let l = std::os::unix::net::UnixListener::bind_addr(&abstract_addr)?; + l.set_nonblocking(true)?; + let l = { + let _guard = runtime.enter(); + tokio::net::UnixListener::from_std(l)? + }; + let srv = + SccacheServer::<_>::with_listener(l, runtime, client, dist_client, storage); + Ok(( + srv.local_addr() + .unwrap_or_else(|| crate::net::SocketAddr::UnixAbstract(p.to_vec())), + Box::new(move |f| srv.run(f)) as Box _>, + )) + } + } + })(); match res { - Ok(srv) => { - let port = srv.port(); - info!("server started, listening on port {}", port); - notify_server_startup(¬ify, ServerStartup::Ok { port })?; - srv.run(future::pending::<()>())?; + Ok((addr, run)) => { + info!("server started, listening on {addr}"); + notify_server_startup( + ¬ify, + ServerStartup::Ok { + addr: addr.to_string(), + }, + )?; + run(future::pending::<()>())?; Ok(()) } Err(e) => { error!("failed to start server: {}", e); - match e.downcast_ref::() { - Some(io_err) if io::ErrorKind::AddrInUse == io_err.kind() => { - notify_server_startup(¬ify, ServerStartup::AddrInUse)?; - } - Some(io_err) if cfg!(windows) && Some(10013) == io_err.raw_os_error() => { - // 10013 is the "WSAEACCES" error, which can occur if the requested port - // has been allocated for other purposes, such as winNAT or Hyper-V. - let windows_help_message = - "A Windows port exclusion is blocking use of the configured port.\nTry setting SCCACHE_SERVER_PORT to a new value."; - let reason: String = format!("{windows_help_message}\n{e}"); - notify_server_startup(¬ify, ServerStartup::Err { reason })?; - } - _ => { - let reason = e.to_string(); - notify_server_startup(¬ify, ServerStartup::Err { reason })?; - } - }; - Err(e) + if io::ErrorKind::AddrInUse == e.kind() { + notify_server_startup(¬ify, ServerStartup::AddrInUse)?; + } else if cfg!(windows) && Some(10013) == e.raw_os_error() { + // 10013 is the "WSAEACCES" error, which can occur if the requested port + // has been allocated for other purposes, such as winNAT or Hyper-V. + let windows_help_message = + "A Windows port exclusion is blocking use of the configured port.\nTry setting SCCACHE_SERVER_PORT to a new value."; + let reason: String = format!("{windows_help_message}\n{e}"); + notify_server_startup(¬ify, ServerStartup::Err { reason })?; + } else { + let reason = e.to_string(); + notify_server_startup(¬ify, ServerStartup::Err { reason })?; + } + Err(e.into()) } } } -pub struct SccacheServer { +pub struct SccacheServer { runtime: Runtime, - listener: TcpListener, + listener: A, rx: mpsc::Receiver, timeout: Duration, service: SccacheService, wait: WaitUntilZero, } -impl SccacheServer { +impl SccacheServer { pub fn new( port: u16, runtime: Runtime, client: Client, dist_client: DistClientContainer, storage: Arc, - ) -> Result> { - let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port); - let listener = runtime.block_on(TcpListener::bind(&SocketAddr::V4(addr)))?; + ) -> Result { + let addr = crate::net::SocketAddr::with_port(port); + let listener = runtime.block_on(tokio::net::TcpListener::bind(addr.as_net().unwrap()))?; Ok(Self::with_listener( listener, @@ -529,14 +575,16 @@ impl SccacheServer { storage, )) } +} +impl SccacheServer { pub fn with_listener( - listener: TcpListener, + listener: A, runtime: Runtime, client: Client, dist_client: DistClientContainer, storage: Arc, - ) -> SccacheServer { + ) -> Self { // Prepare the service which we'll use to service all incoming TCP // connections. let (tx, rx) = mpsc::channel(1); @@ -580,8 +628,8 @@ impl SccacheServer { /// Returns the port that this server is bound to #[allow(dead_code)] - pub fn port(&self) -> u16 { - self.listener.local_addr().unwrap().port() + pub fn local_addr(&self) -> Option { + self.listener.local_addr().unwrap() } /// Runs this server to completion. @@ -593,6 +641,7 @@ impl SccacheServer { where F: Future, C: Send, + A::Socket: 'static, { let SccacheServer { runtime, @@ -607,7 +656,7 @@ impl SccacheServer { // connections in separate tasks. let server = async move { loop { - let (socket, _) = listener.accept().await?; + let socket = listener.accept().await?; trace!("incoming connection"); let conn = service.clone().bind(socket).map_err(|res| { error!("Failed to bind socket: {}", res); diff --git a/src/test/tests.rs b/src/test/tests.rs index d5c3ee0e9..3aa98a80f 100644 --- a/src/test/tests.rs +++ b/src/test/tests.rs @@ -58,7 +58,7 @@ fn run_server_thread( cache_dir: &Path, options: T, ) -> ( - u16, + crate::net::SocketAddr, Sender, Arc>, thread::JoinHandle<()>, @@ -90,28 +90,28 @@ where let client = unsafe { Client::new() }; let srv = SccacheServer::new(0, runtime, client, dist_client, storage).unwrap(); - let mut srv: SccacheServer>> = srv; - assert!(srv.port() > 0); + let mut srv: SccacheServer<_, Arc>> = srv; + let addr = srv.local_addr().unwrap(); + assert!(matches!(addr, crate::net::SocketAddr::Net(a) if a.port() > 0)); if let Some(options) = options { if let Some(timeout) = options.idle_timeout { srv.set_idle_timeout(Duration::from_millis(timeout)); } } - let port = srv.port(); let creator = srv.command_creator().clone(); - tx.send((port, creator)).unwrap(); + tx.send((addr, creator)).unwrap(); srv.run(shutdown_rx).unwrap(); }); - let (port, creator) = rx.recv().unwrap(); - (port, shutdown_tx, creator, handle) + let (addr, creator) = rx.recv().unwrap(); + (addr, shutdown_tx, creator, handle) } #[test] fn test_server_shutdown() { let f = TestFixture::new(); - let (port, _sender, _storage, child) = run_server_thread(f.tempdir.path(), None); + let (addr, _sender, _storage, child) = run_server_thread(f.tempdir.path(), None); // Connect to the server. - let conn = connect_to_server(port).unwrap(); + let conn = connect_to_server(&addr).unwrap(); // Ask it to shut down request_shutdown(conn).unwrap(); // Ensure that it shuts down. @@ -123,7 +123,7 @@ fn test_server_shutdown() { fn test_server_shutdown_no_idle() { let f = TestFixture::new(); // Set a ridiculously low idle timeout. - let (port, _sender, _storage, child) = run_server_thread( + let (addr, _sender, _storage, child) = run_server_thread( f.tempdir.path(), ServerOptions { idle_timeout: Some(0), @@ -131,7 +131,7 @@ fn test_server_shutdown_no_idle() { }, ); - let conn = connect_to_server(port).unwrap(); + let conn = connect_to_server(&addr).unwrap(); request_shutdown(conn).unwrap(); child.join().unwrap(); } @@ -157,9 +157,9 @@ fn test_server_idle_timeout() { #[test] fn test_server_stats() { let f = TestFixture::new(); - let (port, sender, _storage, child) = run_server_thread(f.tempdir.path(), None); + let (addr, sender, _storage, child) = run_server_thread(f.tempdir.path(), None); // Connect to the server. - let conn = connect_to_server(port).unwrap(); + let conn = connect_to_server(&addr).unwrap(); // Ask it for stats. let info = request_stats(conn).unwrap(); assert_eq!(0, info.stats.compile_requests); @@ -174,9 +174,9 @@ fn test_server_stats() { #[test] fn test_server_unsupported_compiler() { let f = TestFixture::new(); - let (port, sender, server_creator, child) = run_server_thread(f.tempdir.path(), None); + let (addr, sender, server_creator, child) = run_server_thread(f.tempdir.path(), None); // Connect to the server. - let conn = connect_to_server(port).unwrap(); + let conn = connect_to_server(&addr).unwrap(); { let mut c = server_creator.lock().unwrap(); // fail rust driver check @@ -226,13 +226,13 @@ fn test_server_compile() { let _ = env_logger::try_init(); let f = TestFixture::new(); let gcc = f.mk_bin("gcc").unwrap(); - let (port, sender, server_creator, child) = run_server_thread(f.tempdir.path(), None); + let (addr, sender, server_creator, child) = run_server_thread(f.tempdir.path(), None); // Connect to the server. const PREPROCESSOR_STDOUT: &[u8] = b"preprocessor stdout"; const PREPROCESSOR_STDERR: &[u8] = b"preprocessor stderr"; const STDOUT: &[u8] = b"some stdout"; const STDERR: &[u8] = b"some stderr"; - let conn = connect_to_server(port).unwrap(); + let conn = connect_to_server(&addr).unwrap(); // Write a dummy input file so the preprocessor cache mode can work std::fs::write(f.tempdir.path().join("file.c"), "whatever").unwrap(); { @@ -308,6 +308,7 @@ fn test_server_port_in_use() { "SCCACHE_SERVER_PORT", listener.local_addr().unwrap().port().to_string(), ) + .env_remove("SCCACHE_SERVER_UDS") .output() .unwrap(); assert!(!output.status.success()); diff --git a/src/util.rs b/src/util.rs index 2000ae832..001aedd2f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -945,6 +945,67 @@ pub fn new_reqwest_blocking_client() -> reqwest::blocking::Client { .expect("http client must build with success") } +fn unhex(b: u8) -> std::io::Result { + match b { + b'0'..=b'9' => Ok(b - b'0'), + b'a'..=b'f' => Ok(b - b'a' + 10), + b'A'..=b'F' => Ok(b - b'A' + 10), + _ => Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "invalid hex digit", + )), + } +} + +/// A reverse version of std::ascii::escape_default +pub fn ascii_unescape_default(s: &[u8]) -> std::io::Result> { + let mut out = Vec::with_capacity(s.len() + 4); + let mut offset = 0; + while offset < s.len() { + let c = s[offset]; + if c == b'\\' { + offset += 1; + if offset >= s.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "incomplete escape", + )); + } + let c = s[offset]; + match c { + b'n' => out.push(b'\n'), + b'r' => out.push(b'\r'), + b't' => out.push(b'\t'), + b'\'' => out.push(b'\''), + b'"' => out.push(b'"'), + b'\\' => out.push(b'\\'), + b'x' => { + offset += 1; + if offset + 1 >= s.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "incomplete hex escape", + )); + } + let v = unhex(s[offset])? << 4 | unhex(s[offset + 1])?; + out.push(v); + offset += 1; + } + _ => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "invalid escape", + )); + } + } + } else { + out.push(c); + } + offset += 1; + } + Ok(out) +} + #[cfg(test)] mod tests { use super::{OsStrExt, TimeMacroFinder}; @@ -1055,4 +1116,46 @@ mod tests { finder.find_time_macros(b"TIMESTAMP__ This is larger than the haystack"); assert!(finder.found_timestamp()); } + + #[test] + fn test_ascii_unescape_default() { + let mut alphabet = r#"\\'"\t\n\r"#.as_bytes().to_vec(); + alphabet.push(b'a'); + alphabet.push(b'1'); + alphabet.push(0); + alphabet.push(0xff); + let mut input = vec![]; + let mut output = vec![]; + let mut alphabet_indexes = [0; 3]; + let mut tested_cases = 0; + // Following loop may test duplicated inputs, but it's not a problem + loop { + input.clear(); + output.clear(); + for idx in alphabet_indexes { + if idx < alphabet.len() { + input.push(alphabet[idx]); + } + } + if input.is_empty() { + break; + } + output.extend(input.as_slice().escape_ascii()); + let result = super::ascii_unescape_default(&output).unwrap(); + assert_eq!(input, result, "{:?}", output); + tested_cases += 1; + for idx in &mut alphabet_indexes { + *idx += 1; + if *idx > alphabet.len() { + // Use `>` so we can test various input length. + *idx = 0; + } else { + break; + } + } + } + assert_eq!(tested_cases, (alphabet.len() + 1).pow(3) - 1); + let empty_result = super::ascii_unescape_default(&[]).unwrap(); + assert!(empty_result.is_empty(), "{:?}", empty_result); + } }