From d8c8a0b5700aecc8c3930d9d44af22da1de22ab4 Mon Sep 17 00:00:00 2001 From: Gaius Date: Wed, 24 Jan 2024 20:46:16 +0800 Subject: [PATCH] feat: add make_download_task_request to proxy (#232) Signed-off-by: Gaius --- src/bin/dfget/main.rs | 8 +-- src/config/mod.rs | 5 ++ src/lib.rs | 16 +++++ src/proxy/header.rs | 141 ++++++++++++++++++++++++++++++++++++++++++ src/proxy/mod.rs | 91 ++++++++++++++++++++++++++- 5 files changed, 255 insertions(+), 6 deletions(-) create mode 100644 src/proxy/header.rs diff --git a/src/bin/dfget/main.rs b/src/bin/dfget/main.rs index 7fd207da7d3..55349ea05fc 100644 --- a/src/bin/dfget/main.rs +++ b/src/bin/dfget/main.rs @@ -15,11 +15,9 @@ */ use clap::Parser; -use dragonfly_api::common::v2::Download; -use dragonfly_api::common::v2::TaskType; +use dragonfly_api::common::v2::{Download, TaskType}; use dragonfly_api::dfdaemon::v2::DownloadTaskRequest; -use dragonfly_client::config::dfdaemon; -use dragonfly_client::config::dfget; +use dragonfly_client::config::{self, dfdaemon, dfget}; use dragonfly_client::grpc::dfdaemon_download::DfdaemonDownloadClient; use dragonfly_client::grpc::health::HealthClient; use dragonfly_client::tracing::init_tracing; @@ -79,7 +77,7 @@ struct Args { #[arg( long = "piece-length", - default_value_t = 4194304, + default_value_t = config::default_piece_length(), help = "Specify the byte length of the piece" )] piece_length: u64, diff --git a/src/config/mod.rs b/src/config/mod.rs index 036b04fcabb..e51307e58ac 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -97,3 +97,8 @@ pub fn default_cache_dir() -> PathBuf { #[cfg(target_os = "macos")] return home::home_dir().unwrap().join(".dragonfly").join("cache"); } + +// default_piece_length is the default piece length for task. +pub fn default_piece_length() -> u64 { + 4 * 1024 * 1024 +} diff --git a/src/lib.rs b/src/lib.rs index 8dc81d6d3c1..7db634acd10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,6 +96,22 @@ pub enum Error { #[error(transparent)] Elapsed(#[from] tokio_stream::Elapsed), + // InvalidUri is the error for invalid uri. + #[error(transparent)] + InvalidUri(#[from] http::uri::InvalidUri), + + // InvalidUriParts is the error for invalid uri parts. + #[error(transparent)] + InvalidUriParts(#[from] http::uri::InvalidUriParts), + + // InvalidHeaderValue is the error for invalid header value. + #[error(transparent)] + InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue), + + // HeaderToStrError is the error for header to str. + #[error(transparent)] + HeaderToStrError(#[from] reqwest::header::ToStrError), + // RangeUnsatisfiableError is the error for range unsatisfiable. #[error(transparent)] RangeUnsatisfiableError(#[from] http_range_header::RangeUnsatisfiableError), diff --git a/src/proxy/header.rs b/src/proxy/header.rs new file mode 100644 index 00000000000..fec174953b2 --- /dev/null +++ b/src/proxy/header.rs @@ -0,0 +1,141 @@ +/* + * Copyright 2024 The Dragonfly Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::config; +use crate::utils::http::parse_range_header; +use crate::Result; +use dragonfly_api::common::v2::{Priority, Range}; +use reqwest::header::HeaderMap; +use tracing::error; + +// DRAGONFLY_TAG_HEADER is the header key of tag in http request. +pub const DRAGONFLY_TAG_HEADER: &str = "X-Dragonfly-Tag"; + +// DRAGONFLY_APPLICATION_HEADER is the header key of application in http request. +pub const DRAGONFLY_APPLICATION_HEADER: &str = "X-Dragonfly-Application"; + +// DRAGONFLY_PRIORITY_HEADER is the header key of priority in http request, +// refer to https://github.com/dragonflyoss/api/blob/main/proto/common.proto#L67. +pub const DRAGONFLY_PRIORITY_HEADER: &str = "X-Dragonfly-Priority"; + +// DRAGONFLY_FILTERS_HEADER is the header key of filters in http request, +// it is the filtered query params to generate the task id. +// When filter is "X-Dragonfly-Filtered-Query-Params: Signature,Expires,ns" for example: +// http://example.com/xyz?Expires=e1&Signature=s1&ns=docker.io and http://example.com/xyz?Expires=e2&Signature=s2&ns=docker.io +// will generate the same task id. +// Default value includes the filtered query params of s3, gcs, oss, obs, cos. +pub const DRAGONFLY_FILTERED_QUERY_PARAMS_HEADER: &str = "X-Dragonfly-Filtered-Query-Params"; + +// DRAGONFLY_PIECE_LENGTH_HEADER is the header key of piece length in http request, +// it specifies the piece length of the task. +pub const DRAGONFLY_PIECE_LENGTH_HEADER: &str = "X-Dragonfly-Piece-Length"; + +// get_range gets the range from http header. +pub fn get_range(header: &HeaderMap, content_length: u64) -> Result> { + match header.get(reqwest::header::RANGE) { + Some(range) => { + let range = range.to_str()?; + Ok(Some(parse_range_header(range, content_length)?)) + } + None => Ok(None), + } +} + +// get_tag gets the tag from http header. +pub fn get_tag(header: &HeaderMap) -> Option { + match header.get(DRAGONFLY_TAG_HEADER) { + Some(tag) => match tag.to_str() { + Ok(tag) => Some(tag.to_string()), + Err(err) => { + error!("get tag from header failed: {}", err); + None + } + }, + None => None, + } +} + +// get_application gets the application from http header. +pub fn get_application(header: &HeaderMap) -> Option { + match header.get(DRAGONFLY_APPLICATION_HEADER) { + Some(application) => match application.to_str() { + Ok(application) => Some(application.to_string()), + Err(err) => { + error!("get application from header failed: {}", err); + None + } + }, + None => None, + } +} + +// get_priority gets the priority from http header. +pub fn get_priority(header: &HeaderMap) -> i32 { + let default_priority = Priority::Level6 as i32; + match header.get(DRAGONFLY_PRIORITY_HEADER) { + Some(priority) => match priority.to_str() { + Ok(priority) => match priority.parse::() { + Ok(priority) => priority, + Err(err) => { + error!("parse priority from header failed: {}", err); + default_priority + } + }, + Err(err) => { + error!("get priority from header failed: {}", err); + default_priority + } + }, + None => default_priority, + } +} + +// get_filters gets the filters from http header. +pub fn get_filtered_query_params( + header: &HeaderMap, + default_filtered_query_params: Vec, +) -> Vec { + match header.get(DRAGONFLY_FILTERED_QUERY_PARAMS_HEADER) { + Some(filters) => match filters.to_str() { + Ok(filters) => filters.split(',').map(|s| s.to_string()).collect(), + Err(err) => { + error!("get filters from header failed: {}", err); + default_filtered_query_params + } + }, + None => default_filtered_query_params, + } +} + +// get_piece_length gets the piece length from http header. +pub fn get_piece_length(header: &HeaderMap) -> u64 { + match header.get(DRAGONFLY_PIECE_LENGTH_HEADER) { + Some(piece_length) => match piece_length.to_str() { + Ok(piece_length) => match piece_length.parse::() { + Ok(piece_length) => piece_length, + Err(err) => { + error!("parse piece length from header failed: {}", err); + config::default_piece_length() + } + }, + Err(err) => { + error!("get piece length from header failed: {}", err); + config::default_piece_length() + } + }, + None => config::default_piece_length(), + } +} diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 0fa338732d4..8d7535674b7 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -14,10 +14,13 @@ * limitations under the License. */ -use crate::config::dfdaemon::Config; +use crate::config::dfdaemon::{Config, Rule}; use crate::shutdown; +use crate::utils::http::headermap_to_hashmap; use crate::Result as ClientResult; use bytes::Bytes; +use dragonfly_api::common::v2::{Download, TaskType}; +use dragonfly_api::dfdaemon::v2::DownloadTaskRequest; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; use hyper::client::conn::http1::Builder; use hyper::server::conn::http1; @@ -32,6 +35,8 @@ use tokio::net::TcpStream; use tokio::sync::mpsc; use tracing::{error, info, instrument, Span}; +pub mod header; + // Proxy is the proxy server. pub struct Proxy { // config is the configuration of the dfdaemon. @@ -219,6 +224,90 @@ pub async fn https_handler( } } +// make_download_task_request makes a request for downloading the task. +#[instrument(skip_all)] +fn make_download_task_request( + request: Request, + rule: Rule, + content_length: u64, +) -> ClientResult { + // Construct the download url. + let url = make_download_url(request.uri(), rule.use_tls, rule.redirect)?; + + // TODO: Remove the convertion after the http crate version is the same. + // Convert the Reqwest header to the Hyper header, because of the http crate + // version is different. Reqwest header depends on the http crate + // version 0.2, but the Hyper header depends on the http crate version 0.1. + let mut header = reqwest::header::HeaderMap::new(); + for (raw_header_key, raw_header_value) in request.headers() { + let header_name: reqwest::header::HeaderName = match raw_header_key.to_string().parse() { + Ok(header_name) => header_name, + Err(err) => { + error!("parse header name error: {}", err); + continue; + } + }; + + let header_value: reqwest::header::HeaderValue = match raw_header_value.to_str() { + Ok(header_value) => match header_value.parse() { + Ok(header_value) => header_value, + Err(err) => { + error!("parse header value error: {}", err); + continue; + } + }, + Err(err) => { + error!("parse header value error: {}", err); + continue; + } + }; + + header.insert(header_name, header_value); + } + + Ok(DownloadTaskRequest { + download: Some(Download { + url, + digest: None, + range: header::get_range(&header, content_length)?, + r#type: TaskType::Dfdaemon as i32, + tag: header::get_tag(&header), + application: header::get_application(&header), + priority: header::get_priority(&header), + filters: header::get_filtered_query_params(&header, rule.filtered_query_params), + request_header: headermap_to_hashmap(&header), + piece_length: header::get_piece_length(&header), + output_path: None, + timeout: None, + need_back_to_source: false, + }), + }) +} + +// make_download_url makes a download url by the given uri. +#[instrument(skip_all)] +fn make_download_url( + uri: &hyper::Uri, + use_tls: bool, + redirect: Option, +) -> ClientResult { + let mut parts = uri.clone().into_parts(); + + // Set the scheme to https if the rule uses tls. + if use_tls { + parts.scheme = Some(http::uri::Scheme::HTTPS); + } + + // Set the authority to the redirect address. + if let Some(redirect) = redirect { + parts.authority = Some(http::uri::Authority::from_static(Box::leak( + redirect.into_boxed_str(), + ))); + } + + Ok(http::Uri::from_parts(parts)?.to_string()) +} + // empty returns an empty body. #[instrument(skip_all)] fn empty() -> BoxBody {