Skip to content

Commit

Permalink
bug(reattach): connection is lost and session id is empty (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrusso8 authored Dec 17, 2024
1 parent 7e55ffb commit 57dcd2d
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions crates/connect/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use arrow_ipc::reader::StreamReader;

use uuid::Uuid;

use crate::errors::SparkError;

mod builder;
Expand All @@ -30,9 +32,6 @@ pub type SparkClient = SparkConnectClient<HeadersMiddleware<Channel>>;
#[allow(dead_code)]
#[derive(Default, Debug, Clone)]
pub(crate) struct ResponseHandler {
session_id: Option<String>,
operation_id: Option<String>,
response_id: Option<String>,
metrics: Option<spark::execute_plan_response::Metrics>,
observed_metrics: Option<spark::execute_plan_response::ObservedMetrics>,
pub(crate) schema: Option<spark::DataType>,
Expand Down Expand Up @@ -67,6 +66,9 @@ pub(crate) struct AnalyzeHandler {
pub struct SparkConnectClient<T> {
stub: Arc<RwLock<SparkConnectServiceClient<T>>>,
builder: ChannelBuilder,
session_id: String,
operation_id: Option<String>,
response_id: Option<String>,
pub(crate) handler: ResponseHandler,
pub(crate) analyzer: AnalyzeHandler,
pub(crate) user_context: Option<spark::UserContext>,
Expand All @@ -83,10 +85,14 @@ where
{
pub fn new(stub: Arc<RwLock<SparkConnectServiceClient<T>>>, builder: ChannelBuilder) -> Self {
let user_ref = builder.user_id.clone().unwrap_or("".to_string());
let session_id = builder.session_id.to_string();

SparkConnectClient {
stub,
builder,
session_id,
operation_id: None,
response_id: None,
handler: ResponseHandler::default(),
analyzer: AnalyzeHandler::default(),
user_context: Some(spark::UserContext {
Expand All @@ -101,7 +107,7 @@ where

/// Session ID
pub fn session_id(&self) -> String {
self.builder.session_id.to_string()
self.session_id.clone()
}

/// Change the reattachable execute value
Expand All @@ -127,11 +133,15 @@ where
vec![]
}

pub fn execute_plan_request_with_metadata(&self) -> spark::ExecutePlanRequest {
pub fn execute_plan_request_with_metadata(&mut self) -> spark::ExecutePlanRequest {
let operation_id = Uuid::new_v4().to_string();

self.operation_id = Some(operation_id.clone());

spark::ExecutePlanRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
operation_id: None,
operation_id: Some(operation_id),
plan: None,
client_type: self.builder.user_agent.clone(),
request_options: self.request_options(),
Expand Down Expand Up @@ -173,11 +183,11 @@ where
let mut client = self.stub.write().await;

let req = spark::ReattachExecuteRequest {
session_id: self.handler.session_id.clone().unwrap(),
session_id: self.session_id(),
user_context: self.user_context.clone(),
operation_id: self.handler.operation_id.clone().unwrap(),
operation_id: self.operation_id.clone().unwrap(),
client_type: self.builder.user_agent.clone(),
last_response_id: self.handler.response_id.clone(),
last_response_id: self.response_id.clone(),
};

let mut stream = client.reattach_execute(req).await?.into_inner();
Expand Down Expand Up @@ -208,7 +218,7 @@ where
None
}
Err(err) => {
if self.use_reattachable_execute && self.handler.response_id.is_some() {
if self.use_reattachable_execute && self.response_id.is_some() {
self.release_until().await?;
}
return Err(err.into());
Expand All @@ -220,7 +230,7 @@ where

async fn release_until(&mut self) -> Result<(), SparkError> {
let release_until = spark::release_execute_request::ReleaseUntil {
response_id: self.handler.response_id.clone().unwrap(),
response_id: self.response_id.clone().unwrap(),
};

self.release_execute(Some(spark::release_execute_request::Release::ReleaseUntil(
Expand All @@ -245,9 +255,9 @@ where
let mut client = self.stub.write().await;

let req = spark::ReleaseExecuteRequest {
session_id: self.handler.session_id.clone().unwrap(),
session_id: self.session_id(),
user_context: self.user_context.clone(),
operation_id: self.handler.operation_id.clone().unwrap(),
operation_id: self.operation_id.clone().unwrap(),
client_type: self.builder.user_agent.clone(),
release,
};
Expand Down Expand Up @@ -375,9 +385,8 @@ where
fn handle_response(&mut self, resp: spark::ExecutePlanResponse) -> Result<(), SparkError> {
self.validate_session(&resp.session_id)?;

self.handler.session_id = Some(resp.session_id);
self.handler.operation_id = Some(resp.operation_id);
self.handler.response_id = Some(resp.response_id);
self.operation_id = Some(resp.operation_id);
self.response_id = Some(resp.response_id);

if let Some(schema) = &resp.schema {
self.handler.schema = Some(schema.clone());
Expand Down

0 comments on commit 57dcd2d

Please sign in to comment.