Skip to content

Commit

Permalink
fix: forward all headers on identical host (#45)
Browse files Browse the repository at this point in the history
* fix: forward all headers on identical host

* fix: `.bearer_auth()` -> `.header()`
  • Loading branch information
McPatate authored Jul 23, 2024
1 parent 27c5ab5 commit a1feb83
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use rand::{thread_rng, Rng};
use reqwest::header::{
HeaderMap, HeaderName, HeaderValue, AUTHORIZATION, CONTENT_LENGTH, CONTENT_RANGE, RANGE,
};
use reqwest::Url;
use std::collections::HashMap;
use std::fs::remove_file;
use std::io::SeekFrom;
Expand Down Expand Up @@ -74,9 +75,9 @@ fn download(
if path.exists() {
match remove_file(filename) {
Ok(_) => err,
Err(err) => PyException::new_err(format!(
"Error while removing corrupted file: {err:?}"
)),
Err(err) => {
PyException::new_err(format!("Error while removing corrupted file: {err}"))
}
}
} else {
err
Expand Down Expand Up @@ -168,41 +169,50 @@ async fn download_async(
for (k, v) in input_headers {
let name: HeaderName = k
.try_into()
.map_err(|err| PyException::new_err(format!("Invalid header: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Invalid header: {err}")))?;
let value: HeaderValue = AsRef::<str>::as_ref(&v)
.try_into()
.map_err(|err| PyException::new_err(format!("Invalid header value: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Invalid header value: {err}")))?;
if name == AUTHORIZATION {
auth_token = Some(v);
auth_token = Some(value);
} else {
headers.insert(name, value);
}
}
};

let response = if let Some(token) = auth_token {
client.get(&url).bearer_auth(token)
let response = if let Some(token) = auth_token.as_ref() {
client.get(&url).header(AUTHORIZATION, token)
} else {
client.get(&url)
}
.headers(headers.clone())
.header(RANGE, "bytes=0-0")
.send()
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?
.error_for_status()
.map_err(|err| PyException::new_err(err.to_string()))?;

// Only call the final redirect URL to avoid overloading the Hub with requests and also
// altering the download count
let redirected_url = response.url().to_string();
let redirected_url = response.url();
if Url::parse(&url)
.map_err(|err| PyException::new_err(format!("failed to parse url: {err}")))?
.host()
== redirected_url.host()
{
if let Some(token) = auth_token {
headers.insert(AUTHORIZATION, token);
}
}

let content_range = response
.headers()
.get(CONTENT_RANGE)
.ok_or(PyException::new_err("No content length"))?
.to_str()
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?;

let size: Vec<&str> = content_range.split('/').collect();
// Content-Range: bytes 0-0/702517648
Expand All @@ -213,14 +223,14 @@ async fn download_async(
"Error while downloading: No size was detected",
))?
.parse()
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?;

let mut handles = FuturesUnordered::new();
let semaphore = Arc::new(Semaphore::new(max_files));
let parallel_failures_semaphore = Arc::new(Semaphore::new(parallel_failures));

for start in (0..length).step_by(chunk_size) {
let url = redirected_url.clone();
let url = redirected_url.to_string();
let filename = filename.clone();
let client = client.clone();
let headers = headers.clone();
Expand All @@ -232,19 +242,19 @@ async fn download_async(
let permit = semaphore
.acquire_owned()
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?;
let mut chunk = download_chunk(&client, &url, &filename, start, stop, headers.clone()).await;
let mut i = 0;
if parallel_failures > 0 {
while let Err(dlerr) = chunk {
if i >= max_retries {
return Err(PyException::new_err(format!(
"Failed after too many retries ({max_retries:?}): {dlerr:?}"
"Failed after too many retries ({max_retries}): {dlerr}"
)));
}
let parallel_failure_permit = parallel_failures_semaphore.clone().try_acquire_owned().map_err(|err| {
PyException::new_err(format!(
"Failed too many failures in parallel ({parallel_failures:?}): {dlerr:?} ({err:?})"
"Failed too many failures in parallel ({parallel_failures}): {dlerr} ({err})"
))
})?;

Expand Down Expand Up @@ -274,7 +284,7 @@ async fn download_async(
}
Err(err) => {
return Err(PyException::new_err(format!(
"Error while downloading: {err:?}"
"Error while downloading: {err}"
)));
}
}
Expand All @@ -298,26 +308,26 @@ async fn download_chunk(
.create(true)
.open(filename)
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?;
file.seek(SeekFrom::Start(start as u64))
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?;
let response = client
.get(url)
.headers(headers)
.header(RANGE, range)
.send()
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?
.error_for_status()
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?;
let content = response
.bytes()
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?;
file.write_all(&content)
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
.map_err(|err| PyException::new_err(format!("Error while downloading: {err}")))?;
Ok(())
}

Expand Down Expand Up @@ -357,13 +367,13 @@ async fn upload_async(
while let Err(ul_err) = chunk {
if i >= max_retries {
return Err(PyException::new_err(format!(
"Failed after too many retries ({max_retries:?}): {ul_err:?}"
"Failed after too many retries ({max_retries}): {ul_err}"
)));
}

let parallel_failure_permit = parallel_failures_semaphore.clone().try_acquire_owned().map_err(|err| {
PyException::new_err(format!(
"Failed too many failures in parallel ({parallel_failures:?}): {ul_err:?} ({err:?})"
"Failed too many failures in parallel ({parallel_failures}): {ul_err} ({err})"
))
})?;

Expand Down

0 comments on commit a1feb83

Please sign in to comment.