Skip to content

Commit

Permalink
Progress bars for publish
Browse files Browse the repository at this point in the history
  • Loading branch information
konstin committed Sep 21, 2024
1 parent 2d5af07 commit fdce2be
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions crates/uv-fs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use fs2::FileExt;
use std::fmt::Display;
use std::io;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::task::{Context, Poll};
use tempfile::NamedTempFile;
use tokio::io::{AsyncRead, ReadBuf};
use tracing::{debug, error, info, trace, warn};

pub use crate::path::*;
Expand Down Expand Up @@ -387,3 +391,32 @@ impl Drop for LockedFile {
}
}
}

/// An asynchronous reader that reports progress as bytes are read.
pub struct ProgressReader<Reader: AsyncRead + Unpin, Callback: Fn(usize) + Unpin> {
reader: Reader,
callback: Callback,
}

impl<Reader: AsyncRead + Unpin, Callback: Fn(usize) + Unpin> ProgressReader<Reader, Callback> {
/// Create a new [`ProgressReader`] that wraps another reader.
pub fn new(reader: Reader, callback: Callback) -> Self {
Self { reader, callback }
}
}

impl<Reader: AsyncRead + Unpin, Callback: Fn(usize) + Unpin> AsyncRead
for ProgressReader<Reader, Callback>
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.as_mut().reader)
.poll_read(cx, buf)
.map_ok(|()| {
(self.callback)(buf.filled().len());
})
}
}
1 change: 1 addition & 0 deletions crates/uv-publish/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ serde_json = { workspace = true }
sha2 = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true , features = ["io"] }
tracing = { workspace = true }
url = { workspace = true }

Expand Down
46 changes: 38 additions & 8 deletions crates/uv-publish/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ use sha2::{Digest, Sha256};
use std::collections::HashSet;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::{fmt, io};
use thiserror::Error;
use tokio::io::AsyncReadExt;
use tokio_util::io::ReaderStream;
use tracing::{debug, enabled, trace, Level};
use url::Url;
use uv_client::BaseClient;
use uv_fs::Simplified;
use uv_fs::{ProgressReader, Simplified};
use uv_metadata::read_metadata_async_seek;

#[derive(Error, Debug)]
Expand Down Expand Up @@ -79,6 +81,13 @@ pub enum PublishSendError {
RedirectError(Url),
}

pub trait Reporter: Send + Sync + 'static {
fn on_progress(&self, name: &str, id: usize);
fn on_download_start(&self, name: &str, size: Option<u64>) -> usize;
fn on_download_progress(&self, id: usize, inc: u64);
fn on_download_complete(&self);
}

