diff --git a/Cargo.lock b/Cargo.lock index ef5925419..45144265c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1308,8 +1308,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1831,6 +1833,21 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "9.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ae10193d25051e74945f1ea2d0b42e03cc3b890f7e4cc5faa44997d808193f" +dependencies = [ + "base64 0.21.7", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -2232,6 +2249,7 @@ dependencies = [ "http-body 1.0.1", "hyper 0.14.31", "hyper-rustls", + "jsonwebtoken", "lz4_flex", "memory-stats", "mock_instant", @@ -2518,6 +2536,16 @@ dependencies = [ "bitflags", ] +[[package]] +name = "pem" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" +dependencies = [ + "base64 0.22.1", + "serde", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -3259,6 +3287,18 @@ dependencies = [ "libc", ] +[[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 1.0.69", + "time", +] + [[package]] name = "slab" version = "0.4.9" @@ -3423,6 +3463,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" dependencies = [ "deranged", + "itoa", "num-conv", "powerfmt", "serde", diff --git a/nativelink-config/examples/README.md b/nativelink-config/examples/README.md index 50828e65c..d702fd7bf 100644 --- a/nativelink-config/examples/README.md +++ b/nativelink-config/examples/README.md @@ -41,7 +41,7 @@ The value of `stores` includes top-level keys, which are user supplied names sto ### Store Type Once the store has been named and its object exists, -the next key is the type of store. The options are `filesystem`, `memory`, `compression`, `dedup`, `fast_slow`, `verify`, and `experimental_s3_store`. +the next key is the type of store. The options are `filesystem`, `memory`, `compression`, `dedup`, `fast_slow`, `verify`, `experimental_s3_store` and `experimental_gcs_store`. ```json5 { diff --git a/nativelink-config/examples/gcs_backend.json5 b/nativelink-config/examples/gcs_backend.json5 deleted file mode 100644 index 6bfaafd97..000000000 --- a/nativelink-config/examples/gcs_backend.json5 +++ /dev/null @@ -1,168 +0,0 @@ -{ - "stores": { - "CAS_MAIN_STORE": { - "verify": { - "backend": { - "dedup": { - "index_store": { - "fast_slow": { - "fast": { - "filesystem": { - "content_path": "/tmp/nativelink/data/content_path-index", - "temp_path": "/tmp/nativelink/data/tmp_path-index", - "eviction_policy": { - // 500mb. - "max_bytes": 500000000 - } - } - }, - "slow": { - "experimental_gcs_store": { - "project_id": "inbound-entity-447014-k2", - "bucket": "test-bucket-aman-nativelink", - "key_prefix": "test-prefix-index/", - "retry": { - "max_retries": 6, - "delay": 0.3, - "jitter": 0.5 - }, - "max_concurrent_uploads": 10 - } - } - } - }, - "content_store": { - "compression": { - "compression_algorithm": { - "lz4": {} - }, - "backend": { - "fast_slow": { - "fast": { - "filesystem": { - "content_path": "/tmp/nativelink/data/content_path-content", - "temp_path": "/tmp/nativelink/data/tmp_path-content", - "eviction_policy": { - "max_bytes": 2000000000 - } - } - }, - "slow": { - "experimental_gcs_store": { - "project_id": "inbound-entity-447014-k2", - "bucket": "test-bucket-aman-nativelink", - "key_prefix": "test-prefix-dedup-cas/", - "retry": { - "max_retries": 6, - "delay": 0.3, - "jitter": 0.5 - }, - "max_concurrent_uploads": 10 - } - } - } - } - } - } - } - }, - "verify_size": true - } - }, - "AC_MAIN_STORE": { - "fast_slow": { - "fast": { - "memory": { - "eviction_policy": { - "max_bytes": 100000000 - } - }, - "filesystem": { - "content_path": "/tmp/nativelink/data/content_path-ac", - "temp_path": "/tmp/nativelink/data/tmp_path-ac", - "eviction_policy": { - // 500mb. - "max_bytes": 500000000 - } - } - }, - "slow": { - "experimental_gcs_store": { - "project_id": "inbound-entity-447014-k2", - // Name of the bucket to upload to. - "bucket": "test-bucket-aman-nativelink", - "key_prefix": "test-prefix-ac/", - "retry": { - "max_retries": 6, - "delay": 0.3, - "jitter": 0.5 - }, - "max_concurrent_uploads": 10 - } - } - } - }, - }, - "schedulers": { - "MAIN_SCHEDULER": { - "simple": { - "supported_platform_properties": { - "cpu_count": "minimum", - "memory_kb": "minimum", - "network_kbps": "minimum", - "disk_read_iops": "minimum", - "disk_read_bps": "minimum", - "disk_write_iops": "minimum", - "disk_write_bps": "minimum", - "shm_size": "minimum", - "gpu_count": "minimum", - "gpu_model": "exact", - "cpu_vendor": "exact", - "cpu_arch": "exact", - "cpu_model": "exact", - "kernel_version": "exact", - "docker_image": "priority", - "lre-rs": "priority" - } - } - } - }, - "servers": [{ - "listener": { - "http": { - "socket_address": "0.0.0.0:50051" - } - }, - "services": { - "cas": { - "main": { - "cas_store": "CAS_MAIN_STORE" - } - }, - "ac": { - "main": { - "ac_store": "AC_MAIN_STORE" - } - }, - "execution": { - "main": { - "cas_store": "CAS_MAIN_STORE", - "scheduler": "MAIN_SCHEDULER" - } - }, - "capabilities": { - "main": { - "remote_execution": { - "scheduler": "MAIN_SCHEDULER" - } - } - }, - "bytestream": { - "cas_stores": { - "main": "CAS_MAIN_STORE" - } - }, - "health": {} - } - }] -} diff --git a/nativelink-config/src/stores.rs b/nativelink-config/src/stores.rs index 1bb7506bd..53401af82 100644 --- a/nativelink-config/src/stores.rs +++ b/nativelink-config/src/stores.rs @@ -85,7 +85,7 @@ pub enum StoreSpec { /// **Example JSON Config:** /// ```json /// "experimental_gcs_store": { - /// "project_id": "sample-project", + /// "service_email": "email@domain.com", /// "bucket": "test-bucket", /// "key_prefix": "test-prefix-index/", /// "retry": { @@ -813,7 +813,7 @@ pub struct S3Spec { pub struct GcsSpec { /// Project ID for the GCS service #[serde(default, deserialize_with = "convert_string_with_shellexpand")] - pub project_id: String, + pub service_email: String, /// Bucket name to use as the backend #[serde(default, deserialize_with = "convert_string_with_shellexpand")] diff --git a/nativelink-store/BUILD.bazel b/nativelink-store/BUILD.bazel index 56c6e1841..8ce0809b8 100644 --- a/nativelink-store/BUILD.bazel +++ b/nativelink-store/BUILD.bazel @@ -18,6 +18,10 @@ rust_library( "src/existence_cache_store.rs", "src/fast_slow_store.rs", "src/filesystem_store.rs", + "src/gcs_client/auth.rs", + "src/gcs_client/client.rs", + "src/gcs_client/grpc_client.rs", + "src/gcs_client/mod.rs", "src/gcs_store.rs", "src/grpc_store.rs", "src/lib.rs", @@ -62,6 +66,7 @@ rust_library( "@crates//:http-body", "@crates//:hyper-0.14.31", "@crates//:hyper-rustls", + "@crates//:jsonwebtoken", "@crates//:lz4_flex", "@crates//:parking_lot", "@crates//:patricia_tree", diff --git a/nativelink-store/Cargo.toml b/nativelink-store/Cargo.toml index ab98b9f84..47b86c3b8 100644 --- a/nativelink-store/Cargo.toml +++ b/nativelink-store/Cargo.toml @@ -40,6 +40,7 @@ fred = { version = "10.0.1", default-features = false, features = [ "subscriber-client", ] } googleapis-tonic-google-storage-v2 = "0.17.0" +jsonwebtoken = "9.3.0" patricia_tree = { version = "0.8.0", default-features = false } futures = { version = "0.3.31", default-features = false } hex = { version = "0.4.3", default-features = false } diff --git a/nativelink-store/src/gcs_client/auth.rs b/nativelink-store/src/gcs_client/auth.rs new file mode 100644 index 000000000..5647c800f --- /dev/null +++ b/nativelink-store/src/gcs_client/auth.rs @@ -0,0 +1,201 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use jsonwebtoken::{encode, Algorithm, EncodingKey, Header}; +use nativelink_config::stores::GcsSpec; +use nativelink_error::{make_err, Code, Error}; +use rand::Rng; +use serde::Serialize; +use tokio::sync::{Mutex, RwLock}; + +const SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform"; +const AUDIENCE: &str = "https://storage.googleapis.com/"; +const TOKEN_LIFETIME: Duration = Duration::from_secs(3600); // 1 hour +const REFRESH_WINDOW: Duration = Duration::from_secs(300); // 5 minutes +const MAX_REFRESH_ATTEMPTS: u32 = 3; +const RETRY_DELAY_BASE: Duration = Duration::from_secs(1); + +const ENV_PRIVATE_KEY: &str = "GCS_PRIVATE_KEY"; +const ENV_AUTH_TOKEN: &str = "GOOGLE_AUTH_TOKEN"; + +#[derive(Debug, Serialize)] +struct JwtClaims { + iss: String, + sub: String, + aud: String, + iat: u64, + exp: u64, + scope: String, +} + +#[derive(Clone)] +struct TokenInfo { + token: String, + refresh_at: u64, // Timestamp when token should be refreshed +} + +pub struct GcsAuth { + token_cache: RwLock>, + refresh_lock: Mutex<()>, + service_email: String, + private_key: String, +} + +impl GcsAuth { + pub async fn new(spec: &GcsSpec) -> Result { + // First try to get direct token from environment + if let Ok(token) = std::env::var(ENV_AUTH_TOKEN) { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| make_err!(Code::Internal, "Failed to get system time: {}", e))? + .as_secs(); + + return Ok(Self { + token_cache: RwLock::new(Some(TokenInfo { + token, + refresh_at: now + TOKEN_LIFETIME.as_secs() - REFRESH_WINDOW.as_secs(), + })), + refresh_lock: Mutex::new(()), + service_email: String::new(), + private_key: String::new(), + }); + } + + let service_email = spec.service_email.clone(); + + // Get private key from environment + let private_key = std::env::var(ENV_PRIVATE_KEY).map_err(|_| { + make_err!( + Code::NotFound, + "Environment variable {} not found", + ENV_PRIVATE_KEY + ) + })?; + + Ok(Self { + token_cache: RwLock::new(None), + refresh_lock: Mutex::new(()), + service_email, + private_key, + }) + } + + fn add_jitter(duration: Duration) -> Duration { + let jitter = rand::thread_rng().gen_range(-5..=5); + duration.saturating_add(Duration::from_secs_f64(f64::from(jitter) * 0.1)) + } + + async fn generate_token(&self) -> Result { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| make_err!(Code::Internal, "Failed to get system time: {}", e))? + .as_secs(); + + let expiry = now + TOKEN_LIFETIME.as_secs(); + let refresh_at = expiry - REFRESH_WINDOW.as_secs(); + + let claims = JwtClaims { + iss: self.service_email.clone(), + sub: self.service_email.clone(), + aud: AUDIENCE.to_string(), + iat: now, + exp: expiry, + scope: SCOPE.to_string(), + }; + + let header = Header::new(Algorithm::RS256); + let key = EncodingKey::from_rsa_pem(self.private_key.as_bytes()) + .map_err(|e| make_err!(Code::Internal, "Failed to create encoding key: {}", e))?; + + let token = encode(&header, &claims, &key) + .map_err(|e| make_err!(Code::Internal, "Failed to encode JWT: {}", e))?; + + Ok(TokenInfo { token, refresh_at }) + } + + async fn refresh_token(&self) -> Result { + let mut attempt = 0; + loop { + match self.generate_token().await { + Ok(token_info) => return Ok(token_info), + Err(e) => { + attempt += 1; + if attempt >= MAX_REFRESH_ATTEMPTS { + return Err(make_err!( + Code::Internal, + "Failed to refresh token after {} attempts: {}", + MAX_REFRESH_ATTEMPTS, + e + )); + } + let delay = Self::add_jitter(RETRY_DELAY_BASE * (2_u32.pow(attempt - 1))); + tokio::time::sleep(delay).await; + } + } + } + } + + pub async fn get_valid_token(&self) -> Result { + if let Some(token_info) = self.token_cache.read().await.as_ref() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| make_err!(Code::Internal, "Failed to get system time: {}", e))? + .as_secs(); + + if now < token_info.refresh_at { + return Ok(token_info.token.clone()); + } + } + + let _refresh_guard = self.refresh_lock.lock().await; + + if let Some(token_info) = self.token_cache.read().await.as_ref() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| make_err!(Code::Internal, "Failed to get system time: {}", e))? + .as_secs(); + + if now < token_info.refresh_at { + return Ok(token_info.token.clone()); + } + } + + let token_info = if self.private_key.is_empty() { + if let Ok(token) = std::env::var(ENV_AUTH_TOKEN) { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| make_err!(Code::Internal, "Failed to get system time: {}", e))? + .as_secs(); + + TokenInfo { + token, + refresh_at: now + TOKEN_LIFETIME.as_secs() - REFRESH_WINDOW.as_secs(), + } + } else { + return Err(make_err!( + Code::Unauthenticated, + "No valid authentication method available" + )); + } + } else { + self.refresh_token().await? + }; + + *self.token_cache.write().await = Some(token_info.clone()); + + Ok(token_info.token) + } +} diff --git a/nativelink-store/src/gcs_client/client.rs b/nativelink-store/src/gcs_client/client.rs new file mode 100644 index 000000000..93a2dc431 --- /dev/null +++ b/nativelink-store/src/gcs_client/client.rs @@ -0,0 +1,488 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::time::Duration; + +use futures::{StreamExt, TryStreamExt}; +use googleapis_tonic_google_storage_v2::google::storage::v2::storage_client::StorageClient; +use googleapis_tonic_google_storage_v2::google::storage::v2::{ + bidi_write_object_request, write_object_request, BidiWriteObjectRequest, ChecksummedData, + Object, ReadObjectRequest, StartResumableWriteRequest, WriteObjectRequest, WriteObjectSpec, +}; +use nativelink_config::stores::GcsSpec; +use nativelink_error::{make_err, Code, Error, ResultExt}; +use nativelink_util::buf_channel::{make_buf_channel_pair, DropCloserReadHalf}; +use nativelink_util::retry::{Retrier, RetryResult}; +use tokio::sync::RwLock; +use tokio::time::sleep; +use tonic::metadata::{MetadataMap, MetadataValue}; +use tonic::transport::{Channel, ClientTlsConfig}; +use tonic::Request; + +use crate::gcs_client::auth::GcsAuth; +use crate::gcs_client::grpc_client::{ + BidiWriteObjectStream, GcsGrpcClient, GcsGrpcClientWrapper, WriteObjectStream, +}; + +const MAX_CHUNK_SIZE: usize = 4 * 1024 * 1000; // < 4 MiB + +#[derive(Clone)] +pub struct ObjectPath { + bucket: String, + path: String, +} + +async fn create_channel(spec: &GcsSpec) -> Result { + let endpoint = match &spec.endpoint { + Some(endpoint) => { + let prefix = if spec.insecure_allow_http { + "http://" + } else { + "https://" + }; + format!("{prefix}{endpoint}") + } + None => std::env::var("GOOGLE_STORAGE_ENDPOINT") + .unwrap_or_else(|_| "https://storage.googleapis.com".to_string()), + }; + + let mut channel = Channel::from_shared(endpoint) + .map_err(|e| make_err!(Code::InvalidArgument, "Invalid GCS endpoint: {e:?}"))? + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(30)) + .tcp_nodelay(true); + + if !spec.disable_http2 { + channel = channel + .http2_adaptive_window(true) + .http2_keep_alive_interval(Duration::from_secs(30)); + } + + if !spec.insecure_allow_http { + channel = channel + .tls_config(ClientTlsConfig::new().with_native_roots()) + .map_err(|e| make_err!(Code::InvalidArgument, "Failed to configure TLS: {e:?}"))?; + } + + channel + .connect() + .await + .map_err(|e| make_err!(Code::Unavailable, "Failed to connect to GCS: {e:?}")) +} + +impl ObjectPath { + pub fn new(bucket: String, path: &str) -> Self { + let normalized_path = path.replace('\\', "/").trim_start_matches('/').to_string(); + Self { + bucket, + path: normalized_path, + } + } + + pub fn get_formatted_bucket(&self) -> String { + format!("projects/_/buckets/{}", self.bucket) + } +} + +#[derive(Clone)] +pub struct GcsClient { + client: Arc>, + auth: Arc, + retrier: Arc, +} + +impl GcsClient { + pub async fn new_with_client( + client: impl GcsGrpcClient + 'static, + spec: &GcsSpec, + jitter_fn: Arc Duration + Send + Sync>, + ) -> Result { + Ok(Self { + client: Arc::new(RwLock::new(GcsGrpcClientWrapper::new(client))), + auth: Arc::new(GcsAuth::new(spec).await?), + retrier: Arc::new(Retrier::new( + Arc::new(|duration| Box::pin(sleep(duration))), + jitter_fn, + spec.retry.clone(), + )), + }) + } + + pub async fn new( + spec: &GcsSpec, + jitter_fn: Arc Duration + Send + Sync>, + ) -> Result { + let channel = create_channel(spec).await?; + let storage_client = StorageClient::new(channel); + Self::new_with_client(storage_client, spec, jitter_fn).await + } + + async fn add_auth_and_common_headers( + &self, + metadata: &mut MetadataMap, + object: ObjectPath, + ) -> Result<(), Error> { + // Add authorization header + let token = self.auth.get_valid_token().await?; + metadata.insert( + "authorization", + MetadataValue::try_from(&format!("Bearer {token}")).unwrap(), + ); + + // Add bucket parameter. This is required for all requests + let bucket = object.get_formatted_bucket(); + let encoded_bucket = urlencoding::encode(&bucket); + let params = format!("bucket={encoded_bucket}"); + + metadata.insert( + "x-goog-request-params", + MetadataValue::try_from(¶ms).unwrap(), + ); + + Ok(()) + } + + async fn prepare_request(&self, request: T, object: ObjectPath) -> Request { + let mut request = Request::new(request); + self.add_auth_and_common_headers(request.metadata_mut(), object) + .await + .expect("Failed to add headers"); + request + } + + fn create_write_spec(&self, object: &ObjectPath, size: i64) -> WriteObjectSpec { + WriteObjectSpec { + resource: Some(Object { + name: object.path.clone(), + bucket: object.get_formatted_bucket(), + size, + content_type: "application/octet-stream".to_string(), + ..Default::default() + }), + object_size: Some(size), + ..Default::default() + } + } + + pub async fn simple_upload( + &self, + object: ObjectPath, + reader: DropCloserReadHalf, + size: i64, + ) -> Result<(), Error> { + let retrier = self.retrier.clone(); + let client = self.client.clone(); + let object_clone = object.clone(); + let self_clone = self.clone(); + let write_spec = self.create_write_spec(&object, size); + + // Create a stream that will yield our operation result + let operation_stream = futures::stream::unfold( + (client, object_clone, self_clone, write_spec, reader), + move |(client, object, self_ref, write_spec, mut reader)| { + async move { + let (mut tx, mut rx) = make_buf_channel_pair(); + + let attempt_result = async { + let (upload_res, bind_res) = tokio::join!( + async { + let mut client_guard = client.write().await; + let mut buffer = Vec::with_capacity(size as usize); + while let Ok(Some(chunk)) = rx.try_next().await { + buffer.extend_from_slice(&chunk); + } + let crc32c = crc32c::crc32c(&buffer); + + let init_request = WriteObjectRequest { + first_message: Some( + write_object_request::FirstMessage::WriteObjectSpec( + write_spec.clone(), + ), + ), + write_offset: 0, + data: None, + finish_write: false, + ..Default::default() + }; + + let data_request = WriteObjectRequest { + first_message: None, + write_offset: 0, + data: Some(write_object_request::Data::ChecksummedData( + ChecksummedData { + content: buffer, + crc32c: Some(crc32c), + }, + )), + finish_write: true, + ..Default::default() + }; + + let request_stream = Box::pin(futures::stream::iter(vec![ + init_request, + data_request, + ])) + as WriteObjectStream; + let mut request = Request::new(request_stream); + + self_ref + .add_auth_and_common_headers( + request.metadata_mut(), + object.clone(), + ) + .await?; + + client_guard + .handle_request(|client| Box::pin(client.write_object(request))) + .await + }, + async { tx.bind_buffered(&mut reader).await } + ); + + match (upload_res, bind_res) { + (Ok(_), Ok(())) => Ok(()), + (Err(e), _) | (_, Err(e)) => Err(e), + } + } + .await; + + // Return both the result and the state for potential next retry + Some(( + RetryResult::Ok(attempt_result), + (client, object, self_ref, write_spec, reader), + )) + } + }, + ); + + retrier.retry(operation_stream).await? + } + + pub async fn resumable_upload( + &self, + object: ObjectPath, + reader: DropCloserReadHalf, + size: i64, + ) -> Result<(), Error> { + let retrier = self.retrier.clone(); + let client = self.client.clone(); + let object_clone = object.clone(); + let self_clone = self.clone(); + let write_spec = self.create_write_spec(&object, size); + + let operation_stream = futures::stream::unfold( + (client, object_clone, self_clone, write_spec, reader), + move |(client, object, self_ref, write_spec, mut reader)| async move { + let attempt_result = async { + let mut client_guard = client.write().await; + let start_request = StartResumableWriteRequest { + write_object_spec: Some(write_spec.clone()), + common_object_request_params: None, + object_checksums: None, + }; + + let request = self_ref + .prepare_request(start_request, object.clone()) + .await; + let response = client_guard + .handle_request(|client| Box::pin(client.start_resumable_write(request))) + .await?; + + let upload_id = response.into_inner().upload_id; + + let mut requests = Vec::new(); + requests.push(BidiWriteObjectRequest { + first_message: Some(bidi_write_object_request::FirstMessage::UploadId( + upload_id, + )), + write_offset: 0, + finish_write: false, + data: None, + ..Default::default() + }); + + let mut offset = 0; + while offset < size { + let chunk_size = std::cmp::min(MAX_CHUNK_SIZE, (size - offset) as usize); + + let chunk = reader + .consume(Some(chunk_size)) + .await + .err_tip(|| "Failed to read chunk")?; + + if chunk.is_empty() { + break; + } + + let chunk_len = chunk.len(); + let crc32c = crc32c::crc32c(&chunk); + let is_last = offset + (chunk_len as i64) >= size; + + requests.push(BidiWriteObjectRequest { + first_message: None, + write_offset: offset, + data: Some(bidi_write_object_request::Data::ChecksummedData( + ChecksummedData { + content: chunk.to_vec(), + crc32c: Some(crc32c), + }, + )), + finish_write: is_last, + ..Default::default() + }); + + offset += chunk_len as i64; + } + + let request_stream = + Box::pin(futures::stream::iter(requests)) as BidiWriteObjectStream; + let mut request = Request::new(request_stream); + + self_ref + .add_auth_and_common_headers(request.metadata_mut(), object.clone()) + .await?; + + client_guard + .handle_request(|client| Box::pin(client.bidi_write_object(request))) + .await?; + + Ok(()) + } + .await; + + Some(( + RetryResult::Ok(attempt_result), + (client, object, self_ref, write_spec, reader), + )) + }, + ); + + retrier.retry(operation_stream).await? + } + + async fn read_object( + &self, + object: ObjectPath, + read_offset: Option, + read_limit: Option, + metadata_only: bool, + ) -> Result, Vec)>, Error> { + let retrier = self.retrier.clone(); + let client = self.client.clone(); + let object_clone = object.clone(); + let self_clone = self.clone(); + + let operation_stream = futures::stream::unfold( + (client, object_clone, self_clone), + move |(client, object, self_ref)| { + let read_offset = read_offset; + let read_limit = read_limit; + let metadata_only = metadata_only; + + async move { + let attempt_result = async { + let mut client_guard = client.write().await; + let request = ReadObjectRequest { + bucket: object.get_formatted_bucket(), + object: object.path.clone(), + read_offset: read_offset.unwrap_or(0), + read_limit: read_limit.unwrap_or(0), + ..Default::default() + }; + + let auth_request = self_ref.prepare_request(request, object.clone()).await; + let response = client_guard + .handle_request(|client| { + let future = client.read_object(auth_request); + Box::pin(future) + }) + .await; + + match response { + Ok(response) => { + let mut content = Vec::new(); + let mut metadata = None; + let mut stream = response.into_inner(); + + if let Some(Ok(first_message)) = stream.next().await { + metadata = first_message.metadata; + if metadata_only { + return Ok(Some((metadata, content))); + } + } + + while let Some(chunk) = stream.next().await { + match chunk { + Ok(data) => { + if let Some(checksummed_data) = data.checksummed_data { + content.extend(checksummed_data.content); + } + } + Err(e) => { + return Err(make_err!( + Code::Unavailable, + "Error reading object data: {e:?}" + )); + } + } + } + + Ok(Some((metadata, content))) + } + Err(e) => match e.code { + Code::NotFound => Ok(None), + _ => Err(make_err!( + Code::Unavailable, + "Failed to read object: {e:?}" + )), + }, + } + } + .await; + + Some((RetryResult::Ok(attempt_result), (client, object, self_ref))) + } + }, + ); + + retrier.retry(operation_stream).await? + } + + pub async fn read_object_metadata(&self, object: ObjectPath) -> Result, Error> { + Ok(self + .read_object(object, None, None, true) + .await? + .and_then(|(metadata, _)| metadata)) + } + + pub async fn read_object_content( + &self, + object: ObjectPath, + start: i64, + end: Option, + ) -> Result, Error> { + match self + .read_object(object.clone(), Some(start), end.map(|e| e - start), false) + .await? + { + Some((_, content)) => Ok(content), + None => Err(make_err!( + Code::NotFound, + "Object not found: {}", + object.path + )), + } + } +} diff --git a/nativelink-store/src/gcs_client/grpc_client.rs b/nativelink-store/src/gcs_client/grpc_client.rs new file mode 100644 index 000000000..86d7ce5ca --- /dev/null +++ b/nativelink-store/src/gcs_client/grpc_client.rs @@ -0,0 +1,154 @@ +// gcs_client/grpc_client.rs +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use futures::Stream; +use googleapis_tonic_google_storage_v2::google::storage::v2::storage_client::StorageClient; +use googleapis_tonic_google_storage_v2::google::storage::v2::{ + BidiWriteObjectRequest, BidiWriteObjectResponse, ReadObjectRequest, ReadObjectResponse, + StartResumableWriteRequest, StartResumableWriteResponse, WriteObjectRequest, + WriteObjectResponse, +}; +use nativelink_error::{make_err, Code, Error}; +use tonic::transport::Channel; +use tonic::{Request, Response, Status, Streaming}; + +pub type WriteObjectStream = Pin + Send + 'static>>; +pub type BidiWriteObjectStream = + Pin + Send + 'static>>; + +struct WriteObjectFuture<'a> { + client: &'a mut StorageClient, + request: Option>, +} + +impl Future for WriteObjectFuture<'_> { + type Output = Result, Status>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Safety: we're not moving any pinned fields + let this = unsafe { self.get_unchecked_mut() }; + + if let Some(request) = this.request.take() { + let fut = StorageClient::write_object(this.client, request); + futures::pin_mut!(fut); + fut.poll(cx) + } else { + Poll::Ready(Err(Status::internal("Request already taken"))) + } + } +} + +struct BidiWriteObjectFuture<'a> { + client: &'a mut StorageClient, + request: Option>, +} + +impl Future for BidiWriteObjectFuture<'_> { + type Output = Result>, Status>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Safety: we're not moving any pinned fields + let this = unsafe { self.get_unchecked_mut() }; + + if let Some(request) = this.request.take() { + let fut = StorageClient::bidi_write_object(this.client, request); + futures::pin_mut!(fut); + fut.poll(cx) + } else { + Poll::Ready(Err(Status::internal("Request already taken"))) + } + } +} + +#[async_trait] +pub trait GcsGrpcClient: Send + Sync + 'static { + async fn read_object( + &mut self, + request: Request, + ) -> Result>, Status>; + + async fn write_object( + &mut self, + request: Request, + ) -> Result, Status>; + + async fn start_resumable_write( + &mut self, + request: Request, + ) -> Result, Status>; + + async fn bidi_write_object( + &mut self, + request: Request, + ) -> Result>, Status>; +} + +#[async_trait] +impl GcsGrpcClient for StorageClient { + async fn read_object( + &mut self, + request: Request, + ) -> Result>, Status> { + StorageClient::read_object(self, request).await + } + + async fn write_object( + &mut self, + request: Request, + ) -> Result, Status> { + WriteObjectFuture { + client: self, + request: Some(request), + } + .await + } + + async fn start_resumable_write( + &mut self, + request: Request, + ) -> Result, Status> { + StorageClient::start_resumable_write(self, request).await + } + + async fn bidi_write_object( + &mut self, + request: Request, + ) -> Result>, Status> { + BidiWriteObjectFuture { + client: self, + request: Some(request), + } + .await + } +} + +// Client wrapper for connection management and configuration +pub struct GcsGrpcClientWrapper { + inner: Arc>, +} + +impl GcsGrpcClientWrapper { + pub fn new(client: T) -> Self { + Self { + inner: Arc::new(tokio::sync::Mutex::new(client)), + } + } + + pub async fn handle_request( + &mut self, + operation: impl FnOnce( + &mut dyn GcsGrpcClient, + ) + -> Pin, Status>> + Send + '_>>, + ) -> Result, Error> { + let mut guard = self.inner.lock().await; + operation(&mut *guard).await.map_err(|e| match e.code() { + tonic::Code::NotFound => make_err!(Code::NotFound, "Resource not found: {}", e), + _ => make_err!(Code::Unavailable, "Operation failed: {}", e), + }) + } +} diff --git a/nativelink-store/src/gcs_client/mod.rs b/nativelink-store/src/gcs_client/mod.rs new file mode 100644 index 000000000..fcf5e94af --- /dev/null +++ b/nativelink-store/src/gcs_client/mod.rs @@ -0,0 +1,3 @@ +pub mod auth; +pub mod client; +pub mod grpc_client; diff --git a/nativelink-store/src/gcs_store.rs b/nativelink-store/src/gcs_store.rs index c1de550a6..a45204154 100644 --- a/nativelink-store/src/gcs_store.rs +++ b/nativelink-store/src/gcs_store.rs @@ -1,21 +1,26 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::borrow::Cow; -use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; use futures::stream::unfold; -use futures::{Stream, StreamExt, TryStreamExt}; -use googleapis_tonic_google_storage_v2::google::storage::v2::storage_client::StorageClient; -use googleapis_tonic_google_storage_v2::google::storage::v2::{ - bidi_write_object_request, write_object_request, BidiWriteObjectRequest, - BidiWriteObjectResponse, ChecksummedData, Object, ReadObjectRequest, ReadObjectResponse, - StartResumableWriteRequest, StartResumableWriteResponse, WriteObjectRequest, - WriteObjectResponse, WriteObjectSpec, -}; +use futures::TryStreamExt; use nativelink_config::stores::GcsSpec; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_metric::MetricsComponent; @@ -28,14 +33,11 @@ use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::store_trait::{StoreDriver, StoreKey, UploadSizeInfo}; use rand::rngs::OsRng; use rand::Rng; -use tokio::sync::RwLock; -use tokio::time::{sleep, Instant}; -use tonic::metadata::{MetadataMap, MetadataValue}; -use tonic::transport::Channel; -use tonic::{Request, Response, Status, Streaming}; +use tokio::time::sleep; use tracing::{error, info}; use crate::cas_utils::is_zero_digest; +use crate::gcs_client::client::{GcsClient, ObjectPath}; // Constants for GCS operations // Unlike what is specified in the docs, there is a slight discrepancy between @@ -49,617 +51,6 @@ const MAX_CHUNK_SIZE: usize = 4 * 1024 * 1000; // < 4 MiB const DEFAULT_MAX_CONCURRENT_UPLOADS: usize = 10; const DEFAULT_MAX_RETRY_BUFFER_SIZE: usize = 4 * 1024 * 1000; // < MiB -#[derive(Clone)] -pub struct ObjectPath { - bucket: String, - path: String, -} - -impl ObjectPath { - fn new(bucket: String, path: &str) -> Self { - let normalized_path = path.replace('\\', "/").trim_start_matches('/').to_string(); - Self { - bucket, - path: normalized_path, - } - } - - fn get_formatted_bucket(&self) -> String { - format!("projects/_/buckets/{}", self.bucket) - } -} - -pub struct GcsAuth { - token: RwLock<(String, Instant)>, -} - -impl GcsAuth { - async fn new() -> Result { - let token = Self::fetch_token().await?; - Ok(Self { - token: RwLock::new((token, Instant::now() + Duration::from_secs(3600))), - }) - } - - async fn fetch_token() -> Result { - std::env::var("GOOGLE_AUTH_TOKEN").map_err(|_| { - make_err!( - Code::Unauthenticated, - "GOOGLE_AUTH_TOKEN environment variable not found" - ) - }) - } - - async fn get_valid_token(&self) -> Result { - let token_guard = self.token.read().await; - if Instant::now() < token_guard.1 { - return Ok(token_guard.0.clone()); - } - drop(token_guard); - - let mut token_guard = self.token.write().await; - if Instant::now() >= token_guard.1 { - token_guard.0 = Self::fetch_token().await?; - token_guard.1 = Instant::now() + Duration::from_secs(3600); - } - Ok(token_guard.0.clone()) - } -} - -// Define concrete stream types -pub type WriteObjectStream = Pin + Send + 'static>>; -pub type BidiWriteObjectStream = - Pin + Send + 'static>>; - -struct WriteObjectFuture<'a> { - client: &'a mut StorageClient, - request: Option>, -} - -impl Future for WriteObjectFuture<'_> { - type Output = Result, Status>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // Safety: we're not moving any pinned fields - let this = unsafe { self.get_unchecked_mut() }; - - if let Some(request) = this.request.take() { - let fut = StorageClient::write_object(this.client, request); - futures::pin_mut!(fut); - fut.poll(cx) - } else { - Poll::Ready(Err(Status::internal("Request already taken"))) - } - } -} - -struct BidiWriteObjectFuture<'a> { - client: &'a mut StorageClient, - request: Option>, -} - -impl Future for BidiWriteObjectFuture<'_> { - type Output = Result>, Status>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // Safety: we're not moving any pinned fields - let this = unsafe { self.get_unchecked_mut() }; - - if let Some(request) = this.request.take() { - let fut = StorageClient::bidi_write_object(this.client, request); - futures::pin_mut!(fut); - fut.poll(cx) - } else { - Poll::Ready(Err(Status::internal("Request already taken"))) - } - } -} - -// First, let's modify the trait definition to be explicit about lifetimes -#[async_trait::async_trait] -pub trait StorageOperations: Send + Sync + 'static { - async fn read_object( - &mut self, - request: Request, - ) -> Result>, Status>; - - async fn write_object<'a>( - &'a mut self, - request: Request, - ) -> Result, Status>; - - async fn start_resumable_write( - &mut self, - request: Request, - ) -> Result, Status>; - - async fn bidi_write_object<'a>( - &'a mut self, - request: Request, - ) -> Result>, Status>; -} - -// Then implement it with the same lifetime bounds -#[async_trait::async_trait] -impl StorageOperations for StorageClient { - async fn read_object( - &mut self, - request: Request, - ) -> Result>, Status> { - StorageClient::read_object(self, request).await - } - - async fn write_object( - &mut self, - request: Request, - ) -> Result, Status> { - WriteObjectFuture { - client: self, - request: Some(request), - } - .await - } - - async fn start_resumable_write( - &mut self, - request: Request, - ) -> Result, Status> { - StorageClient::start_resumable_write(self, request).await - } - - async fn bidi_write_object( - &mut self, - request: Request, - ) -> Result>, Status> { - BidiWriteObjectFuture { - client: self, - request: Some(request), - } - .await - } -} - -/// Wrapper for GCS API client operations -#[derive(Clone)] -pub struct GcsClientWrapper { - inner: Arc>, -} - -impl GcsClientWrapper { - fn new(client: T) -> Self { - Self { - inner: Arc::new(tokio::sync::Mutex::new(client)), - } - } - - // Modify handle_request to work with the mutex - async fn handle_request( - &mut self, - operation: impl FnOnce( - &mut dyn StorageOperations, - ) - -> Pin, Status>> + Send + '_>>, - ) -> Result, Error> { - let mut guard = self.inner.lock().await; - operation(&mut *guard).await.map_err(|e| match e.code() { - tonic::Code::NotFound => make_err!(Code::NotFound, "Resource not found: {}", e), - _ => make_err!(Code::Unavailable, "Operation failed: {}", e), - }) - } -} - -/// Client for Google Cloud Storage operations -#[derive(Clone)] -pub struct GcsClient { - pub client: Arc>, - pub auth: Arc, - pub retrier: Arc, -} - -impl GcsClient { - pub async fn new_with_client( - client: impl StorageOperations + 'static, - spec: &GcsSpec, - jitter_fn: Arc Duration + Send + Sync>, - ) -> Result { - Ok(Self { - client: Arc::new(RwLock::new(GcsClientWrapper::new(client))), - auth: Arc::new(GcsAuth::new().await?), - retrier: Arc::new(Retrier::new( - Arc::new(|duration| Box::pin(sleep(duration))), - jitter_fn, - spec.retry.clone(), - )), - }) - } - - async fn new( - spec: &GcsSpec, - jitter_fn: Arc Duration + Send + Sync>, - ) -> Result { - let endpoint = std::env::var("GOOGLE_STORAGE_ENDPOINT") - .unwrap_or_else(|_| "https://storage.googleapis.com".to_string()); - - let channel = Channel::from_shared(endpoint) - .map_err(|e| make_err!(Code::InvalidArgument, "Invalid GCS endpoint: {e:?}"))?; - - // Configure channel... - let channel = channel - .connect() - .await - .map_err(|e| make_err!(Code::Unavailable, "Failed to connect to GCS: {e:?}"))?; - - let storage_client = StorageClient::new(channel); - - Ok(Self { - client: Arc::new(RwLock::new(GcsClientWrapper::new(storage_client))), - auth: Arc::new(GcsAuth::new().await?), - retrier: Arc::new(Retrier::new( - Arc::new(|duration| Box::pin(sleep(duration))), - jitter_fn, - spec.retry.clone(), - )), - }) - } - - async fn add_auth_and_common_headers( - &self, - metadata: &mut MetadataMap, - object: ObjectPath, - ) -> Result<(), Error> { - // Add authorization header - let token = self.auth.get_valid_token().await?; - metadata.insert( - "authorization", - MetadataValue::try_from(&format!("Bearer {token}")).unwrap(), - ); - - // Add bucket parameter. This is required for all requests - let bucket = object.get_formatted_bucket(); - let encoded_bucket = urlencoding::encode(&bucket); - let params = format!("bucket={encoded_bucket}"); - - metadata.insert( - "x-goog-request-params", - MetadataValue::try_from(¶ms).unwrap(), - ); - - Ok(()) - } - - async fn prepare_request(&self, request: T, object: ObjectPath) -> Request { - let mut request = Request::new(request); - self.add_auth_and_common_headers(request.metadata_mut(), object) - .await - .expect("Failed to add headers"); - request - } - - fn create_write_spec(&self, object: &ObjectPath, size: i64) -> WriteObjectSpec { - WriteObjectSpec { - resource: Some(Object { - name: object.path.clone(), - bucket: object.get_formatted_bucket(), - size, - content_type: "application/octet-stream".to_string(), - ..Default::default() - }), - object_size: Some(size), - ..Default::default() - } - } - - async fn simple_upload( - &self, - object: ObjectPath, - reader: DropCloserReadHalf, - size: i64, - ) -> Result<(), Error> { - let retrier = self.retrier.clone(); - let client = self.client.clone(); - let object_clone = object.clone(); - let self_clone = self.clone(); - let write_spec = self.create_write_spec(&object, size); - - // Create a stream that will yield our operation result - let operation_stream = futures::stream::unfold( - (client, object_clone, self_clone, write_spec, reader), - move |(client, object, self_ref, write_spec, mut reader)| { - async move { - let (mut tx, mut rx) = make_buf_channel_pair(); - - let attempt_result = async { - let (upload_res, bind_res) = tokio::join!( - async { - let mut client_guard = client.write().await; - let mut buffer = Vec::with_capacity(size as usize); - while let Ok(Some(chunk)) = rx.try_next().await { - buffer.extend_from_slice(&chunk); - } - let crc32c = crc32c::crc32c(&buffer); - - let init_request = WriteObjectRequest { - first_message: Some( - write_object_request::FirstMessage::WriteObjectSpec( - write_spec.clone(), - ), - ), - write_offset: 0, - data: None, - finish_write: false, - ..Default::default() - }; - - let data_request = WriteObjectRequest { - first_message: None, - write_offset: 0, - data: Some(write_object_request::Data::ChecksummedData( - ChecksummedData { - content: buffer, - crc32c: Some(crc32c), - }, - )), - finish_write: true, - ..Default::default() - }; - - let request_stream = Box::pin(futures::stream::iter(vec![ - init_request, - data_request, - ])) - as WriteObjectStream; - let mut request = Request::new(request_stream); - - self_ref - .add_auth_and_common_headers( - request.metadata_mut(), - object.clone(), - ) - .await?; - - client_guard - .handle_request(|client| Box::pin(client.write_object(request))) - .await - }, - async { tx.bind_buffered(&mut reader).await } - ); - - match (upload_res, bind_res) { - (Ok(_), Ok(())) => Ok(()), - (Err(e), _) | (_, Err(e)) => Err(e), - } - } - .await; - - // Return both the result and the state for potential next retry - Some(( - RetryResult::Ok(attempt_result), - (client, object, self_ref, write_spec, reader), - )) - } - }, - ); - - retrier.retry(operation_stream).await? - } - - async fn resumable_upload( - &self, - object: ObjectPath, - reader: DropCloserReadHalf, - size: i64, - ) -> Result<(), Error> { - let retrier = self.retrier.clone(); - let client = self.client.clone(); - let object_clone = object.clone(); - let self_clone = self.clone(); - let write_spec = self.create_write_spec(&object, size); - - let operation_stream = futures::stream::unfold( - (client, object_clone, self_clone, write_spec, reader), - move |(client, object, self_ref, write_spec, mut reader)| async move { - let attempt_result = async { - let mut client_guard = client.write().await; - let start_request = StartResumableWriteRequest { - write_object_spec: Some(write_spec.clone()), - common_object_request_params: None, - object_checksums: None, - }; - - let request = self_ref - .prepare_request(start_request, object.clone()) - .await; - let response = client_guard - .handle_request(|client| Box::pin(client.start_resumable_write(request))) - .await?; - - let upload_id = response.into_inner().upload_id; - - let mut requests = Vec::new(); - requests.push(BidiWriteObjectRequest { - first_message: Some(bidi_write_object_request::FirstMessage::UploadId( - upload_id, - )), - write_offset: 0, - finish_write: false, - data: None, - ..Default::default() - }); - - let mut offset = 0; - while offset < size { - let chunk_size = std::cmp::min(MAX_CHUNK_SIZE, (size - offset) as usize); - - let chunk = reader - .consume(Some(chunk_size)) - .await - .err_tip(|| "Failed to read chunk")?; - - if chunk.is_empty() { - break; - } - - let chunk_len = chunk.len(); - let crc32c = crc32c::crc32c(&chunk); - let is_last = offset + (chunk_len as i64) >= size; - - requests.push(BidiWriteObjectRequest { - first_message: None, - write_offset: offset, - data: Some(bidi_write_object_request::Data::ChecksummedData( - ChecksummedData { - content: chunk.to_vec(), - crc32c: Some(crc32c), - }, - )), - finish_write: is_last, - ..Default::default() - }); - - offset += chunk_len as i64; - } - - let request_stream = - Box::pin(futures::stream::iter(requests)) as BidiWriteObjectStream; - let mut request = Request::new(request_stream); - - self_ref - .add_auth_and_common_headers(request.metadata_mut(), object.clone()) - .await?; - - client_guard - .handle_request(|client| Box::pin(client.bidi_write_object(request))) - .await?; - - Ok(()) - } - .await; - - Some(( - RetryResult::Ok(attempt_result), - (client, object, self_ref, write_spec, reader), - )) - }, - ); - - retrier.retry(operation_stream).await? - } - - async fn read_object( - &self, - object: ObjectPath, - read_offset: Option, - read_limit: Option, - metadata_only: bool, - ) -> Result, Vec)>, Error> { - let retrier = self.retrier.clone(); - let client = self.client.clone(); - let object_clone = object.clone(); - let self_clone = self.clone(); - - let operation_stream = futures::stream::unfold( - (client, object_clone, self_clone), - move |(client, object, self_ref)| { - let read_offset = read_offset; - let read_limit = read_limit; - let metadata_only = metadata_only; - - async move { - let attempt_result = async { - let mut client_guard = client.write().await; - let request = ReadObjectRequest { - bucket: object.get_formatted_bucket(), - object: object.path.clone(), - read_offset: read_offset.unwrap_or(0), - read_limit: read_limit.unwrap_or(0), - ..Default::default() - }; - - let auth_request = self_ref.prepare_request(request, object.clone()).await; - let response = client_guard - .handle_request(|client| { - let future = client.read_object(auth_request); - Box::pin(future) - }) - .await; - - match response { - Ok(response) => { - let mut content = Vec::new(); - let mut metadata = None; - let mut stream = response.into_inner(); - - if let Some(Ok(first_message)) = stream.next().await { - metadata = first_message.metadata; - if metadata_only { - return Ok(Some((metadata, content))); - } - } - - while let Some(chunk) = stream.next().await { - match chunk { - Ok(data) => { - if let Some(checksummed_data) = data.checksummed_data { - content.extend(checksummed_data.content); - } - } - Err(e) => { - return Err(make_err!( - Code::Unavailable, - "Error reading object data: {e:?}" - )); - } - } - } - - Ok(Some((metadata, content))) - } - Err(e) => match e.code { - Code::NotFound => Ok(None), - _ => Err(make_err!( - Code::Unavailable, - "Failed to read object: {e:?}" - )), - }, - } - } - .await; - - Some((RetryResult::Ok(attempt_result), (client, object, self_ref))) - } - }, - ); - - retrier.retry(operation_stream).await? - } - - async fn read_object_metadata(&self, object: ObjectPath) -> Result, Error> { - Ok(self - .read_object(object, None, None, true) - .await? - .and_then(|(metadata, _)| metadata)) - } - - async fn read_object_content( - &self, - object: ObjectPath, - start: i64, - end: Option, - ) -> Result, Error> { - match self - .read_object(object.clone(), Some(start), end.map(|e| e - start), false) - .await? - { - Some((_, content)) => Ok(content), - None => Err(make_err!( - Code::NotFound, - "Object not found: {}", - object.path - )), - } - } -} - /// The main Google Cloud Storage implementation. #[derive(MetricsComponent)] pub struct GcsStore { diff --git a/nativelink-store/src/lib.rs b/nativelink-store/src/lib.rs index 49b0105d0..e073d6a2f 100644 --- a/nativelink-store/src/lib.rs +++ b/nativelink-store/src/lib.rs @@ -21,6 +21,7 @@ pub mod default_store_factory; pub mod existence_cache_store; pub mod fast_slow_store; pub mod filesystem_store; +pub mod gcs_client; pub mod gcs_store; pub mod grpc_store; pub mod memory_store; diff --git a/nativelink-store/tests/gcs_store_test.rs b/nativelink-store/tests/gcs_store_test.rs index 497e0907b..5b4d25c02 100644 --- a/nativelink-store/tests/gcs_store_test.rs +++ b/nativelink-store/tests/gcs_store_test.rs @@ -1,38 +1,64 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{BufMut, BytesMut}; use futures::Stream; -use http::StatusCode; +use gcs_mock::{MockGcsClient, TestRequest, TestResponse}; +use googleapis_tonic_google_storage_v2::google::storage::v2::{ + bidi_write_object_request, write_object_request, BidiWriteObjectRequest, ChecksummedData, + Object, ReadObjectRequest, StartResumableWriteRequest, WriteObjectRequest, WriteObjectSpec, +}; use nativelink_config::stores::GcsSpec; use nativelink_error::Error; -use nativelink_store::gcs_store::{GcsClient, GcsStore, StorageOperations}; +use nativelink_macro::nativelink_test; +use nativelink_store::gcs_client::client::GcsClient; +use nativelink_store::gcs_client::grpc_client::GcsGrpcClient; +use nativelink_store::gcs_store::GcsStore; +use nativelink_util::buf_channel::make_buf_channel_pair; +use nativelink_util::common::DigestInfo; use nativelink_util::instant_wrapper::MockInstantWrapped; +use nativelink_util::origin_context::OriginContext; +use nativelink_util::spawn; use nativelink_util::store_trait::{StoreKey, StoreLike, UploadSizeInfo}; -use tonic::Status; +use sha2::Digest; +use tonic::{Request, Status}; -pub mod test_utils { +mod gcs_mock { + use std::pin::Pin; use std::str::FromStr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; + use std::task::{Context, Poll}; use async_trait::async_trait; + use bytes::{Bytes, BytesMut}; use googleapis_tonic_google_storage_v2::google::storage::v2::{ bidi_write_object_request, write_object_request, BidiWriteObjectResponse, ChecksummedData, Object, ReadObjectRequest, ReadObjectResponse, StartResumableWriteRequest, StartResumableWriteResponse, WriteObjectResponse, }; - use nativelink_store::gcs_store::{ - BidiWriteObjectStream, StorageOperations, WriteObjectStream, + use http::StatusCode; + use nativelink_store::gcs_client::grpc_client::{ + BidiWriteObjectStream, GcsGrpcClient, WriteObjectStream, }; use tonic::metadata::{MetadataMap, MetadataValue}; - use tonic::{Request, Response, Streaming}; - - use super::*; - - // ===== Mock Response Types ===== + use tonic::{Request, Response, Status, Streaming}; + #[allow(dead_code)] #[derive(Debug)] pub struct TestResponse { pub status: Status, @@ -85,6 +111,7 @@ pub mod test_utils { self } + #[allow(dead_code)] #[must_use] pub fn with_status(status: Status) -> Self { Self { @@ -96,6 +123,7 @@ pub mod test_utils { } } + #[allow(dead_code)] #[derive(Debug)] pub struct TestRequest { pub method: &'static str, @@ -103,8 +131,53 @@ pub mod test_utils { pub response: TestResponse, } - // ===== Mock GCS Client ===== + // Mock Body Implementation + #[derive(Clone)] + struct MockBody { + chunks: Vec, + current: usize, + } + + impl MockBody { + fn new(responses: Vec) -> Self { + Self { + chunks: responses, + current: 0, + } + } + + fn encode_grpc_frame(message: &T) -> Bytes { + let msg_bytes = message.encode_to_vec(); + let mut buf = BytesMut::with_capacity(msg_bytes.len() + 5); + buf.extend_from_slice(&[0]); // Compression flag + buf.extend_from_slice(&(msg_bytes.len() as u32).to_be_bytes()); // Length + buf.extend_from_slice(&msg_bytes); + buf.freeze() + } + } + + impl http_body::Body for MockBody { + type Data = Bytes; + type Error = Box; + + fn poll_frame( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let this = unsafe { self.get_unchecked_mut() }; + if this.current >= this.chunks.len() { + return Poll::Ready(None); + } + + let response = this.chunks[this.current].clone(); + this.current += 1; + let bytes = Self::encode_grpc_frame(&response); + Poll::Ready(Some(Ok(http_body::Frame::data(bytes)))) + } + } + // Mock Client Implementation + #[allow(dead_code)] #[derive(Clone)] pub struct MockGcsClient { expected_requests: Arc>, @@ -125,6 +198,7 @@ pub mod test_utils { self.request_log.lock().unwrap().len() } + #[allow(dead_code)] pub fn single_head_request(response: TestResponse) -> Self { Self::new(vec![TestRequest { method: "READ", @@ -155,81 +229,8 @@ pub mod test_utils { } } - // ===== Mock Streaming Response Implementation ===== - - #[derive(Clone)] - struct MockBody { - chunks: Vec, - current: usize, - } - - impl MockBody { - fn new(responses: Vec) -> Self { - Self { - chunks: responses, - current: 0, - } - } - - fn encode_grpc_frame(message: &T) -> Bytes { - let msg_bytes = message.encode_to_vec(); - let mut buf = BytesMut::with_capacity(msg_bytes.len() + 5); - buf.extend_from_slice(&[0]); // Compression flag (0 = uncompressed) - buf.extend_from_slice(&(msg_bytes.len() as u32).to_be_bytes()); // Message length - buf.extend_from_slice(&msg_bytes); - buf.freeze() - } - } - - impl http_body::Body for MockBody { - type Data = Bytes; - type Error = Box; - - fn poll_frame( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - // Safe to get a mutable reference because we never move out of `self` - let this = unsafe { self.get_unchecked_mut() }; - - if this.current >= this.chunks.len() { - return Poll::Ready(None); - } - - let response = this.chunks[this.current].clone(); - this.current += 1; - let bytes = Self::encode_grpc_frame(&response); - Poll::Ready(Some(Ok(http_body::Frame::data(bytes)))) - } - } - - struct MockDecoder(std::marker::PhantomData); - - impl MockDecoder { - fn new() -> Self { - Self(std::marker::PhantomData) - } - } - - impl tonic::codec::Decoder for MockDecoder { - type Item = T; - type Error = Status; - - fn decode( - &mut self, - buf: &mut tonic::codec::DecodeBuf<'_>, - ) -> Result, Self::Error> { - match T::decode(buf) { - Ok(response) => Ok(Some(response)), - Err(e) => Err(Status::internal(e.to_string())), - } - } - } - - // ===== StorageOperations Implementation ===== - #[async_trait] - impl StorageOperations for MockGcsClient { + impl GcsGrpcClient for MockGcsClient { async fn read_object( &mut self, request: Request, @@ -367,20 +368,30 @@ pub mod test_utils { Ok(Response::new(streaming)) } } -} -use googleapis_tonic_google_storage_v2::google::storage::v2::{ - bidi_write_object_request, write_object_request, BidiWriteObjectRequest, ChecksummedData, - Object, ReadObjectRequest, StartResumableWriteRequest, WriteObjectRequest, WriteObjectSpec, -}; -use nativelink_macro::nativelink_test; -use nativelink_util::buf_channel::make_buf_channel_pair; -use nativelink_util::common::DigestInfo; -use nativelink_util::origin_context::OriginContext; -use nativelink_util::spawn; -use sha2::Digest; -use test_utils::{MockGcsClient, TestRequest, TestResponse}; -use tonic::Request; + struct MockDecoder(std::marker::PhantomData); + + impl MockDecoder { + fn new() -> Self { + Self(std::marker::PhantomData) + } + } + + impl tonic::codec::Decoder for MockDecoder { + type Item = T; + type Error = Status; + + fn decode( + &mut self, + buf: &mut tonic::codec::DecodeBuf<'_>, + ) -> Result, Self::Error> { + match T::decode(buf) { + Ok(response) => Ok(Some(response)), + Err(e) => Err(Status::internal(e.to_string())), + } + } + } +} // Tests #[nativelink_test]