diff --git a/Cargo.lock b/Cargo.lock index d2d776b..a9c6c87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.2" @@ -95,13 +107,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.74" +version = "0.1.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", ] [[package]] @@ -438,6 +450,12 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "endian-type" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" + [[package]] name = "enum-as-inner" version = "0.6.0" @@ -447,7 +465,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", ] [[package]] @@ -523,6 +541,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -588,7 +621,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", ] [[package]] @@ -690,6 +723,9 @@ name = "hashbrown" version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +dependencies = [ + "ahash", +] [[package]] name = "heck" @@ -1098,6 +1134,35 @@ dependencies = [ "autocfg", ] +[[package]] +name = "metrics" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "884adb57038347dfbaf2d5065887b6cf4312330dc8e94bc30a1a839bd79d3261" +dependencies = [ + "ahash", + "portable-atomic", +] + +[[package]] +name = "metrics-util" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4259040465c955f9f2f1a4a8a16dc46726169bca0f88e8fb2dbeced487c3e828" +dependencies = [ + "aho-corasick", + "crossbeam-epoch", + "crossbeam-utils", + "hashbrown 0.14.3", + "indexmap 2.1.0", + "metrics", + "num_cpus", + "ordered-float", + "quanta", + "radix_trie", + "sketches-ddsketch", +] + [[package]] name = "mime" version = "0.3.17" @@ -1157,6 +1222,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "nibble_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1198,7 +1272,7 @@ dependencies = [ [[package]] name = "object_store" version = "0.10.2" -source = "git+https://github.com/andrebsguedes/arrow-rs.git?branch=unsigned-payload-and-azure-list-offset#5f3bf4e235194ab818e67d2a69e6695970af5b63" +source = "git+https://github.com/andrebsguedes/arrow-rs.git?tag=v0.10.2-beta1#3e15b9a308e29479bc33e4f06855227d93bf88a6" dependencies = [ "async-trait", "base64", @@ -1226,12 +1300,14 @@ dependencies = [ [[package]] name = "object_store_ffi" -version = "0.8.2" +version = "0.9.0" dependencies = [ "anyhow", "async-channel", "async-compression", + "async-trait", "backoff", + "base64", "bytes", "criterion", "crossbeam-queue", @@ -1239,20 +1315,28 @@ dependencies = [ "flume", "futures-util", "hyper", + "metrics", + "metrics-util", "moka", "object_store", "once_cell", + "openssl", "pin-project", + "quanta", + "rand", "regex", "reqwest", "serde", "serde_json", + "serde_path_to_error", "thiserror", "tokio", "tokio-util", "tracing", "tracing-subscriber", "url", + "uuid", + "zeroize", ] [[package]] @@ -1267,12 +1351,69 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "openssl" +version = "0.10.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +dependencies = [ + "bitflags 2.4.1", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-src" +version = "300.3.1+3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7259953d42a81bf137fbbd73bd30a8e1914d6dce43c2b90ed575783a22608b91" +dependencies = [ + "cc", +] + +[[package]] +name = "openssl-sys" +version = "0.9.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +dependencies = [ + "cc", + "libc", + "openssl-src", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "ordered-float" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537" +dependencies = [ + "num-traits", +] + [[package]] name = "os_str_bytes" version = "6.6.1" @@ -1337,7 +1478,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", ] [[package]] @@ -1358,6 +1499,12 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" +[[package]] +name = "portable-atomic" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1366,9 +1513,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.70" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -1453,13 +1600,23 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.33" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" dependencies = [ "proc-macro2", ] +[[package]] +name = "radix_trie" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" +dependencies = [ + "endian-type", + "nibble_vec", +] + [[package]] name = "rand" version = "0.8.5" @@ -1759,22 +1916,22 @@ checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" [[package]] name = "serde" -version = "1.0.193" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.193" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", ] [[package]] @@ -1788,6 +1945,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1818,6 +1985,12 @@ dependencies = [ "libc", ] +[[package]] +name = "sketches-ddsketch" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c" + [[package]] name = "slab" version = "0.4.9" @@ -1893,9 +2066,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.40" +version = "2.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13fa70a4ee923979ffb522cacce59d34421ebdea5625e1073c4326ef9d2dd42e" +checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" dependencies = [ "proc-macro2", "quote", @@ -1937,7 +2110,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", ] [[package]] @@ -2001,7 +2174,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", ] [[package]] @@ -2075,7 +2248,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", ] [[package]] @@ -2175,9 +2348,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.6.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" dependencies = [ "getrandom", ] @@ -2188,6 +2361,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.4" @@ -2240,7 +2419,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", "wasm-bindgen-shared", ] @@ -2274,7 +2453,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.79", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2515,6 +2694,26 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index f839a1c..db497ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "object_store_ffi" -version = "0.8.2" +version = "0.9.0" edition = "2021" [[bench]] @@ -18,7 +18,7 @@ bench = false # https://doc.rust-lang.org/cargo/reference/profiles.html [profile.release] -debug = 1 +debug = true [features] default = ["julia"] @@ -34,7 +34,7 @@ futures-util = "0.3" reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "hickory-dns"] } # object_store = { version = "0.10.1", features = ["azure", "aws"] } # Pinned to a specific commit while waiting for upstream -object_store = { git = "https://github.com/andrebsguedes/arrow-rs.git", branch = "unsigned-payload-and-azure-list-offset", features = ["azure", "aws", "experimental-azure-list-offset", "experimental-arbitrary-list-prefix"] } +object_store = { git = "https://github.com/andrebsguedes/arrow-rs.git", tag = "v0.10.2-beta1", features = ["azure", "aws", "experimental-azure-list-offset", "experimental-arbitrary-list-prefix"] } thiserror = "1" anyhow = { version = "1", features = ["backtrace"] } once_cell = "1.18" @@ -51,6 +51,21 @@ flate2 = { version = "1.0.28", features=["zlib-ng"], default-features = false} async-compression = { version = "0.4.6", default-features = false, features = ["tokio", "gzip", "zlib", "deflate", "zstd"] } flume = "0.11.0" pin-project = "1.1.5" +uuid = { version = "1.10.0", features = ["v4"] } +base64 = "0.22.1" +rand = "0.8.5" +zeroize = "1.8.1" +async-trait = "0.1.81" +serde_path_to_error = "0.1.16" +metrics = "0.23.0" +metrics-util = "0.17.0" +quanta = "0.12.3" + +[target.'cfg(not(target_os = "macos"))'.dependencies] +openssl = { version = "0.10.66" } + +[target.'cfg(target_os = "macos")'.dependencies] +openssl = { version = "0.10.66", features = ["vendored"] } [dev-dependencies] criterion = { version = "0.4", default-features = false, features = ["cargo_bench_support", "html_reports"] } diff --git a/Cross.toml b/Cross.toml new file mode 100644 index 0000000..ed66fa6 --- /dev/null +++ b/Cross.toml @@ -0,0 +1,11 @@ +[target.aarch64-unknown-linux-gnu] +pre-build = [ + "dpkg --add-architecture $CROSS_DEB_ARCH", + "apt-get update && apt-get --assume-yes install libssl-dev:$CROSS_DEB_ARCH" +] + +[target.x86_64-unknown-linux-gnu] +pre-build = [ + "dpkg --add-architecture $CROSS_DEB_ARCH", + "apt-get update && apt-get --assume-yes install libssl-dev:$CROSS_DEB_ARCH" +] diff --git a/src/crud_ops.rs b/src/crud_ops.rs index 6d2ea41..d3d2b51 100644 --- a/src/crud_ops.rs +++ b/src/crud_ops.rs @@ -1,11 +1,13 @@ -use crate::{CResult, Client, RawConfig, NotifyGuard, SQ, static_config, Request, util::cstr_to_path, Context, RawResponse, ResponseGuard}; +use crate::{duration_on_drop, encryption::{encrypt, CrypterReader, CrypterWriter, Mode}, error::Kind as ErrorKind, export_queued_op, metrics, util::cstr_to_path, with_retries, BoxedReader, BoxedUpload, CResult, Client, Context, NotifyGuard, RawConfig, RawResponse, Request, ResponseGuard, SQ}; +use bytes::Bytes; +use ::metrics::counter; use object_store::{path::Path, ObjectStore}; -use anyhow::anyhow; -use std::ffi::{c_char, c_void}; -use futures_util::StreamExt; -use tokio::io::AsyncWriteExt; +use tokio_util::io::StreamReader; +use std::{ffi::{c_char, c_void}, sync::Arc}; +use futures_util::{stream, StreamExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; // The type used to give Julia the result of an async request. It will be allocated // by Julia as part of the request and filled in by Rust. @@ -42,206 +44,217 @@ impl RawResponse for Response { } } -async fn multipart_get(slice: &mut [u8], path: &Path, client: &Client) -> anyhow::Result { - let result = client.store.head(&path).await?; - if result.size > slice.len() { - return Err(anyhow!("Supplied buffer was too small")); - } - - let part_ranges = crate::util::size_to_ranges(result.size); - - let result_vec = client.store.get_ranges(&path, &part_ranges).await?; - let mut accum: usize = 0; - for i in 0..result_vec.len() { - slice[accum..accum + result_vec[i].len()].copy_from_slice(&result_vec[i]); - accum += result_vec[i].len(); - } - - return Ok(accum); -} - -async fn multipart_put(slice: &[u8], path: &Path, client: Client) -> anyhow::Result<()> { - let mut writer = object_store::buffered::BufWriter::with_capacity( - client.store, - path.clone(), - 10 * 1024 * 1024 - ) - .with_max_concurrency(64); - match writer.write_all(slice).await { - Ok(_) => { - match writer.flush().await { - Ok(_) => { - writer.shutdown().await?; - return Ok(()); - } - Err(e) => { - writer.abort().await?; - return Err(e.into()); - } - } - } - Err(e) => { - writer.abort().await?; - return Err(e.into()); - } - }; -} - -pub(crate) async fn handle_get(client: Client, slice: &mut [u8], path: &Path) -> anyhow::Result { - // Multipart Get - if slice.len() > static_config().multipart_get_threshold as usize { - let accum = multipart_get(slice, path, &client).await?; - return Ok(accum); - } - - // Single part Get - let body = client.store.get(path).await?; - let mut batch_stream = body.into_stream().chunks(8); - +async fn read_to_slice(reader: &mut BoxedReader, mut slice: &mut [u8]) -> crate::Result { let mut received_bytes = 0; - while let Some(batch) = batch_stream.next().await { - for result in batch { - let chunk = match result { - Ok(c) => c, - Err(e) => { - let err = anyhow::Error::new(e); - tracing::warn!("Error while fetching a chunk: {:#}", err); - return Err(err); + loop { + match reader.read_buf(&mut slice).await { + Ok(0) => { + if slice.len() == 0 { + // TODO is there a better way to check for this? + let mut scratch = [0u8; 1]; + if let Ok(0) = reader.read_buf(&mut scratch.as_mut_slice()).await { + // slice was the exact size, done + break; + } else { + return Err(ErrorKind::BufferTooSmall.into()); + } + } else { + // done + break; } - }; - - let len = chunk.len(); - - if received_bytes + len > slice.len() { - return Err(anyhow!("Supplied buffer was too small")); } - - slice[received_bytes..(received_bytes + len)].copy_from_slice(&chunk); - received_bytes += len; + Ok(n) => received_bytes += n, + Err(e) => { + let err = ErrorKind::BodyIo(e); + tracing::warn!("Error while reading body: {}", err); + return Err(err.into()); + } } } Ok(received_bytes) } -pub(crate) async fn handle_put(client: Client, slice: &'static [u8], path: &Path) -> anyhow::Result { - let len = slice.len(); - if len < static_config().multipart_put_threshold as usize { - let _ = client.store.put(path, slice.into()).await?; - } else { - let _ = multipart_put(slice, path, client).await?; - } +impl Client { + async fn get_impl(&self, path: &Path, slice: &mut [u8]) -> crate::Result { + let guard = duration_on_drop!(metrics::get_attempt_duration); + let path = &self.full_path(path); - Ok(len) -} + // Multipart Get + if slice.len() > self.config.multipart_get_threshold { + guard.discard(); + return self.multipart_get_impl(path, slice).await + } -pub(crate) async fn handle_delete(client: Client, path: &Path) -> anyhow::Result { - client.store.delete(path).await?; + // Single part Get + let result = self.store.get(path).await?; + let attributes = result.attributes.clone(); - Ok(0) -} + let mut reader: Box = Box::new(StreamReader::new(result.into_stream())); -#[no_mangle] -pub extern "C" fn get( - path: *const c_char, - buffer: *mut u8, - size: usize, - config: *const RawConfig, - response: *mut Response, - handle: *const c_void -) -> CResult { - let response = unsafe { ResponseGuard::new(response, handle) }; - let path = unsafe { std::ffi::CStr::from_ptr(path) }; - let path = unsafe{ cstr_to_path(path) }; - let slice = unsafe { std::slice::from_raw_parts_mut(buffer, size) }; - let config = unsafe { & (*config) }; - match SQ.get() { - Some(sq) => { - match sq.try_send(Request::Get(path, slice, config, response)) { - Ok(_) => CResult::Ok, - Err(flume::TrySendError::Full(Request::Get(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel full, backoff"); - CResult::Backoff - } - Err(flume::TrySendError::Disconnected(Request::Get(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - CResult::Error - } - _ => unreachable!("the response type must match") - } - } - None => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - return CResult::Error; + if let Some(cryptmp) = self.crypto_material_provider.as_ref() { + let material = cryptmp.material_from_metadata(path.as_ref(), &attributes).await?; + let decrypter_reader = CrypterReader::new(reader, Mode::Decrypt, &material) + .map_err(ErrorKind::ContentDecrypt)?; + reader = Box::new(decrypter_reader); } - } -} -#[no_mangle] -pub extern "C" fn put( - path: *const c_char, - buffer: *const u8, - size: usize, - config: *const RawConfig, - response: *mut Response, - handle: *const c_void -) -> CResult { - let response = unsafe { ResponseGuard::new(response, handle) }; - let path = unsafe { std::ffi::CStr::from_ptr(path) }; - let path = unsafe{ cstr_to_path(path) }; - let slice = unsafe { std::slice::from_raw_parts(buffer, size) }; - let config = unsafe { & (*config) }; - match SQ.get() { - Some(sq) => { - match sq.try_send(Request::Put(path, slice, config, response)) { - Ok(_) => CResult::Ok, - Err(flume::TrySendError::Full(Request::Put(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel full, backoff"); - CResult::Backoff - } - Err(flume::TrySendError::Disconnected(Request::Put(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - CResult::Error - } - _ => unreachable!("the response type must match") + Ok(read_to_slice(&mut reader, slice).await?) + } + pub async fn get(&self, path: &Path, slice: &mut [u8]) -> crate::Result { + counter!(metrics::total_get_ops).increment(1); + with_retries!(self, self.get_impl(path, slice).await) + } + async fn put_impl(&self, path: &Path, slice: Bytes) -> crate::Result { + let guard = duration_on_drop!(metrics::put_attempt_duration); + let path = &self.full_path(path); + let len = slice.len(); + if len < self.config.multipart_put_threshold { + if let Some(cryptmp) = self.crypto_material_provider.as_ref() { + let (material, attrs) = cryptmp.material_for_write(path.as_ref(), Some(slice.len())).await?; + let ciphertext = if slice.is_empty() { + // Do not write any padding if there was no data + vec![] + } else { + encrypt(&slice, &material) + .map_err(ErrorKind::ContentEncrypt)? + }; + let _ = self.store.put_opts( + path, + ciphertext.into(), + attrs.into() + ).await?; + } else { + let _ = self.store.put(path, slice.into()).await?; } + } else { + guard.discard(); + return self.multipart_put_impl(path, &slice).await; } - None => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - return CResult::Error; + + Ok(len) + } + pub async fn put(&self, path: &Path, slice: Bytes) -> crate::Result { + counter!(metrics::total_put_ops).increment(1); + with_retries!(self, self.put_impl(path, slice.clone()).await) + } + async fn delete_impl(&self, path: &Path) -> crate::Result { + let _guard = duration_on_drop!(metrics::delete_attempt_duration); + let path = &self.full_path(path); + self.store.delete(path).await?; + Ok(0) + } + pub async fn delete(&self, path: &Path) -> crate::Result { + counter!(metrics::total_delete_ops).increment(1); + with_retries!(self, self.delete_impl(path).await) + } + async fn multipart_get_impl(&self, path: &Path, slice: &mut [u8]) -> crate::Result { + let _guard = duration_on_drop!(metrics::multipart_get_attempt_duration); + let result = self.store.get_opts( + &path, + object_store::GetOptions { + head: true, + ..Default::default() + } + ).await?; + + let part_ranges = crate::util::size_to_ranges(result.meta.size, self.config.multipart_get_part_size); + let result_vec = self.store.get_ranges(&path, &part_ranges).await?; + let mut reader: BoxedReader = Box::new(StreamReader::new(stream::iter(result_vec).map(|b| Ok::<_, std::io::Error>(b)))); + + if let Some(cryptmp) = self.crypto_material_provider.as_ref() { + let material = cryptmp.material_from_metadata(path.as_ref(), &result.attributes).await?; + let decrypter_reader = CrypterReader::new(reader, Mode::Decrypt, &material) + .map_err(ErrorKind::ContentDecrypt)?; + reader = Box::new(decrypter_reader); } + + Ok(read_to_slice(&mut reader, slice).await?) } -} + pub async fn multipart_get(&self, path: &Path, slice: &mut [u8]) -> crate::Result { + with_retries!(self, self.multipart_get_impl(path, slice).await) + } + async fn multipart_put_impl(&self, path: &Path, slice: &[u8]) -> crate::Result { + let _guard = duration_on_drop!(metrics::multipart_put_attempt_duration); + let mut writer: BoxedUpload = if let Some(cryptmp) = self.crypto_material_provider.as_ref() { + let (material, attrs) = cryptmp.material_for_write(path.as_ref(), Some(slice.len())).await?; + let writer = object_store::buffered::BufWriter::with_capacity( + Arc::clone(&self.store), + path.clone(), + self.config.multipart_put_part_size + ) + .with_attributes(attrs) + .with_max_concurrency(self.config.multipart_put_concurrency); + let encrypter_writer = CrypterWriter::new(writer, Mode::Encrypt, &material) + .map_err(ErrorKind::ContentEncrypt)?; + Box::new(encrypter_writer) + } else { + Box::new( + object_store::buffered::BufWriter::with_capacity( + Arc::clone(&self.store), + path.clone(), + self.config.multipart_put_part_size + ) + .with_max_concurrency(self.config.multipart_put_concurrency) + ) + }; -#[no_mangle] -pub extern "C" fn delete( - path: *const c_char, - config: *const RawConfig, - response: *mut Response, - handle: *const c_void -) -> CResult { - let response = unsafe { ResponseGuard::new(response, handle) }; - let path = unsafe { std::ffi::CStr::from_ptr(path) }; - let path = unsafe{ cstr_to_path(path) }; - let config = unsafe { & (*config) }; - match SQ.get() { - Some(sq) => { - match sq.try_send(Request::Delete(path, config, response)) { - Ok(_) => CResult::Ok, - Err(flume::TrySendError::Full(Request::Delete(_, _, response))) => { - response.into_error("object_store_ffi internal channel full, backoff"); - CResult::Backoff - } - Err(flume::TrySendError::Disconnected(Request::Delete(_, _, response))) => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - CResult::Error + match writer.write_all(slice).await { + Ok(_) => { + match writer.flush().await { + Ok(_) => { + writer.shutdown().await + .map_err(ErrorKind::BodyIo)?; + return Ok(slice.len()); + } + Err(e) => { + writer.abort().await?; + return Err(ErrorKind::BodyIo(e).into()); + } } - _ => unreachable!("the response type must match") } - } - None => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - return CResult::Error; - } + Err(e) => { + writer.abort().await?; + return Err(ErrorKind::BodyIo(e).into()); + } + }; + } + pub async fn multipart_put(&self, path: &Path, slice: &[u8]) -> crate::Result { + with_retries!(self, self.multipart_put_impl(path, slice).await) } } + +export_queued_op!( + get, + Response, + |config, response| { + let path = unsafe { std::ffi::CStr::from_ptr(path) }; + let path = unsafe{ cstr_to_path(path) }; + let slice = unsafe { std::slice::from_raw_parts_mut(buffer, size) }; + Ok(Request::Get(path, slice, config, response)) + }, + path: *const c_char, buffer: *mut u8, size: usize +); + +export_queued_op!( + put, + Response, + |config, response| { + let path = unsafe { std::ffi::CStr::from_ptr(path) }; + let path = unsafe{ cstr_to_path(path) }; + let slice = unsafe { std::slice::from_raw_parts(buffer, size) }; + Ok(Request::Put(path, slice, config, response)) + }, + path: *const c_char, buffer: *const u8, size: usize +); + +export_queued_op!( + delete, + Response, + |config, response| { + let path = unsafe { std::ffi::CStr::from_ptr(path) }; + let path = unsafe{ cstr_to_path(path) }; + Ok(Request::Delete(path, config, response)) + }, + path: *const c_char +); diff --git a/src/encryption.rs b/src/encryption.rs new file mode 100644 index 0000000..103e347 --- /dev/null +++ b/src/encryption.rs @@ -0,0 +1,1214 @@ +use openssl::symm::{self, encrypt_aead, decrypt_aead, Cipher, Crypter}; +use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use zeroize::Zeroize; +use core::{ + pin::Pin, + task::{Context, Poll}, +}; +use bytes::Bytes; +use futures_util::ready; +use std::{fmt::{self, Debug}, ops::{Deref, DerefMut}}; + +use base64::prelude::*; +use rand::{RngCore, rngs::OsRng}; +use object_store::Attributes; + +use crate::util::AsyncUpload; + +const AES_GCM_TAG_BYTES: usize = 16; + +#[async_trait::async_trait] +pub(crate) trait CryptoMaterialProvider: + Send + + Sync + + Debug + + 'static { + async fn material_for_write(&self, path: &str, data_len: Option) -> crate::Result<(ContentCryptoMaterial, Attributes)>; + async fn material_from_metadata(&self, path: &str, attr: &Attributes) -> crate::Result; +} + +#[derive(Debug, Copy, Clone)] +pub(crate) enum CryptoScheme { + Aes256Gcm, + Aes128Cbc +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub(crate) enum Mode { + Encrypt, + Decrypt +} + +impl From for openssl::symm::Mode { + fn from(v: Mode) -> openssl::symm::Mode { + match v { + Mode::Encrypt => openssl::symm::Mode::Encrypt, + Mode::Decrypt => openssl::symm::Mode::Decrypt + } + } +} + +impl CryptoScheme { + pub(crate) fn key_len(&self) -> usize { + self.cipher().key_len() + } + pub(crate) fn iv_len(&self) -> usize { + self.cipher().iv_len().expect("needs iv") + } + pub(crate) fn tag_len(&self) -> usize { + match self { + CryptoScheme::Aes256Gcm => AES_GCM_TAG_BYTES, + CryptoScheme::Aes128Cbc => 0 + } + } + pub(crate) fn cipher(&self) -> Cipher { + match self { + CryptoScheme::Aes256Gcm => Cipher::aes_256_gcm(), + CryptoScheme::Aes128Cbc => Cipher::aes_128_cbc(), + } + } +} + +#[derive(Clone)] +pub(crate) struct Iv { + bytes: Vec +} + +impl Deref for Iv { + type Target = [u8]; + + #[inline(always)] + fn deref(&self) -> &[u8] { + self.bytes.as_slice() + } +} + +impl DerefMut for Iv { + #[inline(always)] + fn deref_mut(&mut self) -> &mut [u8] { + self.bytes.as_mut_slice() + } +} + +impl Iv { + pub(crate) fn from_base64(iv: impl AsRef) -> Result { + Ok(Iv { bytes: BASE64_STANDARD.decode(iv.as_ref())? }) + } + pub(crate) fn generate(len: usize) -> Iv { + let mut bytes = vec![0; len]; + OsRng.fill_bytes(&mut bytes); + Iv { bytes } + } + #[allow(unused)] + pub(crate) fn len(&self) -> usize { + self.bytes.len() + } + pub(crate) fn as_base64(&self) -> String { + BASE64_STANDARD.encode(&self.bytes) + } +} + +#[derive(Clone)] +pub(crate) struct Key { + bytes: Vec +} + +impl Deref for Key { + type Target = [u8]; + + #[inline(always)] + fn deref(&self) -> &[u8] { + self.bytes.as_slice() + } +} + +impl DerefMut for Key { + #[inline(always)] + fn deref_mut(&mut self) -> &mut [u8] { + self.bytes.as_mut_slice() + } +} + +impl Key { + pub(crate) fn from_base64(key: impl AsRef) -> Result { + Ok(Key { bytes: BASE64_STANDARD.decode(key.as_ref())? }) + } + pub(crate) fn generate(len: usize) -> Key { + let mut bytes = vec![0; len]; + OsRng.fill_bytes(&mut bytes); + Key { bytes } + } + pub(crate) fn len(&self) -> usize { + self.bytes.len() + } + pub(crate) fn encrypt_aes_128_ecb(self, encryption_key: &Key) -> std::io::Result { + let cipher = Cipher::aes_128_ecb(); + if encryption_key.len() != cipher.key_len() { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid key size")); + } + + let encrypted_bytes = symm::encrypt(cipher, &encryption_key, None, &self)?; + + Ok(EncryptedKey { bytes: encrypted_bytes }) + } + #[allow(unused)] + pub(crate) fn as_base64(&self) -> String { + BASE64_STANDARD.encode(&self.bytes) + } +} + +impl Drop for Key { + fn drop(&mut self) { + self.bytes.zeroize(); + } +} + +// Always encrypted with aes_128_ecb for now +#[derive(Clone)] +pub(crate) struct EncryptedKey { + bytes: Vec, +} + +impl EncryptedKey { + pub(crate) fn from_base64(key: impl AsRef) -> Result { + Ok(EncryptedKey { bytes: BASE64_STANDARD.decode(key.as_ref())? }) + } + pub(crate) fn decrypt_aes_128_ecb(self, decryption_key: &Key) -> std::io::Result { + let cipher = Cipher::aes_128_ecb(); + if decryption_key.len() != cipher.key_len() { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid key size")); + } + let bytes = symm::decrypt(cipher, &decryption_key, None, &self.bytes)?; + + Ok(Key { bytes }) + } + pub(crate) fn as_base64(&self) -> String { + BASE64_STANDARD.encode(&self.bytes) + } +} + +impl Drop for EncryptedKey { + fn drop(&mut self) { + self.bytes.zeroize(); + } +} + +pub(crate) struct ContentCryptoMaterial { + pub scheme: CryptoScheme, + pub cek: Key, + pub iv: Iv, + pub aad: Option +} + +impl ContentCryptoMaterial { + pub fn generate(scheme: CryptoScheme) -> ContentCryptoMaterial { + let cek = Key::generate(scheme.key_len()); + let iv = Iv::generate(scheme.iv_len()); + ContentCryptoMaterial { scheme, cek, iv, aad: None } + } + + pub fn with_aad(self, aad: impl Into) -> ContentCryptoMaterial { + ContentCryptoMaterial { aad: Some(aad.into()), ..self } + } +} + +pub(crate) fn encrypt( + data: &[u8], + material: &ContentCryptoMaterial +) -> std::io::Result> { + match material.scheme { + CryptoScheme::Aes128Cbc => { + let ContentCryptoMaterial { scheme, cek, iv, .. } = material; + Ok(symm::encrypt(scheme.cipher(), &cek, Some(&iv), data)?) + + }, + CryptoScheme::Aes256Gcm => { + let ContentCryptoMaterial { scheme, cek, iv, aad, .. } = material; + let aad_ref = aad.as_deref().unwrap_or_default(); + let mut tag = vec![0; scheme.tag_len()]; + let mut ciphertext = encrypt_aead(scheme.cipher(), &cek, Some(&iv), aad_ref, data, &mut tag)?; + // Postfix tag + ciphertext.extend_from_slice(&tag); + Ok(ciphertext) + } + } +} + +#[allow(unused)] +pub(crate) fn decrypt( + ciphertext: &[u8], + material: &ContentCryptoMaterial +) -> std::io::Result> { + match material.scheme { + CryptoScheme::Aes128Cbc => { + let ContentCryptoMaterial { scheme, cek, iv, .. } = material; + Ok(symm::decrypt(scheme.cipher(), &cek, Some(&iv), ciphertext)?) + + }, + CryptoScheme::Aes256Gcm => { + let ContentCryptoMaterial { scheme, cek, iv, aad, .. } = material; + let aad_ref = aad.as_deref().unwrap_or_default(); + // Postfix tag + let tag_offset = ciphertext.len() - AES_GCM_TAG_BYTES; + let data = decrypt_aead(scheme.cipher(), &cek, Some(&iv), aad_ref, &ciphertext[..tag_offset], &ciphertext[tag_offset..])?; + Ok(data) + } + } +} + +// This implementation consists of a Buffer that has three contiguous regions: the consumed +// region, the filled region and the unfilled region. The buffer starts completely unfilled and +// its state is modified by calling: +// - advance: to indicate that useful data has been writen to the buffer, increasing the filled +// region (over the unfilled one) +// - consume: to indicate that useful data has been processed out of the buffer, incresing the +// consumed region (over the filled one) +// - compact: moves any bytes of the filled region to the start of the buffer, reseting the +// consumed and unfilled regions accordingly +// - clear: resets everything to the initial state +// +// consume ==>| advance ==>| +// +--------------+--------------+----------+ +// | consumed | filled | unfilled | +// +--------------+--------------+----------+ +struct Buffer { + buf: Vec, + pos: usize, + filled: usize +} + +impl Buffer { + fn with_capacity(cap: usize) -> Buffer { + Buffer { + buf: vec![0; cap], + pos: 0, + filled: 0 + } + } + + #[inline] + #[track_caller] + fn consume(&mut self, n: usize) { + assert!(self.pos + n <= self.filled); + self.pos += n; + } + + #[inline] + #[track_caller] + fn advance(&mut self, n: usize) { + assert!(self.filled + n <= self.capacity()); + self.filled += n; + } + + /// Returns the total capacity of the buffer. + #[inline] + pub fn capacity(&self) -> usize { + self.buf.len() + } + + /// Returns a shared reference to the filled portion of the buffer. + #[inline] + pub fn filled(&self) -> &[u8] { + &self.buf[self.pos..self.filled] + } + + /// Returns a mutable reference to the filled portion of the buffer. + #[inline] + #[allow(unused)] + pub fn filled_mut(&mut self) -> &mut [u8] { + &mut self.buf[self.pos..self.filled] + } + + #[inline] + pub fn unfilled(&self) -> &[u8] { + &self.buf[self.filled..] + } + + #[inline] + pub fn unfilled_mut(&mut self) -> &mut [u8] { + &mut self.buf[self.filled..] + } + + /// Returns the number of bytes at the end of the slice that have not yet been filled. + #[inline] + pub fn remaining(&self) -> usize { + self.capacity() - self.filled + } + + /// Returns the number of bytes at the end of the slice that have not yet been filled. + #[inline] + #[allow(unused)] + pub fn available(&self) -> usize { + self.filled - self.pos + } + + /// Clears the buffer, resetting the filled region to empty. + #[inline] + pub fn clear(&mut self) { + self.pos = 0; + self.filled = 0; + } + + /// Compacts the filled portion to the start of the buffer. + #[inline] + pub fn compact(&mut self) { + let len = self.filled - self.pos; + self.buf.copy_within(self.pos..self.filled, 0); + self.pos = 0; + self.filled = len; + } + + /// Appends data to the buffer, advancing the filled position. + /// + /// # Panics + /// + /// Panics if `self.remaining()` is less than `buf.len()`. + #[inline] + #[track_caller] + pub fn put_slice(&mut self, buf: &[u8]) { + assert!( + self.remaining() >= buf.len(), + "buf.len() must fit in remaining(); buf.len() = {}, remaining() = {}", + buf.len(), + self.remaining() + ); + + let amt = buf.len(); + // Cannot overflow, asserted above + let end = self.filled + amt; + + // Safety: the length is asserted above + unsafe { + self.buf[self.filled..end] + .as_mut_ptr() + .cast::() + .copy_from_nonoverlapping(buf.as_ptr(), amt); + } + + self.filled = end; + } +} + +#[derive(Debug)] +enum State { + Initial, + Filling, + Crypting, + Flushing, + Finalizing, + Done, +} + +// Asynchronous reader that transparently en/decrypts data on reads. +// It is capable of dealing with both AES CBC and AES GCM. +#[pin_project] +pub struct CrypterReader { + #[pin] + reader: R, + crypter: Crypter, + mode: Mode, + tag_len: usize, + block_size: usize, + inbuf: Buffer, + outbuf: Buffer, + state: State, + last_flush: bool +} + +impl fmt::Debug for CrypterReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CrypterReader") + .field("block_size", &self.block_size) + .field("mode", &self.mode) + .field("state", &self.state) + .finish() + } +} + +impl CrypterReader { + pub fn new(reader: R, mode: Mode, material: &ContentCryptoMaterial) -> std::io::Result { + CrypterReader::with_capacity(reader, mode, material, 64 * 1024) + } + pub fn with_capacity(reader: R, mode: Mode, material: &ContentCryptoMaterial, capacity: usize) -> std::io::Result { + let ContentCryptoMaterial { scheme, cek, iv, aad, .. } = material; + let cipher = scheme.cipher(); + let block_size = cipher.block_size(); + let mut crypter = Crypter::new(cipher, mode.into(), cek, Some(iv))?; + let tag_len = scheme.tag_len(); + if let Some(aad) = aad { + crypter.aad_update(aad)?; + } + Ok(Self { + reader, + crypter, + mode, + tag_len, + block_size, + inbuf: Buffer::with_capacity(capacity), + outbuf: Buffer::with_capacity(block_size + block_size + tag_len), + state: State::Initial, + last_flush: false + }) + } + + pub fn get_ref(&self) -> &R { + &self.reader + } + + pub fn get_mut(&mut self) -> &mut R { + &mut self.reader + } + + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.project().reader + } + + pub fn into_inner(self) -> R { + self.reader + } +} + +impl AsyncRead for CrypterReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); // Maybe should error instead + } + + let mut this = self.project(); + let block_size = *this.block_size; + + // Size of the tag to extract from end of ciphertext + let extract_tag_len = match this.mode { + Mode::Encrypt => 0, + Mode::Decrypt => *this.tag_len + }; + // Size of the tag to append to end of ciphertext + let append_tag_len = match this.mode { + Mode::Encrypt => *this.tag_len, + Mode::Decrypt => 0 + }; + + loop { + *this.state = match this.state { + State::Initial => { + let mut wrapped = ReadBuf::new(this.inbuf.unfilled_mut()); + match this.reader.as_mut().poll_read(cx, &mut wrapped) { + Poll::Ready(res) => res?, + Poll::Pending => { + if buf.filled().len() > 0 { + // There is pending data in buf return Ready instead of Pending + return Poll::Ready(Ok(())); + } else { + // We need to wait + return Poll::Pending; + } + } + } + let amt = wrapped.filled().len(); + this.inbuf.advance(amt); + if amt == 0 { + // never crypted anything, go to State::Done without flushing + State::Done + } else { + // got some bytes, fill buffer + State::Filling + } + } + State::Filling => { + if this.inbuf.filled().len() <= extract_tag_len { + // we have up to tag_len bytes filled, bring them to the start of the + // buffer + this.inbuf.compact(); + } + if this.inbuf.unfilled().len() > 0 { + let mut wrapped = ReadBuf::new(this.inbuf.unfilled_mut()); + match this.reader.as_mut().poll_read(cx, &mut wrapped) { + Poll::Ready(res) => res?, + Poll::Pending => { + if buf.filled().len() > 0 { + // There is pending data in buf return Ready instead of Pending + return Poll::Ready(Ok(())); + } else { + // We need to wait + return Poll::Pending; + } + } + } + let amt = wrapped.filled().len(); + this.inbuf.advance(amt); + if amt == 0 { + // reached reader eof + if this.inbuf.filled().len() <= extract_tag_len { + // may have enough bytes for a tag + State::Finalizing + } else { + // we still have some extra bytes to crypt + State::Crypting + } + } else { + State::Filling + } + } else { + // we are full, go crypt + State::Crypting + } + } + State::Crypting => { + if this.inbuf.filled().len() > extract_tag_len { + if buf.remaining() > block_size { + // readbuf is big enough, crypt directly into it + let to_crypt = (this.inbuf.filled().len() - extract_tag_len).min(buf.remaining() - block_size); + let amount = this.crypter.update(&this.inbuf.filled()[..to_crypt], buf.initialize_unfilled())?; + buf.advance(amount); + this.inbuf.consume(to_crypt); + State::Crypting + } else { + // readbuf is too small, crypt to outbuf then flush + let to_crypt = (this.inbuf.filled().len() - extract_tag_len).min(this.outbuf.unfilled().len() - block_size); + let amount = this.crypter.update(&this.inbuf.filled()[..to_crypt], this.outbuf.unfilled_mut())?; + this.outbuf.advance(amount); + this.inbuf.consume(to_crypt); + State::Flushing + } + } else { + // not enough to continue crypting, go fill + State::Filling + } + } + State::Flushing => { + if this.outbuf.filled().len() > 0 { + let to_copy = this.outbuf.filled().len().min(buf.remaining()); + buf.put_slice(&this.outbuf.filled()[..to_copy]); + this.outbuf.consume(to_copy); + State::Flushing + } else { + if *this.last_flush { + // outbuf is empty and was last flush, done + this.outbuf.clear(); + State::Done + } else { + // outbuf is empty, reset it and go crypt + this.outbuf.clear(); + State::Crypting + } + } + } + State::Finalizing => { + if this.inbuf.filled().len() == extract_tag_len { + if extract_tag_len > 0 { + // we extracted some tag, set it + this.crypter.set_tag(this.inbuf.filled())?; + } + + if buf.remaining() > block_size + append_tag_len { + // readbuf is big enough, finalize directly into it + let amt = this.crypter.finalize(buf.initialize_unfilled())?; + buf.advance(amt); + if append_tag_len > 0 { + // we need to append a tag + this.crypter.get_tag(&mut buf.initialize_unfilled()[..append_tag_len])?; + buf.advance(append_tag_len); + } + State::Done + } else { + // readbuf is too small, finalize to outbuf then last flush + let amt = this.crypter.finalize(this.outbuf.unfilled_mut())?; + this.outbuf.advance(amt); + if append_tag_len > 0 { + // we need to append a tag + this.crypter.get_tag(&mut this.outbuf.unfilled_mut()[..append_tag_len])?; + this.outbuf.advance(append_tag_len); + } + // Important, flagging this flush as the last one + *this.last_flush = true; + State::Flushing + } + } else { + // The final bytes in the buffer do not match extract_tag_len + debug_assert!(this.inbuf.filled().len() < extract_tag_len); + return Err( + std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unable to read enough bytes for the required tag" + ) + ).into(); + } + } + State::Done => State::Done + }; + + if let State::Done = *this.state { + return Poll::Ready(Ok(())); + } + + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + } + } +} + +// Asynchronous writer that transparently en/decrypts data on write. +// It is capable of dealing with both AES CBC and AES GCM. +#[pin_project] +pub struct CrypterWriter { + #[pin] + writer: W, + crypter: Crypter, + mode: Mode, + extract_tag_len: usize, + append_tag_len: usize, + block_size: usize, + buf: Buffer, + tag_buf: Buffer, + was_updated: bool, + finalized: bool, +} + +impl fmt::Debug for CrypterWriter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CrypterWriter") + .field("block_size", &self.block_size) + .field("mode", &self.mode) + .field("was_updated", &self.was_updated) + .field("finalized", &self.finalized) + .finish() + } +} + + +#[async_trait::async_trait] +impl AsyncUpload for CrypterWriter { + async fn abort(&mut self) -> crate::Result<()> { + Ok(self.writer.abort().await?) + } +} + + +impl CrypterWriter { + pub fn new(writer: W, mode: Mode, material: &ContentCryptoMaterial) -> std::io::Result { + Self::with_capacity(writer, mode, material, 64 * 1024) + } + + pub fn with_capacity(writer: W, mode: Mode, material: &ContentCryptoMaterial, capacity: usize) -> std::io::Result { + let ContentCryptoMaterial { scheme, cek, iv, aad, .. } = material; + let cipher = scheme.cipher(); + let block_size = cipher.block_size(); + let mut crypter = Crypter::new(cipher, mode.into(), cek, Some(iv))?; + let (extract_tag_len, append_tag_len) = match mode { + Mode::Encrypt => (0, scheme.tag_len()), + Mode::Decrypt => (scheme.tag_len(), 0) + }; + if let Some(aad) = aad { + crypter.aad_update(aad)?; + } + Ok(Self { + writer, + crypter, + mode, + extract_tag_len, + append_tag_len, + block_size, + buf: Buffer::with_capacity(capacity), + tag_buf: Buffer::with_capacity(extract_tag_len * 2), + was_updated: false, + finalized: false + }) + } + + pub fn get_ref(&self) -> &W { + &self.writer + } + + pub fn get_mut(&mut self) -> &mut W { + &mut self.writer + } + + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { + self.project().writer + } + + pub fn into_inner(self) -> W { + self.writer + } + + fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + while this.buf.filled().len() > 0 { + match ready!(this.writer.as_mut().poll_write(cx, &this.buf.filled())) { + Ok(0) => { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "failed to write the buffered data", + )).into(); + } + Ok(n) => this.buf.consume(n), + Err(e) => { + return Err(e).into(); + } + } + } + + this.buf.clear(); + + Poll::Ready(Ok(())) + } + + fn finalize(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> std::io::Result<()> { + let this = self.project(); + if *this.extract_tag_len > 0 { + if this.tag_buf.filled().len() == *this.extract_tag_len { + this.crypter.set_tag(this.tag_buf.filled())?; + this.tag_buf.consume(*this.extract_tag_len); + } else { + // The bytes in the tag buffer do not match extract_tag_len + debug_assert!(this.tag_buf.filled().len() < *this.extract_tag_len); + return Err( + std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "not enough written bytes for the required tag" + ) + ).into(); + } + } + let amt = this.crypter.finalize(this.buf.unfilled_mut())?; + this.buf.advance(amt); + if *this.append_tag_len > 0 { + this.crypter.get_tag(&mut this.buf.unfilled_mut()[..*this.append_tag_len])?; + this.buf.advance(*this.append_tag_len); + } + *this.finalized = true; + Ok(()) + } +} + +impl AsyncWrite for CrypterWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let block_size = self.block_size; + let extract_tag_len = self.extract_tag_len; + // We will only write the amount that fits our buffer. + let upper_bound = buf.len().min(self.buf.capacity().saturating_sub(block_size)); + let required_space = upper_bound + block_size; + // Truncate buffer to upper bound; + let buf = &buf[..upper_bound]; + if self.buf.unfilled().len() < required_space { + ready!(self.as_mut().flush_buf(cx))?; + } + + let this = self.project(); + + // Buffer last bytes without decrypting if we are expecting a tag + let last_bytes_offset = buf.len().saturating_sub(extract_tag_len); + let last = &buf[last_bytes_offset..]; + + if last.len() > 0 { + let overflow = (last.len() + this.tag_buf.filled().len()).saturating_sub(extract_tag_len); + if overflow > 0 { + // we have more bytes than extract_tag_len flush the excess to main buffer + debug_assert!(overflow <= last.len()); + let amt = this.crypter.update(&this.tag_buf.filled()[..overflow], this.buf.unfilled_mut())?; + this.buf.advance(amt); + this.tag_buf.consume(overflow); + // crypter was updated, needs to be finalized + *this.was_updated = true; + } + this.tag_buf.compact(); + this.tag_buf.put_slice(last); + } + + // These bytes cannot be the tag, crypt them into main buffer + let first = &buf[..last_bytes_offset]; + if first.len() > 0 { + let amt = this.crypter.update(first, this.buf.unfilled_mut())?; + this.buf.advance(amt); + // crypter was updated, needs to be finalized + *this.was_updated = true; + } + Poll::Ready(Ok(first.len() + last.len())) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().flush_buf(cx))?; + self.get_pin_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // crypter can only be finalized if it was updated at least once + if !self.finalized && self.was_updated { + if self.buf.unfilled().len() < self.block_size + self.append_tag_len { + ready!(self.as_mut().flush_buf(cx))?; + } + self.as_mut().finalize(cx)?; + } + + ready!(self.as_mut().flush_buf(cx))?; + self.get_pin_mut().poll_shutdown(cx) + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use rand::Rng; + + #[tokio::test] + async fn crypter_reader_aes_cbc_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes128Cbc); + + let mut reader = CrypterReader::new(data.as_slice(), Mode::Encrypt, &material).unwrap(); + let mut ciphertext = vec![]; + reader.read_to_end(&mut ciphertext).await.unwrap(); + + let mut reader = CrypterReader::new(ciphertext.as_slice(), Mode::Decrypt, &material).unwrap(); + let mut plaintext = vec![]; + reader.read_to_end(&mut plaintext).await.unwrap(); + assert_eq!(data, plaintext); + } + + #[tokio::test] + async fn crypter_reader_aes_gcm_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes256Gcm) + .with_aad(b"test aad".as_slice()); + + let mut reader = CrypterReader::new(data.as_slice(), Mode::Encrypt, &material).unwrap(); + let mut ciphertext = vec![]; + reader.read_to_end(&mut ciphertext).await.unwrap(); + + let mut reader = CrypterReader::new(ciphertext.as_slice(), Mode::Decrypt, &material).unwrap(); + let mut plaintext = vec![]; + reader.read_to_end(&mut plaintext).await.unwrap(); + assert_eq!(data, plaintext); + } + + #[tokio::test] + async fn crypter_writer_aes_cbc_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes128Cbc); + + let mut ciphertext = vec![]; + let mut writer = CrypterWriter::new(&mut ciphertext, Mode::Encrypt, &material).unwrap(); + writer.write_all(&data).await.unwrap(); + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut plaintext = vec![]; + let mut writer = CrypterWriter::new(&mut plaintext, Mode::Decrypt, &material).unwrap(); + writer.write_all(&ciphertext).await.unwrap(); + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + assert_eq!(data, plaintext); + } + + #[tokio::test] + async fn crypter_writer_aes_gcm_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes256Gcm) + .with_aad(b"test aad".as_slice()); + + let mut ciphertext = vec![]; + let mut writer = CrypterWriter::new(&mut ciphertext, Mode::Encrypt, &material).unwrap(); + writer.write_all(&data).await.unwrap(); + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut plaintext = vec![]; + let mut writer = CrypterWriter::new(&mut plaintext, Mode::Decrypt, &material).unwrap(); + writer.write_all(&ciphertext).await.unwrap(); + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + assert_eq!(data, plaintext); + } + + #[tokio::test] + async fn crypter_writer_and_reader_aes_cbc_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes128Cbc); + + let mut ciphertext = vec![]; + let mut writer = CrypterWriter::new(&mut ciphertext, Mode::Encrypt, &material).unwrap(); + writer.write_all(&data).await.unwrap(); + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut reader = CrypterReader::new(ciphertext.as_slice(), Mode::Decrypt, &material).unwrap(); + let mut plaintext = vec![]; + reader.read_to_end(&mut plaintext).await.unwrap(); + assert_eq!(data, plaintext); + } + + #[tokio::test] + async fn crypter_writer_and_reader_aes_gcm_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes256Gcm) + .with_aad(b"test aad".as_slice()); + + let mut ciphertext = vec![]; + let mut writer = CrypterWriter::new(&mut ciphertext, Mode::Encrypt, &material).unwrap(); + writer.write_all(&data).await.unwrap(); + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut reader = CrypterReader::new(ciphertext.as_slice(), Mode::Decrypt, &material).unwrap(); + let mut plaintext = vec![]; + reader.read_to_end(&mut plaintext).await.unwrap(); + assert_eq!(data, plaintext); + } + + #[tokio::test] + async fn crypter_reader_and_writer_aes_cbc_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes128Cbc); + + let mut reader = CrypterReader::new(data.as_slice(), Mode::Encrypt, &material).unwrap(); + let mut ciphertext = vec![]; + reader.read_to_end(&mut ciphertext).await.unwrap(); + + let mut plaintext = vec![]; + let mut writer = CrypterWriter::new(&mut plaintext, Mode::Decrypt, &material).unwrap(); + writer.write_all(&ciphertext).await.unwrap(); + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + + assert_eq!(data, plaintext); + } + + #[tokio::test] + async fn crypter_reader_and_writer_aes_gcm_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes256Gcm) + .with_aad(b"test aad".as_slice()); + + let mut reader = CrypterReader::new(data.as_slice(), Mode::Encrypt, &material).unwrap(); + let mut ciphertext = vec![]; + reader.read_to_end(&mut ciphertext).await.unwrap(); + + let mut plaintext = vec![]; + let mut writer = CrypterWriter::new(&mut plaintext, Mode::Decrypt, &material).unwrap(); + writer.write_all(&ciphertext).await.unwrap(); + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + + assert_eq!(data, plaintext); + } + + #[test] + fn content_encryption_aes_cbc_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes128Cbc); + + let ciphertext = encrypt(&data, &material).unwrap(); + + let plaintext = decrypt(&ciphertext, &material).unwrap(); + + assert_eq!(data, plaintext); + } + + #[test] + fn content_encryption_aes_gcm_round_trip() { + let data: Vec = (0..100000u32).map(|n| (n % 256) as u8).collect(); + + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes256Gcm) + .with_aad("AES/GCM/NoPadding"); + + let ciphertext = encrypt(&data, &material).unwrap(); + + let plaintext = decrypt(&ciphertext, &material).unwrap(); + + assert_eq!(data, plaintext); + } + + async fn random_crypter_writer_writes( + data: &[u8], + mode: Mode, + material: &ContentCryptoMaterial, + max_write_size: usize, + flush_prob: f64 + ) -> Vec { + let mut result = vec![]; + let mut writer = CrypterWriter::new(&mut result, mode, &material).unwrap(); + let mut offset = 0; + loop { + let remaining = data.len() - offset; + assert!(remaining >= 1); + let upper_bound = remaining.min(max_write_size); + let write_size = if upper_bound == 1 { + 1 + } else { + rand::thread_rng().gen_range(1..upper_bound) + }; + let write_end_offset = offset + write_size; + loop { + offset += writer.write(&data[offset..write_end_offset]).await.unwrap(); + if offset == write_end_offset { break; } + } + assert!(offset <= data.len()); + if offset == data.len() { + break; + } + if rand::thread_rng().gen::() <= flush_prob { + writer.flush().await.unwrap(); + } + } + writer.flush().await.unwrap(); + writer.shutdown().await.unwrap(); + result + } + + async fn random_crypter_reader_reads( + data: &[u8], + mode: Mode, + material: &ContentCryptoMaterial, + max_read_size: usize + ) -> Vec { + let mut result = vec![0; data.len() + material.scheme.cipher().block_size() + material.scheme.tag_len()]; + let mut reader = CrypterReader::new(Cursor::new(data), mode, &material).unwrap(); + let mut offset = 0; + loop { + let remaining = result.len() - offset; + if remaining == 0 { + let mut scratch = [0u8; 1]; + if let Ok(0) = reader.read(scratch.as_mut_slice()).await { + break; + } else { + panic!("needs to read past result len"); + } + } + assert!(remaining >= 1); + let upper_bound = remaining.min(max_read_size); + let read_size = if upper_bound == 1 { + 1 + } else { + rand::thread_rng().gen_range(1..upper_bound) + }; + let n = reader.read(&mut result[offset..(offset + read_size)]).await.unwrap(); + if n == 0 { + break; + } + offset += n; + assert!(offset <= result.len()); + } + result.truncate(offset); + result + } + + fn compare_large_slices(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + eprintln!("expected lengths {} and {} to match", a.len(), b.len()); + return false; + } + + for i in 0..a.len() { + let i_a = a[i]; + let i_b = b[i]; + if i_a != i_b { + eprintln!("expected datum {} and {} to match at index {i}", i_a, i_b); + let preceding_range = i.saturating_sub(10)..i; + let succeding_range = i..i.saturating_add(10).min(a.len()); + eprintln!("expected preceding: {:?} and succeding {:?}", &a[preceding_range.clone()], &b[preceding_range]); + eprintln!("received preceding: {:?} and succeding {:?}", &a[succeding_range.clone()], &b[succeding_range]); + return false; + } + } + + true + } + + #[tokio::test] + async fn randomized_crypter_writer_aes_cbc() { + let size = 500 * 1024 * 1024; + let max_write_size = 20 * 1024 * 1024; + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes128Cbc); + let data: Vec = (0..size).map(|n| (n % 256) as u8).collect(); + + for _i in 0..50 { + let ciphertext = random_crypter_writer_writes(&data, Mode::Encrypt, &material, max_write_size, 0.1).await; + + let plaintext = decrypt(&ciphertext, &material).unwrap(); + assert!(compare_large_slices(&data, &plaintext)); + + let plaintext = random_crypter_writer_writes(&ciphertext, Mode::Decrypt, &material, max_write_size, 0.1).await; + assert!(compare_large_slices(&data, &plaintext)); + } + } + + #[tokio::test] + async fn randomized_crypter_reader_aes_cbc() { + let size = 500 * 1024 * 1024; + let max_read_size = 20 * 1024 * 1024; + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes128Cbc); + let data: Vec = (0..size).map(|n| (n % 256) as u8).collect(); + + for _i in 0..50 { + let ciphertext = random_crypter_reader_reads(&data, Mode::Encrypt, &material, max_read_size).await; + + let plaintext = decrypt(&ciphertext, &material).unwrap(); + assert!(compare_large_slices(&data, &plaintext)); + + let plaintext = random_crypter_reader_reads(&ciphertext, Mode::Decrypt, &material, max_read_size).await; + assert!(compare_large_slices(&data, &plaintext)); + } + } + #[tokio::test] + async fn randomized_crypter_writer_aes_gcm() { + let size = 100 * 1024 * 1024; + let max_write_size = 20 * 1024 * 1024; + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes256Gcm); + let data: Vec = (0..size).map(|n| (n % 256) as u8).collect(); + + for _i in 0..50 { + let ciphertext = random_crypter_writer_writes(&data, Mode::Encrypt, &material, max_write_size, 0.1).await; + + let plaintext = decrypt(&ciphertext, &material).unwrap(); + assert!(compare_large_slices(&data, &plaintext)); + + let plaintext = random_crypter_writer_writes(&ciphertext, Mode::Decrypt, &material, max_write_size, 0.1).await; + assert!(compare_large_slices(&data, &plaintext)); + } + } + + #[tokio::test] + async fn randomized_crypter_reader_aes_gcm() { + let size = 100 * 1024 * 1024; + let max_read_size = 20 * 1024 * 1024; + let material = ContentCryptoMaterial::generate(CryptoScheme::Aes256Gcm) + .with_aad("AES/GCM/NoPadding"); + let data: Vec = (0..size).map(|n| (n % 256) as u8).collect(); + + for _i in 0..50 { + let ciphertext = random_crypter_reader_reads(&data, Mode::Encrypt, &material, max_read_size).await; + + let ciphertext2 = encrypt(&data, &material).unwrap(); + assert!(compare_large_slices(&ciphertext, &ciphertext2)); + + let plaintext = decrypt(&ciphertext, &material).unwrap(); + assert!(compare_large_slices(&data, &plaintext)); + + let plaintext = random_crypter_reader_reads(&ciphertext, Mode::Decrypt, &material, max_read_size).await; + assert!(compare_large_slices(&data, &plaintext)); + } + } +} diff --git a/src/error.rs b/src/error.rs index 3fb4def..451037a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,121 +1,346 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use backoff::backoff::Backoff; use object_store::RetryConfig; use once_cell::sync::Lazy; use std::error::Error as StdError; -use anyhow::anyhow; +use thiserror::Error; +use std::fmt; // These regexes are used to extract error info from some object_store private errors. // We construct the regexes lazily and reuse them due to the runtime compilation cost. static CLIENT_ERR_REGEX: Lazy = Lazy::new(|| regex::Regex::new(r"Client \{ status: (?\d+?),").unwrap()); static REQWEST_ERR_REGEX: Lazy = Lazy::new(|| regex::Regex::new(r"Reqwest \{ retries: (?\d+?), max_retries: (?\d+?),").unwrap()); +type BoxedStdError = Box; -#[derive(Debug, Clone)] -pub(crate) struct ErrorInfo { - pub(crate) retries: Option, - pub(crate) reason: ErrorReason +#[derive(Debug)] +pub(crate) struct Metadata { + start: std::time::Instant, + attempts: Vec } -#[derive(Debug, Clone)] -pub(crate) enum ErrorReason { - Unknown, - Code(u16), - Io, - Timeout +impl Metadata { + pub(crate) fn retries(&self) -> usize { + let prev_retries = self.attempts.iter() + .map(|a| a.retries.unwrap_or_default()) + .sum::(); + prev_retries + self.attempts.len().saturating_sub(1) + } + pub(crate) fn retry_report(&self) -> String { + use std::fmt::Write; + let mut report = String::new(); + if !self.attempts.is_empty() { + let attempts_to_display = self.attempts.len().min(10); + write!(report, "Recent attempts ({} out of {}):\n", attempts_to_display, self.attempts.len()).unwrap(); + self.attempts + .iter() + .rev() + .take(10) + .rev() + .for_each(|info| { + write!( + report, + " reason: {:?} after {} retries\n", + info.reason, + info.retries.unwrap_or_default() + ).unwrap() + }); + } else { + write!(report, "There were no attempts\n").unwrap(); + } + write!(report, "Total retries: {}\n", self.retries()).unwrap(); + write!(report, "Total Time: {:?}\n", self.start.elapsed()).unwrap(); + report + } } -pub(crate) fn format_err(error: &anyhow::Error) -> String { - use std::fmt::Write; - let mut error_string = format!("{}\n\nCaused by:\n", error); - error.chain() - .skip(1) - .enumerate() - .for_each(|(idx, cause)| write!(error_string, " {}: {}\n", idx, cause).unwrap()); - error_string +#[derive(Debug)] +pub struct Error { + kind: Kind, + metadata: Option } -pub(crate) fn extract_error_info(error: &anyhow::Error) -> ErrorInfo { - let mut retries = None; +impl> From for Error { + fn from(value: T) -> Self { + Error { kind: value.into(), metadata: None } + } +} - // Here we go through the chain of type erased errors that led to the current one - // trying to downcast each to concrete types. We fallback to error string parsing only on - // private errors (as we don't have the type) and mainly to extract helpfull information - // on a best effort basis. - for e in error.chain() { - if let Some(e) = e.downcast_ref::() { - if let Some(code) = e.status() { - return ErrorInfo { - retries, - reason: ErrorReason::Code(code.into()) - } - } - if e.is_timeout() { - return ErrorInfo { - retries, - reason: ErrorReason::Timeout - } - } - if e.is_body() || e.is_connect() || e.is_request() { - return ErrorInfo { - retries, - reason: ErrorReason::Io - } +impl fmt::Display for Error { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.kind, formatter)?; + if self.kind.source().is_some() { + write!(formatter, "\n\nCaused by:\n")?; + self.kind.chain() + .skip(1) + .enumerate() + .map(|(idx, cause)| { + write!(formatter, " {}: {}\n", idx, cause) + }) + .collect::>()?; + } + if let Some(metadata) = self.metadata.as_ref() { + let report = metadata.retry_report(); + write!(formatter, "\n{}", report)?; + } + Ok(()) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.kind.source() + } +} + +impl Error { + fn error_info(&self) -> ErrorInfo { + let mut info = self.kind.error_info(); + if let Some(metadata) = self.metadata.as_ref() { + let retries = metadata.retries(); + if retries > 0 { + info.retries = Some(metadata.retries()); } } + info + } + #[allow(unused)] + pub(crate) fn not_implemented(msg: impl Into) -> Error { Kind::NotImplemented(msg.into()).into() } + pub(crate) fn required_config(msg: impl Into) -> Error { Kind::RequiredConfig(msg.into()).into() } + pub(crate) fn invalid_config(msg: impl Into) -> Error { Kind::InvalidConfig { msg: msg.into(), source: None }.into() } + pub(crate) fn invalid_config_src(msg: impl Into, source: impl Into) -> Error { + Kind::InvalidConfig { msg: msg.into(), source: Some(source.into()) }.into() + } + pub(crate) fn invalid_config_err(msg: &'static str) -> impl Fn(E) -> Error + where + E: Into> + { + return |e: E| Kind::InvalidConfig { msg: msg.into(), source: Some(e.into()) }.into() + } + #[allow(unused)] + pub(crate) fn deserialize_response(msg: impl Into, error: serde_path_to_error::Error) -> Error { + Kind::DeserializeResponse { response: msg.into(), source: error }.into() + } + pub(crate) fn deserialize_response_err(msg: &'static str) -> impl Fn(serde_path_to_error::Error) -> Error + { + return |e| Kind::DeserializeResponse { response: msg.into(), source: e }.into() + } + pub(crate) fn invalid_response(msg: impl Into) -> Error { Kind::InvalidResponse(msg.into()).into() } + pub(crate) fn error_response(msg: impl Into) -> Error { Kind::ErrorResponse(msg.into()).into() } +} + +pub(crate) trait ErrorExt { + fn to_err(self) -> Result; +} + +impl ErrorExt for Result { + fn to_err(self) -> Result { + self.map_err(|e| Kind::Other(e).into()) + } +} - if let Some(e) = e.downcast_ref::() { - if e.is_closed() || e.is_incomplete_message() || e.is_body_write_aborted() { +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum Kind { + #[error("{0}")] + Request(#[from] reqwest::Error), + #[error("{0}")] + ObjectStore(#[from] object_store::Error), + #[error("{0}")] + ErrorResponse(String), + #[error("{0}")] + InvalidResponse(String), + #[error("Failed to deserialize `{response}` response")] + DeserializeResponse { + response: String, + #[source] + source: serde_path_to_error::Error + }, + #[error("Storage referenced as `{0}` is not encrypted")] + StorageNotEncrypted(String), + #[error("{msg}")] + InvalidConfig { + msg: String, + #[source] + source: Option + }, + #[error("Missing required config `{0}`")] + RequiredConfig(String), + #[error("{0} is not implemented")] + NotImplemented(String), + #[error("failed to decode encryption material: {0}")] + MaterialDecode(#[source] base64::DecodeError), + #[error("failed to encrypt or decrypt encryption material: {0}")] + MaterialCrypt(#[source] std::io::Error), + #[error("failed to encrypt the object contents: {0}")] + ContentEncrypt(#[source] std::io::Error), + #[error("failed to decrypt the object contents: {0}")] + ContentDecrypt(#[source] std::io::Error), + #[error("Supplied buffer was too small")] + BufferTooSmall, + #[error("io error while streaming body: {0}")] + BodyIo(#[source] std::io::Error), + #[error("{0}")] + TaskFailed(#[from] tokio::task::JoinError), + #[allow(dead_code)] + #[error("{context}")] + Context { + context: String, + #[source] + source: BoxedStdError + }, + #[error("{0}")] + Wrapped(#[from] Arc), + #[error("{0}")] + Other(#[source] anyhow::Error) +} + +impl Kind { + pub(crate) fn chain(&self) -> Chain { + Chain(Some(self)) + } + + pub(crate) fn error_info(&self) -> ErrorInfo { + if matches!(self, Kind::Request(_)) { + self.chain() + .for_each(|e| println!("----------\n{} {:?}", e, e)); + } + let mut retries = None; + + match self { + Kind::DeserializeResponse { .. } => { return ErrorInfo { - retries, + retries: None, reason: ErrorReason::Io } - } else if e.is_timeout() { - return ErrorInfo { - retries, - reason: ErrorReason::Timeout + }, + Kind::Wrapped(e) => return e.error_info(), + Kind::Context { source, .. } => { + match source.downcast_ref::() { + Some(e) => return e.error_info(), + None => {} } } + _ => {} } - if let Some(e) = e.downcast_ref::() { - if e.kind() == std::io::ErrorKind::TimedOut { - return ErrorInfo { - retries, - reason: ErrorReason::Timeout + // Here we go through the chain of type erased errors that led to the current one + // trying to downcast each to concrete types. We fallback to error string parsing only on + // private errors (as we don't have the type) and mainly to extract helpfull information + // on a best effort basis. + for e in self.chain() { + if let Some(e) = e.downcast_ref::() { + if let Some(code) = e.status() { + return ErrorInfo { + retries, + reason: ErrorReason::Code(code.into()) + } } - } else if e.kind() == std::io::ErrorKind::Other && e.source().is_some() { - // Continue to source error - continue - } else { - return ErrorInfo { - retries, - reason: ErrorReason::Io + if e.is_timeout() { + return ErrorInfo { + retries, + reason: ErrorReason::Timeout + } + } + if e.is_body() || e.is_connect() || e.is_request() { + return ErrorInfo { + retries, + reason: ErrorReason::Io + } } } - } - let error_debug = format!("{:?}", e); - if error_debug.starts_with("Client {") { - if let Some(caps) = CLIENT_ERR_REGEX.captures(&error_debug) { - if let Ok(status) = caps["status"].parse() { - // If we find this error we try to extract the status code from its debug - // representation + if let Some(e) = e.downcast_ref::() { + if e.is_closed() || e.is_incomplete_message() || e.is_body_write_aborted() { + return ErrorInfo { + retries, + reason: ErrorReason::Io + } + } else if e.is_timeout() { + return ErrorInfo { + retries, + reason: ErrorReason::Timeout + } + } + } + + if let Some(e) = e.downcast_ref::() { + if e.kind() == std::io::ErrorKind::TimedOut { return ErrorInfo { retries, - reason: ErrorReason::Code(status) + reason: ErrorReason::Timeout + } + } else if e.kind() == std::io::ErrorKind::NotFound { + if e.source().is_some() { continue; } + // Do not retry root NotFound + return ErrorInfo { + retries, + reason: ErrorReason::Unknown + } + } else if e.kind() == std::io::ErrorKind::Other && e.source().is_some() { + // Continue to source error + continue + } else { + return ErrorInfo { + retries, + reason: ErrorReason::Io } } } - } else if error_debug.starts_with("Reqwest {") { - if let Some(caps) = REQWEST_ERR_REGEX.captures(&error_debug) { - // If we find this error we try to extract the retries from its debug - // representation - retries = caps["retries"].parse::().ok(); + + let error_debug = format!("{:?}", e); + if error_debug.starts_with("Client {") { + if let Some(caps) = CLIENT_ERR_REGEX.captures(&error_debug) { + if let Ok(status) = caps["status"].parse() { + // If we find this error we try to extract the status code from its debug + // representation + return ErrorInfo { + retries, + reason: ErrorReason::Code(status) + } + } + } + } else if error_debug.starts_with("Reqwest {") { + if let Some(caps) = REQWEST_ERR_REGEX.captures(&error_debug) { + // If we find this error we try to extract the retries from its debug + // representation + retries = caps["retries"].parse::().ok(); + } } } + ErrorInfo { retries, reason: ErrorReason::Unknown } } - ErrorInfo { retries, reason: ErrorReason::Unknown } +} + +pub(crate) struct Chain<'a>(Option<&'a (dyn StdError + 'static)>); + +impl<'a> Iterator for Chain<'a> { + type Item = &'a (dyn StdError + 'static); + fn next(&mut self) -> Option { + match self.0.take() { + Some(e) => { + self.0 = e.source(); + Some(e) + } + None => None + } + } +} + +pub type Result = core::result::Result; + +#[derive(Debug, Clone)] +pub(crate) struct ErrorInfo { + pub(crate) retries: Option, + pub(crate) reason: ErrorReason +} + +#[derive(Debug, Clone)] +pub(crate) enum ErrorReason { + Unknown, + Code(u16), + Io, + Timeout } #[derive(Debug)] @@ -175,9 +400,13 @@ impl RetryState { // Retry timeouts up to retry_timeout or max_retries all_retries < max_retries && elapsed < retry_timeout } - ErrorReason::Code(_code) => { - // TODO manage custom status_code retries - false + ErrorReason::Code(code) => { + if (500..600).contains(&code) { + all_retries < max_retries && elapsed < retry_timeout + } else { + // TODO manage custom status_code retries + false + } } ErrorReason::Io => { // Retry io errors up to retry_timeout or max_retries @@ -189,42 +418,17 @@ impl RetryState { } } - pub(crate) fn retry_report(&self) -> String { - use std::fmt::Write; - let mut report = String::new(); - if !self.attempts.is_empty() { - let attempts_to_display = self.attempts.len().min(10); - write!(report, "Recent attempts ({} out of {}):\n", attempts_to_display, self.attempts.len()).unwrap(); - self.attempts - .iter() - .rev() - .take(10) - .rev() - .for_each(|info| { - write!( - report, - " reason: {:?} after {} retries\n", - info.reason, - info.retries.unwrap_or_default() - ).unwrap() - }); - } else { - write!(report, "There were no attempts\n").unwrap(); - } - write!(report, "Total retries: {}\n", self.retries()).unwrap(); - write!(report, "Total Time: {:?}\n", self.start.elapsed()).unwrap(); - report - } - - pub(crate) fn should_retry(&mut self, error: &anyhow::Error) -> anyhow::Result<(ErrorInfo, Duration)> { - let info = extract_error_info(error); + pub(crate) fn should_retry(&mut self, mut error: Error) -> Result<(Error, ErrorInfo, Duration), Error> { + let info = error.error_info(); self.log_attempt(info.clone()); - let decision = if self.should_retry_logic() { - Ok((info, self.next_backoff())) + if self.should_retry_logic() { + Ok((error, info, self.next_backoff())) } else { - let error_report = format_err(error); - Err(anyhow!("{}\n{}", error_report, self.retry_report())) - }; - decision + error.metadata = Some(Metadata { + start: self.start.clone(), + attempts: self.attempts.clone() + }); + Err(error) + } } } diff --git a/src/lib.rs b/src/lib.rs index 90abbe7..8433028 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,41 +1,49 @@ #[global_allocator] static GLOBAL: metrics::InstrumentedAllocator = metrics::InstrumentedAllocator {}; +use encryption::CryptoMaterialProvider; +use error::Error; use object_store::RetryConfig; use once_cell::sync::OnceCell; +use tokio::io::AsyncRead; use tokio::runtime::Runtime; use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; use std::collections::HashMap; use std::ffi::{c_char, c_void, CString}; use std::sync::Arc; use std::time::Duration; -use anyhow::anyhow; use std::collections::hash_map::DefaultHasher; use std::hash::Hasher; +use std::any::Any; use object_store::{path::Path, ObjectStore}; use moka::future::Cache; mod util; -use util::Compression; +use util::{deserialize_str, string_to_path, AsyncUpload, Compression}; mod error; -use error::RetryState; +pub(crate) use error::Result; mod list; -use list::{handle_list, handle_list_stream, ListResponse, ListStreamResponse}; +use list::{ListResponse, ListStreamResponse}; mod crud_ops; -use crud_ops::{handle_get, handle_put, handle_delete, Response}; +use crud_ops::Response; mod stream; -use stream::{handle_get_stream, handle_put_stream, GetStreamResponse, PutStreamResponse}; +use stream::{GetStreamResponse, PutStreamResponse}; mod metrics; pub use metrics::{InstrumentedAllocator, METRICS}; +mod snowflake; +use snowflake::build_store_for_snowflake_stage; + +mod encryption; + // Our global variables needed by our library at runtime. Note that we follow Rust's // safety rules here by making them immutable with write-exactly-once semantics using // either Lazy or OnceCell. @@ -72,6 +80,32 @@ fn result_cb(handle: *const c_void) -> i32 { unsafe { RESULT_CB.get().expect("no result callback")(handle) } } +type BoxedReader = Box; +type BoxedUpload = Box; + +// TODO Use this in the future if we really need to downcast to +// a concrete store. Currently we are using ClientContext for +// this +// +// pub(crate) trait ObjectStoreAny: ObjectStore { +// fn as_any(&self) -> &dyn Any; +// fn as_dyn_store<'a>(self: Arc) -> Arc +// where +// Self: 'a; +// } +// +// +// impl ObjectStoreAny for T { +// fn as_any(&self) -> &dyn Any { +// self +// } +// fn as_dyn_store<'a>(self: Arc) -> Arc +// where +// Self: 'a { +// self +// } +// } + // The result type used for the API functions exposed to Julia. This is used for both // synchronous errors, e.g. our dispatch channel is full, and for async errors such // as HTTP connection errors as part of the async Response. @@ -119,9 +153,9 @@ impl ResponseGuard { *response.context_mut() = Arc::into_raw(context.clone()); ResponseGuard { response, context, handle } } - pub(crate) fn success(self, payload: T::Payload) { + pub(crate) fn success(self, payload: impl Into) { *self.response.result_mut() = CResult::Ok; - self.response.set_payload(Some(payload)); + self.response.set_payload(Some(payload.into())); *self.response.error_message_mut() = std::ptr::null_mut(); } } @@ -166,6 +200,42 @@ enum Request { unsafe impl Send for Request {} +impl Request { + fn cancelled(&self) -> WaitForCancellationFuture { + match self { + Request::Get( .. , response) => response.cancelled(), + Request::Put( .. , response) => response.cancelled(), + Request::Delete( .. , response) => response.cancelled(), + Request::List( .. , response) => response.cancelled(), + Request::ListStream( .. , response) => response.cancelled(), + Request::GetStream( .. , response) => response.cancelled(), + Request::PutStream( .. , response) => response.cancelled() + } + } + fn into_error(self, error: impl std::fmt::Display) where Self: Sized { + match self { + Request::Get( .. , response) => response.into_error(error), + Request::Put( .. , response) => response.into_error(error), + Request::Delete( .. , response) => response.into_error(error), + Request::List( .. , response) => response.into_error(error), + Request::ListStream( .. , response) => response.into_error(error), + Request::GetStream( .. , response) => response.into_error(error), + Request::PutStream( .. , response) => response.into_error(error) + } + } + fn raw_config(&self) -> &RawConfig { + match self { + Request::Get( .. , config, _response) => config, + Request::Put( .. , config, _response) => config, + Request::Delete( .. , config, _response) => config, + Request::List( .. , config, _response) => config, + Request::ListStream( .. , config, _response) => config, + Request::GetStream( .. , config, _response) => config, + Request::PutStream( .. , config, _response) => config + } + } +} + // We use `jl_adopt_thread` to ensure Rust can call into Julia when notifying // the Base.Event that is waiting for the Rust result. // Note that this will be linked in from the Julia process, we do not try @@ -201,16 +271,17 @@ impl RawConfig { config_string.to_str().expect("julia strings are valid utf8") } - fn as_json(&self) -> anyhow::Result { - Ok(serde_json::from_str(self.as_str())?) + fn as_json(&self) -> crate::Result { + let json = deserialize_str(self.as_str()) + .map_err(Error::invalid_config_err("failed to parse json serialized config"))?; + Ok(json) } - fn as_map(&self) -> anyhow::Result> { - let value = self.as_json() - .map_err(|e| anyhow!("failed to parse json serialized config: {}", e))?; + fn as_map(&self) -> crate::Result> { + let value = self.as_json()?; let map: HashMap = serde_json::from_value(value) - .map_err(|e| anyhow!("config must be a json serialized object: {}", e))?; + .map_err(Error::invalid_config_err("config must be a json serialized object"))?; Ok(map) } @@ -218,31 +289,98 @@ impl RawConfig { #[derive(Debug, Clone)] struct Config { - retry_config: RetryConfig + prefix: Option, + retry_config: RetryConfig, + multipart_get_threshold: usize, + multipart_get_part_size: usize, + multipart_get_concurrency: usize, + multipart_put_threshold: usize, + multipart_put_part_size: usize, + multipart_put_concurrency: usize +} + +impl Default for Config { + fn default() -> Self { + Config { + prefix: None, + retry_config: RetryConfig::default(), + multipart_get_threshold: 8 * 1024 * 1024, + multipart_get_part_size: 8 * 1024 * 1024, + multipart_get_concurrency: 16, + multipart_put_threshold: 10 * 1024 * 1024, + multipart_put_part_size: 10 * 1024 * 1024, + multipart_put_concurrency: 16 + } + } +} + +macro_rules! extract_and_parse { + ($map: expr, $config: expr, $field: ident) => { + $map.remove(stringify!($field)) + .map(|s| s.parse()) + .transpose() + .map_err(|e| Error::invalid_config_src(format!("failed to parse {}", stringify!($field)), e))? + .inspect(|v| { + $config.$field = *v; + }); + }; +} + +impl Config { + fn extract_from_map(map: &mut HashMap) -> crate::Result { + let mut config = Config::default(); + config.retry_config = parse_retry_config(map)?; + config.prefix = map.remove("prefix"); + extract_and_parse!(map, config, multipart_get_threshold); + extract_and_parse!(map, config, multipart_get_part_size); + extract_and_parse!(map, config, multipart_get_concurrency); + extract_and_parse!(map, config, multipart_put_threshold); + extract_and_parse!(map, config, multipart_put_part_size); + extract_and_parse!(map, config, multipart_put_concurrency); + Ok(config) + } +} + +#[async_trait::async_trait] +pub(crate) trait Extension: std::fmt::Debug + Send + Sync + 'static { + #[allow(dead_code)] + fn as_any(&self) -> &dyn Any; + async fn current_stage_info(&self) -> crate::Result { + return Err(Error::not_implemented("current_stage_info is not implemented for this client")); + } +} + +impl Extension for () { + fn as_any(&self) -> &dyn std::any::Any { + self + } } +type ClientExtension = Arc; + #[derive(Debug, Clone)] pub struct Client { store: Arc, - config: Config + crypto_material_provider: Option>, + config: Config, + extension: ClientExtension } impl Client { - // Caution: This function is not retried as it does not perform any network - // operations currently. If this invariant no longer holds retries must be - // added here - async fn from_raw_config(config: &RawConfig) -> anyhow::Result { - let mut map = config.as_map()?; - + // Caution: This function is not retried, retries must be handled internally. + async fn from_raw_config(config: &RawConfig) -> crate::Result { + Ok(Client::from_config_map(config.as_map()?).await?) + } + async fn from_config_map(mut map: HashMap) -> crate::Result { let url = map.remove("url") - .ok_or(anyhow!("config object must have a key named 'url'"))?; + .ok_or(Error::invalid_config("config object must have a key named 'url'"))?; let url = url::Url::parse(&url) - .map_err(|e| anyhow!("failed to parse `url`: {}", e))?; + .map_err(Error::invalid_config_err("failed to parse `url`"))?; if let Some(v) = map.remove("azurite_host") { let mut azurite_host = url::Url::parse(&v) - .map_err(|e| anyhow!("failed to parse azurite_host: {}", e))?; + .map_err(Error::invalid_config_err("failed to parse azurite_host"))?; azurite_host.set_path(""); unsafe { std::env::set_var("AZURITE_BLOB_STORAGE_URL", azurite_host.as_str()) }; map.insert("allow_invalid_certificates".into(), "true".into()); @@ -251,69 +389,110 @@ impl Client { if let Some(v) = map.remove("minio_host") { let mut minio_host = url::Url::parse(&v) - .map_err(|e| anyhow!("failed to parse minio_host: {}", e))?; + .map_err(Error::invalid_config_err("failed to parse minio_host"))?; minio_host.set_path(""); map.insert("allow_http".into(), "true".into()); map.insert("aws_endpoint".into(), minio_host.as_str().trim_end_matches('/').to_string()); } - let retry_config = parse_retry_config(&mut map)?; + let mut config = Config::extract_from_map(&mut map)?; - let store: Arc = match url.scheme() { + let client = match url.scheme() { "s3" => { let mut builder = object_store::aws::AmazonS3Builder::default() .with_url(url) - .with_retry(retry_config.clone()); + .with_retry(config.retry_config.clone()); for (key, value) in map { builder = builder.with_config(key.parse()?, value); } - Arc::new(builder.build()?) + + Client { + store: Arc::new(builder.build()?), + crypto_material_provider: None, + config, + extension: Arc::new(()), + } } "az" | "azure" => { let mut builder = object_store::azure::MicrosoftAzureBuilder::default() .with_url(url) - .with_retry(retry_config.clone()); + .with_retry(config.retry_config.clone()); for (key, value) in map { builder = builder.with_config(key.parse()?, value); } - Arc::new(builder.build()?) + + Client { + store: Arc::new(builder.build()?), + crypto_material_provider: None, + config, + extension: Arc::new(()) + } + } + "snowflake" => { + let (store, crypto_material_provider, stage_prefix, extension) = build_store_for_snowflake_stage(map, config.retry_config.clone()).await?; + + let prefix = match (stage_prefix, config.prefix) { + (s, Some(u)) if s.ends_with("/") => Some(format!("{s}{u}")), + (s, Some(u)) => Some(format!("{s}/{u}")), + (s, None) => Some(s) + }; + + config.prefix = prefix; + + Client { + store, + crypto_material_provider, + config, + extension + } } _ => unimplemented!("unknown url scheme") }; - Ok(Client { - store, - config: Config { - retry_config + Ok(client) + } + + fn full_path(&self, path: &Path) -> Path { + if let Some(prefix) = self.config.prefix.as_ref() { + // FIXME We should always prepend the path as this inner check prevents + // intentionally duplicating the prefix. + // We don't do it because currently some other code may still pass + // the prefix and we don't want to apply it twice. + if path.as_ref().starts_with(prefix) { + path.clone() + } else { + unsafe { string_to_path(format!("{prefix}{path}")) } } - }) + } else { + path.clone() + } } } -fn parse_retry_config(map: &mut HashMap) -> anyhow::Result { +fn parse_retry_config(map: &mut HashMap) -> crate::Result { let mut retry_config = RetryConfig::default(); if let Some(value) = map.remove("max_retries") { retry_config.max_retries = value.parse() - .map_err(|e| anyhow!("failed to parse max_retries: {}", e))?; + .map_err(Error::invalid_config_err("failed to parse max_retries"))?; } if let Some(value) = map.remove("retry_timeout_secs") { retry_config.retry_timeout = std::time::Duration::from_secs(value.parse() - .map_err(|e| anyhow!("failed to parse retry_timeout_sec: {}", e))? + .map_err(Error::invalid_config_err("failed to parse retry_timeout_secs"))? ); } if let Some(value) = map.remove("initial_backoff_ms") { retry_config.backoff.init_backoff = std::time::Duration::from_millis(value.parse() - .map_err(|e| anyhow!("failed to parse initial_backoff_ms: {}", e))? + .map_err(Error::invalid_config_err("failed to parse initial_backoff_ms"))? ); } if let Some(value) = map.remove("max_backoff_ms") { retry_config.backoff.max_backoff = std::time::Duration::from_millis(value.parse() - .map_err(|e| anyhow!("failed to parse max_backoff_ms: {}", e))? + .map_err(Error::invalid_config_err("failed to parse max_backoff_ms"))? ); } if let Some(value) = map.remove("backoff_exp_base") { retry_config.backoff.base = value.parse() - .map_err(|e| anyhow!("failed to parse backoff_exp_base: {}", e))?; + .map_err(Error::invalid_config_err("failed to parse backoff_exp_base"))?; } Ok(retry_config) @@ -347,70 +526,38 @@ impl Default for StaticConfig { } } -macro_rules! ensure_client { - ($response: expr, $config: expr) => { - match clients() - .try_get_with($config.get_hash(), Client::from_raw_config($config)).await - .map_err(|e| anyhow!(e)) - { - Ok(client) => { - client - }, - Err(e) => { - tracing::warn!("{}", e); - $response.into_error(e); - continue; - } - } - }; -} - -macro_rules! with_retries_and_cancellation { - ($client:expr, $response:expr, $op: expr) => { - with_retries_and_cancellation!($client, $response, $op, true) - }; - ($client:expr, $response:expr, $op: expr, $emit_warn: expr) => { - let mut retry_state = RetryState::new($client.config.retry_config.clone()); +#[macro_export] +macro_rules! with_retries { + ($this:expr, $op: expr) => {{ + let mut retry_state = crate::error::RetryState::new($this.config.retry_config.clone()); 'retry: loop { - $response.ensure_active(); - tokio::select! { - _ = $response.cancelled() => { - tracing::warn!("operation was cancelled"); - $response.into_error("operation was cancelled"); - break 'retry; - } - res = $op => { - match res { - Ok(v) => { - $response.success(v); - break 'retry; + match $op { + Err(e) => { + match retry_state.should_retry(e) { + Ok((e, info, duration)) => { + tracing::info!("retrying error (reason: {:?}) after {:?}: {}", info.reason, duration, e); + tokio::time::sleep(duration).await; + continue 'retry; } Err(e) => { - match retry_state.should_retry(&e) { - Ok((info, duration)) => { - tracing::info!("retrying error (reason: {:?}) after {:?}: {}", info.reason, duration, e); - tokio::time::sleep(duration).await; - continue 'retry; - } - Err(e) => { - if $emit_warn { - tracing::warn!("{}", e); - } - $response.into_error(e); - break 'retry; - } - } + break Err(e); } } } + ok => { + break ok; + } } } - }; + }}; } #[macro_export] macro_rules! with_cancellation { ($op:expr, $response:expr) => { + with_cancellation!($op, $response, true) + }; + ($op:expr, $response:expr, $emit_warn: expr) => { $response.ensure_active(); tokio::select! { _ = $response.cancelled() => { @@ -423,7 +570,9 @@ macro_rules! with_cancellation { $response.success(v); }, Err(e) => { - tracing::warn!("{}", e); + if $emit_warn { + tracing::warn!("{}", e); + } $response.into_error(e); } } @@ -459,6 +608,83 @@ macro_rules! destroy_with_runtime { }; } +#[macro_export] +macro_rules! export_queued_op { + ($name: ident, $response: ty, $builder: expr, $($v:ident: $t:ty),+) => { + #[no_mangle] + pub extern "C" fn $name($($v: $t),+, config: *const RawConfig, response: *mut $response, handle: *const c_void) -> CResult { + let response = unsafe { ResponseGuard::new(response, handle) }; + let config = unsafe { & (*config) }; + let req_result: Result, anyhow::Error)> = $builder(config, response); + let req = match req_result { + Ok(req) => req, + Err((response, e)) => { + response.into_error(e); + return CResult::Error; + } + }; + match SQ.get() { + Some(sq) => { + match sq.try_send(req) { + Ok(_) => CResult::Ok, + Err(flume::TrySendError::Full(req)) => { + req.into_error("object_store_ffi internal channel full, backoff"); + CResult::Backoff + } + Err(flume::TrySendError::Disconnected(req)) => { + req.into_error("object_store_ffi internal channel closed (may be missing initialization)"); + CResult::Error + } + } + } + None => { + req.into_error("object_store_ffi internal channel closed (may be missing initialization)"); + return CResult::Error; + } + } + } + + }; +} + +// TODO use macro for exporting runtime operations +#[macro_export] +macro_rules! export_runtime_op { + ($name: ident, $response: ty, $builder: expr, $state: ident, $asyncop: expr, $($v:ident: $t:ty),+) => { + #[no_mangle] + pub extern "C" fn $name( + $($v: $t),+, + response: *mut $response, + handle: *const c_void + ) -> CResult { + let response = unsafe { ResponseGuard::new(response, handle) }; + let state_result: Result<_, anyhow::Error> = $builder(); + let $state = match state_result { + Ok(s) => s, + Err(e) => { + response.into_error(e); + return CResult::Error; + } + }; + + match RT.get() { + Some(runtime) => { + runtime.spawn(async move { + let op = $asyncop; + + with_cancellation!(op, response); + }); + CResult::Ok + } + None => { + response.into_error("object_store_ffi runtime not started (may be missing initialization)"); + return CResult::Error; + } + } + } + }; +} + trait NotifyGuard { fn is_uninitialized(&mut self) -> bool; fn condition_handle(&self) -> *const c_void; @@ -567,63 +793,43 @@ pub extern "C" fn start( break; } }; + let raw_config = req.raw_config(); + let client = tokio::select! { + _ = req.cancelled() => { + tracing::warn!("operation was cancelled"); + req.into_error("operation was cancelled"); + continue; + } + res = clients().try_get_with(raw_config.get_hash(), Client::from_raw_config(raw_config)) => match res { + Ok(client) => client, + Err(e) => { + tracing::warn!("{}", e); + req.into_error(e); + continue; + } + } + }; match req { - Request::Get(path, slice, config, response) => { - let client = ensure_client!(response, config); - with_retries_and_cancellation!( - client, - response, - handle_get(client.clone(), slice, &path) - ); + Request::Get(path, slice, _config, response) => { + with_cancellation!(client.get(&path, slice), response); } - Request::Put(path, slice, config, response) => { - let client = ensure_client!(response, config); - with_retries_and_cancellation!( - client, - response, - handle_put(client.clone(), slice, &path) - ); + Request::Put(path, slice, _config, response) => { + with_cancellation!(client.put(&path, slice.into()), response); } - Request::Delete(path, config, response) => { - let client = ensure_client!(response, config); - with_retries_and_cancellation!( - client, - response, - handle_delete(client.clone(), &path), - false - ); + Request::Delete(path, _config, response) => { + with_cancellation!(client.delete(&path), response, false); } - Request::List(prefix, offset, config, response) => { - let client = ensure_client!(response, config); - with_retries_and_cancellation!( - client, - response, - handle_list(client.clone(), &prefix, offset.as_ref()) - ); + Request::List(prefix, offset, _config, response) => { + with_cancellation!(client.list(&prefix, offset.as_ref()), response); } - Request::ListStream(prefix, offset, config, response) => { - let client = ensure_client!(response, config); - with_retries_and_cancellation!( - client, - response, - handle_list_stream(client.clone(), &prefix, offset.as_ref()) - ); + Request::ListStream(prefix, offset, _config, response) => { + with_cancellation!(client.list_stream(&prefix, offset.as_ref()), response); } - Request::GetStream(path, size_hint, compression, config, response) => { - let client = ensure_client!(response, config); - with_retries_and_cancellation!( - client, - response, - handle_get_stream(client.clone(), &path, size_hint, compression) - ); + Request::GetStream(path, size_hint, compression, _config, response) => { + with_cancellation!(client.get_stream(&path, size_hint, compression), response); } - Request::PutStream(path, compression, config, response) => { - let client = ensure_client!(response, config); - with_retries_and_cancellation!( - client, - response, - handle_put_stream(client.clone(), &path, compression) - ); + Request::PutStream(path, compression, _config, response) => { + with_cancellation!(client.put_stream(&path, compression), response); } } } diff --git a/src/list.rs b/src/list.rs index 4a2ce3c..d620a98 100644 --- a/src/list.rs +++ b/src/list.rs @@ -1,40 +1,119 @@ -use crate::{destroy_with_runtime, util::cstr_to_path, with_cancellation, CResult, Client, Context, NotifyGuard, RawConfig, RawResponse, Request, ResponseGuard, RT, SQ}; +use crate::{destroy_with_runtime, duration_on_drop, export_queued_op, metrics, util::{cstr_to_path, string_to_path}, with_cancellation, with_retries, CResult, Client, Context, NotifyGuard, RawConfig, RawResponse, Request, ResponseGuard, RT, SQ}; use object_store::{path::Path, ObjectStore, ObjectMeta}; +use pin_project::{pin_project, pinned_drop}; use std::ffi::{c_char, c_void, CString}; -use futures_util::{StreamExt, stream::BoxStream}; +use futures_util::{stream::BoxStream, Stream, StreamExt}; use std::sync::Arc; -pub(crate) async fn handle_list(client: Client, prefix: &Path, offset: Option<&Path>) -> anyhow::Result> { - let stream = if let Some(offset) = offset { - client.store.list_with_offset(Some(&prefix), offset) - } else { - client.store.list(Some(&prefix)) - }; +#[pin_project(PinnedDrop)] +pub struct ListStream { + store_ptr: Option<*const dyn ObjectStore>, + #[pin] + stream: BoxStream<'static, Vec>> +} - let entries: Vec<_> = stream.collect().await; - let entries = entries.into_iter().collect::, _>>()?; - Ok(entries) +impl ListStream { + fn new(client: &Client, prefix: &Path, offset: Option<&Path>) -> ListStream { + let base_prefix = client.config.prefix.clone(); + let store_ptr = Arc::into_raw(client.store.clone()); + // Safety: we coerce this to 'static to generate a static BoxStream from it. + // We ensure the store will outlive the stream by manually dropping ListStream. + let store: &'static dyn ObjectStore = unsafe { &*store_ptr }; + let stream = match (base_prefix, offset) { + (None, None) => { + store.list(Some(&prefix)).chunks(1000).boxed() + } + (None, Some(offset)) => { + store.list_with_offset(Some(&prefix), offset).chunks(1000).boxed() + } + (Some(base), Some(offset)) => { + store.list_with_offset(Some(&prefix), offset) + // Strip internal prefixes + .scan(base, |base, meta| { + let meta = meta.map(|mut meta| { + if let Some(str) = meta.location.as_ref().strip_prefix(&*base) { + meta.location = unsafe { string_to_path(str.to_string()) }; + meta + } else { + meta + } + }); + async { Some(meta) } + }) + .chunks(1000) + .boxed() + } + (Some(base), None) => { + store.list(Some(&prefix)) + // Strip internal prefixes + .scan(base, |base, meta| { + let meta = meta.map(|mut meta| { + if let Some(str) = meta.location.as_ref().strip_prefix(&*base) { + meta.location = unsafe { string_to_path(str.to_string()) }; + meta + } else { + meta + } + }); + async { Some(meta) } + }) + .chunks(1000) + .boxed() + + } + }; + + ListStream { + store_ptr: Some(store_ptr), + stream + } + } } -pub(crate) async fn handle_list_stream(client: Client, prefix: &Path, offset: Option<&Path>) -> anyhow::Result> { - let mut wrapper = Box::new(StreamWrapper { - client: client.store, - stream: None - }); +#[pinned_drop] +impl PinnedDrop for ListStream { + fn drop(mut self: std::pin::Pin<&mut Self>) { + let ptr = self.store_ptr.take().expect("cannot drop twice"); + let arc = unsafe { Arc::from_raw(ptr) }; + let dummy_stream = Box::pin(futures_util::stream::empty::>>()); + let stream = std::mem::replace(&mut self.stream, dummy_stream); + // Safety: Must drop the stream before the arc here + drop(stream); + drop(arc); + } +} - let stream = if let Some(offset) = offset { - wrapper.client.list_with_offset(Some(&prefix), offset).chunks(1000).boxed() - } else { - wrapper.client.list(Some(&prefix)).chunks(1000).boxed() - }; - // Safety: This is needed because the compiler cannot infer that the stream - // will outlive the client. We ensure this happens - // by droping the stream before droping the Arc on destroy_list_stream - wrapper.stream = Some(unsafe { std::mem::transmute(stream) }); +unsafe impl Send for ListStream {} - Ok(wrapper) +impl Stream for ListStream { + type Item = Vec>; + fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.project().stream.poll_next(cx) + } +} + +impl Client { + async fn list_impl(&self, prefix: &Path, offset: Option<&Path>) -> crate::Result> { + let _guard = duration_on_drop!(metrics::list_attempt_duration); + let stream = self.list_stream_impl(prefix, offset).await?; + let entries: Vec<_> = stream.collect().await; + let entries = entries.into_iter().flatten().collect::, _>>()?; + Ok(entries) + } + pub(crate) async fn list(&self, path: &Path, offset: Option<&Path>) -> crate::Result> { + with_retries!(self, self.list_impl(path, offset).await) + } + async fn list_stream_impl(&self, prefix: &Path, offset: Option<&Path>) -> crate::Result { + let prefix = &self.full_path(prefix); + let offset = offset.map(|o| self.full_path(o)); + + Ok(ListStream::new(&self, prefix, offset.as_ref())) + } + pub(crate) async fn list_stream(&self, path: &Path, offset: Option<&Path>) -> crate::Result { + with_retries!(self, self.list_stream_impl(path, offset).await) + } } // Any non-Copy fields of ListEntry must be properly destroyed on destroy_list_entries @@ -142,58 +221,28 @@ impl From for ListEntry { } } -#[no_mangle] -pub extern "C" fn list( - prefix: *const c_char, - offset: *const c_char, - config: *const RawConfig, - response: *mut ListResponse, - handle: *const c_void -) -> CResult { - let response = unsafe { ResponseGuard::new(response, handle) }; - let prefix = unsafe { std::ffi::CStr::from_ptr(prefix) }; - let prefix = unsafe{ cstr_to_path(prefix) }; - let offset = if offset.is_null() { - None - } else { - Some(unsafe { cstr_to_path(std::ffi::CStr::from_ptr(offset)) }) - }; - let config = unsafe { & (*config) }; - match SQ.get() { - Some(sq) => { - match sq.try_send(Request::List(prefix, offset, config, response)) { - Ok(_) => CResult::Ok, - Err(flume::TrySendError::Full(Request::List(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel full, backoff"); - CResult::Backoff - } - Err(flume::TrySendError::Disconnected(Request::List(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - CResult::Error - } - _ => unreachable!("the response type must match") - } - } - None => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - return CResult::Error; - } - } -} - -pub struct StreamWrapper { - client: Arc, - stream: Option>>> -} +export_queued_op!( + list, + ListResponse, + |config, response| { + let prefix = unsafe { std::ffi::CStr::from_ptr(prefix) }; + let prefix = unsafe{ cstr_to_path(prefix) }; + let offset = if offset.is_null() { + None + } else { + Some(unsafe { cstr_to_path(std::ffi::CStr::from_ptr(offset)) }) + }; + Ok(Request::List(prefix, offset, config, response)) + }, + prefix: *const c_char, offset: *const c_char +); #[no_mangle] pub extern "C" fn destroy_list_stream( - stream: *mut StreamWrapper + stream: *mut ListStream ) -> CResult { destroy_with_runtime!({ - let mut boxed = unsafe { Box::from_raw(stream) }; - // Safety: Must drop the stream before the client here - drop(boxed.stream.take()); + let boxed = unsafe { Box::from_raw(stream) }; drop(boxed); }) } @@ -201,7 +250,7 @@ pub extern "C" fn destroy_list_stream( #[repr(C)] pub struct ListStreamResponse { result: CResult, - stream: *mut StreamWrapper, + stream: *mut ListStream, error_message: *mut c_char, context: *const Context } @@ -209,7 +258,7 @@ pub struct ListStreamResponse { unsafe impl Send for ListStreamResponse {} impl RawResponse for ListStreamResponse { - type Payload = Box; + type Payload = ListStream; fn result_mut(&mut self) -> &mut CResult { &mut self.result } @@ -222,7 +271,7 @@ impl RawResponse for ListStreamResponse { fn set_payload(&mut self, payload: Option) { match payload { Some(stream) => { - self.stream = Box::into_raw(stream); + self.stream = Box::into_raw(Box::new(stream)); } None => { self.stream = std::ptr::null_mut(); @@ -231,53 +280,30 @@ impl RawResponse for ListStreamResponse { } } -#[no_mangle] -pub extern "C" fn list_stream( - prefix: *const c_char, - offset: *const c_char, - config: *const RawConfig, - response: *mut ListStreamResponse, - handle: *const c_void -) -> CResult { - let response = unsafe { ResponseGuard::new(response, handle) }; - let prefix = unsafe { std::ffi::CStr::from_ptr(prefix) }; - let prefix = unsafe{ cstr_to_path(prefix) }; - let offset = if offset.is_null() { - None - } else { - Some(unsafe { cstr_to_path(std::ffi::CStr::from_ptr(offset)) }) - }; - let config = unsafe { & (*config) }; - match SQ.get() { - Some(sq) => { - match sq.try_send(Request::ListStream(prefix, offset, config, response)) { - Ok(_) => CResult::Ok, - Err(flume::TrySendError::Full(Request::ListStream(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel full, backoff"); - CResult::Backoff - } - Err(flume::TrySendError::Disconnected(Request::ListStream(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - CResult::Error - } - _ => unreachable!("the response type must match") - } - } - None => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - return CResult::Error; - } - } -} +export_queued_op!( + list_stream, + ListStreamResponse, + |config, response| { + let prefix = unsafe { std::ffi::CStr::from_ptr(prefix) }; + let prefix = unsafe{ cstr_to_path(prefix) }; + let offset = if offset.is_null() { + None + } else { + Some(unsafe { cstr_to_path(std::ffi::CStr::from_ptr(offset)) }) + }; + Ok(Request::ListStream(prefix, offset, config, response)) + }, + prefix: *const c_char, offset: *const c_char +); #[no_mangle] pub extern "C" fn next_list_stream_chunk( - stream: *mut StreamWrapper, + stream: *mut ListStream, response: *mut ListResponse, handle: *const c_void ) -> CResult { let response = unsafe { ResponseGuard::new(response, handle) }; - let wrapper = match unsafe { stream.as_mut() } { + let stream = match unsafe { stream.as_mut() } { Some(w) => w, None => { response.into_error("null stream pointer"); @@ -289,8 +315,7 @@ pub extern "C" fn next_list_stream_chunk( Some(runtime) => { runtime.spawn(async move { let list_op = async { - let stream_ref = wrapper.stream.as_mut().unwrap(); - let option = match stream_ref.next().await { + let option = match stream.next().await { Some(vec) => { vec.into_iter().collect::, _>>()? } diff --git a/src/metrics.rs b/src/metrics.rs index 3c31e3d..01f1e34 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,5 +1,85 @@ use std::alloc::{GlobalAlloc, Layout, System}; -use std::sync::atomic::{AtomicI64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; +use std::time::Instant; + +use metrics::{describe_histogram, histogram, Unit}; + +macro_rules! metric_const { + ($name: ident) => { + #[allow(non_upper_case_globals)] + pub(crate) const $name: &'static str = stringify!($name); + }; +} + +metric_const!(get_attempt_duration); +metric_const!(put_attempt_duration); +metric_const!(delete_attempt_duration); +metric_const!(multipart_get_attempt_duration); +metric_const!(multipart_put_attempt_duration); +metric_const!(material_for_write_duration); +metric_const!(material_from_metadata_duration); +metric_const!(list_attempt_duration); +metric_const!(sf_heartbeat_duration); +metric_const!(sf_token_refresh_duration); +metric_const!(sf_token_login_duration); +metric_const!(sf_query_attempt_duration); +metric_const!(sf_fetch_upload_info_retried_duration); +metric_const!(sf_fetch_path_info_retried_duration); +metric_const!(sf_get_presigned_url_retried_duration); +metric_const!(total_get_ops); +metric_const!(total_put_ops); +metric_const!(total_delete_ops); +metric_const!(total_keyring_get); +metric_const!(total_keyring_miss); +metric_const!(total_fetch_upload_info); +metric_const!(total_fetch_path_info); + +#[allow(dead_code)] +pub fn init_metrics() { + describe_histogram!(get_attempt_duration, Unit::Seconds, "Duration of a get operation attempt"); + describe_histogram!(put_attempt_duration, Unit::Seconds, "Duration of a put operation attempt"); + describe_histogram!(delete_attempt_duration, Unit::Seconds, "Duration of a delete operation attempt"); + describe_histogram!(multipart_get_attempt_duration, Unit::Seconds, "Duration of a multipart get operation attempt"); + describe_histogram!(multipart_put_attempt_duration, Unit::Seconds, "Duration of a multipart put operation attempt"); + describe_histogram!(material_from_metadata_duration, Unit::Seconds, "Time to get a potentially cached key"); + describe_histogram!(material_for_write_duration, Unit::Seconds, "Time to fetch a potentially cached key for writes"); + describe_histogram!(list_attempt_duration, Unit::Seconds, "Duration of a list operation attempt"); + describe_histogram!(sf_token_refresh_duration, Unit::Seconds, "Time to refresh a token from SF"); + describe_histogram!(sf_token_login_duration, Unit::Seconds, "Time to get the first token from SF"); + describe_histogram!(sf_query_attempt_duration, Unit::Seconds, "Time to perform a SF query attempt"); + describe_histogram!(sf_fetch_upload_info_retried_duration, Unit::Seconds, "Time to fetch a new write key from SF"); + describe_histogram!(sf_fetch_path_info_retried_duration, Unit::Seconds, "Time to fetch the key for a path from SF"); + describe_histogram!(sf_get_presigned_url_retried_duration, Unit::Seconds, "Time to fetch a presigned url from SF"); +} + +pub(crate) struct DurationGuard { + name: &'static str, + t0: Instant, + discarded: AtomicBool +} +impl DurationGuard { + pub(crate) fn new(name: &'static str) -> DurationGuard { + DurationGuard { name, t0: Instant::now(), discarded: AtomicBool::new(false) } + } + pub(crate) fn discard(&self) { + self.discarded.store(true, Ordering::Relaxed) + } +} + +impl Drop for DurationGuard { + fn drop(&mut self) { + if !self.discarded.load(Ordering::Relaxed) { + histogram!(self.name).record(Instant::now() - self.t0); + } + } +} + +#[macro_export] +macro_rules! duration_on_drop { + ($name: expr) => { + crate::metrics::DurationGuard::new($name) + }; +} #[derive(Debug, Default)] #[repr(C)] diff --git a/src/snowflake/client.rs b/src/snowflake/client.rs new file mode 100644 index 0000000..8bbd932 --- /dev/null +++ b/src/snowflake/client.rs @@ -0,0 +1,775 @@ +use ::metrics::counter; +use object_store::RetryConfig; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{collections::HashMap, sync::Arc, time::{Duration, Instant, SystemTime, UNIX_EPOCH}}; +use tokio::sync::Mutex; +use zeroize::Zeroize; +use moka::future::Cache; +use crate::{duration_on_drop, error::{Error, RetryState}, metrics}; +use crate::util::{deserialize_str, deserialize_slice}; +// use anyhow::anyhow; + + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub(crate) enum SnowflakeResponse { + Success { + data: T, + success: bool, + }, + Error { + data: SnowflakeErrorData, + code: String, + message: String, + success: bool, + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SnowflakeTokenData { + session_token: String, + #[serde(rename = "validityInSecondsST")] + validity_in_seconds_st: u64, + master_token: String, + #[serde(rename = "validityInSecondsMT")] + validity_in_seconds_mt: u64 +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SnowflakeLoginData { + token: String, + validity_in_seconds: u64, + master_token: String, + master_validity_in_seconds: u64 +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct SnowflakeColType { + name: String +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SnowflakeResultChunk { + url: String, + row_count: usize, + uncompressed_size: usize, + compressed_size: usize +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SnowflakeQueryData { + rowtype: Vec, + rowset: Vec>, + total: usize, + returned: usize, + #[serde(default)] + chunk_headers: HashMap, + #[serde(default)] + chunks: Vec +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub(crate) struct SnowflakeStageCreds { + pub aws_key_id: String, + pub aws_secret_key: String, + pub aws_token: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SnowflakeStageInfo { + pub location_type: String, + pub location: String, + pub path: String, + pub region: String, + pub storage_account: Option, + pub is_client_side_encrypted: bool, + pub ciphers: Option, + pub creds: SnowflakeStageCreds, + pub use_s3_regional_url: bool, + pub end_point: Option, + // This field is not part of the gateway API + // it is only used for testing purposes + pub test_endpoint: Option +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "stage_type")] +#[non_exhaustive] +pub(crate) enum NormalizedStageInfo { + S3 { + bucket: String, + prefix: String, + region: String, + aws_key_id: String, + aws_secret_key: String, + aws_token: String, + #[serde(skip_serializing_if = "Option::is_none")] + end_point: Option, + #[serde(skip_serializing_if = "Option::is_none")] + test_endpoint: Option, + }, + BlobStorage { + storage_account: String, + container: String, + prefix: String, + #[serde(skip_serializing_if = "Option::is_none")] + end_point: Option, + #[serde(skip_serializing_if = "Option::is_none")] + test_endpoint: Option + } + +} + +impl TryFrom<&SnowflakeStageInfo> for NormalizedStageInfo { + type Error = crate::error::Error; + fn try_from(value: &SnowflakeStageInfo) -> Result { + if value.location_type == "S3" { + let (bucket, prefix) = value.location.split_once('/') + .ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the bucket name"))?; + return Ok(NormalizedStageInfo::S3 { + bucket: bucket.to_string(), + prefix: prefix.to_string(), + region: value.region.clone(), + aws_key_id: value.creds.aws_key_id.clone(), + aws_secret_key: value.creds.aws_secret_key.clone(), + aws_token: value.creds.aws_token.clone(), + end_point: value.end_point.clone(), + test_endpoint: value.test_endpoint.clone() + }) + } else { + return Err(Error::not_implemented("Azure BlobStorage is not implemented")); + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[derive(Clone)] +pub(crate) struct SnowflakeEncryptionMaterial { + pub query_stage_master_key: String, + pub query_id: String, + pub smk_id: u64, +} + +impl Drop for SnowflakeEncryptionMaterial { + fn drop(&mut self) { + self.query_stage_master_key.zeroize(); + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SnowflakeUploadData { + pub query_id: String, + pub encryption_material: Option, + pub stage_info: SnowflakeStageInfo, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SnowflakeDownloadData { + pub query_id: String, + #[serde(rename = "src_locations")] + pub src_locations: Vec, + pub encryption_material: Vec>, + pub stage_info: SnowflakeStageInfo, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct SnowflakeErrorData { + query_id: String +} +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +enum SnowflakeQueryResponse { + Success { + data: SnowflakeQueryData, + success: bool, + }, + Upload { + data: SnowflakeUploadData, + success: bool, + }, + Download { + data: SnowflakeDownloadData, + success: bool, + }, + Error { + data: SnowflakeErrorData, + code: String, + message: String, + success: bool, + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct SnowflakeQueryStatus { + id: String, + status: String, + state: String, + error_code: Option, + error_message: Option +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct SnowflakeStatusData { + queries: Vec +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct SnowflakeStatusResponse { + data: SnowflakeStatusData +} + +#[derive(Debug, Clone)] +pub(crate) struct SnowflakeClientConfig { + pub account: String, + pub database: String, + pub endpoint: String, + pub schema: String, + pub warehouse: Option, + pub username: Option, + pub password: Option, + pub role: Option, + pub master_token_path: Option, + pub stage_info_cache_ttl: Option, + pub retry_config: RetryConfig +} + +impl SnowflakeClientConfig { + #[allow(unused)] + pub(crate) fn from_env() -> anyhow::Result { + use std::env; + Ok(SnowflakeClientConfig { + account: env::var("SNOWFLAKE_ACCOUNT")?, + database: env::var("SNOWFLAKE_DATABASE")?, + endpoint: env::var("SNOWFLAKE_ENDPOINT") + .or(env::var("SNOWFLAKE_HOST").map(|h| format!("https://{h}")))?, + schema: env::var("SNOWFLAKE_SCHEMA")?, + warehouse: env::var("SNOWFLAKE_WAREHOUSE").ok(), + username: env::var("SNOWFLAKE_USERNAME").ok(), + password: env::var("SNOWFLAKE_PASSWORD").ok(), + role: env::var("SNOWFLAKE_ROLE").ok(), + master_token_path: env::var("MASTER_TOKEN_PATH").ok(), + stage_info_cache_ttl: None, + retry_config: RetryConfig::default() + }) + } +} + +struct TokenState { + token: String, + expiration: Instant, + master_token: String, + #[allow(unused)] + master_expiration: Instant +} + +#[derive(Clone)] +pub(crate) struct SnowflakeClient { + config: SnowflakeClientConfig, + client: reqwest::Client, + token: Arc>>, + stage_info_cache: Cache> +} + +impl std::fmt::Debug for SnowflakeClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SnowflakeClient") + .field("config", &self.config) + .finish() + } +} + +impl SnowflakeClient { + pub(crate) fn new(config: SnowflakeClientConfig) -> Arc { + let cache_ttl = config.stage_info_cache_ttl.unwrap_or(Duration::from_secs(40 * 60)); + let client = SnowflakeClient { + config, + client: reqwest::Client::builder() + .timeout(Duration::from_secs(180)) + .build().unwrap(), + token: Arc::new(Mutex::new(None)), + stage_info_cache: Cache::builder() + .max_capacity(10) + // Time to live here manages the stage token lifecycle, removing it from the cache + // prior to expiration + .time_to_live(cache_ttl) + .build() + }; + + let client = Arc::new(client); + + { + let client = Arc::downgrade(&client); + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(5 * 60)); + interval.tick().await; + loop { + interval.tick().await; + if let Some(client) = client.upgrade() { + match client.heartbeat().await { + Ok(_) => {}, + Err(e) => { + tracing::warn!("Heartbeat failed: {:#}", e); + } + } + } else { + // Client was dropped, stop heartbeat + break; + } + } + }); + } + + client + } + #[allow(unused)] + pub(crate) fn from_env() -> anyhow::Result> { + Ok(SnowflakeClient::new(SnowflakeClientConfig::from_env()?)) + } + async fn heartbeat(&self) -> crate::Result { + let token = { + let locked = self.token.lock().await; + match locked.as_ref() { + Some(TokenState { token, .. }) => token.clone(), + _ => { + return Ok(false); + } + } + }; + + let _guard = duration_on_drop!(metrics::sf_heartbeat_duration); + let response = self.client.post(format!("{}/session/heartbeat", self.config.endpoint)) + .header("Authorization", format!("Snowflake Token=\"{}\"", token)) + .send() + .await?; + + response.error_for_status_ref()?; + + return Ok(true); + } + pub(crate) async fn token(&self) -> crate::Result { + let mut locked = self.token.lock().await; + match locked.as_ref() { + Some(TokenState { token, expiration, .. }) if Instant::now() + Duration::from_secs(180) < *expiration => { + return Ok(token.clone()); + } + _ => {} + } + + let config = &self.config; + + if let Some(TokenState { token, master_token, .. }) = locked.as_ref() { + let _guard = duration_on_drop!(metrics::sf_token_refresh_duration); + // Renew + let response = self.client.post(format!("{}/session/token-request", config.endpoint)) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .header("Authorization", format!("Snowflake Token=\"{}\"", master_token)) + .query(&[ + ("requestId", uuid::Uuid::new_v4().to_string()) + ]) + .json(&serde_json::json!({ + "oldSessionToken": token, + "requestType": "RENEW" + })) + .send() + .await?; + + response.error_for_status_ref() + .inspect_err(|_| *locked = None)?; + + let token_response_bytes = response + .bytes() + .await + .inspect_err(|_| *locked = None)?; + + let token_response: SnowflakeResponse = deserialize_slice(&token_response_bytes) + .map_err(Error::deserialize_response_err("token")) + .inspect_err(|_| *locked = None)?; + + match token_response { + SnowflakeResponse::Success { data, .. } => { + *locked = Some(TokenState { + token: data.session_token.clone(), + expiration: Instant::now() + Duration::from_secs(data.validity_in_seconds_st), + master_token: data.master_token, + master_expiration: Instant::now() + Duration::from_secs(data.validity_in_seconds_mt) + }); + + return Ok(data.session_token); + + } + SnowflakeResponse::Error { data, code, message, .. } => { + *locked = None; + return Err(Error::error_response(format!("Error (code: {}, query_id: {}): {}", code, data.query_id, message)).into()) + } + } + } else { + let _guard = duration_on_drop!(metrics::sf_token_login_duration); + let response = if let (Some(username), Some(password)) = (&config.username, &config.password) { + // User Password Login + let mut qs = vec![ + ("accountName", &config.account), + ("databaseName", &config.database), + ("schemaName", &config.schema), + ]; + + if let Some(warehouse) = config.warehouse.as_ref() { + qs.push(("warehouse", warehouse)); + } + + if let Some(role) = config.role.as_ref() { + qs.push(("roleName", role)); + } + + let response = self.client.post(format!("{}/session/v1/login-request", config.endpoint)) + .header("Content-Type", "application/json") + .header("Accept", "application/snowflake") + .query(&qs) + .json(&serde_json::json!({ + "data": { + "PASSWORD": password, + "LOGIN_NAME": username, + "ACCOUNT_NAME": &config.account, + "AUTHENTICATOR": "USERNAME_PASSWORD_MFA" + } + })) + .send() + .await?; + + response + } else { + // Master Token Login + let master_token = std::fs::read_to_string(config.master_token_path.as_deref().unwrap_or("/snowflake/session/token")) + .map_err(Error::invalid_config_err("Unable to access master token file"))?; + + + let mut qs = vec![ + ("accountName", &config.account), + ("databaseName", &config.database), + ("schemaName", &config.schema), + ]; + + if let Some(warehouse) = config.warehouse.as_ref() { + qs.push(("warehouse", warehouse)); + } + + let response = self.client.post(format!("{}/session/v1/login-request", config.endpoint)) + .header("Content-Type", "application/json") + .header("Accept", "application/snowflake") + .query(&qs) + .json(&serde_json::json!({ + "data": { + "ACCOUNT_NAME": &config.account, + "TOKEN": &master_token, + "AUTHENTICATOR": "OAUTH" + } + })) + .send() + .await?; + + response + }; + + response.error_for_status_ref()?; + + let login_response_bytes = response + .bytes() + .await?; + + let login_response: SnowflakeResponse = deserialize_slice(&login_response_bytes) + .map_err(Error::deserialize_response_err("login"))?; + + match login_response { + SnowflakeResponse::Success { data, .. } => { + *locked = Some(TokenState { + token: data.token.clone(), + expiration: Instant::now() + Duration::from_secs(data.validity_in_seconds), + master_token: data.master_token, + master_expiration: Instant::now() + Duration::from_secs(data.master_validity_in_seconds) + }); + + return Ok(data.token); + + } + SnowflakeResponse::Error { data, code, message, .. } => { + return Err(Error::error_response(format!("Error (code: {}, query_id: {}): {}", code, data.query_id, message)).into()) + } + } + } + } + async fn query_impl(&self, query: impl AsRef) -> crate::Result { + let token = self.token().await?; + let _guard = duration_on_drop!(metrics::sf_query_attempt_duration); + let config = &self.config; + let response = self.client.post(format!("{}/queries/v1/query-request", config.endpoint)) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .header("Authorization", format!("Snowflake Token=\"{}\"", token)) + .query(&[ + ("clientStartTime", SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs().to_string()), + ("requestId", uuid::Uuid::new_v4().to_string()), + ("request_guid", uuid::Uuid::new_v4().to_string()) + ]) + .json(&serde_json::json!({ + "sqlText": query.as_ref(), + "asyncExec": false, + "sequenceId": 1, + "isInternal": false + })) + .send() + .await?; + + response.error_for_status_ref()?; + + let response_string = response + .text() + .await?; + + let response: SnowflakeResponse = deserialize_str(&response_string) + .map_err(Error::deserialize_response_err("query"))?; + + match response { + SnowflakeResponse::Success { data, .. } => { + Ok(data) + } + SnowflakeResponse::Error { data, code, message, .. } => { + Err(Error::error_response(format!("Error (code: {}, query_id: {}): {}", code, data.query_id, message))) + } + } + } + async fn query(&self, query: impl AsRef) -> crate::Result { + let mut retry_state = RetryState::new(self.config.retry_config.clone()); + 'retry: loop { + match self.query_impl(query.as_ref()).await { + Err(e) => { + match retry_state.should_retry(e) { + Ok((e, info, duration)) => { + tracing::info!("retrying snowflake query error (reason: {:?}) after {:?}: {}", info.reason, duration, e); + tokio::time::sleep(duration).await; + continue 'retry; + } + Err(e) => { + break Err(e); + } + } + } + ok => { + break ok; + } + } + } + } + pub(crate) async fn fetch_upload_info(&self, stage: impl AsRef) -> crate::Result { + let _guard = duration_on_drop!(metrics::sf_fetch_upload_info_retried_duration); + let upload_data: SnowflakeUploadData = self + .query(format!("PUT file:///tmp/whatever @{}", stage.as_ref())) + .await?; + + counter!(metrics::total_fetch_upload_info).increment(1); + Ok(upload_data) + } + pub(crate) async fn fetch_path_info(&self, stage: impl AsRef, path: impl AsRef) -> crate::Result { + let _guard = duration_on_drop!(metrics::sf_fetch_path_info_retried_duration); + let download_data: SnowflakeDownloadData = self + .query(format!("GET @{}/{} file:///tmp/whatever", stage.as_ref(), path.as_ref())) + .await?; + + counter!(metrics::total_fetch_path_info).increment(1); + Ok(download_data) + } + #[allow(unused)] + pub(crate) async fn get_presigned_url(&self, stage: impl AsRef, path: impl AsRef) -> crate::Result { + let _guard = duration_on_drop!(metrics::sf_get_presigned_url_retried_duration); + let query_data: SnowflakeQueryData = self + .query(format!("SELECT get_presigned_url(@{}, '{}')", stage.as_ref(), path.as_ref())) + .await?; + let url = query_data.rowset + .get(0) + .and_then(|r| r.get(0)) + .and_then(|c| c.as_str()) + .ok_or_else(|| Error::invalid_response("Missing url from get_presigned_url response"))?; + Ok(url.to_string()) + } + pub(crate) async fn current_upload_info(&self, stage: impl AsRef) -> crate::Result> { + let stage = stage.as_ref(); + let stage_info = self.stage_info_cache.try_get_with_by_ref(stage, async { + let info = self.fetch_upload_info(stage).await?; + + // TODO schedule task to update token before deadline + Ok::<_, Error>(Arc::new(info)) + }).await?; + Ok(stage_info) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::AtomicUsize; + + use crate::{metrics::init_metrics, Client}; + use super::*; + // use futures_util::StreamExt; + use ::metrics::Unit; + use metrics_util::debugging::Snapshot; + use object_store::path::Path; + + fn format_unit(unit: Option, v: f64) -> String { + match unit { + Some(Unit::Seconds) => format!("{:?}", Duration::from_secs_f64(v)), + _ => format!("{}", v) + } + } + fn print_snapshot(snapshot: Snapshot) { + println!("======================="); + for (key, unit, _shared, value) in snapshot.into_vec() { + let (_kind, key) = key.into_parts(); + let f = |v| format_unit(unit, v); + let value_str = match value { + metrics_util::debugging::DebugValue::Histogram(mut vals) => { + vals.sort(); + let p = |v: f64| { + if vals.len() == 0 { + 0.0f64 + } else { + *vals[((vals.len() as f64 * v).floor() as usize).min(vals.len() - 1)] + } + }; + format!("min: {}, p50: {}, p99: {}, p999: {}, max: {}", f(p(0.0)), f(p(0.5)), f(p(0.99)), f(p(0.999)), f(p(1.0))) + } + metrics_util::debugging::DebugValue::Counter(v) => { + format!("total: {}", v) + } + metrics_util::debugging::DebugValue::Gauge(v) => { + format!("current: {}", f(*v)) + } + }; + println!("{} {}", key.name(), value_str); + } + println!("======================="); + } + + #[tokio::test(flavor = "multi_thread")] + async fn snowflake_gateway_stress_test() -> anyhow::Result<()> { + if std::env::var("SNOWFLAKE_HOST").is_ok() { + let recorder = metrics_util::debugging::DebuggingRecorder::new(); + let snapshotter = recorder.snapshotter(); + recorder.install()?; + + init_metrics(); + + let cancel_token = tokio_util::sync::CancellationToken::new(); + { + let cancel_token = cancel_token.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(1)); + interval.tick().await; + loop { + tokio::select! { + _ = cancel_token.cancelled() => { + print_snapshot(snapshotter.snapshot()); + break; + } + _ = interval.tick() => { + print_snapshot(snapshotter.snapshot()); + } + } + } + + }); + } + let _guard = cancel_token.drop_guard(); + + let config_map = serde_json::from_value(serde_json::json!({ + "url": "snowflake://ENCRYPTED_STAGE_2", + "snowflake_stage": "ENCRYPTED_STAGE_2", + "max_retries": "3", + // "snowflake_stage_info_cache_ttl_secs": "0", + "snowflake_keyring_ttl_secs": "10000", + "snowflake_encryption_scheme": "AES_128_CBC" + }))?; + + let client = Client::from_config_map(config_map).await?; + + let n_files = 20_000; + let concurrency = 512; + +// let t0 = std::time::Instant::now(); +// futures_util::stream::iter(0..n_files) +// .map(|i| { +// let client = client.clone(); +// return tokio::spawn(async move { +// let _ = client.put(&Path::from(format!("_blobs/{:08}.bin", i)), bytes::Bytes::new()).await.unwrap(); +// }) +// }) +// .buffer_unordered(64) +// .for_each(|_| async {}).await; +// println!("{:?}", t0.elapsed()); +// +// let t0 = std::time::Instant::now(); +// let list = client.list(&Path::from("_blobs"), None).await.unwrap(); +// println!("{:?}", list.len()); +// println!("{:?}", t0.elapsed()); + + let error_count = Arc::new(AtomicUsize::new(0)); + let (tx, rx) = flume::unbounded::(); + + let mut tasks = vec![]; + for _taskn in 0..concurrency { + let client = client.clone(); + let error_count = error_count.clone(); + let rx = rx.clone(); + tasks.push(tokio::spawn(async move { + loop { + let i = match rx.recv_async().await { + Ok(r) => r, + _ => { + break; + } + }; + + let mut buf = [0; 10]; + let res = client.get(&Path::from(format!("_blobs/{:08}.bin", i)), &mut buf).await; + if let Err(e) = res { + if !format!("{e}").contains("Generic S3 error") { + println!("failed with: {e}"); + } + error_count.fetch_add(1, std::sync::atomic::Ordering::AcqRel); + } + } + })); + } + + let t0 = std::time::Instant::now(); + for _round in 0..10 { + for i in 0..n_files { + tx.send_async(i).await?; + } + } + drop(tx); + drop(rx); + + for task in tasks { + task.await?; + } + println!("{:?}", t0.elapsed()); + println!("error_count {:?}", error_count); + } else { + println!("Ignoring test as it is not running on an SPCS service"); + assert!(true); + } + Ok(()) + } +} diff --git a/src/snowflake/kms.rs b/src/snowflake/kms.rs new file mode 100644 index 0000000..789791b --- /dev/null +++ b/src/snowflake/kms.rs @@ -0,0 +1,195 @@ +use crate::{duration_on_drop, encryption::{ContentCryptoMaterial, CryptoMaterialProvider, CryptoScheme, EncryptedKey, Iv, Key}, error::{Error, ErrorExt}, metrics, snowflake::SnowflakeClient, util::deserialize_str}; +use crate::error::Kind as ErrorKind; + +use ::metrics::counter; +use serde::{Serialize, Deserialize}; +use object_store::{Attributes, Attribute, AttributeValue}; +use anyhow::Context; +use moka::future::Cache; +use std::sync::Arc; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct MaterialDescription { + pub smk_id: String, + pub query_id: String, + pub key_size: String +} + +#[derive(Clone, Debug)] +pub(crate) struct SnowflakeStageKmsConfig { + pub crypto_scheme: CryptoScheme, + pub keyring_capacity: usize, + pub keyring_ttl: std::time::Duration +} + +impl Default for SnowflakeStageKmsConfig { + fn default() -> Self { + SnowflakeStageKmsConfig { + crypto_scheme: CryptoScheme::Aes128Cbc, + keyring_capacity: 100_000, + // We keep the ttl at 10 minutes to preserve the SF TSS guarantee + // that data cannot be decrypted after this period if the customer + // revokes the customer key + keyring_ttl: std::time::Duration::from_secs(10 * 60) + } + } +} + +#[derive(Clone)] +pub(crate) struct SnowflakeStageKms { + client: Arc, + stage: String, + prefix: String, + config: SnowflakeStageKmsConfig, + keyring: Cache +} + +impl std::fmt::Debug for SnowflakeStageKms { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SnowflakeStageKms") + .field("client", &self.client) + .field("stage", &self.stage) + .field("config", &self.config) + .field("keyring", &"redacted") + .finish() + } +} + +impl SnowflakeStageKms { + pub(crate) fn new( + client: Arc, + stage: impl Into, + prefix: impl Into, + config: SnowflakeStageKmsConfig + ) -> SnowflakeStageKms { + SnowflakeStageKms { + client, + stage: stage.into(), + prefix: prefix.into(), + keyring: Cache::builder() + .max_capacity(config.keyring_capacity as u64) + .time_to_live(config.keyring_ttl) + .build(), + config + } + } +} + +#[async_trait::async_trait] +impl CryptoMaterialProvider for SnowflakeStageKms { + async fn material_for_write(&self, _path: &str, data_len: Option) -> crate::Result<(ContentCryptoMaterial, Attributes)> { + let _guard = duration_on_drop!(metrics::material_for_write_duration); + let info = self.client.current_upload_info(&self.stage).await?; + + let encryption_material = info.encryption_material.as_ref() + .ok_or_else(|| ErrorKind::StorageNotEncrypted(self.stage.clone()))?; + let description = MaterialDescription { + smk_id: encryption_material.smk_id.to_string(), + query_id: encryption_material.query_id.clone(), + key_size: "128".to_string() + }; + let master_key = Key::from_base64(&encryption_material.query_stage_master_key) + .map_err(ErrorKind::MaterialDecode)?; + + let scheme = self.config.crypto_scheme; + let mut material = ContentCryptoMaterial::generate(scheme); + let encrypted_cek = material.cek.clone().encrypt_aes_128_ecb(&master_key) + .map_err(ErrorKind::MaterialCrypt)?; + + let mut attributes = Attributes::new(); + attributes.insert( + Attribute::Metadata("x-amz-key".into()), + AttributeValue::from(encrypted_cek.as_base64()) + ); + attributes.insert( + Attribute::Metadata("x-amz-iv".into()), + AttributeValue::from(material.iv.as_base64()) + ); + if let Some(data_len) = data_len { + attributes.insert( + Attribute::Metadata("x-amz-unencrypted-content-length".into()), + AttributeValue::from(format!("{}", data_len)) + ); + } + attributes.insert( + Attribute::Metadata("x-amz-matdesc".into()), + AttributeValue::from(serde_json::to_string(&description).context("failed to encode matdesc").to_err()?) + ); + + let cek_alg = match scheme { + CryptoScheme::Aes256Gcm => { + let cek_alg = "AES/GCM/NoPadding"; + material = material.with_aad(cek_alg); + cek_alg + }, + CryptoScheme::Aes128Cbc => "AES/CBC/PKCS5Padding" + }; + + attributes.insert( + Attribute::Metadata("x-amz-cek-alg".into()), + AttributeValue::from(cek_alg) + ); + + Ok((material, attributes)) + } + + async fn material_from_metadata(&self, path: &str, attr: &Attributes) -> crate::Result { + let _guard = duration_on_drop!(metrics::material_from_metadata_duration); + let path = path.strip_prefix(&self.prefix).unwrap_or(path); + let required_attribute = |key: &'static str| { + let v: &str = attr.get(&Attribute::Metadata(key.into())) + .ok_or_else(|| Error::required_config(format!("missing required attribute `{}`", key)))? + .as_ref(); + Ok::<_, Error>(v) + }; + + + let material_description: MaterialDescription = deserialize_str(required_attribute("x-amz-matdesc")?) + .map_err(Error::deserialize_response_err("failed to deserialize matdesc"))?; + + let master_key = self.keyring.try_get_with(material_description.query_id, async { + let info = self.client.fetch_path_info(&self.stage, path).await?; + let position = info.src_locations.iter().position(|l| l == path) + .ok_or_else(|| Error::invalid_response("path not found"))?; + let encryption_material = info.encryption_material.get(position) + .cloned() + .ok_or_else(|| Error::invalid_response("src locations and encryption material length mismatch"))? + .ok_or_else(|| Error::invalid_response("path not encrypted"))?; + + let master_key = Key::from_base64(&encryption_material.query_stage_master_key) + .map_err(ErrorKind::MaterialDecode)?; + counter!(metrics::total_keyring_miss).increment(1); + Ok::<_, Error>(master_key) + }).await?; + counter!(metrics::total_keyring_get).increment(1); + + let cek = EncryptedKey::from_base64(required_attribute("x-amz-key")?) + .map_err(ErrorKind::MaterialDecode)?; + let cek = cek.decrypt_aes_128_ecb(&master_key) + .map_err(ErrorKind::MaterialCrypt)?; + let iv = Iv::from_base64(required_attribute("x-amz-iv")?) + .map_err(ErrorKind::MaterialDecode)?; + let alg = required_attribute("x-amz-cek-alg"); + + let scheme = match alg { + Ok("AES/GCM/NoPadding") => CryptoScheme::Aes256Gcm, + Ok("AES/CBC/PKCS5Padding") | Err(_) => CryptoScheme::Aes128Cbc, + Ok(v) => unimplemented!("cek alg `{}` not implemented", v) + }; + + let aad = match alg { + Ok("AES/GCM/NoPadding") => Some("AES/GCM/NoPadding".into()), + _ => None + }; + + let content_material = ContentCryptoMaterial { + scheme, + cek, + iv, + aad + }; + + Ok(content_material) + } +} diff --git a/src/snowflake/mod.rs b/src/snowflake/mod.rs new file mode 100644 index 0000000..66e6c8c --- /dev/null +++ b/src/snowflake/mod.rs @@ -0,0 +1,310 @@ +use crate::{clients, encryption::{CryptoMaterialProvider, CryptoScheme}, error::{Error, ErrorExt, Kind as ErrorKind}, CResult, Client, ClientExtension, Context, Extension, NotifyGuard, RawConfig, RawResponse, ResponseGuard}; +use crate::{RT, with_cancellation}; + +pub(crate) mod client; +use anyhow::Context as AnyhowContext; +use client::{NormalizedStageInfo, SnowflakeClient, SnowflakeClientConfig}; + +pub(crate) mod kms; +use kms::{SnowflakeStageKms, SnowflakeStageKmsConfig}; + +use object_store::{RetryConfig, ObjectStore}; +use tokio::sync::Mutex; +use std::sync::Arc; + +use std::collections::HashMap; +use std::ffi::{CString, c_char, c_void}; + +#[derive(Debug)] +pub(crate) struct SnowflakeS3Extension { + stage: String, + client: Arc +} + +#[async_trait::async_trait] +impl Extension for SnowflakeS3Extension { + fn as_any(&self) -> &dyn std::any::Any { + self + } + async fn current_stage_info(&self) -> crate::Result { + let stage_info = &self + .client + .current_upload_info(&self.stage) + .await? + .stage_info; + let stage_info: NormalizedStageInfo = stage_info.try_into()?; + let string = serde_json::to_string(&stage_info) + .context("failed to encode stage_info as json").to_err()?; + Ok(string) + } +} + +#[derive(Debug)] +pub(crate) struct S3StageCredentialProvider { + stage: String, + client: Arc, + cached: Mutex>> +} + +impl S3StageCredentialProvider { + pub(crate) fn new(stage: impl AsRef, client: Arc) -> S3StageCredentialProvider { + S3StageCredentialProvider { stage: stage.as_ref().to_string(), client, cached: Mutex::new(None) } + } +} + +#[async_trait::async_trait] +impl object_store::CredentialProvider for S3StageCredentialProvider { + type Credential = object_store::aws::AwsCredential; + async fn get_credential(&self) -> object_store::Result> { + let info = self.client.current_upload_info(&self.stage).await + .map_err(|e| { + object_store::Error::Generic { + store: "S3", + source: e.into() + } + })?; + + let mut locked = self.cached.lock().await; + + match locked.as_ref() { + Some(creds) => if creds.key_id == info.stage_info.creds.aws_key_id { + return Ok(Arc::clone(creds)); + } + _ => {} + } + + // The session token is empty when testing against minio + let token = match info.stage_info.creds.aws_token.trim() { + "" => None, + token => Some(token.to_string()) + }; + + let creds = Arc::new(object_store::aws::AwsCredential { + key_id: info.stage_info.creds.aws_key_id.clone(), + secret_key: info.stage_info.creds.aws_secret_key.clone(), + token + }); + + *locked = Some(Arc::clone(&creds)); + + Ok(creds) + } +} + + +#[repr(C)] +pub struct StageInfoResponse { + result: CResult, + stage_info: *mut c_char, + error_message: *mut c_char, + context: *const Context +} + +unsafe impl Send for StageInfoResponse {} + +impl RawResponse for StageInfoResponse { + type Payload = String; + fn result_mut(&mut self) -> &mut CResult { + &mut self.result + } + fn context_mut(&mut self) -> &mut *const Context { + &mut self.context + } + fn error_message_mut(&mut self) -> &mut *mut c_char { + &mut self.error_message + } + fn set_payload(&mut self, payload: Option) { + match payload { + Some(serialized_info) => { + let c_string = CString::new(serialized_info).expect("should not have nulls"); + self.stage_info = c_string.into_raw(); + } + None => { + self.stage_info = std::ptr::null_mut(); + } + } + } +} + +#[no_mangle] +pub extern "C" fn current_stage_info( + config: *const RawConfig, + response: *mut StageInfoResponse, + handle: *const c_void +) -> CResult { + let response = unsafe { ResponseGuard::new(response, handle) }; + let config = unsafe { & (*config) }; + + match RT.get() { + Some(runtime) => { + runtime.spawn(async move { + let info_op = async { + let client = clients() + .try_get_with(config.get_hash(), Client::from_raw_config(config)).await?; + Ok::<_, crate::Error>(client.extension.current_stage_info().await?) + }; + + with_cancellation!(info_op, response); + }); + CResult::Ok + } + None => { + response.into_error("object_store_ffi runtime not started (may be missing initialization)"); + return CResult::Error; + } + } +} + +#[derive(Clone)] +pub(crate) struct SnowflakeConfig { + pub stage: String, + pub client_config: SnowflakeClientConfig, + pub kms_config: Option +} + +pub(crate) fn validate_config_for_snowflake(map: &mut HashMap, retry_config: RetryConfig) -> crate::Result { + let mut required_or_env = |field: &str| { + map + .remove(field) + .or(std::env::var(field.to_uppercase()).ok()) + .ok_or_else(|| { + Error::required_config(field) + }) + }; + + let client_config = SnowflakeClientConfig { + account: required_or_env("snowflake_account")?, + database: required_or_env("snowflake_database")?, + endpoint: required_or_env("snowflake_endpoint") + .or(required_or_env("snowflake_host").map(|h| format!("https://{h}")))?, + schema: required_or_env("snowflake_schema")?, + warehouse: map.remove("snowflake_warehouse").or(std::env::var("SNOWFLAKE_WAREHOUSE").ok()), + username: map.remove("snowflake_username").or(std::env::var("SNOWFLAKE_USERNAME").ok()), + password: map.remove("snowflake_password").or(std::env::var("SNOWFLAKE_PASSWORD").ok()), + role: map.remove("snowflake_role").or(std::env::var("SNOWFLAKE_ROLE").ok()), + master_token_path: map.remove("snowflake_master_token_path").or(std::env::var("MASTER_TOKEN_PATH").ok()), + stage_info_cache_ttl: map.remove("snowflake_stage_info_cache_ttl_secs") + .map(|s| s.parse::()) + .transpose() + .map_err(|e| Error::invalid_config_src("Failed to parse `snowflake_stage_info_cache_ttl_secs`", e))? + .map(|n| std::time::Duration::from_secs(n)), + retry_config + }; + + let kms_config = if let Some(scheme_str) = map.remove("snowflake_encryption_scheme") { + Some(SnowflakeStageKmsConfig { + crypto_scheme: match scheme_str.as_str() { + "AES_256_GCM" => CryptoScheme::Aes256Gcm, + "AES_128_CBC" => CryptoScheme::Aes128Cbc, + _ => return Err(Error::invalid_config("Invalid value for snowflake_encryption_scheme").into()), + }, + keyring_capacity: match map.remove("snowflake_keyring_capacity").map(|s| s.parse::()) { + Some(Ok(cap)) => cap, + Some(Err(e)) => return Err(Error::invalid_config_src("Failed to parse `snowflake_keyring_capacity`", e).into()), + None => 100_000 + }, + keyring_ttl: match map.remove("snowflake_keyring_ttl_secs").map(|s| s.parse::()) { + Some(Ok(secs)) => std::time::Duration::from_secs(secs), + Some(Err(e)) => return Err(Error::invalid_config_src("Failed to parse `snowflake_keyring_ttl_secs`", e).into()), + None => std::time::Duration::from_secs(10 * 60) + } + }) + } else { + None + }; + + let config = SnowflakeConfig { + stage: map.remove("snowflake_stage") + .ok_or_else(|| Error::required_config("snowflake_stage"))?, + client_config, + kms_config + }; + + for (key, _value) in map { + if key.starts_with("snowflake") { + return Err(Error::invalid_config(format!("Unknown config `{key}` found while validating snowflake config")).into()); + } + } + + Ok(config) +} + +pub(crate) async fn build_store_for_snowflake_stage( + mut config_map: HashMap, + retry_config: RetryConfig +) -> crate::Result<( + Arc, + Option>, + String, + ClientExtension +)> { + let config = validate_config_for_snowflake(&mut config_map, retry_config.clone())?; + let client = SnowflakeClient::new(config.client_config); + let info = client.current_upload_info(&config.stage).await?; + + match info.stage_info.location_type.as_ref() { + "S3" => { + let (bucket, stage_prefix) = info.stage_info.location.split_once('/') + .ok_or_else(|| Error::invalid_response("Stage information from snowflake is missing the bucket name"))?; + + let provider = S3StageCredentialProvider::new(&config.stage, client.clone()); + let store = if let Some(test_endpoint) = info.stage_info.test_endpoint.as_deref() { + config_map.insert("allow_http".into(), "true".into()); + let mut builder = object_store::aws::AmazonS3Builder::default() + .with_region(info.stage_info.region.clone()) + .with_bucket_name(bucket) + .with_credentials(Arc::new(provider)) + .with_virtual_hosted_style_request(false) + .with_unsigned_payload(true) + .with_retry(retry_config) + .with_endpoint(test_endpoint); + + for (key, value) in config_map { + builder = builder.with_config(key.parse()?, value); + } + + builder.build()? + } else { + let mut builder = object_store::aws::AmazonS3Builder::default() + .with_region(info.stage_info.region.clone()) + .with_bucket_name(bucket) + .with_credentials(Arc::new(provider)) + .with_virtual_hosted_style_request(true) + .with_unsigned_payload(true) + .with_retry(retry_config); + + if let Some(end_point) = info.stage_info.end_point.as_deref() { + builder = builder.with_endpoint(format!("https://{bucket}.{end_point}")); + } + + for (key, value) in config_map { + builder = builder.with_config(key.parse()?, value); + } + + builder.build()? + }; + + if config.kms_config.is_some() && !info.stage_info.is_client_side_encrypted { + return Err(ErrorKind::StorageNotEncrypted(config.stage.clone()).into()); + } + + let crypto_material_provider = if info.stage_info.is_client_side_encrypted { + let kms_config = config.kms_config.unwrap_or_default(); + let stage_kms = SnowflakeStageKms::new(client.clone(), &config.stage, stage_prefix, kms_config); + Some::>(Arc::new(stage_kms)) + } else { + None + }; + + let extension = Arc::new(SnowflakeS3Extension { + stage: config.stage.clone(), + client + }); + + Ok((Arc::new(store), crypto_material_provider, stage_prefix.to_string(), extension)) + } + _ => { + unimplemented!("unknown stage location type"); + } + } +} diff --git a/src/stream.rs b/src/stream.rs index d6aded0..eacde00 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,32 +1,66 @@ -use crate::{destroy_with_runtime, static_config, CResult, Client, Context, NotifyGuard, RawConfig, RawResponse, Request, ResponseGuard, RT, SQ}; +use crate::encryption::{CrypterReader, CrypterWriter, Mode}; +use crate::{destroy_with_runtime, export_queued_op, with_retries, BoxedReader, BoxedUpload, CResult, Client, Context, NotifyGuard, RawConfig, RawResponse, Request, ResponseGuard, RT, SQ}; use crate::util::{size_to_ranges, Compression, CompressedWriter, with_decoder, cstr_to_path}; -use crate::error::RetryState; +use crate::error::{Kind as ErrorKind, RetryState}; use crate::with_cancellation; use object_store::{path::Path, ObjectStore}; use bytes::Buf; use tokio_util::io::StreamReader; -use tokio::io::{AsyncWriteExt, AsyncReadExt, AsyncRead, AsyncBufReadExt}; +use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; use std::ffi::{c_char, c_void, c_longlong}; +use std::sync::Arc; use futures_util::{StreamExt, TryStreamExt}; -pub(crate) async fn handle_get_stream(client: Client, path: &Path, size_hint: usize, compression: Compression) -> anyhow::Result<(Box, usize)> { - if size_hint > 0 && size_hint < static_config().multipart_get_threshold as usize { - // Perform a single get without the head request - let result = client.store.get(path).await?; - let full_size = result.meta.size; - let stream = result.into_stream().map_err(Into::::into).boxed(); - let reader = StreamReader::new(stream); - let decoded = with_decoder(compression, reader); - return Ok((Box::new(ReadStream { reader: tokio::io::BufReader::with_capacity(64 * 1024, decoded) }), full_size)); - } else { +impl Client { + async fn put_stream_impl(&self, path: &Path, compression: Compression) -> crate::Result { + let path = &self.full_path(path); + let part_size = self.config.multipart_put_part_size; + let concurrency = self.config.multipart_put_concurrency; + let writer: BoxedUpload = if let Some(cryptmp) = self.crypto_material_provider.as_ref() { + let (material, attrs) = cryptmp.material_for_write(path.as_ref(), None).await?; + let writer = object_store::buffered::BufWriter::with_capacity( + Arc::clone(&self.store), + path.clone(), + part_size + ) + .with_attributes(attrs) + .with_max_concurrency(concurrency); + let encrypter_writer = CrypterWriter::new(writer, Mode::Encrypt, &material) + .map_err(ErrorKind::ContentEncrypt)?; + Box::new(encrypter_writer) + } else { + Box::new( + object_store::buffered::BufWriter::with_capacity( + Arc::clone(&self.store), + path.clone(), + part_size + ) + .with_max_concurrency(concurrency) + ) + }; + + let encoded = CompressedWriter::new(compression, writer); + return Ok(Box::new(encoded)); + } + pub(crate) async fn put_stream(&self, path: &Path, compression: Compression) -> crate::Result { + with_retries!(self, self.put_stream_impl(path, compression).await) + } + async fn get_stream_impl(&self, path: &Path, _size_hint: usize, compression: Compression) -> crate::Result { + let path = &self.full_path(path); // Perform head request and prefetch parts in parallel - let meta = client.store.head(&path).await?; - let part_ranges = size_to_ranges(meta.size); + let result = self.store.get_opts( + &path, + object_store::GetOptions { + head: true, + ..Default::default() + } + ).await?; + let part_ranges = size_to_ranges(result.meta.size, self.config.multipart_get_part_size); let state = ( - client, + self.clone(), path.clone() ); let stream = futures_util::stream::iter(part_ranges) @@ -40,17 +74,17 @@ pub(crate) async fn handle_get_stream(client: Client, path: &Path, size_hint: us 'retry: loop { match client.store.get_range(&path, range.clone()).await.map_err(Into::into) { Ok(bytes) => { - return Ok::<_, anyhow::Error>(bytes) + return Ok::<_, crate::error::Error>(bytes) }, Err(e) => { - match retry_state.should_retry(&e) { - Ok((info, duration)) => { + match retry_state.should_retry(e) { + Ok((e, info, duration)) => { tracing::info!("retrying get stream error (reason: {:?}) after {:?}: {}", info.reason, duration, e); tokio::time::sleep(duration).await; continue 'retry; } - Err(report) => { - tracing::warn!("[get stream] {}", report); + Err(e) => { + tracing::warn!("[get stream] {}", e); return Err(e); } } @@ -59,26 +93,26 @@ pub(crate) async fn handle_get_stream(client: Client, path: &Path, size_hint: us } }).await?; }) - .buffered(64) + .buffered(self.config.multipart_get_concurrency) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) .boxed(); - let reader = StreamReader::new(stream); + let reader: Box = if let Some(cryptmp) = self.crypto_material_provider.as_ref() { + let material = cryptmp.material_from_metadata(path.as_ref(), &result.attributes).await?; + let decrypter_reader = CrypterReader::new(StreamReader::new(stream), Mode::Decrypt, &material) + .map_err(ErrorKind::ContentDecrypt)?; + let buffer_reader = BufReader::with_capacity(64 * 1024, decrypter_reader); + Box::new(buffer_reader) + } else { + Box::new(StreamReader::new(stream)) + }; + let decoded = with_decoder(compression, reader); - return Ok((Box::new(ReadStream { reader: tokio::io::BufReader::with_capacity(64 * 1024, decoded) }), meta.size)); + return Ok(Box::new(decoded)); + } + pub(crate) async fn get_stream(&self, path: &Path, size_hint: usize, compression: Compression) -> crate::Result { + with_retries!(self, self.get_stream_impl(path, size_hint, compression).await) } -} - -pub(crate) async fn handle_put_stream(client: Client, path: &Path, compression: Compression) -> anyhow::Result> { - let writer = object_store::buffered::BufWriter::with_capacity( - client.store, - path.clone(), - 10 * 1024 * 1024 - ) - .with_max_concurrency(64); - - let encoded = CompressedWriter::new(compression, writer); - return Ok(Box::new(WriteStream { writer: encoded, aborted: false })); } #[repr(C)] @@ -118,7 +152,13 @@ impl RawResponse for ReadResponse { } pub struct ReadStream { - reader: tokio::io::BufReader> + reader: BufReader +} + +impl From for Box { + fn from(value: BoxedReader) -> Self { + return Box::new(ReadStream { reader: BufReader::with_capacity(64 * 1024, value) }); + } } #[no_mangle] @@ -135,6 +175,7 @@ pub extern "C" fn destroy_read_stream( pub struct GetStreamResponse { result: CResult, stream: *mut ReadStream, + // TODO get rid of this field (deprecated) object_size: u64, error_message: *mut c_char, context: *const Context @@ -143,7 +184,7 @@ pub struct GetStreamResponse { unsafe impl Send for GetStreamResponse {} impl RawResponse for GetStreamResponse { - type Payload = (Box, usize); + type Payload = Box; fn result_mut(&mut self) -> &mut CResult { &mut self.result } @@ -155,9 +196,9 @@ impl RawResponse for GetStreamResponse { } fn set_payload(&mut self, payload: Option) { match payload { - Some((stream, object_size)) => { + Some(stream) => { self.stream = Box::into_raw(stream); - self.object_size = object_size as u64; + self.object_size = 0; } None => { self.stream = std::ptr::null_mut(); @@ -167,48 +208,20 @@ impl RawResponse for GetStreamResponse { } } -#[no_mangle] -pub extern "C" fn get_stream( - path: *const c_char, - size_hint: usize, - decompress: *const c_char, - config: *const RawConfig, - response: *mut GetStreamResponse, - handle: *const c_void -) -> CResult { - let response = unsafe { ResponseGuard::new(response, handle) }; - let path = unsafe { std::ffi::CStr::from_ptr(path) }; - let path = unsafe{ cstr_to_path(path) }; - let decompress = match Compression::try_from(decompress) { - Ok(c) => c, - Err(e) => { - response.into_error(e); - return CResult::Error; - } - }; - let config = unsafe { & (*config) }; - - match SQ.get() { - Some(sq) => { - match sq.try_send(Request::GetStream(path, size_hint, decompress, config, response)) { - Ok(_) => CResult::Ok, - Err(flume::TrySendError::Full(Request::GetStream(_, _, _, _, response))) => { - response.into_error("object_store_ffi internal channel full, backoff"); - CResult::Backoff - } - Err(flume::TrySendError::Disconnected(Request::GetStream(_, _, _, _, response))) => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - CResult::Error - } - _ => unreachable!("the response type must match") - } - } - None => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - return CResult::Error; - } - } -} +export_queued_op!( + get_stream, + GetStreamResponse, + |config, response| { + let path = unsafe { std::ffi::CStr::from_ptr(path) }; + let path = unsafe{ cstr_to_path(path) }; + let decompress = match Compression::try_from(decompress) { + Ok(d) => d, + Err(e) => return Err((response, e)) + }; + Ok(Request::GetStream(path, size_hint, decompress, config, response)) + }, + path: *const c_char, size_hint: usize, decompress: *const c_char +); #[no_mangle] pub extern "C" fn read_from_stream( @@ -236,10 +249,16 @@ pub extern "C" fn read_from_stream( let amount_to_read = size.min(amount); let mut bytes_read = 0; while amount_to_read > bytes_read { + let len = slice.len(); let n = wrapper.reader.read_buf(&mut slice).await?; if n == 0 { - return Ok((bytes_read, true)) + if len == 0 { + // cannot determine if it is eof + return Ok((bytes_read, false)) + } else { + return Ok((bytes_read, true)) + } } else { bytes_read += n; } @@ -355,10 +374,16 @@ impl RawResponse for WriteResponse { pub struct WriteStream { - writer: CompressedWriter, + writer: BoxedUpload, aborted: bool } +impl From for Box { + fn from(value: BoxedUpload) -> Self { + return Box::new(WriteStream { writer: value, aborted: false }); + } +} + #[no_mangle] pub extern "C" fn destroy_write_stream( writer: *mut WriteStream @@ -414,47 +439,20 @@ impl RawResponse for PutStreamResponse { // // The written data can be optionally compressed by providing one of `gzip`, `deflate`, `zlib` or // `zstd` in the `compress` argument. -#[no_mangle] -pub extern "C" fn put_stream( - path: *const c_char, - compress: *const c_char, - config: *const RawConfig, - response: *mut PutStreamResponse, - handle: *const c_void -) -> CResult { - let response = unsafe { ResponseGuard::new(response, handle) }; - let path = unsafe { std::ffi::CStr::from_ptr(path) }; - let path = unsafe{ cstr_to_path(path) }; - let compress = match Compression::try_from(compress) { - Ok(c) => c, - Err(e) => { - response.into_error(e); - return CResult::Error; - } - }; - let config = unsafe { & (*config) }; - - match SQ.get() { - Some(sq) => { - match sq.try_send(Request::PutStream(path, compress, config, response)) { - Ok(_) => CResult::Ok, - Err(flume::TrySendError::Full(Request::PutStream(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel full, backoff"); - CResult::Backoff - } - Err(flume::TrySendError::Disconnected(Request::PutStream(_, _, _, response))) => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - CResult::Error - } - _ => unreachable!("the response type must match") - } - } - None => { - response.into_error("object_store_ffi internal channel closed (may be missing initialization)"); - return CResult::Error; - } - } -} +export_queued_op!( + put_stream, + PutStreamResponse, + |config, response| { + let path = unsafe { std::ffi::CStr::from_ptr(path) }; + let path = unsafe{ cstr_to_path(path) }; + let compress = match Compression::try_from(compress) { + Ok(c) => c, + Err(e) => return Err((response, e)) + }; + Ok(Request::PutStream(path, compress, config, response)) + }, + path: *const c_char, compress: *const c_char +); // Writes bytes to the provided `WriteStream` and optionally flushes the internal buffers. // Any data written to the stream will be buffered and split into 10 MB chunks before being sent @@ -547,7 +545,7 @@ pub extern "C" fn shutdown_write_stream( runtime.spawn(async move { let shutdown_op = async { wrapper.writer.shutdown().await?; - Ok::<_, anyhow::Error>(0) + Ok::<_, anyhow::Error>(0usize) }; // Manual cancellation due to cleanup diff --git a/src/util.rs b/src/util.rs index 796805c..590c885 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,18 +1,14 @@ -use crate::static_config; - use std::ops::Range; use std::ffi::c_char; use anyhow::anyhow; use pin_project::pin_project; use tokio::io::{AsyncRead, AsyncBufRead, AsyncWrite}; -pub(crate) fn size_to_ranges(object_size: usize) -> Vec> { +pub(crate) fn size_to_ranges(object_size: usize, part_size: usize) -> Vec> { if object_size == 0 { return vec![]; } - let part_size: usize = static_config().multipart_get_part_size as usize; - // If the object size happens to be smaller than part_size, // then we will end up doing a single range get of the whole // object. @@ -60,6 +56,18 @@ impl TryFrom<*const c_char> for Compression { } } +#[async_trait::async_trait] +pub(crate) trait AsyncUpload: AsyncWrite { + async fn abort(&mut self) -> crate::Result<()>; +} + +#[async_trait::async_trait] +impl AsyncUpload for object_store::buffered::BufWriter { + async fn abort(&mut self) -> crate::Result<()> { + Ok(object_store::buffered::BufWriter::abort(self).await?) + } +} + #[derive(Debug)] #[pin_project(project = EncoderProj)] pub(crate) enum Encoder { @@ -101,8 +109,24 @@ impl CompressedWriter { } } -impl CompressedWriter { - pub(crate) async fn abort(&mut self) -> anyhow::Result<()> { +#[async_trait::async_trait] +impl AsyncUpload for CompressedWriter { + async fn abort(&mut self) -> crate::Result<()> { + let writer = match &mut self.encoder { + Encoder::None(e) => e, + Encoder::Gzip(e) => e.get_mut(), + Encoder::Deflate(e) => e.get_mut(), + Encoder::Zlib(e) => e.get_mut(), + Encoder::Zstd(e) => e.get_mut(), + }; + + Ok(writer.abort().await?) + } +} + +#[async_trait::async_trait] +impl AsyncUpload for CompressedWriter> { + async fn abort(&mut self) -> crate::Result<()> { let writer = match &mut self.encoder { Encoder::None(e) => e, Encoder::Gzip(e) => e.get_mut(), @@ -211,3 +235,28 @@ pub(crate) unsafe fn cstr_to_path(cstr: &std::ffi::CStr) -> object_store::path:: let path: object_store::path::Path = std::mem::transmute(raw_path); return path; } + +pub(crate) unsafe fn string_to_path(string: String) -> object_store::path::Path { + let raw_path = RawPath { + raw: string + }; + + let path: object_store::path::Path = std::mem::transmute(raw_path); + return path; +} + +pub(crate) fn deserialize_slice<'a, T>(v: &'a [u8]) -> Result> +where + T: serde::Deserialize<'a>, +{ + let de = &mut serde_json::Deserializer::from_slice(v); + serde_path_to_error::deserialize(de) +} + +pub(crate) fn deserialize_str<'a, T>(v: &'a str) -> Result> +where + T: serde::Deserialize<'a>, +{ + let de = &mut serde_json::Deserializer::from_str(v); + serde_path_to_error::deserialize(de) +}