Skip to content

Commit

Permalink
fix: Properly exit SimpleExecuteQueryStream on stream end (#2243)
Browse files Browse the repository at this point in the history
Closes #2242

Also moves the simple stuff to its own module to split up the code a
little bit.
  • Loading branch information
scsmithr authored Dec 11, 2023
1 parent 6b846d1 commit 3382a22
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 110 deletions.
2 changes: 1 addition & 1 deletion crates/glaredb/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pgsrv::auth::LocalAuthenticator;
use pgsrv::handler::{ProtocolHandler, ProtocolHandlerConfig};
use protogen::gen::rpcsrv::service::execution_service_server::ExecutionServiceServer;
use protogen::gen::rpcsrv::simple::simple_service_server::SimpleServiceServer;
use rpcsrv::handler::{RpcHandler, SimpleHandler};
use rpcsrv::{handler::RpcHandler, simple::SimpleHandler};
use sqlexec::engine::{Engine, EngineStorageConfig};
use std::collections::HashMap;
use std::net::SocketAddr;
Expand Down
1 change: 0 additions & 1 deletion crates/glaredb/tests/setup.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use assert_cmd::cmd::Command;
pub const DEFAULT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);

pub fn make_cli() -> Command {
Command::cargo_bin(env!("CARGO_PKG_NAME")).expect("Failed to find binary")
Expand Down
115 changes: 7 additions & 108 deletions crates/rpcsrv/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,23 @@ use crate::{
};
use async_trait::async_trait;
use dashmap::DashMap;
use datafusion::{arrow::ipc::writer::FileWriter as IpcFileWriter, variable::VarType};
use datafusion::{arrow::record_batch::RecordBatch, physical_plan::SendableRecordBatchStream};
use datafusion_ext::{
session_metrics::{BatchStreamWithMetricSender, QueryMetrics, SessionMetricsHandler},
vars::SessionVars,
use datafusion::arrow::ipc::writer::FileWriter as IpcFileWriter;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion_ext::session_metrics::{
BatchStreamWithMetricSender, QueryMetrics, SessionMetricsHandler,
};
use futures::{Stream, StreamExt};
use protogen::{
gen::rpcsrv::common,
gen::rpcsrv::service,
gen::rpcsrv::simple,
rpcsrv::types::{
service::{
DispatchAccessRequest, FetchCatalogRequest, FetchCatalogResponse,
InitializeSessionRequest, InitializeSessionResponse, PhysicalPlanExecuteRequest,
TableProviderResponse,
},
simple::{ExecuteQueryRequest, ExecuteQueryResponse, QueryResultError, QueryResultSuccess},
rpcsrv::types::service::{
DispatchAccessRequest, FetchCatalogRequest, FetchCatalogResponse, InitializeSessionRequest,
InitializeSessionResponse, PhysicalPlanExecuteRequest, TableProviderResponse,
},
};
use sqlexec::{
engine::{Engine, SessionStorageConfig},
remote::batch_stream::ExecutionBatchStream,
OperationInfo,
};
use std::{
collections::HashMap,
Expand Down Expand Up @@ -320,97 +313,3 @@ impl Stream for ExecutionResponseBatchStream {
}
}
}

/// The "simple query" rpc handler.
///
/// Note that this doesn't keep state about sessions, and session only last the
/// lifetime of a query.
pub struct SimpleHandler {
/// Core db engine for creating sessions.
engine: Arc<Engine>,
}

impl SimpleHandler {
pub fn new(engine: Arc<Engine>) -> SimpleHandler {
SimpleHandler { engine }
}
}

#[async_trait]
impl simple::simple_service_server::SimpleService for SimpleHandler {
type ExecuteQueryStream = SimpleExecuteQueryStream;

async fn execute_query(
&self,
request: Request<simple::ExecuteQueryRequest>,
) -> Result<Response<Self::ExecuteQueryStream>, Status> {
// Note that this creates a local session independent of any "remote"
// sessions. This provides full session capabilities (e.g. parsing sql,
// use the dist exec scheduler).
//
// This may be something we change (into what?)
let request = ExecuteQueryRequest::try_from(request.into_inner())?;
let vars = SessionVars::default().with_database_id(request.database_id, VarType::System);
let mut session = self
.engine
.new_local_session_context(vars, request.config.into())
.await
.map_err(RpcsrvError::from)?;

let plan = session
.sql_to_lp(&request.query_text)
.await
.map_err(RpcsrvError::from)?;
let plan = plan.try_into_datafusion_plan().map_err(RpcsrvError::from)?;
let physical = session
.create_physical_plan(plan, &OperationInfo::default())
.await
.map_err(RpcsrvError::from)?;
let stream = session
.execute_physical(physical)
.await
.map_err(RpcsrvError::from)?;

Ok(Response::new(SimpleExecuteQueryStream { inner: stream }))
}
}

/// Stream implementation for sending the results of a simple query request.
// TODO: Only supports a single response stream (we can do many).
// TODO: Only provides "success" or "error" info to the client. Doesn't return
// the actual results.
pub struct SimpleExecuteQueryStream {
inner: SendableRecordBatchStream,
}

impl Stream for SimpleExecuteQueryStream {
type Item = Result<simple::ExecuteQueryResponse, Status>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match self.inner.poll_next_unpin(cx) {
// Drop the result, we're not sending it back to the client.
// And continue to the next loop iteration.
Poll::Ready(Some(Ok(_))) => (),

// Stream completed without error, return success to the client.
Poll::Ready(None) => {
return Poll::Ready(Some(Ok(ExecuteQueryResponse::SuccessResult(
QueryResultSuccess {},
)
.into())))
}

// We got an error, send it back to the client.
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Ok(ExecuteQueryResponse::ErrorResult(
QueryResultError { msg: e.to_string() },
)
.into())))
}

Poll::Pending => return Poll::Pending,
}
}
}
}
1 change: 1 addition & 0 deletions crates/rpcsrv/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod errors;
pub mod handler;
pub mod proxy;
pub mod simple;