impl PublishSendError {
/// Extract `code` from the PyPI json error response, if any.
///
Expand Down Expand Up @@ -212,6 +221,7 @@ pub async fn upload(
client: &BaseClient,
username: Option<&str>,
password: Option<&str>,
reporter: Arc<impl Reporter>,
) -> Result<bool, PublishError> {
let form_metadata = form_metadata(file, filename)
.await
Expand All @@ -224,6 +234,7 @@ pub async fn upload(
username,
password,
form_metadata,
reporter,
)
.await
.map_err(|err| PublishError::PublishPrepare(file.to_path_buf(), err))?;
Expand Down Expand Up @@ -396,18 +407,23 @@ async fn build_request(
username: Option<&str>,
password: Option<&str>,
form_metadata: Vec<(&'static str, String)>,
reporter: Arc<impl Reporter>,
) -> Result<RequestBuilder, PublishPrepareError> {
let mut form = reqwest::multipart::Form::new();
for (key, value) in form_metadata {
form = form.text(key, value);
}

let file: tokio::fs::File = fs_err::tokio::File::open(file).await?.into();
let file_reader = Body::from(file);
form = form.part(
"content",
Part::stream(file_reader).file_name(filename.to_string()),
);
let file = fs_err::tokio::File::open(file).await?;
let idx = reporter.on_download_start(&filename.to_string(), Some(file.metadata().await?.len()));
let reader = ProgressReader::new(file, move |read| {
reporter.on_download_progress(idx, read as u64);
});
// Stream wrapping puts a static lifetime requirement on the reader (so the request doesn't have
// a lifetime) -> callback needs to be static -> reporter reference needs to be Arc'd.
let file_reader = Body::wrap_stream(ReaderStream::new(reader));
let part = Part::stream(file_reader).file_name(filename.to_string());
form = form.part("content", part);

let url = if let Some(username) = username {
if password.is_none() {
Expand Down Expand Up @@ -525,14 +541,26 @@ async fn handle_response(registry: &Url, response: Response) -> Result<bool, Pub

#[cfg(test)]
mod tests {
use crate::{build_request, form_metadata};
use crate::{build_request, form_metadata, Reporter};
use distribution_filename::DistFilename;
use insta::{assert_debug_snapshot, assert_snapshot};
use itertools::Itertools;
use std::path::PathBuf;
use std::sync::Arc;
use url::Url;
use uv_client::BaseClientBuilder;

struct DummyReporter;

impl Reporter for DummyReporter {
fn on_progress(&self, _name: &str, _id: usize) {}
fn on_download_start(&self, _name: &str, _size: Option<u64>) -> usize {
0
}
fn on_download_progress(&self, _id: usize, _inc: u64) {}
fn on_download_complete(&self) {}
}

/// Snapshot the data we send for an upload request for a source distribution.
#[tokio::test]
async fn upload_request_source_dist() {
Expand Down Expand Up @@ -602,6 +630,7 @@ mod tests {
Some("ferris"),
Some("F3RR!S"),
form_metadata,
Arc::new(DummyReporter),
)
.await
.unwrap();
Expand Down Expand Up @@ -744,6 +773,7 @@ mod tests {
Some("ferris"),
Some("F3RR!S"),
form_metadata,
Arc::new(DummyReporter),
)
.await
.unwrap();
Expand Down
5 changes: 5 additions & 0 deletions crates/uv/src/commands/publish.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::commands::reporters::PublishReporter;
use crate::commands::{human_readable_bytes, ExitStatus};
use crate::printer::Printer;
use anyhow::{bail, Result};
use owo_colors::OwoColorize;
use std::fmt::Write;
use std::sync::Arc;
use tracing::info;
use url::Url;
use uv_client::{BaseClientBuilder, Connectivity};
Expand Down Expand Up @@ -51,13 +53,16 @@ pub(crate) async fn publish(
"Uploading".bold().green(),
format!("({bytes:.1}{unit})").dimmed()
)?;
let reporter = PublishReporter::single(printer);
let uploaded = upload(
&file,
&filename,
&publish_url,
&client,
username.as_deref(),
password.as_deref(),
// Needs to be an `Arc` because the reqwest `Body` static lifetime requirement
Arc::new(reporter),
)
.await?; // Filename and/or URL are already attached, if applicable.
info!("Upload succeeded");
Expand Down
45 changes: 44 additions & 1 deletion crates/uv/src/commands/reporters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ impl ProgressReporter {
);

if size.is_some() {
// We're using binary bytes to match `human_readable_bytes`.
progress.set_style(
ProgressStyle::with_template(
"{msg:10.dim} {bar:30.green/dim} {decimal_bytes:>7}/{decimal_total_bytes:7}",
"{msg:10.dim} {bar:30.green/dim} {binary_bytes:>7}/{binary_total_bytes:7}",
)
.unwrap()
.progress_chars("--"),
Expand Down Expand Up @@ -485,6 +486,48 @@ impl uv_python::downloads::Reporter for PythonDownloadReporter {
}
}

#[derive(Debug)]
pub(crate) struct PublishReporter {
reporter: ProgressReporter,
}

impl PublishReporter {
/// Initialize a [`PublishReporter`] for a single upload.
pub(crate) fn single(printer: Printer) -> Self {
Self::new(printer, 1)
}

/// Initialize a [`PublishReporter`] for multiple uploads.
pub(crate) fn new(printer: Printer, length: u64) -> Self {
let multi_progress = MultiProgress::with_draw_target(printer.target());
let root = multi_progress.add(ProgressBar::with_draw_target(
Some(length),
printer.target(),
));
let reporter = ProgressReporter::new(root, multi_progress, printer);
Self { reporter }
}
}

impl uv_publish::Reporter for PublishReporter {
fn on_progress(&self, _name: &str, id: usize) {
self.reporter.on_download_complete(id);
}

fn on_download_start(&self, name: &str, size: Option<u64>) -> usize {
self.reporter.on_download_start(name.to_string(), size)
}

fn on_download_progress(&self, id: usize, inc: u64) {
self.reporter.on_download_progress(id, inc);
}

fn on_download_complete(&self) {
self.reporter.root.set_message("");
self.reporter.root.finish_and_clear();
}
}

/// Like [`std::fmt::Display`], but with colors.
trait ColorDisplay {
fn to_color_string(&self) -> String;
Expand Down

0 comments on commit fdce2be

Please sign in to comment.