diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index ecc121d985a0..5c5f84b3d15a 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -19,7 +19,6 @@ use base64::prelude::BASE64_STANDARD; use base64::Engine; use bytes::Bytes; use std::collections::HashMap; -use std::sync::Arc; use std::time::Duration; use crate::flight_service_client::FlightServiceClient; @@ -45,7 +44,6 @@ use arrow_ipc::{root_as_message, MessageHeader}; use arrow_schema::{ArrowError, Schema, SchemaRef}; use futures::{stream, TryStreamExt}; use prost::Message; -use tokio::sync::{Mutex, MutexGuard}; #[cfg(feature = "tls")] use tonic::transport::{Certificate, ClientTlsConfig, Identity}; use tonic::transport::{Channel, Endpoint}; @@ -56,7 +54,7 @@ use tonic::Streaming; #[derive(Debug, Clone)] pub struct FlightSqlServiceClient { token: Option, - flight_client: Arc>>, + flight_client: FlightServiceClient, } /// A FlightSql protocol client that can run queries against FlightSql servers @@ -124,16 +122,23 @@ impl FlightSqlServiceClient { let flight_client = FlightServiceClient::new(channel); FlightSqlServiceClient { token: None, - flight_client: Arc::new(Mutex::new(flight_client)), + flight_client, } } - fn mut_client( - &mut self, - ) -> Result>, ArrowError> { + /// Return a reference to the underlying [`FlightServiceClient`] + pub fn inner(&self) -> &FlightServiceClient { + &self.flight_client + } + + /// Return a mutable reference to the underlying [`FlightServiceClient`] + pub fn inner_mut(&mut self) -> &mut FlightServiceClient { + &mut self.flight_client + } + + /// Consume this client and return the underlying [`FlightServiceClient`] + pub fn into_inner(self) -> FlightServiceClient { self.flight_client - .try_lock() - .map_err(|_| ArrowError::IoError("Unable to lock client".to_string())) } async fn get_flight_info_for_command( @@ -142,7 +147,7 @@ impl FlightSqlServiceClient { ) -> Result { let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); let fi = self - .mut_client()? + .flight_client .get_flight_info(descriptor) .await .map_err(status_to_arrow_error)? @@ -174,7 +179,7 @@ impl FlightSqlServiceClient { .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?; req.metadata_mut().insert("authorization", val); let resp = self - .mut_client()? + .flight_client .handshake(req) .await .map_err(|e| ArrowError::IoError(format!("Can't handshake {}", e)))?; @@ -208,7 +213,7 @@ impl FlightSqlServiceClient { let cmd = CommandStatementUpdate { query }; let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); let mut result = self - .mut_client()? + .flight_client .do_put(stream::iter(vec![FlightData { flight_descriptor: Some(descriptor), ..Default::default() @@ -247,7 +252,7 @@ impl FlightSqlServiceClient { ticket: Ticket, ) -> Result, ArrowError> { Ok(self - .mut_client()? + .flight_client .do_get(ticket) .await .map_err(status_to_arrow_error)? @@ -332,7 +337,7 @@ impl FlightSqlServiceClient { req.metadata_mut().insert("authorization", val); } let mut result = self - .mut_client()? + .flight_client .do_action(req) .await .map_err(status_to_arrow_error)? @@ -369,7 +374,7 @@ impl FlightSqlServiceClient { /// A PreparedStatement #[derive(Debug, Clone)] pub struct PreparedStatement { - flight_client: Arc>>, + flight_client: FlightServiceClient, parameter_binding: Option, handle: Bytes, dataset_schema: Schema, @@ -378,13 +383,13 @@ pub struct PreparedStatement { impl PreparedStatement { pub(crate) fn new( - client: Arc>>, + flight_client: FlightServiceClient, handle: impl Into, dataset_schema: Schema, parameter_schema: Schema, ) -> Self { PreparedStatement { - flight_client: client, + flight_client, parameter_binding: None, handle: handle.into(), dataset_schema, @@ -399,7 +404,7 @@ impl PreparedStatement { }; let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); let result = self - .mut_client()? + .flight_client .get_flight_info(descriptor) .await .map_err(status_to_arrow_error)? @@ -414,7 +419,7 @@ impl PreparedStatement { }; let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); let mut result = self - .mut_client()? + .flight_client .do_put(stream::iter(vec![FlightData { flight_descriptor: Some(descriptor), ..Default::default() @@ -463,20 +468,12 @@ impl PreparedStatement { body: cmd.as_any().encode_to_vec().into(), }; let _ = self - .mut_client()? + .flight_client .do_action(action) .await .map_err(status_to_arrow_error)?; Ok(()) } - - fn mut_client( - &mut self, - ) -> Result>, ArrowError> { - self.flight_client - .try_lock() - .map_err(|_| ArrowError::IoError("Unable to lock client".to_string())) - } } fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {