diff --git a/src/ipc/server.rs b/src/ipc/server.rs index 6706928aa..9d569cc1b 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -8,15 +8,18 @@ use std::sync::{Arc, Mutex}; use std::{env, io, process}; use anyhow::Context; -use async_channel::{Receiver, Sender, TrySendError}; +use async_channel::{Sender, TrySendError}; use calloop::futures::Scheduler; use calloop::io::Async; use directories::BaseDirs; -use futures_util::io::{AsyncReadExt, BufReader}; -use futures_util::{select_biased, AsyncBufReadExt, AsyncWrite, AsyncWriteExt, FutureExt as _}; +use futures_util::io::{AsyncBufReadExt, AsyncReadExt, BufReader, ReadHalf}; +use futures_util::io::{AsyncWriteExt, WriteHalf}; +use futures_util::stream::{AbortHandle, Abortable, StreamExt}; +use futures_util::{pin_mut, TryFutureExt}; use niri_config::OutputName; use niri_ipc::state::{EventStreamState, EventStreamStatePart as _}; use niri_ipc::{Event, KeyboardLayouts, OutputConfigChanged, Reply, Request, Response, Workspace}; +use serde::Serialize; use smithay::desktop::layer_map_for_output; use smithay::reexports::calloop::generic::Generic; use smithay::reexports::calloop::{Interest, LoopHandle, Mode, PostAction}; @@ -50,15 +53,9 @@ struct ClientCtx { event_stream_state: Rc>, } -struct EventStreamClient { - events: Receiver, - disconnect: Receiver<()>, - write: Box, -} - struct EventStreamSender { events: Sender, - disconnect: Sender<()>, + abort: AbortHandle, } impl IpcServer { @@ -123,7 +120,7 @@ impl IpcServer { for idx in to_remove.into_iter().rev() { let stream = streams.swap_remove(idx); - let _ = stream.disconnect.send_blocking(()); + stream.abort.abort(); } } } @@ -177,80 +174,87 @@ fn on_new_ipc_client(state: &mut State, stream: UnixStream) { } async fn handle_client(ctx: ClientCtx, stream: Async<'static, UnixStream>) -> anyhow::Result<()> { - let (read, mut write) = stream.split(); - let mut buf = String::new(); - - // Read a single line to allow extensibility in the future to keep reading. - BufReader::new(read) - .read_line(&mut buf) - .await - .context("error reading request")?; - - let request = serde_json::from_str(&buf) - .context("error parsing request") - .map_err(|err| err.to_string()); - let requested_error = matches!(request, Ok(Request::ReturnError)); - let requested_event_stream = matches!(request, Ok(Request::EventStream)); - - let reply = match request { - Ok(request) => process(&ctx, request).await, - Err(err) => Err(err), + let (read, write) = stream.split(); + + let mut input_buf = Vec::::new(); + let mut read: BufReader>> = BufReader::new(read); + let mut write = JsonWriter { + write, + buf: Vec::new(), }; - if let Err(err) = &reply { - if !requested_error { - warn!("error processing IPC request: {err:?}"); + loop { + input_buf.clear(); + let res = read.read_until(b'\n', &mut input_buf).await; + match res { + Ok(0) => return Ok(()), + Ok(_) => {} + Err(err) if err.kind() == io::ErrorKind::BrokenPipe => return Ok(()), + Err(err) => { + return Err(err).context("error parsing request"); + } } - } - - let mut buf = serde_json::to_vec(&reply).context("error formatting reply")?; - buf.push(b'\n'); - write.write_all(&buf).await.context("error writing reply")?; - if requested_event_stream { - let (events_tx, events_rx) = async_channel::bounded(EVENT_STREAM_BUFFER_SIZE); - let (disconnect_tx, disconnect_rx) = async_channel::bounded(1); - - // Spawn a task for the client. - let client = EventStreamClient { - events: events_rx, - disconnect: disconnect_rx, - write: Box::new(write) as _, - }; - let future = async move { - if let Err(err) = handle_event_stream_client(client).await { - warn!("error handling IPC event stream client: {err:?}"); + let request = serde_json::from_slice(&input_buf) + .context("error parsing request") + .map_err(|err| err.to_string()); + let requested_error = matches!(request, Ok(Request::ReturnError)); + + let reply = match request { + Ok(request) => { + let reply; + write = { + let mut write = Some(write); + reply = process_request(&ctx, request, &mut write).await; + if let Some(write) = write { + write + } else { + // The writer was moved out to `process_request`, which means we are done + // reading user requests. + return Ok(()); + } + }; + reply } + Err(err) => Err(err), }; - if let Err(err) = ctx.scheduler.schedule(future) { - warn!("error scheduling IPC event stream future: {err:?}"); - } - // Send the initial state. - { - let state = ctx.event_stream_state.borrow(); - for event in state.replicate() { - events_tx - .try_send(event) - .expect("initial event burst had more events than buffer size"); + if let Err(err) = &reply { + if !requested_error { + warn!("error processing IPC request: {err:?}"); } } - - // Add it to the list. - { - let mut streams = ctx.event_streams.borrow_mut(); - let sender = EventStreamSender { - events: events_tx, - disconnect: disconnect_tx, - }; - streams.push(sender); - } + write.json_line(&reply).await?; } +} + +struct JsonWriter { + write: WriteHalf>, + buf: Vec, +} + +impl JsonWriter { + async fn json_line(&mut self, val: &T) -> anyhow::Result<()> { + let Self { write, buf } = self; - Ok(()) + buf.clear(); + serde_json::to_writer(&mut *buf, val).context("error formatting reply")?; + buf.push(b'\n'); + write.write_all(buf).await.context("error writing reply")?; + write.flush().await?; + Ok(()) + } } -async fn process(ctx: &ClientCtx, request: Request) -> Reply { +/// Process client request. +/// +/// +/// +async fn process_request( + ctx: &ClientCtx, + request: Request, + writer: &mut Option, +) -> Reply { let response = match request { Request::ReturnError => return Err(String::from("example compositor error")), Request::Version => Response::Version(version()), @@ -384,37 +388,58 @@ async fn process(ctx: &ClientCtx, request: Request) -> Reply { let output = result.map_err(|_| String::from("error getting active output info"))?; Response::FocusedOutput(output) } - Request::EventStream => Response::Handled, - }; + Request::EventStream => { + let mut writer = writer + .take() + .expect("this function is only allowed to be called with `Some(writer)`"); + let (events_tx, events_rx) = async_channel::bounded(EVENT_STREAM_BUFFER_SIZE); + let (abort_handle, abort_registration) = AbortHandle::new_pair(); - Ok(response) -} + let stream = Abortable::new(events_rx, abort_registration); -async fn handle_event_stream_client(client: EventStreamClient) -> anyhow::Result<()> { - let EventStreamClient { - events, - disconnect, - mut write, - } = client; + // Send the initial state. + { + let state = ctx.event_stream_state.borrow(); + for event in state.replicate() { + events_tx + .try_send(event) + .expect("initial event burst had more events than buffer size"); + } + } - while let Ok(event) = events.recv().await { - let mut buf = serde_json::to_vec(&event).context("error formatting event")?; - buf.push(b'\n'); + // Add it to the list. + { + let mut streams = ctx.event_streams.borrow_mut(); + streams.push(EventStreamSender { + events: events_tx, + abort: abort_handle, + }); + } - let res = select_biased! { - _ = disconnect.recv().fuse() => return Ok(()), - res = write.write_all(&buf).fuse() => res, - }; + writer + .json_line(&Result::<_, ()>::Ok(Response::Handled)) + .await + .map_err(|_| String::from("failed to write `Response::Handled`"))?; + + ctx.scheduler + .schedule( + async move { + pin_mut!(stream); + while let Some(event) = stream.next().await { + writer.json_line(&event).await?; + } + anyhow::Ok(()) + } + .unwrap_or_else(|err| warn!("error handling IPC event stream client: {err:?}")), + ) + .unwrap_or_else(|err| warn!("error handling IPC event stream client: {err:?}")); - match res { - Ok(()) => (), - // Normal client disconnection. - Err(err) if err.kind() == io::ErrorKind::BrokenPipe => return Ok(()), - res @ Err(_) => res.context("error writing event")?, + // Note that this value will be ignored in the further execution + Response::Handled } - } + }; - Ok(()) + Ok(response) } fn make_ipc_window(mapped: &Mapped, workspace_id: Option) -> niri_ipc::Window {