Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(daemon): add support for daemon on windows #2014

Merged
merged 14 commits into from
May 13, 2024
Merged
4 changes: 4 additions & 0 deletions crates/atuin-client/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ pub struct Daemon {

/// The path to the unix socket used by the daemon
pub socket_path: String,

// The port that should be used if using tcp (mainly on windows)
pub tcp_port: u64,
}

impl Default for Preview {
Expand All @@ -369,6 +372,7 @@ impl Default for Daemon {
enabled: false,
sync_frequency: 300,
socket_path: "".to_string(),
tcp_port: 2468,
YummyOreo marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down
16 changes: 14 additions & 2 deletions crates/atuin-daemon/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use eyre::{eyre, Result};
use tokio::net::UnixStream;
#[cfg(windows)]
use tokio::net::TcpStream;
use tonic::transport::{Channel, Endpoint, Uri};
use tower::service_fn;

#[cfg(unix)]
use tokio::net::UnixStream;

use atuin_client::history::History;

use crate::history::{
Expand All @@ -20,7 +24,15 @@ impl HistoryClient {
.connect_with_connector(service_fn(move |_: Uri| {
let path = path.to_string();

UnixStream::connect(path)
#[cfg(unix)]
{
UnixStream::connect(path)
}
#[cfg(windows)]
{
let url = format!("127.0.0.1:{}", settings.daemon.tcp_port);
TcpStream::connect(url)
}
}))
.await
.map_err(|_| eyre!("failed to connect to local atuin daemon. Is it running?"))?;
Expand Down
53 changes: 41 additions & 12 deletions crates/atuin-daemon/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use atuin_client::encryption;
use atuin_client::history::store::HistoryStore;
use atuin_client::record::sqlite_store::SqliteStore;
use atuin_client::settings::Settings;
use std::fmt::format;
use std::path::PathBuf;
use std::sync::Arc;
use time::OffsetDateTime;
Expand All @@ -13,8 +14,6 @@ use atuin_client::database::{Database, Sqlite as HistoryDatabase};
use atuin_client::history::{History, HistoryId};
use dashmap::DashMap;
use eyre::Result;
use tokio::net::UnixListener;
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{transport::Server, Request, Response, Status};

use crate::history::history_server::{History as HistorySvc, HistoryServer};
Expand Down Expand Up @@ -134,6 +133,7 @@ impl HistorySvc for HistoryService {
}
}

#[cfg(unix)]
async fn shutdown_signal(socket: PathBuf) {
let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to register sigterm handler");
Expand All @@ -150,6 +150,14 @@ async fn shutdown_signal(socket: PathBuf) {
eprintln!("Shutting down...");
}

#[cfg(windows)]
async fn shutdown_signal() {
tokio::signal::windows::ctrl_c()
.expect("failed to register signal handler")
.recv()
.await;
}

// break the above down when we end up with multiple services

/// Listen on a unix socket
Expand All @@ -168,12 +176,6 @@ pub async fn listen(

let history = HistoryService::new(history_store.clone(), history_db.clone());

let socket = settings.daemon.socket_path.clone();
let uds = UnixListener::bind(socket.clone())?;
let uds_stream = UnixListenerStream::new(uds);

tracing::info!("listening on unix socket {:?}", socket);

// start services
tokio::spawn(sync::worker(
settings.clone(),
Expand All @@ -182,10 +184,37 @@ pub async fn listen(
history_db,
));

Server::builder()
.add_service(HistoryServer::new(history))
.serve_with_incoming_shutdown(uds_stream, shutdown_signal(socket.into()))
.await?;
#[cfg(unix)]
YummyOreo marked this conversation as resolved.
Show resolved Hide resolved
{
use tokio::net::UnixListener;
use tokio_stream::wrappers::UnixListenerStream;

let socket = settings.daemon.socket_path.clone();

let uds = UnixListener::bind(socket.clone())?;
let uds_stream = UnixListenerStream::new(uds);

tracing::info!("listening on unix socket {:?}", socket);
Server::builder()
.add_service(HistoryServer::new(history))
.serve_with_incoming_shutdown(uds_stream, shutdown_signal(socket.into()))
.await?;
}

#[cfg(not(unix))]
{
use tokio::net::TcpListener;
use tokio_stream::wrappers::TcpListenerStream;

let url = format!("127.0.0.1:{}", settings.daemon.tcp_port);
let tcp = TcpListener::bind(url).await?;
let tcp_stream = TcpListenerStream::new(tcp);

Server::builder()
.add_service(HistoryServer::new(history))
.serve_with_incoming_shutdown(tcp_stream, shutdown_signal())
.await?;
}

Ok(())
}
Loading