Skip to content

Commit

Permalink
Merge pull request #44 from flatcar/kai/noasync
Browse files Browse the repository at this point in the history
Switch to reqwest::blocking
  • Loading branch information
pothos authored Dec 21, 2023
2 parents f624f27 + 8244bc4 commit 517b4ec
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 89 deletions.
22 changes: 9 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ env_logger = "0.10"
globset = "0.4"
log = "0.4"
protobuf = "3.2.0"
reqwest = "0.11"
reqwest = { version = "0.11", features = ["blocking"] }
sha2 = "0.10"
tempfile = "3.8.1"
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
url = "2"
uuid = "1.2"

Expand Down
11 changes: 6 additions & 5 deletions examples/download_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ use std::str::FromStr;

use ue_rs::download_and_hash;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
fn main() -> Result<(), Box<dyn Error>> {
let client = reqwest::blocking::Client::new();

let url = Url::from_str(std::env::args().nth(1).expect("missing URL (second argument)").as_str())?;

println!("fetching {}...", url);

let data = Vec::new();
let res = download_and_hash(&client, url, data, false).await?;
let tempdir = tempfile::tempdir()?;
let path = tempdir.path().join("tmpfile");
let res = download_and_hash(&client, url, &path, false)?;
tempdir.close()?;

println!("hash: {}", res.hash);

Expand Down
16 changes: 7 additions & 9 deletions examples/full_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ fn get_pkgs_to_download(resp: &omaha::Response) -> Result<Vec<(Url, omaha::Hash<
Ok(to_download)
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
fn main() -> Result<(), Box<dyn Error>> {
let client = reqwest::blocking::Client::new();

const APP_VERSION_DEFAULT: &str = "3340.0.0+nightly-20220823-2100";
const MACHINE_ID_DEFAULT: &str = "abce671d61774703ac7be60715220bfe";
Expand All @@ -59,7 +58,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
track: Cow::Borrowed(TRACK_DEFAULT),
};

let response_text = ue_rs::request::perform(&client, parameters).await.context(format!(
let response_text = ue_rs::request::perform(&client, parameters).context(format!(
"perform({APP_VERSION_DEFAULT}, {MACHINE_ID_DEFAULT}, {TRACK_DEFAULT}) failed"
))?;

Expand All @@ -79,11 +78,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
for (url, expected_sha256) in pkgs_to_dl {
println!("downloading {}...", url);

// TODO: use a file or anything that implements std::io::Write here.
// std::io::BufWriter wrapping an std::fs::File is probably the right choice.
// std::io::sink() is basically just /dev/null
let data = std::io::sink();
let res = ue_rs::download_and_hash(&client, url.clone(), data, false).await.context(format!("download_and_hash({url:?}) failed"))?;
let tempdir = tempfile::tempdir()?;
let path = tempdir.path().join("tmpfile");
let res = ue_rs::download_and_hash(&client, url.clone(), &path, false).context(format!("download_and_hash({url:?}) failed"))?;
tempdir.close()?;

println!("\texpected sha256: {}", expected_sha256);
println!("\tcalculated sha256: {}", res.hash);
Expand Down
7 changes: 3 additions & 4 deletions examples/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ use anyhow::Context;

use ue_rs::request;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
fn main() -> Result<(), Box<dyn Error>> {
let client = reqwest::blocking::Client::new();

const APP_VERSION_DEFAULT: &str = "3340.0.0+nightly-20220823-2100";
const MACHINE_ID_DEFAULT: &str = "abce671d61774703ac7be60715220bfe";
Expand All @@ -20,7 +19,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
track: Cow::Borrowed(TRACK_DEFAULT),
};

let response = request::perform(&client, parameters).await.context(format!(
let response = request::perform(&client, parameters).context(format!(
"perform({APP_VERSION_DEFAULT}, {MACHINE_ID_DEFAULT}, {TRACK_DEFAULT}) failed"
))?;

Expand Down
34 changes: 13 additions & 21 deletions src/bin/download_sysext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use argh::FromArgs;
use globset::{Glob, GlobSet, GlobSetBuilder};
use hard_xml::XmlRead;
use omaha::FileSize;
use reqwest::Client;
use reqwest::blocking::Client;
use reqwest::redirect::Policy;
use url::Url;

Expand Down Expand Up @@ -94,7 +94,7 @@ impl<'a> Package<'a> {
Ok(())
}

async fn download(&mut self, into_dir: &Path, client: &reqwest::Client, print_progress: bool) -> Result<()> {
fn download(&mut self, into_dir: &Path, client: &Client, print_progress: bool) -> Result<()> {
// FIXME: use _range_start for completing downloads
let _range_start = match self.status {
PackageStatus::ToDownload => 0,
Expand All @@ -105,9 +105,7 @@ impl<'a> Package<'a> {
info!("downloading {}...", self.url);

let path = into_dir.join(&*self.name);
let mut file = File::create(path.clone()).context(format!("failed to create path ({:?})", path.display()))?;

let res = match ue_rs::download_and_hash(client, self.url.clone(), &mut file, print_progress).await {
let res = match ue_rs::download_and_hash(client, self.url.clone(), &path, print_progress) {
Ok(ok) => ok,
Err(err) => {
error!("Downloading failed with error {}", err);
Expand Down Expand Up @@ -243,28 +241,26 @@ fn get_pkgs_to_download<'a>(resp: &'a omaha::Response, glob_set: &GlobSet)
}

// Read data from remote URL into File
async fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client, print_progress: bool) -> Result<Package<'a>>
fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client, print_progress: bool) -> Result<Package<'a>>
where
U: reqwest::IntoUrl + From<U> + std::clone::Clone + std::fmt::Debug,
Url: From<U>,
{
let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?;

ue_rs::download_and_hash(client, input_url.clone(), &mut file, print_progress).await.context(format!("unable to download data(url {:?})", input_url))?;
let r = ue_rs::download_and_hash(client, input_url.clone(), path, print_progress).context(format!("unable to download data(url {:?})", input_url))?;

Ok(Package {
name: Cow::Borrowed(path.file_name().unwrap_or(OsStr::new("fakepackage")).to_str().unwrap_or("fakepackage")),
hash: hash_on_disk_sha256(path, None)?,
size: FileSize::from_bytes(file.metadata().context(format!("failed to get metadata, path ({:?})", path.display()))?.len() as usize),
hash: r.hash,
size: FileSize::from_bytes(r.data.metadata().context(format!("failed to get metadata, path ({:?})", path.display()))?.len() as usize),
url: input_url.into(),
status: PackageStatus::Unverified,
})
}

async fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client, print_progress: bool) -> Result<()> {
fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client, print_progress: bool) -> Result<()> {
pkg.check_download(unverified_dir)?;

pkg.download(unverified_dir, client, print_progress).await.context(format!("unable to download \"{:?}\"", pkg.name))?;
pkg.download(unverified_dir, client, print_progress).context(format!("unable to download \"{:?}\"", pkg.name))?;

// Unverified payload is stored in e.g. "output_dir/.unverified/oem.gz".
// Verified payload is stored in e.g. "output_dir/oem.raw".
Expand Down Expand Up @@ -322,8 +318,7 @@ impl Args {
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
fn main() -> Result<(), Box<dyn Error>> {
env_logger::init();

let args: Args = argh::from_env();
Expand Down Expand Up @@ -374,17 +369,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
Url::from_str(url.as_str()).context(anyhow!("failed to convert into url ({:?})", url))?,
&client,
args.print_progress,
)
.await?;
)?;
do_download_verify(
&mut pkg_fake,
output_dir,
unverified_dir.as_path(),
args.pubkey_file.as_str(),
&client,
args.print_progress,
)
.await?;
)?;

// verify only a fake package, early exit and skip the rest.
return Ok(());
Expand Down Expand Up @@ -417,8 +410,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
args.pubkey_file.as_str(),
&client,
args.print_progress,
)
.await?;
)?;
}

// clean up data
Expand Down
43 changes: 11 additions & 32 deletions src/download.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
use anyhow::{Context, Result, bail};
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
use std::io;
use std::io::{BufReader, Read, Seek, SeekFrom};
use std::fs::File;
use std::path::Path;
use log::info;
use url::Url;

use reqwest::StatusCode;
use reqwest::blocking::Client;

use sha2::{Sha256, Digest};

pub struct DownloadResult<W: std::io::Write> {
pub struct DownloadResult {
pub hash: omaha::Hash<omaha::Sha256>,
pub data: W,
pub data: File,
}

pub fn hash_on_disk_sha256(path: &Path, maxlen: Option<usize>) -> Result<omaha::Hash<omaha::Sha256>> {
Expand Down Expand Up @@ -57,18 +57,16 @@ pub fn hash_on_disk_sha256(path: &Path, maxlen: Option<usize>) -> Result<omaha::
Ok(omaha::Hash::from_bytes(hasher.finalize().into()))
}

pub async fn download_and_hash<U, W>(client: &reqwest::Client, url: U, mut data: W, print_progress: bool) -> Result<DownloadResult<W>>
pub fn download_and_hash<U>(client: &Client, url: U, path: &Path, print_progress: bool) -> Result<DownloadResult>
where
U: reqwest::IntoUrl + Clone,
W: io::Write,
Url: From<U>,
{
let client_url = url.clone();

#[rustfmt::skip]
let mut res = client.get(url)
.send()
.await
.context(format!("client get and send({:?}) failed", client_url.as_str()))?;

// Redirect was already handled at this point, so there is no need to touch
Expand All @@ -89,33 +87,14 @@ where
}
}

let mut hasher = Sha256::new();

let mut bytes_read = 0usize;
let bytes_to_read = res.content_length().unwrap_or(u64::MAX) as usize;

while let Some(chunk) = res.chunk().await.context("failed to get response chunk")? {
bytes_read += chunk.len();

hasher.update(&chunk);
data.write_all(&chunk).context("failed to write_all chunk")?;

if print_progress {
print!(
"\rread {}/{} ({:3}%)",
bytes_read,
bytes_to_read,
((bytes_read as f32 / bytes_to_read as f32) * 100.0f32).floor()
);
io::stdout().flush().context("failed to flush stdout")?;
}
if print_progress {
println!("writing to {}", path.display());
}

data.flush().context("failed to flush data")?;
println!();
let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?;
res.copy_to(&mut file)?;

Ok(DownloadResult {
hash: omaha::Hash::from_bytes(hasher.finalize().into()),
data,
hash: hash_on_disk_sha256(path, None)?,
data: file,
})
}
5 changes: 2 additions & 3 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub struct Parameters<'a> {
pub machine_id: Cow<'a, str>,
}

pub async fn perform<'a>(client: &reqwest::Client, parameters: Parameters<'a>) -> Result<String> {
pub fn perform<'a>(client: &reqwest::blocking::Client, parameters: Parameters<'a>) -> Result<String> {
let req_body = {
let r = omaha::Request {
protocol_version: Cow::Borrowed(PROTOCOL_VERSION),
Expand Down Expand Up @@ -78,8 +78,7 @@ pub async fn perform<'a>(client: &reqwest::Client, parameters: Parameters<'a>) -
let resp = client.post(UPDATE_URL)
.body(req_body)
.send()
.await
.context("client post send({UPDATE_URL}) failed")?;

resp.text().await.context("failed to get response")
resp.text().context("failed to get response")
}

0 comments on commit 517b4ec

Please sign in to comment.