From 9f86bc9d1df004aaa2d20e7bacdaf13f58cdf827 Mon Sep 17 00:00:00 2001 From: Gaius Date: Sat, 27 Jan 2024 10:28:17 +0800 Subject: [PATCH] feat: intercept http request to download task by p2p in proxy (#234) Signed-off-by: Gaius --- Cargo.lock | 1 + Cargo.toml | 1 + src/bin/dfdaemon/main.rs | 1 + src/grpc/dfdaemon_download.rs | 18 ++- src/grpc/dfdaemon_upload.rs | 7 +- src/proxy/header.rs | 15 +-- src/proxy/mod.rs | 232 +++++++++++++++++++++++----------- src/task/mod.rs | 96 -------------- src/task/piece.rs | 10 +- src/utils/http.rs | 50 ++++++++ 10 files changed, 245 insertions(+), 186 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9d85ed8ae7c..d18a3dd43bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -551,6 +551,7 @@ dependencies = [ "fs2", "fslock", "futures", + "futures-util", "hashring", "hex", "home", diff --git a/Cargo.toml b/Cargo.toml index 7fdba98495b..7a2dae58521 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,3 +85,4 @@ hyper-rustls = "0.26" http-body-util = "0.1.0" regex = "1.10.2" http-range-header = "0.4.0" +futures-util = "0.3.30" diff --git a/src/bin/dfdaemon/main.rs b/src/bin/dfdaemon/main.rs index 973995f9cc1..d3155a94943 100644 --- a/src/bin/dfdaemon/main.rs +++ b/src/bin/dfdaemon/main.rs @@ -195,6 +195,7 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize proxy server. let proxy = Proxy::new( config.clone(), + task.clone(), shutdown.clone(), shutdown_complete_tx.clone(), ); diff --git a/src/grpc/dfdaemon_download.rs b/src/grpc/dfdaemon_download.rs index a5f0b077487..508484e4f97 100644 --- a/src/grpc/dfdaemon_download.rs +++ b/src/grpc/dfdaemon_download.rs @@ -16,7 +16,7 @@ use crate::shutdown; use crate::task; -use crate::utils::http::hashmap_to_headermap; +use crate::utils::http::{get_range, hashmap_to_headermap}; use crate::Result as ClientResult; use dragonfly_api::common::v2::Task; use dragonfly_api::dfdaemon::v2::{ @@ -150,7 +150,7 @@ impl DfdaemonDownload for DfdaemonDownloadServerHandler { let request = request.into_inner(); // Check whether the download is empty. - let download = request + let mut download = request .download .ok_or(Status::invalid_argument("missing download"))?; @@ -204,6 +204,20 @@ impl DfdaemonDownload for DfdaemonDownloadServerHandler { Status::internal(e.to_string()) })?; + // Download's range priority is higher than the request header's range. + // If download protocol is http, use the range of the request header. + // If download protocol is not http, use the range of the download. + if download.range.is_none() { + let content_length = task + .content_length() + .ok_or(Status::internal("missing content length in the response"))?; + + download.range = get_range(&request_header, content_length).map_err(|err| { + error!("get range failed: {}", err); + Status::failed_precondition(err.to_string()) + })?; + } + // Clone the task. let task_manager = self.task.clone(); diff --git a/src/grpc/dfdaemon_upload.rs b/src/grpc/dfdaemon_upload.rs index 15dce15b217..3f95109c7a1 100644 --- a/src/grpc/dfdaemon_upload.rs +++ b/src/grpc/dfdaemon_upload.rs @@ -277,7 +277,12 @@ impl DfdaemonUpload for DfdaemonUploadServerHandler { let mut reader = self .task .piece - .upload_from_local_peer_into_async_read(task_id.as_str(), piece_number, piece.length) + .upload_from_local_peer_into_async_read( + task_id.as_str(), + piece_number, + piece.length, + false, + ) .await .map_err(|err| { error!("read piece content from local storage: {}", err); diff --git a/src/proxy/header.rs b/src/proxy/header.rs index fec174953b2..e582c267765 100644 --- a/src/proxy/header.rs +++ b/src/proxy/header.rs @@ -15,9 +15,7 @@ */ use crate::config; -use crate::utils::http::parse_range_header; -use crate::Result; -use dragonfly_api::common::v2::{Priority, Range}; +use dragonfly_api::common::v2::Priority; use reqwest::header::HeaderMap; use tracing::error; @@ -43,17 +41,6 @@ pub const DRAGONFLY_FILTERED_QUERY_PARAMS_HEADER: &str = "X-Dragonfly-Filtered-Q // it specifies the piece length of the task. pub const DRAGONFLY_PIECE_LENGTH_HEADER: &str = "X-Dragonfly-Piece-Length"; -// get_range gets the range from http header. -pub fn get_range(header: &HeaderMap, content_length: u64) -> Result> { - match header.get(reqwest::header::RANGE) { - Some(range) => { - let range = range.to_str()?; - Ok(Some(parse_range_header(range, content_length)?)) - } - None => Ok(None), - } -} - // get_tag gets the tag from http header. pub fn get_tag(header: &HeaderMap) -> Option { match header.get(DRAGONFLY_TAG_HEADER) { diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index d0e3da15281..d6c1ac4230e 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -14,25 +14,31 @@ * limitations under the License. */ -use crate::config::dfdaemon::{Config, Rule}; +use crate::config::dfdaemon::Config; +use crate::grpc::dfdaemon_download::DfdaemonDownloadClient; use crate::shutdown; -use crate::utils::http::headermap_to_hashmap; +use crate::task::Task; +use crate::utils::http::{headermap_to_hashmap, hyper_headermap_to_reqwest_headermap}; use crate::Result as ClientResult; use bytes::Bytes; use dragonfly_api::common::v2::{Download, TaskType}; use dragonfly_api::dfdaemon::v2::DownloadTaskRequest; -use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; +use futures_util::TryStreamExt; +use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; +use hyper::body::Frame; use hyper::client::conn::http1::Builder; use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::upgrade::Upgraded; use hyper::{Method, Request, Response}; use hyper_util::rt::tokio::TokioIo; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio::sync::mpsc; +use tokio_util::io::ReaderStream; use tracing::{error, info, instrument, Span}; pub mod header; @@ -42,6 +48,9 @@ pub struct Proxy { // config is the configuration of the dfdaemon. config: Arc, + // task is the task manager. + task: Arc, + // addr is the address of the proxy server. addr: SocketAddr, @@ -57,11 +66,13 @@ impl Proxy { // new creates a new Proxy. pub fn new( config: Arc, + task: Arc, shutdown: shutdown::Shutdown, shutdown_complete_tx: mpsc::UnboundedSender<()>, ) -> Self { Self { config: config.clone(), + task: task.clone(), addr: SocketAddr::new(config.proxy.server.ip.unwrap(), config.proxy.server.port), shutdown, _shutdown_complete: shutdown_complete_tx, @@ -90,16 +101,15 @@ impl Proxy { let io = TokioIo::new(tcp); info!("accepted connection from {}", remote_address); - // Clone the config. let config = self.config.clone(); - + let task = self.task.clone(); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) .serve_connection( io, - service_fn(move |request| handler(config.clone(), request)), + service_fn(move |request| handler(config.clone(), task.clone(), request)), ) .with_upgrades() .await @@ -122,6 +132,7 @@ impl Proxy { #[instrument(skip_all, fields(uri, method))] pub async fn handler( config: Arc, + task: Arc, request: Request, ) -> Result>, hyper::Error> { info!("handle request: {:?}", request); @@ -135,13 +146,14 @@ pub async fn handler( return https_handler(config, request).await; } - return http_handler(config, request).await; + return http_handler(config, task, request).await; } // http_handler handles the http request. #[instrument(skip_all)] pub async fn http_handler( config: Arc, + task: Arc, request: Request, ) -> Result>, hyper::Error> { let Some(host) = request.uri().host() else { @@ -155,7 +167,148 @@ pub async fn http_handler( if let Some(rules) = config.proxy.rules.clone() { for rule in rules.iter() { if rule.regex.is_match(request.uri().to_string().as_str()) { - // TODO: handle https request. + // Convert the Reqwest header to the Hyper header. + let request_header = hyper_headermap_to_reqwest_headermap(request.headers()); + + // Construct the download url. + let url = + match make_download_url(request.uri(), rule.use_tls, rule.redirect.clone()) { + Ok(url) => url, + Err(err) => { + let mut response = Response::new(full( + err.to_string().to_string().as_bytes().to_vec(), + )); + *response.status_mut() = http::StatusCode::BAD_REQUEST; + return Ok(response); + } + }; + + // Get parameters from the header. + let tag = header::get_tag(&request_header); + let application = header::get_application(&request_header); + let priority = header::get_priority(&request_header); + let piece_length = header::get_piece_length(&request_header); + let filtered_query_params = header::get_filtered_query_params( + &request_header, + rule.filtered_query_params.clone(), + ); + let request_header = headermap_to_hashmap(&request_header); + + // Initialize the dfdaemon download client. + let dfdaemon_download_client = match DfdaemonDownloadClient::new_unix( + config.download.server.socket_path.clone(), + ) + .await + { + Ok(client) => client, + Err(err) => { + let mut response = + Response::new(full(err.to_string().to_string().as_bytes().to_vec())); + *response.status_mut() = http::StatusCode::BAD_REQUEST; + return Ok(response); + } + }; + + // Download the task by the dfdaemon download client. + let response = match dfdaemon_download_client + .download_task(DownloadTaskRequest { + download: Some(Download { + url, + digest: None, + // Download range use header range in HTTP protocol. + range: None, + r#type: TaskType::Dfdaemon as i32, + tag, + application, + priority, + filtered_query_params, + request_header, + piece_length, + output_path: None, + timeout: None, + need_back_to_source: false, + }), + }) + .await + { + Ok(response) => response, + Err(err) => { + let mut response = + Response::new(full(err.to_string().to_string().as_bytes().to_vec())); + *response.status_mut() = http::StatusCode::BAD_REQUEST; + return Ok(response); + } + }; + + // Write the task data to the reader. + let (reader, mut writer) = tokio::io::duplex(1024); + + // Handle the response from the download grpc server. + let mut out_stream = response.into_inner(); + while let Some(message) = match out_stream.message().await { + Ok(message) => message, + Err(err) => { + let mut response = + Response::new(full(err.to_string().to_string().as_bytes().to_vec())); + *response.status_mut() = http::StatusCode::BAD_REQUEST; + return Ok(response); + } + } { + let piece = match message.piece { + Some(piece) => piece, + None => { + let mut response = + Response::new(full("download task response piece is empty")); + *response.status_mut() = http::StatusCode::BAD_REQUEST; + return Ok(response); + } + }; + + let mut need_piece_number = 0; + let piece_reader = match task + .piece + .download_from_local_peer_into_async_read( + message.task_id.as_str(), + piece.number, + piece.length, + true, + ) + .await + { + Ok(reader) => reader, + Err(err) => { + let mut response = Response::new(full( + err.to_string().to_string().as_bytes().to_vec(), + )); + *response.status_mut() = http::StatusCode::BAD_REQUEST; + return Ok(response); + } + }; + + // Sort by piece number and return to reader in order. + let mut finished_piece_readers = HashMap::new(); + finished_piece_readers.insert(piece.number, piece_reader); + while let Some(piece_reader) = + finished_piece_readers.get_mut(&need_piece_number) + { + if let Err(err) = tokio::io::copy(piece_reader, &mut writer).await { + let mut response = Response::new(full( + err.to_string().to_string().as_bytes().to_vec(), + )); + *response.status_mut() = http::StatusCode::BAD_REQUEST; + return Ok(response); + } + need_piece_number += 1; + } + } + + // TODO: Construct the reader stream. + let reader_stream = ReaderStream::new(reader); + let stream_body = StreamBody::new(reader_stream.map_ok(Frame::data)); + let boxed_body = stream_body.boxed(); + info!("boxed_body: {:?}", boxed_body); + + // TODO: handle http stream. let mut response = Response::new(full("CONNECT must be to a socket address")); *response.status_mut() = http::StatusCode::BAD_REQUEST; return Ok(response); @@ -224,69 +377,6 @@ pub async fn https_handler( } } -// make_download_task_request makes a request for downloading the task. -#[instrument(skip_all)] -fn make_download_task_request( - request: Request, - rule: Rule, - content_length: u64, -) -> ClientResult { - // Construct the download url. - let url = make_download_url(request.uri(), rule.use_tls, rule.redirect)?; - - // TODO: Remove the convertion after the http crate version is the same. - // Convert the Reqwest header to the Hyper header, because of the http crate - // version is different. Reqwest header depends on the http crate - // version 0.2, but the Hyper header depends on the http crate version 0.1. - let mut header = reqwest::header::HeaderMap::new(); - for (raw_header_key, raw_header_value) in request.headers() { - let header_name: reqwest::header::HeaderName = match raw_header_key.to_string().parse() { - Ok(header_name) => header_name, - Err(err) => { - error!("parse header name error: {}", err); - continue; - } - }; - - let header_value: reqwest::header::HeaderValue = match raw_header_value.to_str() { - Ok(header_value) => match header_value.parse() { - Ok(header_value) => header_value, - Err(err) => { - error!("parse header value error: {}", err); - continue; - } - }, - Err(err) => { - error!("parse header value error: {}", err); - continue; - } - }; - - header.insert(header_name, header_value); - } - - Ok(DownloadTaskRequest { - download: Some(Download { - url, - digest: None, - range: header::get_range(&header, content_length)?, - r#type: TaskType::Dfdaemon as i32, - tag: header::get_tag(&header), - application: header::get_application(&header), - priority: header::get_priority(&header), - filtered_query_params: header::get_filtered_query_params( - &header, - rule.filtered_query_params, - ), - request_header: headermap_to_hashmap(&header), - piece_length: header::get_piece_length(&header), - output_path: None, - timeout: None, - need_back_to_source: false, - }), - }) -} - // make_download_url makes a download url by the given uri. #[instrument(skip_all)] fn make_download_url( diff --git a/src/task/mod.rs b/src/task/mod.rs index c9643a7dd06..4092c62cda5 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -38,7 +38,6 @@ use std::collections::HashMap; use std::path::Path; use std::sync::Arc; use std::time::Duration; -use tokio::io::AsyncRead; use tokio::sync::{ mpsc::{self, Sender}, Semaphore, @@ -330,101 +329,6 @@ impl Task { Ok(()) } - // download downloads a task. - #[allow(clippy::too_many_arguments)] - pub async fn download_into_async_read( - self: Arc, - task: metadata::Task, - host_id: &str, - peer_id: &str, - download: Download, - ) -> ClientResult> { - // Initialize stream channel. - let (out_stream_tx, mut out_stream_rx) = mpsc::channel(1024); - let self_clone = Arc::clone(&self); - let host_id = host_id.to_string(); - let peer_id = peer_id.to_string(); - let task_id = task.id.clone(); - let range = download.range.clone(); - - // Spawn the download task. - tokio::spawn( - async move { - match self_clone - .download( - task.clone(), - host_id.as_str(), - peer_id.as_str(), - download.clone(), - out_stream_tx.clone(), - ) - .await - { - Ok(_) => { - // Download task succeeded. - info!("download task succeeded"); - if download.range.is_none() { - if let Err(err) = self_clone.download_finished(task.id.as_str()) { - error!("download task finished: {}", err); - } - } - } - Err(e) => { - // Download task failed. - self_clone - .download_failed(task.id.as_str()) - .await - .unwrap_or_else(|err| { - error!("download task failed: {}", err); - }); - - error!("download failed: {}", e); - } - } - drop(out_stream_tx); - } - .in_current_span(), - ); - - // If the range is specified, read the task by range. - if let Some(range) = range { - while let Some(message) = out_stream_rx.recv().await { - message?.piece.ok_or(Error::UnexpectedResponse())?; - } - - let reader = self - .storage - .read_task_by_range(task_id.as_str(), range) - .await?; - - return Ok(Box::new(reader) as Box); - } - - // Return async read of the order of the pieces. - let mut need_piece_number = 0; - let (reader, mut writer) = tokio::io::duplex(1024); - while let Some(message) = out_stream_rx.recv().await { - let piece = message?.piece.ok_or(Error::UnexpectedResponse())?; - let piece_reader = self - .piece - .download_from_local_peer_into_async_read( - task_id.as_str(), - piece.number, - piece.length, - ) - .await?; - - let mut finished_piece_readers = HashMap::new(); - finished_piece_readers.insert(piece.number, piece_reader); - while let Some(piece_reader) = finished_piece_readers.get_mut(&need_piece_number) { - tokio::io::copy(piece_reader, &mut writer).await?; - need_piece_number += 1; - } - } - - Ok(Box::new(reader)) - } - // download_partial_with_scheduler downloads a partial task with scheduler. #[allow(clippy::too_many_arguments)] async fn download_partial_with_scheduler( diff --git a/src/task/piece.rs b/src/task/piece.rs index f994548b453..e1ef0601d64 100644 --- a/src/task/piece.rs +++ b/src/task/piece.rs @@ -213,9 +213,12 @@ impl Piece { task_id: &str, number: u32, length: u64, + disable_rate_limit: bool, ) -> Result { // Acquire the upload rate limiter. - self.upload_rate_limiter.acquire(length as usize).await; + if !disable_rate_limit { + self.upload_rate_limiter.acquire(length as usize).await; + } // Upload the piece content. self.storage.upload_piece(task_id, number).await @@ -227,9 +230,12 @@ impl Piece { task_id: &str, number: u32, length: u64, + disable_rate_limit: bool, ) -> Result { // Acquire the download rate limiter. - self.download_rate_limiter.acquire(length as usize).await; + if !disable_rate_limit { + self.download_rate_limiter.acquire(length as usize).await; + } // Upload the piece content. self.storage.upload_piece(task_id, number).await diff --git a/src/utils/http.rs b/src/utils/http.rs index 267638206b2..3da0877aeab 100644 --- a/src/utils/http.rs +++ b/src/utils/http.rs @@ -18,6 +18,7 @@ use crate::{Error, Result}; use dragonfly_api::common::v2::Range; use reqwest::header::{HeaderMap, HeaderValue}; use std::collections::HashMap; +use tracing::error; // headermap_to_hashmap converts a headermap to a hashmap. pub fn headermap_to_hashmap(header: &HeaderMap) -> HashMap { @@ -39,6 +40,44 @@ pub fn hashmap_to_headermap(header: &HashMap) -> Result reqwest::header::HeaderMap { + let mut reqwest_header = reqwest::header::HeaderMap::new(); + for (hyper_header_key, hyper_header_value) in hyper_header.iter() { + let reqwest_header_name: reqwest::header::HeaderName = + match hyper_header_key.to_string().parse() { + Ok(reqwest_header_name) => reqwest_header_name, + Err(err) => { + error!("parse header name error: {}", err); + continue; + } + }; + + let reqwest_header_value: reqwest::header::HeaderValue = match hyper_header_value.to_str() { + Ok(reqwest_header_value) => match reqwest_header_value.parse() { + Ok(reqwest_header_value) => reqwest_header_value, + Err(err) => { + error!("parse header value error: {}", err); + continue; + } + }, + Err(err) => { + error!("parse header value error: {}", err); + continue; + } + }; + + reqwest_header.insert(reqwest_header_name, reqwest_header_value); + } + + reqwest_header +} + // header_vec_to_hashmap converts a vector of header string to a hashmap. pub fn header_vec_to_hashmap(raw_header: Vec) -> Result> { let mut header = HashMap::new(); @@ -52,6 +91,17 @@ pub fn header_vec_to_hashmap(raw_header: Vec) -> Result Result> { + match header.get(reqwest::header::RANGE) { + Some(range) => { + let range = range.to_str()?; + Ok(Some(parse_range_header(range, content_length)?)) + } + None => Ok(None), + } +} + // parse_range_header parses a Range header string as per RFC 7233, // supported Range Header: "Range": "bytes=100-200", "Range": "bytes=-50", // "Range": "bytes=150-", "Range": "bytes=0-0,-1".