Skip to content

Commit

Permalink
Merge pull request #41 from flatcar/kai/dl-print
Browse files Browse the repository at this point in the history
download: Make progress reporting opt-in
  • Loading branch information
pothos authored Dec 7, 2023
2 parents 2c58396 + 407fa0b commit f624f27
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/download_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
println!("fetching {}...", url);

let data = Vec::new();
let res = download_and_hash(&client, url, data).await?;
let res = download_and_hash(&client, url, data, false).await?;

println!("hash: {}", res.hash);

Expand Down
2 changes: 1 addition & 1 deletion examples/full_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
// std::io::BufWriter wrapping an std::fs::File is probably the right choice.
// std::io::sink() is basically just /dev/null
let data = std::io::sink();
let res = ue_rs::download_and_hash(&client, url.clone(), data).await.context(format!("download_and_hash({url:?}) failed"))?;
let res = ue_rs::download_and_hash(&client, url.clone(), data, false).await.context(format!("download_and_hash({url:?}) failed"))?;

println!("\texpected sha256: {}", expected_sha256);
println!("\tcalculated sha256: {}", res.hash);
Expand Down
28 changes: 21 additions & 7 deletions src/bin/download_sysext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl<'a> Package<'a> {
Ok(())
}

async fn download(&mut self, into_dir: &Path, client: &reqwest::Client) -> Result<()> {
async fn download(&mut self, into_dir: &Path, client: &reqwest::Client, print_progress: bool) -> Result<()> {
// FIXME: use _range_start for completing downloads
let _range_start = match self.status {
PackageStatus::ToDownload => 0,
Expand All @@ -107,7 +107,7 @@ impl<'a> Package<'a> {
let path = into_dir.join(&*self.name);
let mut file = File::create(path.clone()).context(format!("failed to create path ({:?})", path.display()))?;

let res = match ue_rs::download_and_hash(client, self.url.clone(), &mut file).await {
let res = match ue_rs::download_and_hash(client, self.url.clone(), &mut file, print_progress).await {
Ok(ok) => ok,
Err(err) => {
error!("Downloading failed with error {}", err);
Expand Down Expand Up @@ -243,14 +243,14 @@ fn get_pkgs_to_download<'a>(resp: &'a omaha::Response, glob_set: &GlobSet)
}

// Read data from remote URL into File
async fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client) -> Result<Package<'a>>
async fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client, print_progress: bool) -> Result<Package<'a>>
where
U: reqwest::IntoUrl + From<U> + std::clone::Clone + std::fmt::Debug,
Url: From<U>,
{
let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?;

ue_rs::download_and_hash(client, input_url.clone(), &mut file).await.context(format!("unable to download data(url {:?})", input_url))?;
ue_rs::download_and_hash(client, input_url.clone(), &mut file, print_progress).await.context(format!("unable to download data(url {:?})", input_url))?;

Ok(Package {
name: Cow::Borrowed(path.file_name().unwrap_or(OsStr::new("fakepackage")).to_str().unwrap_or("fakepackage")),
Expand All @@ -261,10 +261,10 @@ where
})
}

async fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client) -> Result<()> {
async fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client, print_progress: bool) -> Result<()> {
pkg.check_download(unverified_dir)?;

pkg.download(unverified_dir, client).await.context(format!("unable to download \"{:?}\"", pkg.name))?;
pkg.download(unverified_dir, client, print_progress).await.context(format!("unable to download \"{:?}\"", pkg.name))?;

// Unverified payload is stored in e.g. "output_dir/.unverified/oem.gz".
// Verified payload is stored in e.g. "output_dir/oem.raw".
Expand Down Expand Up @@ -304,6 +304,10 @@ struct Args {
/// may be specified multiple times.
#[argh(option, short = 'm')]
image_match: Vec<String>,

/// report download progress
#[argh(switch, short = 'v')]
print_progress: bool,
}

impl Args {
Expand Down Expand Up @@ -369,6 +373,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
&temp_payload_path,
Url::from_str(url.as_str()).context(anyhow!("failed to convert into url ({:?})", url))?,
&client,
args.print_progress,
)
.await?;
do_download_verify(
Expand All @@ -377,6 +382,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
unverified_dir.as_path(),
args.pubkey_file.as_str(),
&client,
args.print_progress,
)
.await?;

Expand Down Expand Up @@ -404,7 +410,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
////

for pkg in pkgs_to_dl.iter_mut() {
do_download_verify(pkg, output_dir, unverified_dir.as_path(), args.pubkey_file.as_str(), &client).await?;
do_download_verify(
pkg,
output_dir,
unverified_dir.as_path(),
args.pubkey_file.as_str(),
&client,
args.print_progress,
)
.await?;
}

// clean up data
Expand Down
19 changes: 10 additions & 9 deletions src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub fn hash_on_disk_sha256(path: &Path, maxlen: Option<usize>) -> Result<omaha::
Ok(omaha::Hash::from_bytes(hasher.finalize().into()))
}

pub async fn download_and_hash<U, W>(client: &reqwest::Client, url: U, mut data: W) -> Result<DownloadResult<W>>
pub async fn download_and_hash<U, W>(client: &reqwest::Client, url: U, mut data: W, print_progress: bool) -> Result<DownloadResult<W>>
where
U: reqwest::IntoUrl + Clone,
W: io::Write,
Expand Down Expand Up @@ -100,14 +100,15 @@ where
hasher.update(&chunk);
data.write_all(&chunk).context("failed to write_all chunk")?;

// TODO: better way to report progress?
print!(
"\rread {}/{} ({:3}%)",
bytes_read,
bytes_to_read,
((bytes_read as f32 / bytes_to_read as f32) * 100.0f32).floor()
);
io::stdout().flush().context("failed to flush stdout")?;
if print_progress {
print!(
"\rread {}/{} ({:3}%)",
bytes_read,
bytes_to_read,
((bytes_read as f32 / bytes_to_read as f32) * 100.0f32).floor()
);
io::stdout().flush().context("failed to flush stdout")?;
}
}

data.flush().context("failed to flush data")?;
Expand Down

0 comments on commit f624f27

Please sign in to comment.