diff --git a/Cargo.toml b/Cargo.toml index 7e8a251..d2bd198 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni" -version = "0.1.4" +version = "0.1.5" authors = ["Tin Rabzelj "] description = "Utility Library for Rust" repository = "https://github.com/tinrab/bomboni" @@ -30,9 +30,9 @@ proto = ["prost", "dep:bomboni_proto"] request = ["dep:bomboni_request"] [dependencies] -bomboni_common = { path = "bomboni_common", version = "0.1.4" } -bomboni_derive = { path = "bomboni_derive", version = "0.1.4" } +bomboni_common = { path = "bomboni_common", version = "0.1.5" } +bomboni_derive = { path = "bomboni_derive", version = "0.1.5" } -bomboni_prost = { path = "bomboni_prost", version = "0.1.4", optional = true } -bomboni_proto = { path = "bomboni_proto", version = "0.1.4", optional = true } -bomboni_request = { path = "bomboni_request", version = "0.1.4", optional = true } +bomboni_prost = { path = "bomboni_prost", version = "0.1.5", optional = true } +bomboni_proto = { path = "bomboni_proto", version = "0.1.5", optional = true } +bomboni_request = { path = "bomboni_request", version = "0.1.5", optional = true } diff --git a/bomboni_common/Cargo.toml b/bomboni_common/Cargo.toml index a746928..b1f3278 100644 --- a/bomboni_common/Cargo.toml +++ b/bomboni_common/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_common" -version = "0.1.4" +version = "0.1.5" authors = ["Tin Rabzelj "] description = "Common things for Bomboni library." repository = "https://github.com/tinrab/bomboni" diff --git a/bomboni_derive/Cargo.toml b/bomboni_derive/Cargo.toml index 66a901b..b09fd51 100644 --- a/bomboni_derive/Cargo.toml +++ b/bomboni_derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_derive" -version = "0.1.4" +version = "0.1.5" authors = ["Tin Rabzelj "] description = "Provides derive implementations for Bomboni library." repository = "https://github.com/tinrab/bomboni" diff --git a/bomboni_prost/Cargo.toml b/bomboni_prost/Cargo.toml index 7ab09aa..581a1d5 100644 --- a/bomboni_prost/Cargo.toml +++ b/bomboni_prost/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_prost" -version = "0.1.4" +version = "0.1.5" authors = ["Tin Rabzelj "] description = "Utilities for working with prost. Part of Bomboni library." repository = "https://github.com/tinrab/bomboni" @@ -14,7 +14,7 @@ name = "bomboni_prost" path = "src/lib.rs" [dependencies] -itertools = "0.11.0" +itertools = "0.12.0" convert_case = "0.6.0" prost = "0.12.1" prost-types = "0.12.1" diff --git a/bomboni_proto/Cargo.toml b/bomboni_proto/Cargo.toml index 686263d..d67ff4e 100644 --- a/bomboni_proto/Cargo.toml +++ b/bomboni_proto/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_proto" -version = "0.1.4" +version = "0.1.5" authors = ["Tin Rabzelj "] description = "Utilities for working with Protobuf/gRPC. Part of Bomboni library." repository = "https://github.com/tinrab/bomboni" @@ -18,7 +18,7 @@ testing = [] [dependencies] thiserror = "1.0.50" -itertools = "0.11.0" +itertools = "0.12.0" chrono = { version = "0.4.31", features = ["serde"] } prost = "0.12.1" serde = { version = "1.0.192", features = ["derive"] } @@ -28,5 +28,5 @@ pot = "3.0.0" serde_json = "1.0.108" [build-dependencies] -bomboni_prost = { path = "../bomboni_prost", version = "0.1.4" } +bomboni_prost = { path = "../bomboni_prost", version = "0.1.5" } prost-build = "0.12.1" diff --git a/bomboni_proto/src/protobuf/duration.rs b/bomboni_proto/src/protobuf/duration.rs index 19284c1..acaf5b5 100644 --- a/bomboni_proto/src/protobuf/duration.rs +++ b/bomboni_proto/src/protobuf/duration.rs @@ -11,7 +11,7 @@ use thiserror::Error; use crate::google::protobuf::Duration; -#[derive(Error, Debug)] +#[derive(Error, Debug, Clone, PartialEq)] pub enum DurationError { #[error("duration is out of range")] OutOfRange, diff --git a/bomboni_proto/src/protobuf/timestamp.rs b/bomboni_proto/src/protobuf/timestamp.rs index d990b9b..a953185 100644 --- a/bomboni_proto/src/protobuf/timestamp.rs +++ b/bomboni_proto/src/protobuf/timestamp.rs @@ -11,7 +11,7 @@ use thiserror::Error; use crate::google::protobuf::Timestamp; -#[derive(Error, Debug)] +#[derive(Error, Debug, Clone, PartialEq)] pub enum TimestampError { #[error("invalid nanoseconds")] InvalidNanoseconds, diff --git a/bomboni_request/Cargo.toml b/bomboni_request/Cargo.toml index 29d3534..7875454 100644 --- a/bomboni_request/Cargo.toml +++ b/bomboni_request/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bomboni_request" -version = "0.1.4" +version = "0.1.5" authors = ["Tin Rabzelj "] description = "Utilities for working with API requests. Part of Bomboni library." repository = "https://github.com/tinrab/bomboni" @@ -17,10 +17,15 @@ path = "src/lib.rs" testing = [] [dependencies] -bomboni_common = { path = "../bomboni_common", version = "0.1.4" } -bomboni_derive = { path = "../bomboni_derive", version = "0.1.4" } +bomboni_common = { path = "../bomboni_common", version = "0.1.5" } +bomboni_derive = { path = "../bomboni_derive", version = "0.1.5" } thiserror = "1.0.50" -itertools = "0.11.0" +itertools = "0.12.0" chrono = "0.4.31" pest = "2.7.5" pest_derive = "2.7.5" +base64ct = { version = "1.6.0", features = ["alloc"] } +aes-gcm = { version = "0.10.3", features = ["alloc"] } +blake2 = "0.10.6" +rsa = "0.9.3" +rand = "0.8.5" diff --git a/bomboni_request/src/filter/error.rs b/bomboni_request/src/filter/error.rs index 3ab0cb6..511666b 100644 --- a/bomboni_request/src/filter/error.rs +++ b/bomboni_request/src/filter/error.rs @@ -5,7 +5,7 @@ use crate::schema::ValueType; use super::parser::Rule; -#[derive(Error, Debug, Clone, PartialEq, Eq)] +#[derive(Error, Debug, Clone, PartialEq)] pub enum FilterError { #[error("failed to parse filter from `{start}` to `{end}`")] Parse { start: usize, end: usize }, diff --git a/bomboni_request/src/filter/mod.rs b/bomboni_request/src/filter/mod.rs index 11c5595..a212697 100644 --- a/bomboni_request/src/filter/mod.rs +++ b/bomboni_request/src/filter/mod.rs @@ -1,3 +1,9 @@ +//! # Filter +//! +//! Utility for specifying filters on queries, as described in Google AIP standard [1]. +//! +//! [1]: https://google.aip.dev/160 + use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::ops::Deref; @@ -381,6 +387,17 @@ impl Filter { Filter::Value(value) => value.value_type(), } } + + pub fn is_valid(&self, schema: &Schema) -> bool { + // TODO: verify if this is fine + self.get_result_value_type(schema).is_some() + } +} + +impl Default for Filter { + fn default() -> Self { + Filter::Conjunction(Vec::new()) + } } impl Display for Filter { @@ -443,6 +460,29 @@ mod tests { use super::*; + #[test] + fn validate_schema() { + let schema = UserItem::get_schema(); + macro_rules! check { + (@valid $filter:expr) => { + assert!(check!($filter)); + }; + (@invalid $filter:expr) => { + assert!(!check!($filter)); + }; + ($filter:expr) => { + Filter::parse($filter).unwrap().is_valid(&schema) + }; + } + + check!(@valid "42"); + check!(@valid "false"); + + check!(@invalid "a"); + check!(@invalid "a.b"); + check!(@invalid "f()"); + } + #[test] fn it_works() { Filter::parse(r#" "#).unwrap(); diff --git a/bomboni_request/src/ordering/error.rs b/bomboni_request/src/ordering/error.rs index d0ff667..8a6905c 100644 --- a/bomboni_request/src/ordering/error.rs +++ b/bomboni_request/src/ordering/error.rs @@ -1,6 +1,6 @@ use thiserror::Error; -#[derive(Error, Debug, Clone, PartialEq, Eq)] +#[derive(Error, Debug, Clone, PartialEq)] pub enum OrderingError { #[error("duplicate ordering field `{0}`")] DuplicateField(String), diff --git a/bomboni_request/src/ordering/mod.rs b/bomboni_request/src/ordering/mod.rs index 69d0e61..12b860c 100644 --- a/bomboni_request/src/ordering/mod.rs +++ b/bomboni_request/src/ordering/mod.rs @@ -6,13 +6,15 @@ use std::{ use itertools::Itertools; +use crate::schema::Schema; + use self::error::{OrderingError, OrderingResult}; use super::schema::SchemaMapped; pub mod error; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct Ordering { pub terms: Vec, } @@ -91,6 +93,19 @@ impl Ordering { } Some(cmp::Ordering::Equal) } + + pub fn is_valid(&self, schema: &Schema) -> bool { + for term in self.terms.iter() { + if let Some(field) = schema.get_field(&term.name) { + if !field.ordered { + return false; + } + } else { + return false; + } + } + true + } } impl Display for Ordering { diff --git a/bomboni_request/src/query.rs b/bomboni_request/src/query.rs deleted file mode 100644 index 8b13789..0000000 --- a/bomboni_request/src/query.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/bomboni_request/src/query/error.rs b/bomboni_request/src/query/error.rs new file mode 100644 index 0000000..d8a155b --- /dev/null +++ b/bomboni_request/src/query/error.rs @@ -0,0 +1,41 @@ +use thiserror::Error; + +use crate::{filter::error::FilterError, ordering::error::OrderingError}; + +#[derive(Error, Debug, Clone, PartialEq)] +pub enum QueryError { + #[error("filter error: {0}")] + FilterError(FilterError), + #[error("filter is too long")] + FilterTooLong, + #[error("filter schema mismatch")] + FilterSchemaMismatch, + #[error("ordering error: {0}")] + OrderingError(OrderingError), + #[error("ordering is too long")] + OrderingTooLong, + #[error("ordering schema mismatch")] + OrderingSchemaMismatch, + #[error("query is too long")] + QueryTooLong, + #[error("page token is invalid")] + InvalidPageToken, + #[error("page token could not be built")] + PageTokenFailure, + #[error("page size specified is invalid")] + InvalidPageSize, +} + +pub type QueryResult = Result; + +impl From for QueryError { + fn from(err: FilterError) -> Self { + QueryError::FilterError(err) + } +} + +impl From for QueryError { + fn from(err: OrderingError) -> Self { + QueryError::OrderingError(err) + } +} diff --git a/bomboni_request/src/query/list.rs b/bomboni_request/src/query/list.rs new file mode 100644 index 0000000..b12088e --- /dev/null +++ b/bomboni_request/src/query/list.rs @@ -0,0 +1,251 @@ +//! # List query. +//! +//! Utility for working with Google AIP standard List method [1]. +//! +//! [1]: https://google.aip.dev/132 + +use crate::{ + filter::Filter, + ordering::{Ordering, OrderingTerm}, + schema::Schema, +}; + +use super::{ + error::{QueryError, QueryResult}, + page_token::PageTokenBuilder, + utility::{parse_query_filter, parse_query_ordering}, +}; + +/// Represents a list query. +/// List queries list paged, filtered and ordered items. +#[derive(Debug, Clone)] +pub struct ListQuery { + pub filter: Filter, + pub ordering: Ordering, + pub page_size: i32, + pub page_token: Option, +} + +/// Config for list query builder. +/// +/// `primary_ordering_term` should probably never be `None`. +/// If the request does not contain an "order_by" field, usage of this function should pre-insert one. +/// The default ordering term can the primary key of the schema item. +/// If ordering is not specified, then behavior of query's page tokens [`PageToken`] is undefined. +#[derive(Debug, Clone)] +pub struct ListQueryConfig { + pub max_page_size: Option, + pub default_page_size: i32, + pub primary_ordering_term: Option, + pub max_filter_length: Option, + pub max_ordering_length: Option, +} + +pub struct ListQueryBuilder { + schema: Schema, + options: ListQueryConfig, + page_token_builder: P, +} + +impl Default for ListQueryConfig { + fn default() -> Self { + ListQueryConfig { + max_page_size: None, + default_page_size: 20, + primary_ordering_term: None, + max_filter_length: None, + max_ordering_length: None, + } + } +} + +impl ListQueryBuilder

