diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index b11f4513b6df..7419d6463dcc 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -38,18 +38,17 @@ use futures::{StreamExt, TryStreamExt}; use reqwest::header::{HeaderName, IF_MATCH, IF_NONE_MATCH}; use reqwest::{Method, StatusCode}; use std::{sync::Arc, time::Duration}; -use tokio::io::AsyncWrite; use url::Url; use crate::aws::client::{RequestError, S3Client}; use crate::client::get::GetClientExt; use crate::client::list::ListClientExt; use crate::client::CredentialProvider; -use crate::multipart::{MultiPartStore, PartId, PutPart, WriteMultiPart}; +use crate::multipart::{MultiPartStore, PartId}; use crate::signer::Signer; use crate::{ Error, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, PutMode, - PutOptions, PutResult, Result, + PutOptions, PutResult, Result, Upload, UploadPart, }; static TAGS_HEADER: HeaderName = HeaderName::from_static("x-amz-tagging"); @@ -85,6 +84,7 @@ const STORE: &str = "S3"; /// [`CredentialProvider`] for [`AmazonS3`] pub type AwsCredentialProvider = Arc>; +use crate::client::parts::Parts; pub use credential::{AwsAuthorizer, AwsCredential}; /// Interface for [Amazon S3](https://aws.amazon.com/s3/). @@ -211,25 +211,18 @@ impl ObjectStore for AmazonS3 { } } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - let id = self.client.create_multipart(location).await?; - - let upload = S3MultiPartUpload { - location: location.clone(), - upload_id: id.clone(), - client: Arc::clone(&self.client), - }; - - Ok((id, Box::new(WriteMultiPart::new(upload, 8)))) - } - - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()> { - self.client - .delete_request(location, &[("uploadId", multipart_id)]) - .await + async fn upload(&self, location: &Path) -> Result> { + let upload_id = self.client.create_multipart(location).await?; + + Ok(Box::new(S3MultiPartUpload { + part_idx: 0, + state: Arc::new(UploadState { + client: Arc::clone(&self.client), + location: location.clone(), + upload_id: upload_id.clone(), + parts: Default::default(), + }), + })) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { @@ -319,25 +312,50 @@ impl ObjectStore for AmazonS3 { } } +#[derive(Debug)] struct S3MultiPartUpload { + part_idx: usize, + state: Arc, +} + +#[derive(Debug)] +struct UploadState { + parts: Parts, location: Path, upload_id: String, client: Arc, } #[async_trait] -impl PutPart for S3MultiPartUpload { - async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { - self.client - .put_part(&self.location, &self.upload_id, part_idx, buf.into()) +impl Upload for S3MultiPartUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + let idx = self.part_idx; + self.part_idx += 1; + let state = Arc::clone(&self.state); + Box::pin(async move { + let part = state + .client + .put_part(&state.location, &state.upload_id, idx, data) + .await?; + state.parts.put(idx, part); + Ok(()) + }) + } + + async fn complete(&mut self) -> Result { + let parts = self.state.parts.finish(self.part_idx)?; + + self.state + .client + .complete_multipart(&self.state.location, &self.state.upload_id, parts) .await } - async fn complete(&self, completed_parts: Vec) -> Result<()> { - self.client - .complete_multipart(&self.location, &self.upload_id, completed_parts) - .await?; - Ok(()) + async fn abort(&mut self) -> Result<()> { + self.state + .client + .delete_request(&self.state.location, &[("uploadId", &self.state.upload_id)]) + .await } } @@ -377,7 +395,6 @@ mod tests { use crate::{client::get::GetClient, tests::*}; use bytes::Bytes; use hyper::HeaderMap; - use tokio::io::AsyncWriteExt; const NON_EXISTENT_NAME: &str = "nonexistentname"; @@ -542,9 +559,9 @@ mod tests { store.put(&locations[0], data.clone()).await.unwrap(); store.copy(&locations[0], &locations[1]).await.unwrap(); - let (_, mut writer) = store.put_multipart(&locations[2]).await.unwrap(); - writer.write_all(&data).await.unwrap(); - writer.shutdown().await.unwrap(); + let mut upload = store.upload(&locations[2]).await.unwrap(); + upload.put_part(data.clone()).await.unwrap(); + upload.complete().await.unwrap(); for location in &locations { let res = store diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 423052bfee68..8bf139884820 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -27,7 +27,7 @@ use crate::{ path::Path, signer::Signer, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, - Result, + Result, Upload, UploadPart, }; use async_trait::async_trait; use bytes::Bytes; @@ -36,7 +36,6 @@ use reqwest::Method; use std::fmt::Debug; use std::sync::Arc; use std::time::Duration; -use tokio::io::AsyncWrite; use url::Url; use crate::client::get::GetClientExt; @@ -50,6 +49,8 @@ mod credential; /// [`CredentialProvider`] for [`MicrosoftAzure`] pub type AzureCredentialProvider = Arc>; +use crate::azure::client::AzureClient; +use crate::client::parts::Parts; pub use builder::{AzureConfigKey, MicrosoftAzureBuilder}; pub use credential::AzureCredential; @@ -90,22 +91,17 @@ impl ObjectStore for MicrosoftAzure { self.client.put_blob(location, bytes, opts).await } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - let inner = AzureMultiPartUpload { - client: Arc::clone(&self.client), - location: location.to_owned(), - }; - Ok((String::new(), Box::new(WriteMultiPart::new(inner, 8)))) + async fn upload(&self, location: &Path) -> Result> { + Ok(Box::new(AzureMultiPartUpload { + part_idx: 0, + state: Arc::new(UploadState { + client: Arc::clone(&self.client), + location: location.clone(), + parts: Default::default(), + }), + })) } - async fn abort_multipart(&self, _location: &Path, _multipart_id: &MultipartId) -> Result<()> { - // There is no way to drop blocks that have been uploaded. Instead, they simply - // expire in 7 days. - Ok(()) - } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { self.client.get_opts(location, options).await @@ -193,20 +189,43 @@ impl Signer for MicrosoftAzure { /// put_multipart_part -> PUT block /// complete -> PUT block list /// abort -> No equivalent; blocks are simply dropped after 7 days -#[derive(Debug, Clone)] +#[derive(Debug)] struct AzureMultiPartUpload { - client: Arc, + part_idx: usize, + state: Arc, +} + +#[derive(Debug)] +struct UploadState { location: Path, + parts: Parts, + client: Arc, } #[async_trait] -impl PutPart for AzureMultiPartUpload { - async fn put_part(&self, buf: Vec, idx: usize) -> Result { - self.client.put_block(&self.location, idx, buf.into()).await +impl Upload for AzureMultiPartUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + let idx = self.part_idx; + self.part_idx += 1; + let state = Arc::clone(&self.state); + Box::pin(async move { + let part = state.client.put_block(&state.location, idx, data).await?; + state.parts.put(idx, part); + Ok(()) + }) + } + + async fn complete(&mut self) -> Result { + let parts = self.state.parts.finish(self.part_idx)?; + + self.state + .client + .put_block_list(&self.state.location, parts) + .await } - async fn complete(&self, parts: Vec) -> Result<()> { - self.client.put_block_list(&self.location, parts).await?; + async fn abort(&mut self) -> Result<()> { + // Nothing to do Ok(()) } } diff --git a/object_store/src/buffered.rs b/object_store/src/buffered.rs index e7f26649da6e..6f6e6824c9b4 100644 --- a/object_store/src/buffered.rs +++ b/object_store/src/buffered.rs @@ -27,7 +27,7 @@ use std::io::{Error, ErrorKind, SeekFrom}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; /// The default buffer size used by [`BufReader`] pub const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024; @@ -367,7 +367,7 @@ mod tests { use super::*; use crate::memory::InMemory; use crate::path::Path; - use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncSeekExt}; + use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; #[tokio::test] async fn test_buf_reader() { diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index ba549abed783..e5be8ce86106 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -37,7 +37,6 @@ use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use snafu::{OptionExt, ResultExt, Snafu}; -use tokio::io::AsyncWrite; use url::Url; use crate::client::get::GetClientExt; @@ -45,7 +44,7 @@ use crate::client::header::get_etag; use crate::http::client::Client; use crate::path::Path; use crate::{ - ClientConfigKey, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, + ClientConfigKey, ClientOptions, GetOptions, GetResult, ListResult, ObjectMeta, ObjectStore, PutMode, PutOptions, PutResult, Result, RetryConfig, Upload, }; diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 8bcaa45ff4ce..a6f50110e267 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -2197,7 +2197,7 @@ mod tests { pub(crate) async fn tagging(storage: &dyn ObjectStore, validate: bool, get_tags: F) where F: Fn(Path) -> Fut + Send + Sync, - Fut: Future> + Send, + Fut: std::future::Future> + Send, { use bytes::Buf; use serde::Deserialize; diff --git a/object_store/tests/get_range_file.rs b/object_store/tests/get_range_file.rs index f73d78578f08..846648d1d285 100644 --- a/object_store/tests/get_range_file.rs +++ b/object_store/tests/get_range_file.rs @@ -25,7 +25,6 @@ use object_store::path::Path; use object_store::*; use std::fmt::Formatter; use tempfile::tempdir; -use tokio::io::AsyncWrite; #[derive(Debug)] struct MyStore(LocalFileSystem); @@ -42,16 +41,10 @@ impl ObjectStore for MyStore { self.0.put_opts(path, data, opts).await } - async fn put_multipart( - &self, - _: &Path, - ) -> Result<(MultipartId, Box)> { + async fn upload(&self, _location: &Path) -> Result> { todo!() } - async fn abort_multipart(&self, _: &Path, _: &MultipartId) -> Result<()> { - todo!() - } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { self.0.get_opts(location, options).await