Skip to content

Commit

Permalink
feat: ctrl c to kill current query
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li committed Jan 6, 2025
1 parent 7d49de3 commit 560fddf
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 7 deletions.
1 change: 1 addition & 0 deletions cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ arrow = { workspace = true }
async-recursion = "1.1.0"
async-trait = "0.1"
clap = { version = "4.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.3", features = ["termination"] }
comfy-table = "7.1"
csv = "1.3"
databend-common-ast = "0.1.3"
Expand Down
23 changes: 23 additions & 0 deletions cli/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

use std::fmt::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::{collections::HashSet, env};

use anyhow::{anyhow, Result};
Expand All @@ -31,6 +33,8 @@ use crate::{
web::set_data,
};

pub(crate) const INTERRUPTED_MESSAGE: &str = "Interrupted by Ctrl+C";

#[async_trait::async_trait]
pub trait ChunkDisplay {
async fn display(&mut self) -> Result<ServerStats>;
Expand All @@ -49,6 +53,7 @@ pub struct FormatDisplay<'a> {
progress: Option<ProgressBar>,
start: Instant,
stats: Option<ServerStats>,
interrupted: Arc<AtomicBool>,
}

impl<'a> FormatDisplay<'a> {
Expand All @@ -58,6 +63,7 @@ impl<'a> FormatDisplay<'a> {
replace_newline: bool,
start: Instant,
data: RowStatsIterator,
interrupted: Arc<AtomicBool>,
) -> Self {
Self {
settings,
Expand All @@ -69,6 +75,7 @@ impl<'a> FormatDisplay<'a> {
progress: None,
start,
stats: None,
interrupted,
}
}
}
Expand Down Expand Up @@ -133,6 +140,9 @@ impl<'a> FormatDisplay<'a> {
let mut rows = Vec::new();
let mut error = None;
while let Some(line) = self.data.next().await {
if self.interrupted.load(Ordering::SeqCst) {
return Err(anyhow!(INTERRUPTED_MESSAGE));
}
match line {
Ok(RowWithStats::Row(row)) => {
self.rows += 1;
Expand Down Expand Up @@ -224,6 +234,9 @@ impl<'a> FormatDisplay<'a> {
.quote_style(quote_style)
.from_writer(std::io::stdout());
while let Some(line) = self.data.next().await {
if self.interrupted.load(Ordering::SeqCst) {
return Err(anyhow!(INTERRUPTED_MESSAGE));
}
match line {
Ok(RowWithStats::Row(row)) => {
self.rows += 1;
Expand Down Expand Up @@ -254,6 +267,9 @@ impl<'a> FormatDisplay<'a> {
.quote_style(quote_style)
.from_writer(std::io::stdout());
while let Some(line) = self.data.next().await {
if self.interrupted.load(Ordering::SeqCst) {
return Err(anyhow!(INTERRUPTED_MESSAGE));
}
match line {
Ok(RowWithStats::Row(row)) => {
self.rows += 1;
Expand All @@ -274,6 +290,9 @@ impl<'a> FormatDisplay<'a> {
async fn display_null(&mut self) -> Result<()> {
let mut error = None;
while let Some(line) = self.data.next().await {
if self.interrupted.load(Ordering::SeqCst) {
return Err(anyhow!(INTERRUPTED_MESSAGE));
}
match line {
Ok(RowWithStats::Row(_)) => {
self.rows += 1;
Expand Down Expand Up @@ -365,6 +384,10 @@ impl<'a> FormatDisplay<'a> {
#[async_trait::async_trait]
impl ChunkDisplay for FormatDisplay<'_> {
async fn display(&mut self) -> Result<ServerStats> {
if self.interrupted.load(Ordering::SeqCst) {
return Err(anyhow!(INTERRUPTED_MESSAGE));
}

match self.settings.output_format {
OutputFormat::Table => {
self.display_table().await?;
Expand Down
32 changes: 30 additions & 2 deletions cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use rustyline::{CompletionType, Editor};
use std::collections::BTreeMap;
use std::io::BufRead;
use std::path::Path;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use tokio::fs::{remove_file, File};
use tokio::io::AsyncWriteExt;
Expand All @@ -38,6 +40,7 @@ use tokio_stream::StreamExt;

use crate::config::Settings;
use crate::config::TimeOption;
use crate::display::INTERRUPTED_MESSAGE;
use crate::display::{format_write_progress, ChunkDisplay, FormatDisplay};
use crate::helper::CliHelper;
use crate::web::find_available_port;
Expand Down Expand Up @@ -75,6 +78,7 @@ pub struct Session {

server_handle: Option<JoinHandle<std::io::Result<()>>>,
keywords: Option<Arc<sled::Db>>,
interrupted: Arc<AtomicBool>,
}

impl Session {
Expand Down Expand Up @@ -180,9 +184,19 @@ impl Session {
None
};

let interrupted = Arc::new(AtomicBool::new(false));
let interrupted_clone = interrupted.clone();

if is_repl {
println!();

// Register the Ctrl+C handler
ctrlc::set_handler(move || {
interrupted_clone.store(true, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
}

Ok(Self {
client,
conn,
Expand All @@ -191,6 +205,7 @@ impl Session {
query: String::new(),
keywords,
server_handle,
interrupted,
})
}

Expand Down Expand Up @@ -329,6 +344,12 @@ impl Session {
}
} else {
eprintln!("error: {}", e);
if e.to_string().contains(INTERRUPTED_MESSAGE) {
if let Some(query_id) = self.conn.last_query_id() {
println!("killing query: {}", query_id);
let _ = self.conn.kill_query(&query_id).await;
}
}
self.query.clear();
break;
}
Expand Down Expand Up @@ -458,6 +479,7 @@ impl Session {
) -> Result<Option<ServerStats>> {
let query = query.trim_end_matches(';').trim();

self.interrupted.store(false, Ordering::SeqCst);
if is_repl {
if query.starts_with('!') {
return self.handle_commands(query).await;
Expand Down Expand Up @@ -503,8 +525,14 @@ impl Session {
_ => self.conn.query_iter_ext(query).await?,
};

let mut displayer =
FormatDisplay::new(&self.settings, query, replace_newline, start, data);
let mut displayer = FormatDisplay::new(
&self.settings,
query,
replace_newline,
start,
data,
self.interrupted.clone(),
);
let stats = displayer.display().await?;
Ok(Some(stats))
}
Expand Down
25 changes: 20 additions & 5 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ pub struct APIClient {

presign: PresignMode,
last_node_id: Arc<parking_lot::Mutex<Option<String>>>,
last_query_id: Arc<parking_lot::Mutex<Option<String>>>,
}

impl APIClient {
Expand Down Expand Up @@ -347,6 +348,15 @@ impl APIClient {
pub fn set_last_node_id(&self, node_id: String) {
*self.last_node_id.lock() = Some(node_id)
}

pub fn set_last_query_id(&self, query_id: Option<String>) {
*self.last_query_id.lock() = query_id
}

pub fn last_query_id(&self) -> Option<String> {
self.last_query_id.lock().clone()
}

fn last_node_id(&self) -> Option<String> {
self.last_node_id.lock().clone()
}
Expand Down Expand Up @@ -412,6 +422,8 @@ impl APIClient {
if let Some(err) = result.error {
return Err(Error::QueryFailed(err));
}

self.set_last_query_id(Some(query_id));
self.handle_warnings(&result);
Ok(result)
}
Expand Down Expand Up @@ -452,12 +464,13 @@ impl APIClient {
}
}

#[allow(dead_code)]
async fn kill_query(&self, query_id: &str, kill_uri: &str) -> Result<()> {
info!("kill query: {}", kill_uri);
let endpoint = self.endpoint.join(kill_uri)?;
pub async fn kill_query(&self, query_id: &str) -> Result<()> {
let kill_uri = format!("/v1/query/{}/kill", query_id);
let endpoint = self.endpoint.join(&kill_uri)?;
let headers = self.make_headers(Some(query_id))?;
let mut builder = self.cli.post(endpoint.clone());
info!("kill query: {}", kill_uri);

let mut builder = self.cli.post(endpoint);
builder = self.wrap_auth_or_session_token(builder)?;
let resp = builder.headers(headers.clone()).send().await?;
if resp.status() != 200 {
Expand All @@ -473,6 +486,7 @@ impl APIClient {
if let Some(node_id) = self.last_node_id() {
self.set_last_node_id(node_id.clone());
}

if let Some(next_uri) = &resp.next_uri {
let schema = resp.schema;
let mut data = resp.data;
Expand Down Expand Up @@ -954,6 +968,7 @@ impl Default for APIClient {
disable_login: false,
session_token_info: None,
closed: Arc::new(Default::default()),
last_query_id: Arc::new(Default::default()),
server_version: None,
}
}
Expand Down
3 changes: 3 additions & 0 deletions driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ pub trait Connection: Send + Sync {
Ok(())
}

fn last_query_id(&self) -> Option<String>;

async fn version(&self) -> Result<String> {
let row = self.query_row("SELECT version()").await?;
let version = match row {
Expand All @@ -108,6 +110,7 @@ pub trait Connection: Send + Sync {
}

async fn exec(&self, sql: &str) -> Result<i64>;
async fn kill_query(&self, query_id: &str) -> Result<()>;
async fn query_iter(&self, sql: &str) -> Result<RowIterator>;
async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator>;

Expand Down
9 changes: 9 additions & 0 deletions driver/src/flight_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,22 @@ impl Connection for FlightSQLConnection {
}
}

fn last_query_id(&self) -> Option<String> {
None
}

async fn exec(&self, sql: &str) -> Result<i64> {
self.handshake().await?;
let mut client = self.client.lock().await;
let affected_rows = client.execute_update(sql.to_string(), None).await?;
Ok(affected_rows)
}

async fn kill_query(&self, _query_id: &str) -> Result<()> {
// todo: implement kill query
Ok(())
}

async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
let rows_with_progress = self.query_iter_ext(sql).await?;
let rows = rows_with_progress.filter_rows().await;
Expand Down
8 changes: 8 additions & 0 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ impl Connection for RestAPIConnection {
}
}

fn last_query_id(&self) -> Option<String> {
self.client.last_query_id()
}

async fn close(&self) -> Result<()> {
self.client.close().await;
Ok(())
Expand All @@ -74,6 +78,10 @@ impl Connection for RestAPIConnection {
Ok(resp.stats.progresses.write_progress.rows as i64)
}

async fn kill_query(&self, query_id: &str) -> Result<()> {
Ok(self.client.kill_query(query_id).await?)
}

async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
info!("query iter: {}", sql);
let rows_with_progress = self.query_iter_ext(sql).await?;
Expand Down

0 comments on commit 560fddf

Please sign in to comment.