diff --git a/src/lib.rs b/src/lib.rs index 535c2e4..6a6d814 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ use rand::{thread_rng, Rng}; use reqwest::header::{ HeaderMap, HeaderName, HeaderValue, AUTHORIZATION, CONTENT_LENGTH, CONTENT_RANGE, RANGE, }; +use reqwest::Url; use std::collections::HashMap; use std::fs::remove_file; use std::io::SeekFrom; @@ -74,9 +75,9 @@ fn download( if path.exists() { match remove_file(filename) { Ok(_) => err, - Err(err) => PyException::new_err(format!( - "Error while removing corrupted file: {err:?}" - )), + Err(err) => { + PyException::new_err(format!("Error while removing corrupted file: {err}")) + } } } else { err @@ -168,20 +169,20 @@ async fn download_async( for (k, v) in input_headers { let name: HeaderName = k .try_into() - .map_err(|err| PyException::new_err(format!("Invalid header: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Invalid header: {err}")))?; let value: HeaderValue = AsRef::::as_ref(&v) .try_into() - .map_err(|err| PyException::new_err(format!("Invalid header value: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Invalid header value: {err}")))?; if name == AUTHORIZATION { - auth_token = Some(v); + auth_token = Some(value); } else { headers.insert(name, value); } } }; - let response = if let Some(token) = auth_token { - client.get(&url).bearer_auth(token) + let response = if let Some(token) = auth_token.as_ref() { + client.get(&url).header(AUTHORIZATION, token) } else { client.get(&url) } @@ -189,20 +190,29 @@ async fn download_async( .header(RANGE, "bytes=0-0") .send() .await - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))? + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))? .error_for_status() .map_err(|err| PyException::new_err(err.to_string()))?; // Only call the final redirect URL to avoid overloading the Hub with requests and also // altering the download count - let redirected_url = response.url().to_string(); + let redirected_url = response.url(); + if Url::parse(&url) + .map_err(|err| PyException::new_err(format!("failed to parse url: {err}")))? + .host() + == redirected_url.host() + { + if let Some(token) = auth_token { + headers.insert(AUTHORIZATION, token); + } + } let content_range = response .headers() .get(CONTENT_RANGE) .ok_or(PyException::new_err("No content length"))? .to_str() - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?; let size: Vec<&str> = content_range.split('/').collect(); // Content-Range: bytes 0-0/702517648 @@ -213,14 +223,14 @@ async fn download_async( "Error while downloading: No size was detected", ))? .parse() - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?; let mut handles = FuturesUnordered::new(); let semaphore = Arc::new(Semaphore::new(max_files)); let parallel_failures_semaphore = Arc::new(Semaphore::new(parallel_failures)); for start in (0..length).step_by(chunk_size) { - let url = redirected_url.clone(); + let url = redirected_url.to_string(); let filename = filename.clone(); let client = client.clone(); let headers = headers.clone(); @@ -232,19 +242,19 @@ async fn download_async( let permit = semaphore .acquire_owned() .await - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?; let mut chunk = download_chunk(&client, &url, &filename, start, stop, headers.clone()).await; let mut i = 0; if parallel_failures > 0 { while let Err(dlerr) = chunk { if i >= max_retries { return Err(PyException::new_err(format!( - "Failed after too many retries ({max_retries:?}): {dlerr:?}" + "Failed after too many retries ({max_retries}): {dlerr}" ))); } let parallel_failure_permit = parallel_failures_semaphore.clone().try_acquire_owned().map_err(|err| { PyException::new_err(format!( - "Failed too many failures in parallel ({parallel_failures:?}): {dlerr:?} ({err:?})" + "Failed too many failures in parallel ({parallel_failures}): {dlerr} ({err})" )) })?; @@ -274,7 +284,7 @@ async fn download_async( } Err(err) => { return Err(PyException::new_err(format!( - "Error while downloading: {err:?}" + "Error while downloading: {err}" ))); } } @@ -298,26 +308,26 @@ async fn download_chunk( .create(true) .open(filename) .await - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?; file.seek(SeekFrom::Start(start as u64)) .await - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?; let response = client .get(url) .headers(headers) .header(RANGE, range) .send() .await - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))? + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))? .error_for_status() - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?; let content = response .bytes() .await - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?; file.write_all(&content) .await - .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + .map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?; Ok(()) } @@ -357,13 +367,13 @@ async fn upload_async( while let Err(ul_err) = chunk { if i >= max_retries { return Err(PyException::new_err(format!( - "Failed after too many retries ({max_retries:?}): {ul_err:?}" + "Failed after too many retries ({max_retries}): {ul_err}" ))); } let parallel_failure_permit = parallel_failures_semaphore.clone().try_acquire_owned().map_err(|err| { PyException::new_err(format!( - "Failed too many failures in parallel ({parallel_failures:?}): {ul_err:?} ({err:?})" + "Failed too many failures in parallel ({parallel_failures}): {ul_err} ({err})" )) })?;