From fba19b0142daed54c181cdb8f634f29cf7d37f8d Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 27 Jul 2023 02:32:07 -0400 Subject: [PATCH] Cleanup multipart upload trait (#4572) * Cleanup multipart upload trait * Update object_store/src/multipart.rs Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: Liang-Chi Hsieh --- object_store/src/aws/client.rs | 4 +- object_store/src/aws/mod.rs | 30 ++++--------- object_store/src/azure/mod.rs | 17 +++----- object_store/src/gcp/mod.rs | 77 ++++++++++++++++------------------ object_store/src/multipart.rs | 50 +++++++--------------- 5 files changed, 69 insertions(+), 109 deletions(-) diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 971d2c60862e..188897620b91 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -23,7 +23,7 @@ use crate::client::list::ListClient; use crate::client::list_response::ListResponse; use crate::client::retry::RetryExt; use crate::client::GetOptionsExt; -use crate::multipart::UploadPart; +use crate::multipart::PartId; use crate::path::DELIMITER; use crate::{ ClientOptions, GetOptions, ListResult, MultipartId, Path, Result, RetryConfig, @@ -479,7 +479,7 @@ impl S3Client { &self, location: &Path, upload_id: &str, - parts: Vec, + parts: Vec, ) -> Result<()> { let parts = parts .into_iter() diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index e74e6f2dfc3e..5a29bd0fc6c7 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -56,7 +56,7 @@ use crate::client::{ TokenCredentialProvider, }; use crate::config::ConfigValue; -use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; +use crate::multipart::{PartId, PutPart, WriteMultiPart}; use crate::{ ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, Result, RetryConfig, @@ -227,7 +227,7 @@ impl ObjectStore for AmazonS3 { client: Arc::clone(&self.client), }; - Ok((id, Box::new(CloudMultiPartUpload::new(upload, 8)))) + Ok((id, Box::new(WriteMultiPart::new(upload, 8)))) } async fn abort_multipart( @@ -308,12 +308,8 @@ struct S3MultiPartUpload { } #[async_trait] -impl CloudMultiPartUploadImpl for S3MultiPartUpload { - async fn put_multipart_part( - &self, - buf: Vec, - part_idx: usize, - ) -> Result { +impl PutPart for S3MultiPartUpload { + async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { use reqwest::header::ETAG; let part = (part_idx + 1).to_string(); @@ -326,26 +322,16 @@ impl CloudMultiPartUploadImpl for S3MultiPartUpload { ) .await?; - let etag = response - .headers() - .get(ETAG) - .context(MissingEtagSnafu) - .map_err(crate::Error::from)?; + let etag = response.headers().get(ETAG).context(MissingEtagSnafu)?; - let etag = etag - .to_str() - .context(BadHeaderSnafu) - .map_err(crate::Error::from)?; + let etag = etag.to_str().context(BadHeaderSnafu)?; - Ok(UploadPart { + Ok(PartId { content_id: etag.to_string(), }) } - async fn complete( - &self, - completed_parts: Vec, - ) -> Result<(), std::io::Error> { + async fn complete(&self, completed_parts: Vec) -> Result<()> { self.client .complete_multipart(&self.location, &self.upload_id, completed_parts) .await?; diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index d2735038321b..8619319a5b25 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -28,7 +28,7 @@ //! after 7 days. use self::client::{BlockId, BlockList}; use crate::{ - multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, + multipart::{PartId, PutPart, WriteMultiPart}, path::Path, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, RetryConfig, @@ -42,7 +42,6 @@ use percent_encoding::percent_decode_str; use serde::{Deserialize, Serialize}; use snafu::{OptionExt, ResultExt, Snafu}; use std::fmt::{Debug, Formatter}; -use std::io; use std::str::FromStr; use std::sync::Arc; use tokio::io::AsyncWrite; @@ -186,7 +185,7 @@ impl ObjectStore for MicrosoftAzure { client: Arc::clone(&self.client), location: location.to_owned(), }; - Ok((String::new(), Box::new(CloudMultiPartUpload::new(inner, 8)))) + Ok((String::new(), Box::new(WriteMultiPart::new(inner, 8)))) } async fn abort_multipart( @@ -243,12 +242,8 @@ struct AzureMultiPartUpload { } #[async_trait] -impl CloudMultiPartUploadImpl for AzureMultiPartUpload { - async fn put_multipart_part( - &self, - buf: Vec, - part_idx: usize, - ) -> Result { +impl PutPart for AzureMultiPartUpload { + async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { let content_id = format!("{part_idx:20}"); let block_id: BlockId = content_id.clone().into(); @@ -264,10 +259,10 @@ impl CloudMultiPartUploadImpl for AzureMultiPartUpload { ) .await?; - Ok(UploadPart { content_id }) + Ok(PartId { content_id }) } - async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error> { + async fn complete(&self, completed_parts: Vec) -> Result<()> { let blocks = completed_parts .into_iter() .map(|part| BlockId::from(part.content_id)) diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index d4d370373d0d..d98e6b068d4f 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -29,7 +29,6 @@ //! to abort the upload and drop those unneeded parts. In addition, you may wish to //! consider implementing automatic clean up of unused parts that are older than one //! week. -use std::io; use std::str::FromStr; use std::sync::Arc; @@ -52,7 +51,7 @@ use crate::client::{ TokenCredentialProvider, }; use crate::{ - multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, + multipart::{PartId, PutPart, WriteMultiPart}, path::{Path, DELIMITER}, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, RetryConfig, @@ -117,6 +116,15 @@ enum Error { #[snafu(display("Error getting put response body: {}", source))] PutResponseBody { source: reqwest::Error }, + #[snafu(display("Got invalid put response: {}", source))] + InvalidPutResponse { source: quick_xml::de::DeError }, + + #[snafu(display("Error performing post request {}: {}", path, source))] + PostRequest { + source: crate::client::retry::Error, + path: String, + }, + #[snafu(display("Error decoding object size: {}", source))] InvalidSize { source: std::num::ParseIntError }, @@ -148,6 +156,12 @@ enum Error { #[snafu(display("Configuration key: '{}' is not known.", key))] UnknownConfigurationKey { key: String }, + + #[snafu(display("ETag Header missing from response"))] + MissingEtag, + + #[snafu(display("Received header containing non-ASCII data"))] + BadHeader { source: header::ToStrError }, } impl From for super::Error { @@ -283,14 +297,9 @@ impl GoogleCloudStorageClient { })?; let data = response.bytes().await.context(PutResponseBodySnafu)?; - let result: InitiateMultipartUploadResult = quick_xml::de::from_reader( - data.as_ref().reader(), - ) - .context(InvalidXMLResponseSnafu { - method: "POST".to_string(), - url, - data, - })?; + let result: InitiateMultipartUploadResult = + quick_xml::de::from_reader(data.as_ref().reader()) + .context(InvalidPutResponseSnafu)?; Ok(result.upload_id) } @@ -472,24 +481,16 @@ struct GCSMultipartUpload { } #[async_trait] -impl CloudMultiPartUploadImpl for GCSMultipartUpload { +impl PutPart for GCSMultipartUpload { /// Upload an object part - async fn put_multipart_part( - &self, - buf: Vec, - part_idx: usize, - ) -> Result { + async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { let upload_id = self.multipart_id.clone(); let url = format!( "{}/{}/{}", self.client.base_url, self.client.bucket_name_encoded, self.encoded_path ); - let credential = self - .client - .get_credential() - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + let credential = self.client.get_credential().await?; let response = self .client @@ -504,26 +505,24 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { .header(header::CONTENT_LENGTH, format!("{}", buf.len())) .body(buf) .send_retry(&self.client.retry_config) - .await?; + .await + .context(PutRequestSnafu { + path: &self.encoded_path, + })?; let content_id = response .headers() .get("ETag") - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, - "response headers missing ETag", - ) - })? + .context(MissingEtagSnafu)? .to_str() - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? + .context(BadHeaderSnafu)? .to_string(); - Ok(UploadPart { content_id }) + Ok(PartId { content_id }) } /// Complete a multipart upload - async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error> { + async fn complete(&self, completed_parts: Vec) -> Result<()> { let upload_id = self.multipart_id.clone(); let url = format!( "{}/{}/{}", @@ -539,16 +538,11 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { }) .collect(); - let credential = self - .client - .get_credential() - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - + let credential = self.client.get_credential().await?; let upload_info = CompleteMultipartUpload { parts }; let data = quick_xml::se::to_string(&upload_info) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))? + .context(InvalidPutResponseSnafu)? // We cannot disable the escaping that transforms "/" to ""e;" :( // https://github.com/tafia/quick-xml/issues/362 // https://github.com/tafia/quick-xml/issues/350 @@ -561,7 +555,10 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { .query(&[("uploadId", upload_id)]) .body(data) .send_retry(&self.client.retry_config) - .await?; + .await + .context(PostRequestSnafu { + path: &self.encoded_path, + })?; Ok(()) } @@ -588,7 +585,7 @@ impl ObjectStore for GoogleCloudStorage { multipart_id: upload_id.clone(), }; - Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8)))) + Ok((upload_id, Box::new(WriteMultiPart::new(inner, 8)))) } async fn abort_multipart( diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs index 5f9b7e67488f..d4c911fceab4 100644 --- a/object_store/src/multipart.rs +++ b/object_store/src/multipart.rs @@ -31,40 +31,33 @@ use crate::Result; type BoxedTryFuture = Pin> + Send>>; /// A trait that can be implemented by cloud-based object stores -/// and used in combination with [`CloudMultiPartUpload`] to provide +/// and used in combination with [`WriteMultiPart`] to provide /// multipart upload support #[async_trait] -pub trait CloudMultiPartUploadImpl: 'static { +pub trait PutPart: Send + Sync + 'static { /// Upload a single part - async fn put_multipart_part( - &self, - buf: Vec, - part_idx: usize, - ) -> Result; + async fn put_part(&self, buf: Vec, part_idx: usize) -> Result; /// Complete the upload with the provided parts /// /// `completed_parts` is in order of part number - async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error>; + async fn complete(&self, completed_parts: Vec) -> Result<()>; } /// Represents a part of a file that has been successfully uploaded in a multipart upload process. #[derive(Debug, Clone)] -pub struct UploadPart { +pub struct PartId { /// Id of this part pub content_id: String, } -/// Struct that manages and controls multipart uploads to a cloud storage service. -pub struct CloudMultiPartUpload -where - T: CloudMultiPartUploadImpl, -{ +/// Wrapper around a [`PutPart`] that implements [`AsyncWrite`] +pub struct WriteMultiPart { inner: Arc, /// A list of completed parts, in sequential order. - completed_parts: Vec>, + completed_parts: Vec>, /// Part upload tasks currently running - tasks: FuturesUnordered>, + tasks: FuturesUnordered>, /// Maximum number of upload tasks to run concurrently max_concurrency: usize, /// Buffer that will be sent in next upload. @@ -80,10 +73,7 @@ where completion_task: Option>, } -impl CloudMultiPartUpload -where - T: CloudMultiPartUploadImpl, -{ +impl WriteMultiPart { /// Create a new multipart upload with the implementation and the given maximum concurrency pub fn new(inner: T, max_concurrency: usize) -> Self { Self { @@ -114,7 +104,7 @@ where } /// Poll current tasks - pub fn poll_tasks( + fn poll_tasks( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Result<(), io::Error> { @@ -130,12 +120,7 @@ where } Ok(()) } -} -impl CloudMultiPartUpload -where - T: CloudMultiPartUploadImpl + Send + Sync, -{ // The `poll_flush` function will only flush the in-progress tasks. // The `final_flush` method called during `poll_shutdown` will flush // the `current_buffer` along with in-progress tasks. @@ -153,7 +138,7 @@ where let inner = Arc::clone(&self.inner); let part_idx = self.current_part_idx; self.tasks.push(Box::pin(async move { - let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?; + let upload_part = inner.put_part(out_buffer, part_idx).await?; Ok((part_idx, upload_part)) })); } @@ -169,10 +154,7 @@ where } } -impl AsyncWrite for CloudMultiPartUpload -where - T: CloudMultiPartUploadImpl + Send + Sync, -{ +impl AsyncWrite for WriteMultiPart { fn poll_write( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -199,7 +181,7 @@ where let inner = Arc::clone(&self.inner); let part_idx = self.current_part_idx; self.tasks.push(Box::pin(async move { - let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?; + let upload_part = inner.put_part(out_buffer, part_idx).await?; Ok((part_idx, upload_part)) })); self.current_part_idx += 1; @@ -269,9 +251,9 @@ where } } -impl std::fmt::Debug for CloudMultiPartUpload { +impl std::fmt::Debug for WriteMultiPart { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("CloudMultiPartUpload") + f.debug_struct("WriteMultiPart") .field("completed_parts", &self.completed_parts) .field("tasks", &self.tasks) .field("max_concurrency", &self.max_concurrency)