diff --git a/cli/src/main.rs b/cli/src/main.rs index c063ba419..0e03512e8 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -367,6 +367,14 @@ pub async fn main() -> Result<()> { } settings.time = args.time; + let log_dir = format!( + "{}/.bendsql", + std::env::var("HOME").unwrap_or_else(|_| ".".to_string()) + ); + + let _guards = trace::init_logging(&log_dir, &args.log_level).await?; + info!("-> bendsql version: {}", VERSION.as_str()); + let mut session = match session::Session::try_new(dsn, settings, is_repl).await { Ok(session) => session, Err(err) => { @@ -390,14 +398,6 @@ pub async fn main() -> Result<()> { } }; - let log_dir = format!( - "{}/.bendsql", - std::env::var("HOME").unwrap_or_else(|_| ".".to_string()) - ); - - let _guards = trace::init_logging(&log_dir, &args.log_level).await?; - info!("-> bendsql version: {}", VERSION.as_str()); - if args.check { session.check().await?; return Ok(()); diff --git a/cli/src/session.rs b/cli/src/session.rs index ca4a52132..ed01886b0 100644 --- a/cli/src/session.rs +++ b/cli/src/session.rs @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; -use std::io::BufRead; -use std::path::Path; -use std::sync::Arc; - use anyhow::anyhow; use anyhow::Result; use async_recursion::async_recursion; @@ -31,6 +26,10 @@ use rustyline::config::Builder; use rustyline::error::ReadlineError; use rustyline::history::DefaultHistory; use rustyline::{CompletionType, Editor}; +use std::collections::BTreeMap; +use std::io::BufRead; +use std::path::Path; +use std::sync::Arc; use tokio::fs::{remove_file, File}; use tokio::io::AsyncWriteExt; use tokio::task::JoinHandle; @@ -352,6 +351,9 @@ impl Session { }, } } + if let Err(e) = self.conn.close().await { + println!("got error when closing session: {}", e); + } println!("Bye~"); let _ = rl.save_history(&get_history_path()); } @@ -394,6 +396,7 @@ impl Session { println!("{:.3}", server_time_ms / 1000.0); } } + self.conn.close().await.ok(); Ok(()) } diff --git a/core/src/client.rs b/core/src/client.rs index d21855c6e..4b4da3c05 100644 --- a/core/src/client.rs +++ b/core/src/client.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::collections::BTreeMap; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -32,7 +32,7 @@ use crate::{ response::QueryResponse, session::SessionState, }; -use log::{error, info, warn}; +use log::{debug, error, info, warn}; use once_cell::sync::Lazy; use percent_encoding::percent_decode_str; use reqwest::cookie::CookieStore; @@ -78,6 +78,8 @@ pub struct APIClient { disable_session_token: bool, session_token_info: Option>>, + closed: Arc, + server_version: Option, wait_time_secs: Option, @@ -354,6 +356,7 @@ impl APIClient { } pub async fn start_query(&self, sql: &str) -> Result { + info!("start query: {}", sql); self.start_query_inner(sql, None).await } @@ -483,7 +486,6 @@ impl APIClient { } pub async fn query(&self, sql: &str) -> Result { - info!("query: {}", sql); let resp = self.start_query(sql).await?; self.wait_for_query(resp).await } @@ -652,7 +654,7 @@ impl APIClient { Err(Error::Logic(status, ..)) | Err(Error::Response { status, .. }) if status == 404 => { - // old server + info!("login return 404, skip login on the old version server"); return Ok(()); } Err(e) => return Err(e), @@ -664,15 +666,17 @@ impl APIClient { LoginResponseResult::Ok(info) => { self.server_version = Some(info.version.clone()); if let Some(tokens) = info.tokens { + info!("login success with session token"); self.session_token_info = Some(Arc::new(parking_lot::Mutex::new((tokens, Instant::now())))) } + info!("login success without session token"); } } Ok(()) } - fn build_log_out_request(&mut self) -> Result { + fn build_log_out_request(&self) -> Result { let endpoint = self.endpoint.join("/v1/session/logout")?; let session_state = self.session_state(); @@ -691,8 +695,12 @@ impl APIClient { } fn need_logout(&self) -> bool { - self.session_token_info.is_some() - || self.session_state.lock().need_keep_alive.unwrap_or(false) + (self.session_token_info.is_some() + || self.session_state.lock().need_keep_alive.unwrap_or(false)) + && !self + .closed + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .unwrap() } async fn refresh_session_token( @@ -882,20 +890,26 @@ impl APIClient { sleep(jitter(Duration::from_secs(10))).await; } } -} -impl Drop for APIClient { - fn drop(&mut self) { + pub async fn close(&self) { if self.need_logout() { let cli = self.cli.clone(); let req = self .build_log_out_request() .expect("failed to build logout request"); - tokio::spawn(async move { - if let Err(err) = cli.execute(req).await { - error!("logout request failed: {}", err); - }; - }); + if let Err(err) = cli.execute(req).await { + error!("logout request failed: {}", err); + } else { + debug!("logout success"); + }; + } + } +} + +impl Drop for APIClient { + fn drop(&mut self) { + if self.need_logout() { + warn!("APIClient::close() was not called"); } } } @@ -937,6 +951,7 @@ impl Default for APIClient { disable_session_token: true, disable_login: false, session_token_info: None, + closed: Arc::new(Default::default()), server_version: None, } } diff --git a/driver/src/conn.rs b/driver/src/conn.rs index 757078ea2..5acf77a32 100644 --- a/driver/src/conn.rs +++ b/driver/src/conn.rs @@ -90,6 +90,9 @@ pub type Reader = Box; #[async_trait] pub trait Connection: Send + Sync { async fn info(&self) -> ConnectionInfo; + async fn close(&self) -> Result<()> { + Ok(()) + } async fn version(&self) -> Result { let row = self.query_row("SELECT version()").await?; diff --git a/driver/src/rest_api.rs b/driver/src/rest_api.rs index 471c0f280..b8ef787fb 100644 --- a/driver/src/rest_api.rs +++ b/driver/src/rest_api.rs @@ -54,6 +54,11 @@ impl Connection for RestAPIConnection { } } + async fn close(&self) -> Result<()> { + self.client.close().await; + Ok(()) + } + async fn exec(&self, sql: &str) -> Result { info!("exec: {}", sql); let mut resp = self.client.start_query(sql).await?;