Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Inner FlightServiceClient on FlightSqlServiceClient (#3551) #3556

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use crate::flight_service_client::FlightServiceClient;
Expand All @@ -45,7 +44,6 @@ use arrow_ipc::{root_as_message, MessageHeader};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use futures::{stream, TryStreamExt};
use prost::Message;
use tokio::sync::{Mutex, MutexGuard};
#[cfg(feature = "tls")]
use tonic::transport::{Certificate, ClientTlsConfig, Identity};
use tonic::transport::{Channel, Endpoint};
Expand All @@ -56,7 +54,7 @@ use tonic::Streaming;
#[derive(Debug, Clone)]
pub struct FlightSqlServiceClient {
token: Option<String>,
flight_client: Arc<Mutex<FlightServiceClient<Channel>>>,
flight_client: FlightServiceClient<Channel>,
}

/// A FlightSql protocol client that can run queries against FlightSql servers
Expand Down Expand Up @@ -124,16 +122,23 @@ impl FlightSqlServiceClient {
let flight_client = FlightServiceClient::new(channel);
FlightSqlServiceClient {
token: None,
flight_client: Arc::new(Mutex::new(flight_client)),
flight_client,
}
}

fn mut_client(
&mut self,
) -> Result<MutexGuard<FlightServiceClient<Channel>>, ArrowError> {
/// Return a reference to the underlying [`FlightServiceClient`]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also recommend an into_inner() function if possible

pub fn inner(&self) -> &FlightServiceClient<Channel> {
&self.flight_client
}

/// Return a mutable reference to the underlying [`FlightServiceClient`]
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
&mut self.flight_client
}

/// Consume this client and return the underlying [`FlightServiceClient`]
pub fn into_inner(self) -> FlightServiceClient<Channel> {
self.flight_client
.try_lock()
.map_err(|_| ArrowError::IoError("Unable to lock client".to_string()))
}

async fn get_flight_info_for_command<M: ProstMessageExt>(
Expand All @@ -142,7 +147,7 @@ impl FlightSqlServiceClient {
) -> Result<FlightInfo, ArrowError> {
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let fi = self
.mut_client()?
.flight_client
.get_flight_info(descriptor)
.await
.map_err(status_to_arrow_error)?
Expand Down Expand Up @@ -174,7 +179,7 @@ impl FlightSqlServiceClient {
.map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
req.metadata_mut().insert("authorization", val);
let resp = self
.mut_client()?
.flight_client
.handshake(req)
.await
.map_err(|e| ArrowError::IoError(format!("Can't handshake {}", e)))?;
Expand Down Expand Up @@ -208,7 +213,7 @@ impl FlightSqlServiceClient {
let cmd = CommandStatementUpdate { query };
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let mut result = self
.mut_client()?
.flight_client
.do_put(stream::iter(vec![FlightData {
flight_descriptor: Some(descriptor),
..Default::default()
Expand Down Expand Up @@ -247,7 +252,7 @@ impl FlightSqlServiceClient {
ticket: Ticket,
) -> Result<Streaming<FlightData>, ArrowError> {
Ok(self
.mut_client()?
.flight_client
.do_get(ticket)
.await
.map_err(status_to_arrow_error)?
Expand Down Expand Up @@ -332,7 +337,7 @@ impl FlightSqlServiceClient {
req.metadata_mut().insert("authorization", val);
}
let mut result = self
.mut_client()?
.flight_client
.do_action(req)
.await
.map_err(status_to_arrow_error)?
Expand Down Expand Up @@ -369,7 +374,7 @@ impl FlightSqlServiceClient {
/// A PreparedStatement
#[derive(Debug, Clone)]
pub struct PreparedStatement<T> {
flight_client: Arc<Mutex<FlightServiceClient<T>>>,
flight_client: FlightServiceClient<T>,
parameter_binding: Option<RecordBatch>,
handle: Bytes,
dataset_schema: Schema,
Expand All @@ -378,13 +383,13 @@ pub struct PreparedStatement<T> {

impl PreparedStatement<Channel> {
pub(crate) fn new(
client: Arc<Mutex<FlightServiceClient<Channel>>>,
flight_client: FlightServiceClient<Channel>,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Channel is cheaply cloneable, so there is no reason for the additional Arc<Mutex>

See https://docs.rs/tonic/latest/tonic/transport/struct.Channel.html#multiplexing-requests

handle: impl Into<Bytes>,
dataset_schema: Schema,
parameter_schema: Schema,
) -> Self {
PreparedStatement {
flight_client: client,
flight_client,
parameter_binding: None,
handle: handle.into(),
dataset_schema,
Expand All @@ -399,7 +404,7 @@ impl PreparedStatement<Channel> {
};
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let result = self
.mut_client()?
.flight_client
.get_flight_info(descriptor)
.await
.map_err(status_to_arrow_error)?
Expand All @@ -414,7 +419,7 @@ impl PreparedStatement<Channel> {
};
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let mut result = self
.mut_client()?
.flight_client
.do_put(stream::iter(vec![FlightData {
flight_descriptor: Some(descriptor),
..Default::default()
Expand Down Expand Up @@ -463,20 +468,12 @@ impl PreparedStatement<Channel> {
body: cmd.as_any().encode_to_vec().into(),
};
let _ = self
.mut_client()?
.flight_client
.do_action(action)
.await
.map_err(status_to_arrow_error)?;
Ok(())
}

fn mut_client(
&mut self,
) -> Result<MutexGuard<FlightServiceClient<Channel>>, ArrowError> {
self.flight_client
.try_lock()
.map_err(|_| ArrowError::IoError("Unable to lock client".to_string()))
}
}

fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {
Expand Down