diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index bdd51dda4f9f..fe1292fcff6e 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::task::Poll; + use crate::{ decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, @@ -24,8 +26,9 @@ use arrow_schema::Schema; use bytes::Bytes; use futures::{ future::ready, + ready, stream::{self, BoxStream}, - Stream, StreamExt, TryStreamExt, + FutureExt, Stream, StreamExt, TryStreamExt, }; use tonic::{metadata::MetadataMap, transport::Channel}; @@ -262,6 +265,15 @@ impl FlightClient { /// [`Stream`](futures::Stream) of [`FlightData`] and returning a /// stream of [`PutResult`]. /// + /// # Note + /// + /// The input stream is [`Result`] so that this can be connected + /// to a streaming data source, such as [`FlightDataEncoder`](crate::encode::FlightDataEncoder), + /// without having to buffer. If the input stream returns an error + /// that error will not be sent to the server, instead it will be + /// placed into the result stream and the server connection + /// terminated. + /// /// # Example: /// ```no_run /// # async fn run() { @@ -279,9 +291,7 @@ impl FlightClient { /// /// // encode the batch as a stream of `FlightData` /// let flight_data_stream = FlightDataEncoderBuilder::new() - /// .build(futures::stream::iter(vec![Ok(batch)])) - /// // data encoder return Results, but do_put requires FlightData - /// .map(|batch|batch.unwrap()); + /// .build(futures::stream::iter(vec![Ok(batch)])); /// /// // send the stream and get the results as `PutResult` /// let response: Vec= client @@ -293,20 +303,40 @@ impl FlightClient { /// .expect("error calling do_put"); /// # } /// ``` - pub async fn do_put + Send + 'static>( + pub async fn do_put> + Send + 'static>( &mut self, request: S, ) -> Result>> { - let request = self.make_request(request); - - let response = self - .inner - .do_put(request) - .await? - .into_inner() - .map_err(FlightError::Tonic); + let (sender, mut receiver) = futures::channel::oneshot::channel(); + + // Intercepts client errors and sends them to the oneshot channel above + let mut request = Box::pin(request); // Pin to heap + let mut sender = Some(sender); // Wrap into Option so can be taken + let request_stream = futures::stream::poll_fn(move |cx| { + Poll::Ready(match ready!(request.poll_next_unpin(cx)) { + Some(Ok(data)) => Some(data), + Some(Err(e)) => { + let _ = sender.take().unwrap().send(e); + None + } + None => None, + }) + }); + + let request = self.make_request(request_stream); + let mut response_stream = self.inner.do_put(request).await?.into_inner(); + + // Forwards errors from the error oneshot with priority over responses from server + let error_stream = futures::stream::poll_fn(move |cx| { + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + } + let next = ready!(response_stream.poll_next_unpin(cx)); + Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) + }); - Ok(response.boxed()) + // combine the response from the server and any error from the client + Ok(error_stream.boxed()) } /// Make a `DoExchange` call to the server with the provided diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index ab1cfa1fb053..ed928a52c99a 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -248,8 +248,10 @@ async fn test_do_put() { test_server .set_do_put_response(expected_response.clone().into_iter().map(Ok).collect()); + let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); + let response_stream = client - .do_put(futures::stream::iter(input_flight_data.clone())) + .do_put(input_stream) .await .expect("error making request"); @@ -266,15 +268,15 @@ async fn test_do_put() { } #[tokio::test] -async fn test_do_put_error() { +async fn test_do_put_error_server() { do_test(|test_server, mut client| async move { client.add_header("foo-header", "bar-header-value").unwrap(); let input_flight_data = test_flight_data().await; - let response = client - .do_put(futures::stream::iter(input_flight_data.clone())) - .await; + let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); + + let response = client.do_put(input_stream).await; let response = match response { Ok(_) => panic!("unexpected success"), Err(e) => e, @@ -290,7 +292,7 @@ async fn test_do_put_error() { } #[tokio::test] -async fn test_do_put_error_stream() { +async fn test_do_put_error_stream_server() { do_test(|test_server, mut client| async move { client.add_header("foo-header", "bar-header-value").unwrap(); @@ -307,8 +309,10 @@ async fn test_do_put_error_stream() { test_server.set_do_put_response(response); + let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); + let response_stream = client - .do_put(futures::stream::iter(input_flight_data.clone())) + .do_put(input_stream) .await .expect("error making request"); @@ -326,6 +330,87 @@ async fn test_do_put_error_stream() { .await; } +#[tokio::test] +async fn test_do_put_error_client() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::invalid_argument("bad arg: client"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e.clone(), + ))])); + + // server responds with one good message + let response = vec![Ok(PutResult { + app_metadata: Bytes::from("foo-metadata"), + })]; + test_server.set_do_put_response(response); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client + expect_status(response, e); + // server still got the request messages until the client sent the error + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_put_error_client_and_server() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e_client = Status::invalid_argument("bad arg: client"); + let e_server = Status::invalid_argument("bad arg: server"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e_client.clone(), + ))])); + + // server responds with an error (e.g. because it got truncated data) + let response = vec![Err(e_server)]; + test_server.set_do_put_response(response); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client (not the server) + expect_status(response, e_client); + // server still got the request messages until the client sent the error + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + #[tokio::test] async fn test_do_exchange() { do_test(|test_server, mut client| async move {