mod session;
162 changes: 162 additions & 0 deletions crates/rpcsrv/src/simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
use crate::errors::{Result, RpcsrvError};
use async_trait::async_trait;
use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::variable::VarType;
use datafusion_ext::vars::SessionVars;
use futures::{Stream, StreamExt};
use protogen::{
gen::rpcsrv::simple,
rpcsrv::types::simple::{
ExecuteQueryRequest, ExecuteQueryResponse, QueryResultError, QueryResultSuccess,
},
};
use sqlexec::{engine::Engine, OperationInfo};
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tonic::{Request, Response, Status};

/// The "simple query" rpc handler.
///
/// Note that this doesn't keep state about sessions, and sessions only last the
/// lifetime of a query.
pub struct SimpleHandler {
/// Core db engine for creating sessions.
engine: Arc<Engine>,
}

impl SimpleHandler {
pub fn new(engine: Arc<Engine>) -> SimpleHandler {
SimpleHandler { engine }
}
}

#[async_trait]
impl simple::simple_service_server::SimpleService for SimpleHandler {
type ExecuteQueryStream = SimpleExecuteQueryStream;

async fn execute_query(
&self,
request: Request<simple::ExecuteQueryRequest>,
) -> Result<Response<Self::ExecuteQueryStream>, Status> {
// Note that this creates a local session independent of any "remote"
// sessions. This provides full session capabilities (e.g. parsing sql,
// use the dist exec scheduler).
//
// This may be something we change (into what?)
let request = ExecuteQueryRequest::try_from(request.into_inner())?;
let vars = SessionVars::default().with_database_id(request.database_id, VarType::System);
let mut session = self
.engine
.new_local_session_context(vars, request.config.into())
.await
.map_err(RpcsrvError::from)?;

let plan = session
.sql_to_lp(&request.query_text)
.await
.map_err(RpcsrvError::from)?;
let plan = plan.try_into_datafusion_plan().map_err(RpcsrvError::from)?;
let physical = session
.create_physical_plan(plan, &OperationInfo::default())
.await
.map_err(RpcsrvError::from)?;
let stream = session
.execute_physical(physical)
.await
.map_err(RpcsrvError::from)?;

Ok(Response::new(SimpleExecuteQueryStream::new(stream)))
}
}

/// Stream implementation for sending the results of a simple query request.
// TODO: Only supports a single response stream (we can do many).
// TODO: Only provides "success" or "error" info to the client. Doesn't return
// the actual results.
pub struct SimpleExecuteQueryStream {
inner: SendableRecordBatchStream,
done: bool,
}

impl SimpleExecuteQueryStream {
pub fn new(stream: SendableRecordBatchStream) -> SimpleExecuteQueryStream {
SimpleExecuteQueryStream {
inner: stream,
done: false,
}
}
}

impl Stream for SimpleExecuteQueryStream {
type Item = Result<simple::ExecuteQueryResponse, Status>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}

loop {
match self.inner.poll_next_unpin(cx) {
// Drop the result, we're not sending it back to the client.
// And continue to the next loop iteration.
Poll::Ready(Some(Ok(_))) => (),

// Stream completed without error, return success to the client.
Poll::Ready(None) => {
self.done = true; // Make sure we properly signal stream end on next poll.
return Poll::Ready(Some(Ok(ExecuteQueryResponse::SuccessResult(
QueryResultSuccess {},
)
.into())));
}

// We got an error, send it back to the client.
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Ok(ExecuteQueryResponse::ErrorResult(
QueryResultError { msg: e.to_string() },
)
.into())))
}

Poll::Pending => return Poll::Pending,
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion::{
arrow::{datatypes::Schema, record_batch::RecordBatch},
physical_plan::stream::RecordBatchStreamAdapter,
};
use futures::stream::{self, StreamExt};

#[tokio::test]
async fn simple_stream_exits() {
// https://github.com/GlareDB/glaredb/issues/2242

let inner = Box::pin(RecordBatchStreamAdapter::new(
Arc::new(Schema::empty()),
stream::iter([
Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))),
Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))),
]),
));

let stream = SimpleExecuteQueryStream::new(inner);
let output = stream
.map(|result| result.unwrap())
.collect::<Vec<_>>()
.await;

let expected: &[simple::ExecuteQueryResponse] =
&[ExecuteQueryResponse::SuccessResult(QueryResultSuccess {}).into()];

assert_eq!(output, expected)
}
}

0 comments on commit 3382a22

Please sign in to comment.