Skip to content

Commit

Permalink
Expose Inner FlightServiceClient on FlightSqlServiceClient (#3551) (#…
Browse files Browse the repository at this point in the history
…3556)

* Remove unnecessary Mutex from FlightSqlServiceClient (#3551)

* Add inner and inner_mut

* Add into_inner
  • Loading branch information
tustvold authored Jan 18, 2023
1 parent 40837a8 commit 3ae1c72
Showing 1 changed file with 25 additions and 28 deletions.
53 changes: 25 additions & 28 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use crate::flight_service_client::FlightServiceClient;
Expand All @@ -45,7 +44,6 @@ use arrow_ipc::{root_as_message, MessageHeader};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use futures::{stream, TryStreamExt};
use prost::Message;
use tokio::sync::{Mutex, MutexGuard};
#[cfg(feature = "tls")]
use tonic::transport::{Certificate, ClientTlsConfig, Identity};
use tonic::transport::{Channel, Endpoint};
Expand All @@ -56,7 +54,7 @@ use tonic::Streaming;
#[derive(Debug, Clone)]
pub struct FlightSqlServiceClient {
token: Option<String>,
flight_client: Arc<Mutex<FlightServiceClient<Channel>>>,
flight_client: FlightServiceClient<Channel>,
}

/// A FlightSql protocol client that can run queries against FlightSql servers
Expand Down Expand Up @@ -124,16 +122,23 @@ impl FlightSqlServiceClient {
let flight_client = FlightServiceClient::new(channel);
FlightSqlServiceClient {
token: None,
flight_client: Arc::new(Mutex::new(flight_client)),
flight_client,
}
}

fn mut_client(
&mut self,
) -> Result<MutexGuard<FlightServiceClient<Channel>>, ArrowError> {
/// Return a reference to the underlying [`FlightServiceClient`]
pub fn inner(&self) -> &FlightServiceClient<Channel> {
&self.flight_client
}

/// Return a mutable reference to the underlying [`FlightServiceClient`]
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
&mut self.flight_client
}

/// Consume this client and return the underlying [`FlightServiceClient`]
pub fn into_inner(self) -> FlightServiceClient<Channel> {
self.flight_client
.try_lock()
.map_err(|_| ArrowError::IoError("Unable to lock client".to_string()))
}

async fn get_flight_info_for_command<M: ProstMessageExt>(
Expand All @@ -142,7 +147,7 @@ impl FlightSqlServiceClient {
) -> Result<FlightInfo, ArrowError> {
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let fi = self
.mut_client()?
.flight_client
.get_flight_info(descriptor)
.await
.map_err(status_to_arrow_error)?
Expand Down Expand Up @@ -174,7 +179,7 @@ impl FlightSqlServiceClient {
.map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
req.metadata_mut().insert("authorization", val);
let resp = self
.mut_client()?
.flight_client
.handshake(req)
.await
.map_err(|e| ArrowError::IoError(format!("Can't handshake {}", e)))?;
Expand Down Expand Up @@ -208,7 +213,7 @@ impl FlightSqlServiceClient {
let cmd = CommandStatementUpdate { query };
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let mut result = self
.mut_client()?
.flight_client
.do_put(stream::iter(vec![FlightData {
flight_descriptor: Some(descriptor),
..Default::default()
Expand Down Expand Up @@ -247,7 +252,7 @@ impl FlightSqlServiceClient {
ticket: Ticket,
) -> Result<Streaming<FlightData>, ArrowError> {
Ok(self
.mut_client()?
.flight_client
.do_get(ticket)
.await
.map_err(status_to_arrow_error)?
Expand Down Expand Up @@ -332,7 +337,7 @@ impl FlightSqlServiceClient {
req.metadata_mut().insert("authorization", val);
}
let mut result = self
.mut_client()?
.flight_client
.do_action(req)
.await
.map_err(status_to_arrow_error)?
Expand Down Expand Up @@ -369,7 +374,7 @@ impl FlightSqlServiceClient {
/// A PreparedStatement
#[derive(Debug, Clone)]
pub struct PreparedStatement<T> {
flight_client: Arc<Mutex<FlightServiceClient<T>>>,
flight_client: FlightServiceClient<T>,
parameter_binding: Option<RecordBatch>,
handle: Bytes,
dataset_schema: Schema,
Expand All @@ -378,13 +383,13 @@ pub struct PreparedStatement<T> {

impl PreparedStatement<Channel> {
pub(crate) fn new(
client: Arc<Mutex<FlightServiceClient<Channel>>>,
flight_client: FlightServiceClient<Channel>,
handle: impl Into<Bytes>,
dataset_schema: Schema,
parameter_schema: Schema,
) -> Self {
PreparedStatement {
flight_client: client,
flight_client,
parameter_binding: None,
handle: handle.into(),
dataset_schema,
Expand All @@ -399,7 +404,7 @@ impl PreparedStatement<Channel> {
};
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let result = self
.mut_client()?
.flight_client
.get_flight_info(descriptor)
.await
.map_err(status_to_arrow_error)?
Expand All @@ -414,7 +419,7 @@ impl PreparedStatement<Channel> {
};
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let mut result = self
.mut_client()?
.flight_client
.do_put(stream::iter(vec![FlightData {
flight_descriptor: Some(descriptor),
..Default::default()
Expand Down Expand Up @@ -463,20 +468,12 @@ impl PreparedStatement<Channel> {
body: cmd.as_any().encode_to_vec().into(),
};
let _ = self
.mut_client()?
.flight_client
.do_action(action)
.await
.map_err(status_to_arrow_error)?;
Ok(())
}

fn mut_client(
&mut self,
) -> Result<MutexGuard<FlightServiceClient<Channel>>, ArrowError> {
self.flight_client
.try_lock()
.map_err(|_| ArrowError::IoError("Unable to lock client".to_string()))
}
}

fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {
Expand Down

0 comments on commit 3ae1c72

Please sign in to comment.