diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3312a83..17d98ac 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,3 +47,7 @@ jobs: run: | cargo test timeout-minutes: 4 + - name: example + run: | + cargo run --example basic_op + timeout-minutes: 4 diff --git a/fbthrift-transport/tests/transport_tokio_io.rs b/fbthrift-transport/tests/transport_tokio_io.rs index 06121c7..ab4ae9e 100644 --- a/fbthrift-transport/tests/transport_tokio_io.rs +++ b/fbthrift-transport/tests/transport_tokio_io.rs @@ -23,8 +23,8 @@ mod transport_tokio_io_tests { task::JoinHandle, }; - use nebula_fbthrift_transport::AsyncTransport; use fbthrift_transport_response_handler::ResponseHandler; + use nebula_fbthrift_transport::AsyncTransport; #[derive(Clone)] pub struct FooResponseHandler; diff --git a/nebula_rust/Cargo.toml b/nebula_rust/Cargo.toml index 06604f6..6fab764 100644 --- a/nebula_rust/Cargo.toml +++ b/nebula_rust/Cargo.toml @@ -22,7 +22,11 @@ tokio = { version = "1.8.2", features = ["full"] } fbthrift = { version = "0.0.2" } fbthrift-transport = { path = "../fbthrift-transport", package = "nebula-fbthrift-transport" , features = ["tokio_io"], version = "0.0.2" } bytes = { version = "0.5" } +futures = { version = "0.3.16" } [build-dependencies] [dev-dependencies] + +[[example]] +name = "basic_op" diff --git a/nebula_rust/examples/basic_op.rs b/nebula_rust/examples/basic_op.rs new file mode 100644 index 0000000..171b1f3 --- /dev/null +++ b/nebula_rust/examples/basic_op.rs @@ -0,0 +1,31 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +use nebula_rust::graph_client; + +#[tokio::main] +async fn main() { + let mut conf = graph_client::pool_config::PoolConfig::new(); + conf.min_connection_pool_size(2) + .max_connection_pool_size(10) + .address("localhost:9669".to_string()); + + let pool = graph_client::connection_pool::ConnectionPool::new(&conf).await; + let session = pool.get_session("root", "nebula", true).await.unwrap(); + + let resp = session.execute("YIELD 1").await.unwrap(); + assert!(resp.error_code == common::types::ErrorCode::SUCCEEDED); + + println!("{:?}", resp.data.as_ref().unwrap()); + println!( + "The result of query `YIELD 1' is {}.", + if let common::types::Value::iVal(v) = resp.data.unwrap().rows[0].values[0] { + v + } else { + panic!() + } + ); +} diff --git a/nebula_rust/src/graph_client/connection.rs b/nebula_rust/src/graph_client/connection.rs index 0999519..10e19e0 100644 --- a/nebula_rust/src/graph_client/connection.rs +++ b/nebula_rust/src/graph_client/connection.rs @@ -13,18 +13,24 @@ use tokio::net::TcpStream; use crate::graph_client::transport_response_handler; +/// The simple abstraction of a connection to nebula graph server +#[derive(Default)] pub struct Connection { - client: client::GraphServiceImpl< - BinaryProtocol, - AsyncTransport, + // The option is used to construct a null connection + // which is used to give back the connection to pool from session + // So we could assume it's alway not null + client: Option< + client::GraphServiceImpl< + BinaryProtocol, + AsyncTransport, + >, >, } impl Connection { - /// Create connection with the specified [host:port] - pub async fn new(host: &str, port: i32) -> Result { - let addr = format!("{}:{}", host, port); - let stream = TcpStream::connect(addr).await?; + /// Create connection with the specified [host:port] address + pub async fn new_from_address(address: &str) -> Result { + let stream = TcpStream::connect(address).await?; let transport = AsyncTransport::new( stream, AsyncTransportConfiguration::new( @@ -32,11 +38,20 @@ impl Connection { ), ); Ok(Connection { - client: client::GraphServiceImpl::new(transport), + client: Some(client::GraphServiceImpl::new(transport)), }) } + /// Create connection with the specified [host:port] + pub async fn new(host: &str, port: i32) -> Result { + let address = format!("{}:{}", host, port); + Connection::new_from_address(&address).await + } + /// Authenticate by username and password + /// The returned error of `Result` only means the request/response status + /// The error from Nebula Graph is still in `error_code` field in response, so you need check it + /// to known wether authenticate succeeded pub async fn authenticate( &self, username: &str, @@ -44,6 +59,8 @@ impl Connection { ) -> std::result::Result { let result = self .client + .as_ref() + .unwrap() .authenticate( &username.to_string().into_bytes(), &password.to_string().into_bytes(), @@ -56,11 +73,12 @@ impl Connection { } /// Sign out the authentication by session id which got by authenticating previous + /// The returned error of `Result` only means the request/response status pub async fn signout( &self, session_id: i64, ) -> std::result::Result<(), common::types::ErrorCode> { - let result = self.client.signout(session_id).await; + let result = self.client.as_ref().unwrap().signout(session_id).await; if let Err(_) = result { return Err(common::types::ErrorCode::E_RPC_FAILURE); } @@ -68,6 +86,9 @@ impl Connection { } /// Execute the query with current session id which got by authenticating previous + /// The returned error of `Result` only means the request/response status + /// The error from Nebula Graph is still in `error_code` field in response, so you need check it + /// to known wether the query execute succeeded pub async fn execute( &self, session_id: i64, @@ -75,6 +96,8 @@ impl Connection { ) -> std::result::Result { let result = self .client + .as_ref() + .unwrap() .execute(session_id, &query.to_string().into_bytes()) .await; if let Err(_) = result { diff --git a/nebula_rust/src/graph_client/connection_pool.rs b/nebula_rust/src/graph_client/connection_pool.rs index c4e7abb..1846657 100644 --- a/nebula_rust/src/graph_client/connection_pool.rs +++ b/nebula_rust/src/graph_client/connection_pool.rs @@ -4,4 +4,143 @@ * attached with Common Clause Condition 1.0, found in the LICENSES directory. */ -mod graph_client {}; +use crate::graph_client::connection::Connection; +use crate::graph_client::pool_config::PoolConfig; +use crate::graph_client::session::Session; + +/// The pool of connection to server, it's MT-safe to access. +pub struct ConnectionPool { + /// The connections + /// The interior mutable to enable could get multiple sessions in one scope + conns: std::sync::Mutex>>, + /// It should be immutable + config: PoolConfig, + /// Address cursor + cursor: std::cell::RefCell, + /// The total count of connections, contains which hold by session + conns_count: std::cell::RefCell, +} + +impl ConnectionPool { + /// Construct pool by the configuration + pub async fn new(conf: &PoolConfig) -> Self { + let conns = std::collections::LinkedList::::new(); + let pool = ConnectionPool { + conns: std::sync::Mutex::new(std::cell::RefCell::new(conns)), + config: conf.clone(), + cursor: std::cell::RefCell::new(std::sync::atomic::AtomicUsize::new(0)), + conns_count: std::cell::RefCell::new(std::sync::atomic::AtomicUsize::new(0)), + }; + assert!(pool.config.min_connection_pool_size <= pool.config.max_connection_pool_size); + pool.new_connection(pool.config.min_connection_pool_size) + .await; + pool + } + + /// Get a session authenticated by username and password + /// retry_connect means keep the connection available if true + pub async fn get_session( + &self, + username: &str, + password: &str, + retry_connect: bool, + ) -> std::result::Result, common::types::ErrorCode> { + if self.conns.lock().unwrap().borrow_mut().is_empty() { + self.new_connection(1).await; + } + let conn = self.conns.lock().unwrap().borrow_mut().pop_back(); + if let Some(conn) = conn { + let resp = conn.authenticate(username, password).await?; + if resp.error_code != common::types::ErrorCode::SUCCEEDED { + return Err(resp.error_code); + } + Ok(Session::new( + resp.session_id.unwrap(), + conn, + self, + username.to_string(), + password.to_string(), + if let Some(time_zone_name) = resp.time_zone_name { + std::str::from_utf8(&time_zone_name).unwrap().to_string() + } else { + String::new() + }, + resp.time_zone_offset_seconds.unwrap(), + retry_connect, + )) + } else { + Err(common::types::ErrorCode::E_UNKNOWN) + } + } + + /// Give back the connection to pool + #[inline] + pub fn give_back(&self, conn: Connection) { + self.conns.lock().unwrap().borrow_mut().push_back(conn); + } + + /// Get the count of connections + #[inline] + pub fn len(&self) -> usize { + self.conns.lock().unwrap().borrow().len() + } + + // Add new connection to pool + // inc is the count of new connection created, which shouldn't be zero + // the incremental count maybe can't fit when occurs error in connection creating + async fn new_connection(&self, inc: u32) { + assert!(inc != 0); + // TODO concurrent these + let mut count = 0; + let mut loop_count = 0; + let loop_limit = inc as usize * self.config.addresses.len(); + while count < inc { + if count as usize + + self + .conns_count + .borrow() + .load(std::sync::atomic::Ordering::Acquire) + >= self.config.max_connection_pool_size as usize + { + // Reach the pool size limit + break; + } + let cursor = { self.cursor() }; + match Connection::new_from_address(&self.config.addresses[cursor]).await { + Ok(conn) => { + self.conns.lock().unwrap().borrow_mut().push_back(conn); + count += 1; + } + Err(_) => (), + }; + loop_count += 1; + if loop_count > loop_limit { + // Can't get so many connections, avoid dead loop + break; + } + } + // Release ordering make sure inc happened after creating new connections + self.conns_count + .borrow_mut() + .fetch_add(count as usize, std::sync::atomic::Ordering::Release); + } + + // cursor on the server addresses + fn cursor(&self) -> usize { + if self + .cursor + .borrow() + .load(std::sync::atomic::Ordering::Relaxed) + >= self.config.addresses.len() + { + self.cursor + .borrow_mut() + .store(0, std::sync::atomic::Ordering::Relaxed); + 0 + } else { + self.cursor + .borrow_mut() + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + } + } +} diff --git a/nebula_rust/src/graph_client/mod.rs b/nebula_rust/src/graph_client/mod.rs index 99cf0db..9aeaa93 100644 --- a/nebula_rust/src/graph_client/mod.rs +++ b/nebula_rust/src/graph_client/mod.rs @@ -5,4 +5,7 @@ */ pub mod connection; +pub mod connection_pool; +pub mod pool_config; +pub mod session; mod transport_response_handler; diff --git a/nebula_rust/src/graph_client/pool_config.rs b/nebula_rust/src/graph_client/pool_config.rs index 33c14ca..3e3de23 100644 --- a/nebula_rust/src/graph_client/pool_config.rs +++ b/nebula_rust/src/graph_client/pool_config.rs @@ -4,4 +4,58 @@ * attached with Common Clause Condition 1.0, found in the LICENSES directory. */ -mod graph_client {}; +#[derive(Debug, Default, Clone)] +pub struct PoolConfig { + /// connection timeout in ms + pub timeout: u32, + pub idle_time: u32, + /// max limit count of connections in pool + pub max_connection_pool_size: u32, + /// min limit count of connections in pool, also the initial count if works well + pub min_connection_pool_size: u32, + /// address of graph server + pub addresses: std::vec::Vec, +} + +impl PoolConfig { + #[inline] + pub fn new() -> Self { + Self::default() + } + + #[inline] + pub fn timeout(&mut self, timeout: u32) -> &mut Self { + self.timeout = timeout; + self + } + + #[inline] + pub fn idle_time(&mut self, idle_time: u32) -> &mut Self { + self.idle_time = idle_time; + self + } + + #[inline] + pub fn max_connection_pool_size(&mut self, size: u32) -> &mut Self { + self.max_connection_pool_size = size; + self + } + + #[inline] + pub fn min_connection_pool_size(&mut self, size: u32) -> &mut Self { + self.min_connection_pool_size = size; + self + } + + #[inline] + pub fn addresses(&mut self, addresses: std::vec::Vec) -> &mut Self { + self.addresses = addresses; + self + } + + #[inline] + pub fn address(&mut self, address: String) -> &mut Self { + self.addresses.push(address); + self + } +} diff --git a/nebula_rust/src/graph_client/session.rs b/nebula_rust/src/graph_client/session.rs index c4e7abb..b215abc 100644 --- a/nebula_rust/src/graph_client/session.rs +++ b/nebula_rust/src/graph_client/session.rs @@ -4,4 +4,82 @@ * attached with Common Clause Condition 1.0, found in the LICENSES directory. */ -mod graph_client {}; +use crate::graph_client::connection::Connection; +use crate::graph_client::connection_pool::ConnectionPool; + +pub struct Session<'a> { + session_id: i64, + conn: Connection, + pool: &'a ConnectionPool, + username: String, + password: String, + // empty means not a named timezone + time_zone_name: String, + // Offset to utc in seconds + offset_secs: i32, + // Keep connection if true + retry_connect: bool, +} + +impl<'a> Session<'a> { + pub fn new( + session_id: i64, + conn: Connection, + pool: &'a ConnectionPool, + username: String, + password: String, + time_zone_name: String, + offset_secs: i32, + retry_connect: bool, + ) -> Self { + Session { + session_id: session_id, + conn: conn, + pool: pool, + username: username, + password: password, + time_zone_name: time_zone_name, + offset_secs: offset_secs, + retry_connect: retry_connect, + } + } + + /// sign out the session + #[inline] + pub async fn signout(&self) -> std::result::Result<(), common::types::ErrorCode> { + self.conn.signout(self.session_id).await + } + + /// Execute the query in current session + /// The returned error of `Result` only means the request/response status + /// The error from Nebula Graph is still in `error_code` field in response, so you need check it + /// to known wether the query execute succeeded + #[inline] + pub async fn execute( + &self, + query: &str, + ) -> std::result::Result { + self.conn.execute(self.session_id, query).await + } + + /// Get the time zone name + #[inline] + pub fn time_zone_name(&self) -> &str { + &self.time_zone_name + } + + /// Get the time zone offset to UTC in seconds + #[inline] + pub fn offset_secs(&self) -> i32 { + self.offset_secs + } +} + +impl<'a> Drop for Session<'a> { + /// Drop session will sign out the session in server + /// and give back connection to pool + fn drop(&mut self) { + futures::executor::block_on(self.signout()); + self.pool.give_back(std::mem::take(&mut self.conn)); + } +} diff --git a/nebula_rust/src/lib.rs b/nebula_rust/src/lib.rs index e87dc26..5ef11a8 100644 --- a/nebula_rust/src/lib.rs +++ b/nebula_rust/src/lib.rs @@ -5,3 +5,4 @@ */ pub mod graph_client; +pub mod value; diff --git a/nebula_rust/src/value/data_set.rs b/nebula_rust/src/value/data_set.rs new file mode 100644 index 0000000..11eb837 --- /dev/null +++ b/nebula_rust/src/value/data_set.rs @@ -0,0 +1,58 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +pub trait DataSet { + /// Construct data set with name of columns + fn new(col_names: &[String]) -> Self; + + /// Construct data set from vec of columns name + fn from_columns_name(col_names: std::vec::Vec) -> Self; + + /// push one row into back of data set + fn push(&mut self, row: common::types::Row); + + /// Get rows size + fn len(&self) -> usize; + + /// Get count of columns + fn cols_len(&self) -> usize; +} + +impl DataSet for common::types::DataSet { + fn new(col_names: &[String]) -> Self { + let cols_bytes = col_names.into_iter().map(|s| s.as_bytes().to_vec()).collect(); + common::types::DataSet { + column_names: cols_bytes, + rows: vec![], + } + } + + fn from_columns_name(col_names: std::vec::Vec) -> Self { + let cols_bytes = col_names + .into_iter() + .map(|s| s.as_bytes().to_vec()) + .collect(); + common::types::DataSet { + column_names: cols_bytes, + rows: vec![], + } + } + + #[inline] + fn push(&mut self, row: common::types::Row) { + self.rows.push(row); + } + + #[inline] + fn len(&self) -> usize { + self.rows.len() + } + + #[inline] + fn cols_len(&self) -> usize { + self.column_names.len() + } +} diff --git a/nebula_rust/src/value/mod.rs b/nebula_rust/src/value/mod.rs new file mode 100644 index 0000000..513df73 --- /dev/null +++ b/nebula_rust/src/value/mod.rs @@ -0,0 +1,9 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +/// Some extension of the thrift value +pub mod data_set; +pub mod row; diff --git a/nebula_rust/src/value/row.rs b/nebula_rust/src/value/row.rs new file mode 100644 index 0000000..ea2e028 --- /dev/null +++ b/nebula_rust/src/value/row.rs @@ -0,0 +1,35 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +pub trait Row { + /// Construct row by columns name + fn new(row: &[common::types::Value]) -> Self; + + /// Construct row by vec of column name + fn from_vec(row: std::vec::Vec) -> Self; + + /// Get row length + fn len(&self) -> usize; +} + +impl Row for common::types::Row { + #[inline] + fn new(row: &[common::types::Value]) -> Self { + common::types::Row { + values: row.to_vec(), + } + } + + #[inline] + fn from_vec(row: std::vec::Vec) -> Self { + common::types::Row { values: row } + } + + #[inline] + fn len(&self) -> usize { + self.values.len() + } +} diff --git a/nebula_rust/tests/test_connection.rs b/nebula_rust/tests/test_connection.rs index dd21ec0..d6bec99 100644 --- a/nebula_rust/tests/test_connection.rs +++ b/nebula_rust/tests/test_connection.rs @@ -9,6 +9,8 @@ extern crate nebula_rust; #[cfg(test)] mod test_connection { use nebula_rust::graph_client; + use nebula_rust::value::data_set::DataSet; + use nebula_rust::value::row::Row; #[tokio::test] async fn basic_op() { @@ -26,7 +28,9 @@ mod test_connection { assert!(result.is_ok()); let response = result.unwrap(); assert!(response.error_code == common::types::ErrorCode::SUCCEEDED); - println!("{:?}", response.data.unwrap()); + let mut dt = common::types::DataSet::new(&["1".to_string()]); + dt.push(common::types::Row::new(&[common::types::Value::iVal(1)])); + assert!(dt == response.data.unwrap()); let result = conn.signout(session_id).await; assert!(result.is_ok()); diff --git a/nebula_rust/tests/test_connection_pool.rs b/nebula_rust/tests/test_connection_pool.rs new file mode 100644 index 0000000..a13037c --- /dev/null +++ b/nebula_rust/tests/test_connection_pool.rs @@ -0,0 +1,47 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +extern crate nebula_rust; + +#[cfg(test)] +mod test_connection { + use nebula_rust::graph_client; + use nebula_rust::value::data_set::DataSet; + use nebula_rust::value::row::Row; + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn mt_safe() { + let mut conf = graph_client::pool_config::PoolConfig::new(); + conf.min_connection_pool_size(5) + .max_connection_pool_size(10) + .address("localhost:9669".to_string()); + let pool = graph_client::connection_pool::ConnectionPool::new(&conf).await; + + { + // Consume all connections + let futs = (0..conf.max_connection_pool_size) + .into_iter() + .map(|_| pool.get_session("root", "nebula", true)) + .collect::>(); + let sessions = futures::future::join_all(futs).await; + for session in &sessions { + let resp = session.as_ref().unwrap().execute("YIELD 1").await.unwrap(); + assert!(resp.error_code == common::types::ErrorCode::SUCCEEDED); + + let mut dt = common::types::DataSet::new(&["1".to_string()]); + dt.push(common::types::Row::new(&[common::types::Value::iVal(1)])); + assert!(dt == resp.data.unwrap()); + } + + assert!(pool.len() == 0); + + // out of pool size limit + let result = pool.get_session("root", "nebula", true).await; + assert!(!result.is_ok()); + } + assert!(pool.len() == 10); + } +} diff --git a/nebula_rust/tests/test_session.rs b/nebula_rust/tests/test_session.rs new file mode 100644 index 0000000..b81e383 --- /dev/null +++ b/nebula_rust/tests/test_session.rs @@ -0,0 +1,32 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +extern crate nebula_rust; + +#[cfg(test)] +mod test_connection { + use nebula_rust::graph_client; + use nebula_rust::value::data_set::DataSet; + use nebula_rust::value::row::Row; + + #[tokio::test] + async fn basic_op() { + let mut conf = graph_client::pool_config::PoolConfig::new(); + conf.min_connection_pool_size(2) + .max_connection_pool_size(10) + .address("localhost:9669".to_string()); + + let pool = graph_client::connection_pool::ConnectionPool::new(&conf).await; + let session = pool.get_session("root", "nebula", true).await.unwrap(); + + let resp = session.execute("YIELD 1").await.unwrap(); + assert!(resp.error_code == common::types::ErrorCode::SUCCEEDED); + + let mut dt = common::types::DataSet::new(&["1".to_string()]); + dt.push(common::types::Row::new(&[common::types::Value::iVal(1)])); + assert!(dt == resp.data.unwrap()); + } +}