{ + pub fn new(schema: Schema, options: ListQueryConfig, page_token_builder: P) -> Self { + ListQueryBuilder { + schema, + options, + page_token_builder, + } + } + + pub fn build( + &self, + page_size: Option, + page_token: Option<&str>, + filter: Option<&str>, + ordering: Option<&str>, + ) -> QueryResult> { + let filter = parse_query_filter(filter, &self.schema, self.options.max_filter_length)?; + let mut ordering = + parse_query_ordering(ordering, &self.schema, self.options.max_ordering_length)?; + + // Pre-insert primary ordering term. + // This is needed for page tokens to work. + if let Some(primary_ordering_term) = self.options.primary_ordering_term.as_ref() { + if ordering + .terms + .iter() + .all(|term| term.name != primary_ordering_term.name) + { + ordering.terms.insert(0, primary_ordering_term.clone()); + } + } + + // Handle paging. + let mut page_size = page_size.unwrap_or(self.options.default_page_size); + if page_size < 0 { + return Err(QueryError::InvalidPageSize); + } + if let Some(max_page_size) = self.options.max_page_size { + // Intentionally clamp page size to max page size. + if page_size > max_page_size { + page_size = max_page_size; + } + } + + let page_token = + if let Some(page_token) = page_token.filter(|page_token| !page_token.is_empty()) { + Some( + self.page_token_builder + .parse(&filter, &ordering, page_token)?, + ) + } else { + None + }; + + Ok(ListQuery { + filter, + ordering, + page_size, + page_token, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + filter::error::FilterError, ordering::OrderingDirection, + query::page_token::plain::PlainPageTokenBuilder, testing::schema::UserItem, + }; + + use super::*; + + #[test] + fn it_works() { + let qb = get_query_builder(); + let query = qb + .build( + Some(10_000), + None, + Some("displayName = \"John\""), + Some("age desc"), + ) + .unwrap(); + assert_eq!(query.page_size, 20); + assert_eq!(query.filter.to_string(), "displayName = \"John\""); + assert_eq!(query.ordering.to_string(), "id desc, age desc"); + } + + #[test] + fn errors() { + let q = get_query_builder(); + assert!(matches!( + q.build(Some(-1), None, None, None), + Err(QueryError::InvalidPageSize) + )); + assert!(matches!( + q.build(Some(-1), None, None, None), + Err(QueryError::InvalidPageSize) + )); + assert!(matches!( + q.build(None, None, Some("f!"), None).unwrap_err(), + QueryError::FilterError(FilterError::Parse { start, end }) + if start == 1 && end == 1 + )); + assert_eq!( + q.build(None, None, Some(&("a".repeat(100))), None) + .unwrap_err(), + QueryError::FilterTooLong + ); + assert_eq!( + q.build(None, None, Some("lol"), None).unwrap_err(), + QueryError::FilterSchemaMismatch + ); + assert_eq!( + q.build(None, None, None, Some(&("a".repeat(100)))) + .unwrap_err(), + QueryError::OrderingTooLong + ); + assert_eq!( + q.build(None, None, None, Some("lol")).unwrap_err(), + QueryError::OrderingSchemaMismatch + ); + } + + #[test] + fn page_tokens() { + let qb = get_query_builder(); + let last_item: UserItem = UserItem { + id: "1337".into(), + display_name: "John".into(), + age: 14000, + }; + + macro_rules! assert_page_token { + ($filter1:expr, $ordering1:expr, $filter2:expr, $ordering2:expr, $expected_token:expr $(,)?) => {{ + let first_page = qb.build(Some(3), None, $filter1, $ordering1).unwrap(); + let next_page_token = qb + .page_token_builder + .build_next(&first_page.filter, &first_page.ordering, &last_item) + .unwrap(); + let next_page: ListQuery = qb + .build(Some(3), Some(&next_page_token), $filter2, $ordering2) + .unwrap(); + assert_eq!( + next_page.page_token.unwrap().filter.to_string(), + $expected_token + ); + }}; + } + + assert_page_token!( + Some(r#"displayName = "John""#), + None, + Some(r#"displayName = "John""#), + None, + r#"id <= "1337""#, + ); + assert_page_token!( + None, + Some(r#"id desc, age desc"#), + None, + Some(r#"id desc, age desc"#), + r#"id <= "1337" AND age <= 14000"#, + ); + assert_page_token!( + None, + Some(r#"id desc, age asc"#), + None, + Some(r#"id desc, age desc"#), + r#"id <= "1337" AND age >= 14000"#, + ); + } + + fn get_query_builder() -> ListQueryBuilder { + ListQueryBuilder::::new( + UserItem::get_schema(), + ListQueryConfig { + max_page_size: Some(20), + default_page_size: 10, + primary_ordering_term: Some(OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Descending, + }), + max_filter_length: Some(50), + max_ordering_length: Some(50), + }, + PlainPageTokenBuilder {}, + ) + } +} diff --git a/bomboni_request/src/query/mod.rs b/bomboni_request/src/query/mod.rs new file mode 100644 index 0000000..46e834c --- /dev/null +++ b/bomboni_request/src/query/mod.rs @@ -0,0 +1,5 @@ +pub mod error; +pub mod list; +pub mod page_token; +pub mod search; +pub mod utility; diff --git a/bomboni_request/src/query/page_token/aes256.rs b/bomboni_request/src/query/page_token/aes256.rs new file mode 100644 index 0000000..bcb0ad3 --- /dev/null +++ b/bomboni_request/src/query/page_token/aes256.rs @@ -0,0 +1,174 @@ +use aes_gcm::{ + aead::{Aead, OsRng}, + AeadCore, Aes256Gcm, Key, KeyInit, +}; +use base64ct::{Base64, Base64Url, Encoding}; + +use crate::{ + filter::Filter, + ordering::Ordering, + query::{ + error::{QueryError, QueryResult}, + page_token::utility::get_page_filter, + }, + schema::SchemaMapped, +}; + +use super::{utility::make_page_key, FilterPageToken, PageTokenBuilder}; + +const NONCE_LENGTH: usize = 12; + +/// AES-256-GCM page token builder. +/// The page token is encrypted using the query parameters as the key. +/// This is useful for ensuring that the page token was generated for the same paging rules. +pub struct Aes256PageTokenBuilder { + url_safe: bool, +} + +impl Aes256PageTokenBuilder { + pub fn new(url_safe: bool) -> Self { + Aes256PageTokenBuilder { url_safe } + } +} + +impl PageTokenBuilder for Aes256PageTokenBuilder { + type PageToken = FilterPageToken; + + fn parse( + &self, + filter: &Filter, + ordering: &Ordering, + page_token: &str, + ) -> QueryResult { + let decoded = if self.url_safe { + Base64Url::decode_vec(page_token).map_err(|_| QueryError::InvalidPageToken)? + } else { + Base64::decode_vec(page_token).map_err(|_| QueryError::InvalidPageToken)? + }; + + let key = make_page_key::<32>(filter, ordering).ok_or(QueryError::InvalidPageToken)?; + let key: &Key = (&key).into(); + + let cipher = Aes256Gcm::new(key); + if decoded.len() <= NONCE_LENGTH { + return Err(QueryError::InvalidPageToken); + } + let (nonce_buf, encrypted) = decoded.split_at(NONCE_LENGTH); + let nonce = nonce_buf + .try_into() + .map_err(|_| QueryError::InvalidPageToken)?; + + let plaintext = cipher + .decrypt(nonce, encrypted) + .map_err(|_| QueryError::InvalidPageToken)?; + + let page_filter = + Filter::parse(&String::from_utf8(plaintext).map_err(|_| QueryError::InvalidPageToken)?) + .map_err(|_| QueryError::InvalidPageToken)?; + + Ok(Self::PageToken { + filter: page_filter, + }) + } + + fn build_next( + &self, + filter: &Filter, + ordering: &Ordering, + next_item: &T, + ) -> QueryResult { + let page_filter = get_page_filter(ordering, next_item); + if page_filter.is_empty() { + return Err(QueryError::PageTokenFailure); + } + let plaintext = page_filter.to_string(); + + let key = make_page_key::<32>(filter, ordering).ok_or(QueryError::PageTokenFailure)?; + let key: &Key = (&key).into(); + + let cipher = Aes256Gcm::new(key); + // 96-bits; unique per message + let nonce = Aes256Gcm::generate_nonce(&mut OsRng); + let mut encrypted = cipher.encrypt(&nonce, plaintext.as_bytes()).unwrap(); + // Prepend nonce to encrypted buffer + encrypted.splice(0..0, nonce); + + if self.url_safe { + Ok(Base64Url::encode_string(&encrypted)) + } else { + Ok(Base64::encode_string(&encrypted)) + } + } +} + +#[cfg(test)] +mod tests { + use crate::testing::schema::UserItem; + + use super::*; + + #[test] + fn it_works() { + let b = Aes256PageTokenBuilder::new(true); + let filter = Filter::parse(r#"displayName = "John""#).unwrap(); + let ordering = Ordering::parse("id desc, age desc").unwrap(); + let page_token = b + .build_next( + &filter, + &ordering, + &UserItem { + id: "1337".into(), + display_name: "John".into(), + age: 14000, + }, + ) + .unwrap(); + assert!(page_token.trim().len() > NONCE_LENGTH); + assert!(!page_token.contains(r#"id <= "1337""#)); + let parsed = b.parse(&filter, &ordering, &page_token).unwrap(); + assert_eq!( + parsed.filter.to_string(), + r#"id <= "1337" AND age <= 14000"# + ); + } + + #[test] + fn errors() { + let b = Aes256PageTokenBuilder::new(true); + + // Generate key for different parameters + let filter = Filter::parse("id=1").unwrap(); + let ordering = Ordering::parse("age desc").unwrap(); + let page_token = b + .build_next( + &filter, + &ordering, + &UserItem { + id: "1337".into(), + display_name: "John".into(), + age: 14000, + }, + ) + .unwrap(); + let parsed = b.parse(&filter, &ordering, &page_token).unwrap(); + assert_eq!(parsed.filter.to_string(), r#"age <= 14000"#); + assert_eq!( + b.parse( + &Filter::parse("id=2").unwrap(), + &Ordering::parse("age desc").unwrap(), + &page_token + ) + .unwrap_err(), + QueryError::InvalidPageToken + ); + assert_eq!( + b.parse( + &Filter::parse("id=1").unwrap(), + &Ordering::parse("age asc").unwrap(), + &page_token + ) + .unwrap_err(), + QueryError::InvalidPageToken + ); + } +} diff --git a/bomboni_request/src/query/page_token/base64.rs b/bomboni_request/src/query/page_token/base64.rs new file mode 100644 index 0000000..a08e2e7 --- /dev/null +++ b/bomboni_request/src/query/page_token/base64.rs @@ -0,0 +1,93 @@ +use base64ct::{Base64, Base64Url, Encoding}; + +use crate::{ + filter::Filter, + ordering::Ordering, + query::{ + error::{QueryError, QueryResult}, + page_token::utility::get_page_filter, + }, +}; + +use super::{FilterPageToken, PageTokenBuilder}; + +/// Page token builder for Base64-encoded tokens. +pub struct Base64PageTokenBuilder { + url_safe: bool, +} + +impl Base64PageTokenBuilder { + pub fn new(url_safe: bool) -> Self { + Base64PageTokenBuilder { url_safe } + } +} + +impl PageTokenBuilder for Base64PageTokenBuilder { + type PageToken = FilterPageToken; + + fn parse( + &self, + _filter: &Filter, + _ordering: &Ordering, + page_token: &str, + ) -> QueryResult { + let decoded = if self.url_safe { + Base64Url::decode_vec(page_token).map_err(|_| QueryError::InvalidPageToken)? + } else { + Base64::decode_vec(page_token).map_err(|_| QueryError::InvalidPageToken)? + }; + let page_filter = + Filter::parse(&String::from_utf8(decoded).map_err(|_| QueryError::InvalidPageToken)?)?; + Ok(Self::PageToken { + filter: page_filter, + }) + } + + fn build_next( + &self, + _filter: &Filter, + ordering: &Ordering, + next_item: &T, + ) -> QueryResult { + let page_filter = get_page_filter(ordering, next_item); + if page_filter.is_empty() { + return Err(QueryError::PageTokenFailure); + } + if self.url_safe { + Ok(Base64Url::encode_string(page_filter.to_string().as_bytes())) + } else { + Ok(Base64::encode_string(page_filter.to_string().as_bytes())) + } + } +} + +#[cfg(test)] +mod tests { + use crate::testing::schema::UserItem; + + use super::*; + + #[test] + fn it_works() { + let b = Base64PageTokenBuilder::new(true); + let filter = Filter::parse(r#"displayName = "John""#).unwrap(); + let ordering = Ordering::parse("id desc, age desc").unwrap(); + let page_token = b + .build_next( + &filter, + &ordering, + &UserItem { + id: "1337".into(), + display_name: "John".into(), + age: 14000, + }, + ) + .unwrap(); + assert_eq!(page_token, "aWQgPD0gIjEzMzciIEFORCBhZ2UgPD0gMTQwMDA="); + let parsed = b.parse(&filter, &ordering, &page_token).unwrap(); + assert_eq!( + parsed.filter.to_string(), + r#"id <= "1337" AND age <= 14000"# + ); + } +} diff --git a/bomboni_request/src/query/page_token/mod.rs b/bomboni_request/src/query/page_token/mod.rs new file mode 100644 index 0000000..856965f --- /dev/null +++ b/bomboni_request/src/query/page_token/mod.rs @@ -0,0 +1,46 @@ +//! The page token is used to determine the next page of results. +//! How it is used to query the database is implementation-specific. +//! One way is to filter IDs greater than the last item's ID of the previous page. +//! If the query parameters change, then the page token is invalid. +//! To ensure that a valid token is used, we can encrypt it along with the query parameters and decrypt it before use. +//! Encryption is also desirable to prevent users from guessing the next page of results, or to hide sensitive information. + +use crate::{filter::Filter, ordering::Ordering, schema::SchemaMapped}; + +use super::error::QueryResult; +pub mod aes256; +pub mod base64; +pub mod plain; +pub mod rsa; +mod utility; + +/// A page token containing a filter. +#[derive(Debug, Clone)] +pub struct FilterPageToken { + pub filter: Filter, +} + +pub trait PageTokenBuilder { + type PageToken: Clone; + + /// Parse a page token. + /// [`QueryError::InvalidPageToken`] is returned if the page token is invalid for any reason. + fn parse( + &self, + filter: &Filter, + ordering: &Ordering, + page_token: &str, + ) -> QueryResult; + + /// Build a page token for the next page of results. + /// + /// Note that "last item" is not necessarily the last item of the page, but N+1th one. + /// We can fetch page_size+1 items from the database to determine if there are more results. + /// [`QueryError::PageTokenFailure`] is returned if the page token could not be built. + fn build_next( + &self, + filter: &Filter, + ordering: &Ordering, + next_item: &T, + ) -> QueryResult; +} diff --git a/bomboni_request/src/query/page_token/plain.rs b/bomboni_request/src/query/page_token/plain.rs new file mode 100644 index 0000000..cb6aa78 --- /dev/null +++ b/bomboni_request/src/query/page_token/plain.rs @@ -0,0 +1,41 @@ +use crate::{ + filter::Filter, + ordering::Ordering, + query::error::{QueryError, QueryResult}, + schema::SchemaMapped, +}; + +use super::{utility::get_page_filter, FilterPageToken, PageTokenBuilder}; + +/// Plain text page token builder. +/// Used only for testing. +pub struct PlainPageTokenBuilder {} + +impl PageTokenBuilder for PlainPageTokenBuilder { + type PageToken = FilterPageToken; + + fn parse( + &self, + _filter: &Filter, + _ordering: &Ordering, + page_token: &str, + ) -> QueryResult { + let page_filter = Filter::parse(page_token)?; + Ok(Self::PageToken { + filter: page_filter, + }) + } + + fn build_next( + &self, + _filter: &Filter, + ordering: &Ordering, + next_item: &T, + ) -> QueryResult { + let page_filter = get_page_filter(ordering, next_item); + if page_filter.is_empty() { + return Err(QueryError::PageTokenFailure); + } + Ok(format!("{}", page_filter)) + } +} diff --git a/bomboni_request/src/query/page_token/rsa.rs b/bomboni_request/src/query/page_token/rsa.rs new file mode 100644 index 0000000..be5a5b2 --- /dev/null +++ b/bomboni_request/src/query/page_token/rsa.rs @@ -0,0 +1,184 @@ +use crate::{ + filter::Filter, + ordering::Ordering, + query::{ + error::{QueryError, QueryResult}, + page_token::utility::get_page_filter, + }, +}; + +use super::{utility::make_page_key, FilterPageToken, PageTokenBuilder}; +use base64ct::{Base64, Base64Url, Encoding}; +use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey}; + +const PARAMS_KEY_LENGTH: usize = 32; + +/// Page token builder for RSA-encrypted tokens. +pub struct RsaPageTokenBuilder { + private_key: RsaPrivateKey, + public_key: RsaPublicKey, + url_safe: bool, +} + +impl RsaPageTokenBuilder { + pub fn new(private_key: RsaPrivateKey, public_key: RsaPublicKey, url_safe: bool) -> Self { + RsaPageTokenBuilder { + private_key, + public_key, + url_safe, + } + } +} + +impl PageTokenBuilder for RsaPageTokenBuilder { + type PageToken = FilterPageToken; + + fn parse( + &self, + filter: &Filter, + ordering: &Ordering, + page_token: &str, + ) -> QueryResult { + let decoded = if self.url_safe { + Base64Url::decode_vec(page_token).map_err(|_| QueryError::InvalidPageToken)? + } else { + Base64::decode_vec(page_token).map_err(|_| QueryError::InvalidPageToken)? + }; + + let mut plaintext = self + .private_key + .decrypt(Pkcs1v15Encrypt, &decoded) + .map_err(|_| QueryError::InvalidPageToken)?; + + // Verify key + let page_filter_text = plaintext.split_off(PARAMS_KEY_LENGTH); + let params_key = make_page_key::(filter, ordering) + .ok_or(QueryError::InvalidPageToken)?; + if params_key != plaintext.as_slice() { + return Err(QueryError::InvalidPageToken); + } + + let page_filter = Filter::parse( + &String::from_utf8(page_filter_text).map_err(|_| QueryError::InvalidPageToken)?, + ) + .map_err(|_| QueryError::InvalidPageToken)?; + + Ok(Self::PageToken { + filter: page_filter, + }) + } + + fn build_next( + &self, + filter: &Filter, + ordering: &Ordering, + next_item: &T, + ) -> QueryResult { + let page_filter = get_page_filter(ordering, next_item); + if page_filter.is_empty() { + return Err(QueryError::PageTokenFailure); + } + + // Include both filter and ordering into encryption. + let mut plaintext = make_page_key::(filter, ordering) + .ok_or(QueryError::PageTokenFailure)? + .to_vec(); + plaintext.extend(page_filter.to_string().as_bytes()); + + let mut rng = rand::thread_rng(); + let encrypted = self + .public_key + .encrypt(&mut rng, Pkcs1v15Encrypt, &plaintext) + .unwrap(); + + if self.url_safe { + Ok(Base64Url::encode_string(&encrypted)) + } else { + Ok(Base64::encode_string(&encrypted)) + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::OnceLock; + + use crate::testing::schema::UserItem; + + use super::*; + + #[test] + fn it_works() { + let b = get_builder(); + let filter = Filter::parse(r#"displayName = "John""#).unwrap(); + let ordering = Ordering::parse("id desc, age desc").unwrap(); + let page_token = b + .build_next( + &filter, + &ordering, + &UserItem { + id: "1337".into(), + display_name: "John".into(), + age: 14000, + }, + ) + .unwrap(); + assert!(page_token.trim().len() > 16); + assert!(!page_token.contains(r#"id <= "1337""#)); + let parsed = b.parse(&filter, &ordering, &page_token).unwrap(); + assert_eq!( + parsed.filter.to_string(), + r#"id <= "1337" AND age <= 14000"# + ); + } + + #[test] + fn errors() { + let b = get_builder(); + + // Generate key for different parameters + let filter = Filter::parse("id=1").unwrap(); + let ordering = Ordering::parse("age desc").unwrap(); + let page_token = b + .build_next( + &filter, + &ordering, + &UserItem { + id: "1337".into(), + display_name: "John".into(), + age: 14000, + }, + ) + .unwrap(); + let parsed = b.parse(&filter, &ordering, &page_token).unwrap(); + assert_eq!(parsed.filter.to_string(), r#"age <= 14000"#); + assert_eq!( + b.parse( + &Filter::parse("id=2").unwrap(), + &Ordering::parse("age desc").unwrap(), + &page_token + ) + .unwrap_err(), + QueryError::InvalidPageToken + ); + assert_eq!( + b.parse( + &Filter::parse("id=1").unwrap(), + &Ordering::parse("age asc").unwrap(), + &page_token + ) + .unwrap_err(), + QueryError::InvalidPageToken + ); + } + + fn get_builder() -> &'static RsaPageTokenBuilder { + static SINGLETON: OnceLock = OnceLock::new(); + SINGLETON.get_or_init(|| { + let mut rng = rand::thread_rng(); + let private_key = RsaPrivateKey::new(&mut rng, 720).unwrap(); + let public_key = RsaPublicKey::from(&private_key); + RsaPageTokenBuilder::new(private_key, public_key, true) + }) + } +} diff --git a/bomboni_request/src/query/page_token/utility.rs b/bomboni_request/src/query/page_token/utility.rs new file mode 100644 index 0000000..2af988c --- /dev/null +++ b/bomboni_request/src/query/page_token/utility.rs @@ -0,0 +1,56 @@ +use blake2::Blake2s256; + +use crate::filter::FilterComparator; +use crate::ordering::OrderingDirection; +use crate::value::Value; +use crate::{filter::Filter, schema::SchemaMapped}; +use blake2::Digest; + +use crate::ordering::Ordering; + +/// Constructs a filter that selects items greater than `next_item` based on ordering. +/// For example, if the ordering is "age desc", then the filter will be "age <= {next_item.age}". +/// "Equals" (>=, <=) is used to ensure that the next item is included in the results. +pub fn get_page_filter(ordering: &Ordering, next_item: &T) -> Filter { + let mut filters = Vec::new(); + + for term in ordering.terms.iter() { + let term_argument = match next_item.get_field(&term.name) { + Value::Integer(value) => Filter::Value(value.into()), + Value::Float(value) => Filter::Value(value.into()), + Value::Boolean(value) => Filter::Value(value.into()), + Value::String(value) => Filter::Value(value.into()), + Value::Timestamp(value) => Filter::Value(value.into()), + Value::Repeated(value) => Filter::Value(value.into()), + Value::Any => Filter::Value(Value::Any), + }; + + filters.push(Filter::Restriction( + Box::new(Filter::Name(term.name.clone())), + match term.direction { + OrderingDirection::Ascending => FilterComparator::GreaterOrEqual, + OrderingDirection::Descending => FilterComparator::LessOrEqual, + }, + Box::new(term_argument), + )); + } + + // Disjunction? + Filter::Conjunction(filters) +} + +/// Constructs a page key from a filter and ordering. +/// The key should be completely different for different filters and orderings. +pub fn make_page_key(filter: &Filter, ordering: &Ordering) -> Option<[u8; N]> { + let mut hasher = Blake2s256::new(); + + hasher.update(filter.to_string().as_bytes()); + hasher.update(ordering.to_string().as_bytes()); + + let res = hasher.finalize(); + // TODO: other than 32 bytes? + debug_assert_eq!(res.len(), N); + let key: [u8; N] = res.as_slice().try_into().unwrap(); + + Some(key) +} diff --git a/bomboni_request/src/query/search.rs b/bomboni_request/src/query/search.rs new file mode 100644 index 0000000..67c3099 --- /dev/null +++ b/bomboni_request/src/query/search.rs @@ -0,0 +1,217 @@ +//! # Search query. +//! +//! Utility for working with fuzzy search queries. + +use crate::{ + filter::Filter, + ordering::{Ordering, OrderingTerm}, + schema::Schema, +}; + +use super::{ + error::{QueryError, QueryResult}, + page_token::PageTokenBuilder, + utility::{parse_query_filter, parse_query_ordering}, +}; + +#[derive(Debug, Clone)] +pub struct SearchQuery { + pub query: String, + pub filter: Filter, + pub ordering: Ordering, + pub page_size: i32, + pub page_token: Option, +} + +/// Config for search query builder. +/// +/// `primary_ordering_term` should probably never be `None`. +#[derive(Debug, Clone)] +pub struct SearchQueryConfig { + pub max_page_size: Option, + pub default_page_size: i32, + pub primary_ordering_term: Option, + pub max_query_length: Option, + pub max_filter_length: Option, + pub max_ordering_length: Option, +} + +pub struct SearchQueryBuilder { + schema: Schema, + options: SearchQueryConfig, + page_token_builder: P, +} + +impl Default for SearchQueryConfig { + fn default() -> Self { + SearchQueryConfig { + max_page_size: None, + default_page_size: 20, + primary_ordering_term: None, + max_query_length: None, + max_filter_length: None, + max_ordering_length: None, + } + } +} + +impl SearchQueryBuilder

{ + pub fn new(schema: Schema, options: SearchQueryConfig, page_token_builder: P) -> Self { + SearchQueryBuilder { + schema, + options, + page_token_builder, + } + } + + pub fn build( + &self, + query: &str, + page_size: Option, + page_token: Option<&str>, + filter: Option<&str>, + ordering: Option<&str>, + ) -> QueryResult> { + if matches!(self.options.max_query_length, Some(max) if query.len() > max) { + return Err(QueryError::QueryTooLong); + } + + let filter = parse_query_filter(filter, &self.schema, self.options.max_filter_length)?; + let mut ordering = + parse_query_ordering(ordering, &self.schema, self.options.max_ordering_length)?; + + // Pre-insert primary ordering term. + // This is needed for page tokens to work. + if let Some(primary_ordering_term) = self.options.primary_ordering_term.as_ref() { + if ordering + .terms + .iter() + .all(|term| term.name != primary_ordering_term.name) + { + ordering.terms.insert(0, primary_ordering_term.clone()); + } + } + + // Handle paging. + let mut page_size = page_size.unwrap_or(self.options.default_page_size); + if page_size < 0 { + return Err(QueryError::InvalidPageSize); + } + if let Some(max_page_size) = self.options.max_page_size { + // Intentionally clamp page size to max page size. + if page_size > max_page_size { + page_size = max_page_size; + } + } + + let page_token = + if let Some(page_token) = page_token.filter(|page_token| !page_token.is_empty()) { + Some( + self.page_token_builder + .parse(&filter, &ordering, page_token)?, + ) + } else { + None + }; + + Ok(SearchQuery { + query: query.into(), + filter, + ordering, + page_size, + page_token, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + filter::error::FilterError, ordering::OrderingDirection, + query::page_token::plain::PlainPageTokenBuilder, testing::schema::UserItem, + }; + + use super::*; + + #[test] + fn it_works() { + let qb = get_query_builder(); + let query = qb + .build( + "abc", + Some(10_000), + None, + Some("displayName = \"John\""), + Some("age desc"), + ) + .unwrap(); + assert_eq!(query.page_size, 20); + assert_eq!(query.filter.to_string(), "displayName = \"John\""); + assert_eq!(query.ordering.to_string(), "id desc, age desc"); + } + + #[test] + fn errors() { + let q = get_query_builder(); + assert_eq!( + q.build( + &("a".repeat(100)), + None, + None, + Some(&("a".repeat(100))), + None + ) + .unwrap_err(), + QueryError::QueryTooLong + ); + assert!(matches!( + q.build("abc", Some(-1), None, None, None), + Err(QueryError::InvalidPageSize) + )); + assert!(matches!( + q.build("abc", Some(-1), None, None, None), + Err(QueryError::InvalidPageSize) + )); + assert!(matches!( + q.build("abc", None, None, Some("f!"), None).unwrap_err(), + QueryError::FilterError(FilterError::Parse { start, end }) + if start == 1 && end == 1 + )); + assert_eq!( + q.build("abc", None, None, Some(&("a".repeat(100))), None) + .unwrap_err(), + QueryError::FilterTooLong + ); + assert_eq!( + q.build("abc", None, None, Some("lol"), None).unwrap_err(), + QueryError::FilterSchemaMismatch + ); + assert_eq!( + q.build("abc", None, None, None, Some(&("a".repeat(100)))) + .unwrap_err(), + QueryError::OrderingTooLong + ); + assert_eq!( + q.build("abc", None, None, None, Some("lol")).unwrap_err(), + QueryError::OrderingSchemaMismatch + ); + } + + fn get_query_builder() -> SearchQueryBuilder { + SearchQueryBuilder::::new( + UserItem::get_schema(), + SearchQueryConfig { + max_page_size: Some(20), + default_page_size: 10, + primary_ordering_term: Some(OrderingTerm { + name: "id".into(), + direction: OrderingDirection::Descending, + }), + max_query_length: Some(50), + max_filter_length: Some(50), + max_ordering_length: Some(50), + }, + PlainPageTokenBuilder {}, + ) + } +} diff --git a/bomboni_request/src/query/utility.rs b/bomboni_request/src/query/utility.rs new file mode 100644 index 0000000..dd4b9a4 --- /dev/null +++ b/bomboni_request/src/query/utility.rs @@ -0,0 +1,42 @@ +use crate::{filter::Filter, ordering::Ordering, schema::Schema}; + +use super::error::{QueryError, QueryResult}; + +pub fn parse_query_filter( + filter: Option<&str>, + schema: &Schema, + max_filter_length: Option, +) -> QueryResult { + // Empty string is considered as None, because an optional string can be "", from protobuf's side. + if let Some(filter) = filter.filter(|filter| !filter.is_empty()) { + if matches!(max_filter_length, Some(max) if filter.len() > max) { + return Err(QueryError::FilterTooLong); + } + let filter = Filter::parse(filter)?; + if !filter.is_valid(schema) { + return Err(QueryError::FilterSchemaMismatch); + } + Ok(filter) + } else { + Ok(Filter::default()) + } +} + +pub fn parse_query_ordering( + ordering: Option<&str>, + schema: &Schema, + max_ordering_length: Option, +) -> QueryResult { + if let Some(ordering) = ordering.filter(|ordering| !ordering.is_empty()) { + if matches!(max_ordering_length, Some(max) if ordering.len() > max) { + return Err(QueryError::OrderingTooLong); + } + let ordering = Ordering::parse(ordering)?; + if !ordering.is_valid(schema) { + return Err(QueryError::OrderingSchemaMismatch); + } + Ok(ordering) + } else { + Ok(Ordering::default()) + } +} diff --git a/bomboni_request/src/schema.rs b/bomboni_request/src/schema.rs index 1bd5c89..111cbec 100644 --- a/bomboni_request/src/schema.rs +++ b/bomboni_request/src/schema.rs @@ -50,28 +50,6 @@ pub trait SchemaMapped { } impl Schema { - pub fn is_ordered(&self, field: &str) -> bool { - if let Some(member) = self.get_member(field) { - match member { - MemberSchema::Resource(_) => false, - MemberSchema::Field(field) => field.ordered, - } - } else { - false - } - } - - pub fn is_repeated(&self, field: &str) -> bool { - if let Some(member) = self.get_member(field) { - match member { - MemberSchema::Resource(_) => false, - MemberSchema::Field(field) => field.repeated, - } - } else { - false - } - } - pub fn get_member(&self, name: &str) -> Option<&MemberSchema> { let mut member: Option<&MemberSchema> = None; for step in name.split('.') { @@ -93,6 +71,14 @@ impl Schema { } member } + + pub fn get_field(&self, name: &str) -> Option<&FieldMemberSchema> { + if let Some(MemberSchema::Field(field)) = self.get_member(name) { + Some(field) + } else { + None + } + } } impl FieldMemberSchema { @@ -156,7 +142,7 @@ mod tests { schema.get_member("task.deleted"), Some(MemberSchema::Field(field)) if field.value_type == ValueType::Boolean )); - assert!(schema.is_ordered("user.id")); - assert!(schema.is_repeated("task.tags")); + assert!(schema.get_field("user.id").unwrap().ordered); + assert!(schema.get_field("task.tags").unwrap().repeated); } } diff --git a/develop.sh b/develop.sh index 5f1e644..7b6982c 100755 --- a/develop.sh +++ b/develop.sh @@ -23,8 +23,8 @@ function lint() { } function test() { - cargo test --all-targets --all-features -- --nocapture - cargo test --doc --all-features -- --nocapture + cargo test --workspace --all-targets --all-features -- --nocapture + cargo test --workspace --doc --all-features -- --nocapture } function publish() {