Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for defining RedirectPolicy for a Client #29

Merged
merged 1 commit into from
Dec 10, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 32 additions & 19 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt;
use std::io::{self, Read};
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use hyper::client::IntoUrl;
use hyper::header::{Headers, ContentType, Location, Referer, UserAgent};
Expand All @@ -14,6 +14,7 @@ use serde_json;
use serde_urlencoded;

use ::body::{self, Body};
use ::redirect::{RedirectPolicy, check_redirect};

static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));

Expand All @@ -24,8 +25,9 @@ static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", e
///
/// The `Client` holds a connection pool internally, so it is advised that
/// you create one and reuse it.
#[derive(Clone)]
pub struct Client {
inner: ClientRef, //::hyper::Client,
inner: Arc<ClientRef>, //::hyper::Client,
}

impl Client {
Expand All @@ -34,12 +36,18 @@ impl Client {
let mut client = try!(new_hyper_client());
client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone);
Ok(Client {
inner: ClientRef {
hyper: Arc::new(client),
}
inner: Arc::new(ClientRef {
hyper: client,
redirect_policy: Mutex::new(RedirectPolicy::default()),
}),
})
}

/// Set a `RedirectPolicy` for this client.
pub fn redirect(&mut self, policy: RedirectPolicy) {
*self.inner.redirect_policy.lock().unwrap() = policy;
}

/// Convenience method to make a `GET` request to a URL.
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
self.request(Method::Get, url)
Expand Down Expand Up @@ -75,13 +83,15 @@ impl Client {

impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad("Client")
f.debug_struct("Client")
.field("redirect_policy", &self.inner.redirect_policy)
.finish()
}
}

#[derive(Clone)]
struct ClientRef {
hyper: Arc<::hyper::Client>,
hyper: ::hyper::Client,
redirect_policy: Mutex<RedirectPolicy>,
}

fn new_hyper_client() -> ::Result<::hyper::Client> {
Expand All @@ -97,7 +107,7 @@ fn new_hyper_client() -> ::Result<::hyper::Client> {

/// A builder to construct the properties of a `Request`.
pub struct RequestBuilder {
client: ClientRef,
client: Arc<ClientRef>,

method: Method,
url: Result<Url, ::UrlError>,
Expand Down Expand Up @@ -196,7 +206,7 @@ impl RequestBuilder {
None => None,
};

let mut redirect_count = 0;
let mut urls = Vec::new();

loop {
let res = {
Expand Down Expand Up @@ -237,14 +247,6 @@ impl RequestBuilder {
};

if should_redirect {
//TODO: turn this into self.redirect_policy.check()
if redirect_count > 10 {
return Err(::Error::TooManyRedirects);
}
redirect_count += 1;

headers.set(Referer(url.to_string()));

let loc = {
let loc = res.headers.get::<Location>().map(|loc| url.join(loc));
if let Some(loc) = loc {
Expand All @@ -257,7 +259,18 @@ impl RequestBuilder {
};

url = match loc {
Ok(u) => u,
Ok(loc) => {
headers.set(Referer(url.to_string()));
urls.push(url);
if check_redirect(&client.redirect_policy.lock().unwrap(), &loc, &urls)? {
loc
} else {
debug!("redirect_policy disallowed redirection to '{}'", loc);
return Ok(Response {
inner: res
})
}
},
Err(e) => {
debug!("Location header had invalid URI: {:?}", e);
return Ok(Response {
Expand Down
9 changes: 6 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub enum Error {
Serialize(Box<StdError + Send + Sync>),
/// A request tried to redirect too many times.
TooManyRedirects,
/// An infinite redirect loop was detected.
RedirectLoop,
#[doc(hidden)]
__DontMatchMe,
}
Expand All @@ -22,9 +24,8 @@ impl fmt::Display for Error {
match *self {
Error::Http(ref e) => fmt::Display::fmt(e, f),
Error::Serialize(ref e) => fmt::Display::fmt(e, f),
Error::TooManyRedirects => {
f.pad("Too many redirects")
},
Error::TooManyRedirects => f.pad("Too many redirects"),
Error::RedirectLoop => f.pad("Infinite redirect loop"),
Error::__DontMatchMe => unreachable!()
}
}
Expand All @@ -36,6 +37,7 @@ impl StdError for Error {
Error::Http(ref e) => e.description(),
Error::Serialize(ref e) => e.description(),
Error::TooManyRedirects => "Too many redirects",
Error::RedirectLoop => "Infinite redirect loop",
Error::__DontMatchMe => unreachable!()
}
}
Expand All @@ -45,6 +47,7 @@ impl StdError for Error {
Error::Http(ref e) => Some(e),
Error::Serialize(ref e) => Some(&**e),
Error::TooManyRedirects => None,
Error::RedirectLoop => None,
Error::__DontMatchMe => unreachable!()
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@ pub use url::ParseError as UrlError;
pub use self::client::{Client, Response, RequestBuilder};
pub use self::error::{Error, Result};
pub use self::body::Body;
pub use self::redirect::RedirectPolicy;

mod body;
mod client;
mod error;
mod redirect;
mod tls;


Expand Down
159 changes: 157 additions & 2 deletions src/redirect.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,163 @@
use std::fmt;

use ::Url;

/// A type that controls the policy on how to handle the following of redirects.
///
/// The default value will catch redirect loops, and has a maximum of 10
/// redirects it will follow in a chain before returning an error.
#[derive(Debug)]
pub struct RedirectPolicy {
inner: ()
inner: Policy,
}

impl RedirectPolicy {

/// Create a RedirectPolicy with a maximum number of redirects.
///
/// A `Error::TooManyRedirects` will be returned if the max is reached.
pub fn limited(max: usize) -> RedirectPolicy {
RedirectPolicy {
inner: Policy::Limit(max),
}
}

/// Create a RedirectPolicy that does not follow any redirect.
pub fn none() -> RedirectPolicy {
RedirectPolicy {
inner: Policy::None,
}
}

/// Create a custom RedirectPolicy using the passed function.
///
/// # Note
///
/// The default RedirectPolicy handles redirect loops and a maximum loop
/// chain, but the custom variant does not do that for you automatically.
/// The custom policy should hanve some way of handling those.
///
/// There are variants on `::Error` for both cases that can be used as
/// return values.
///
/// # Example
///
/// ```no_run
/// # use reqwest::RedirectPolicy;
/// # let mut client = reqwest::Client::new().unwrap();
/// client.redirect(RedirectPolicy::custom(|next, previous| {
/// if previous.len() > 5 {
/// Err(reqwest::Error::TooManyRedirects)
/// } else if next.host_str() == Some("example.domain") {
/// // prevent redirects to 'example.domain'
/// Ok(false)
/// } else {
/// Ok(true)
/// }
/// }));
/// ```
pub fn custom<T>(policy: T) -> RedirectPolicy
where T: Fn(&Url, &[Url]) -> ::Result<bool> + Send + Sync + 'static {
RedirectPolicy {
inner: Policy::Custom(Box::new(policy)),
}
}

fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool> {
match self.inner {
Policy::Custom(ref custom) => custom(next, previous),
Policy::Limit(max) => {
if previous.len() == max {
Err(::Error::TooManyRedirects)
} else if previous.contains(next) {
Err(::Error::RedirectLoop)
} else {
Ok(true)
}
},
Policy::None => Ok(false),
}
}
}

impl Default for RedirectPolicy {
fn default() -> RedirectPolicy {
RedirectPolicy::limited(10)
}
}

enum Policy {
Custom(Box<Fn(&Url, &[Url]) -> ::Result<bool> + Send + Sync + 'static>),
Limit(usize),
None,
}

impl fmt::Debug for Policy {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Policy::Custom(..) => f.pad("Custom"),
Policy::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
Policy::None => f.pad("None"),
}
}
}

pub fn check_redirect(policy: &RedirectPolicy, next: &Url, previous: &[Url]) -> ::Result<bool> {
policy.redirect(next, previous)
}

/*
This was the desired way of doing it, but ran in to inference issues when
using closures, since the arguments received are references (&Url and &[Url]),
and the compiler could not infer the lifetimes of those references. That means
people would need to annotate the closure's argument types, which is garbase.

pub trait Redirect {
fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool>;
}

impl<F> Redirect for F
where F: Fn(&Url, &[Url]) -> ::Result<bool> {
fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool> {
self(next, previous)
}
}
*/

#[test]
fn test_redirect_policy_limit() {
let policy = RedirectPolicy::default();
let next = Url::parse("http://x.y/z").unwrap();
let mut previous = (0..9)
.map(|i| Url::parse(&format!("http://a.b/c/{}", i)).unwrap())
.collect::<Vec<_>>();


match policy.redirect(&next, &previous) {
Ok(true) => {},
other => panic!("expected Ok(true), got: {:?}", other)
}

previous.push(Url::parse("http://a.b.d/e/33").unwrap());

match policy.redirect(&next, &previous) {
Err(::Error::TooManyRedirects) => {},
other => panic!("expected TooManyRedirects, got: {:?}", other)
}
}

#[test]
fn test_redirect_policy_custom() {
let policy = RedirectPolicy::custom(|next, _previous| {
if next.host_str() == Some("foo") {
Ok(false)
} else {
Ok(true)
}
});

let next = Url::parse("http://bar/baz").unwrap();
assert_eq!(policy.redirect(&next, &[]).unwrap(), true);

let next = Url::parse("http://foo/baz").unwrap();
assert_eq!(policy.redirect(&next, &[]).unwrap(), false);
}
53 changes: 53 additions & 0 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,56 @@ fn test_redirect_307_does_not_try_if_reader_cannot_reset() {
assert_eq!(res.status(), &reqwest::StatusCode::from_u16(code));
}
}

#[test]
fn test_redirect_policy_can_return_errors() {
let server = server! {
request: b"\
GET /loop HTTP/1.1\r\n\
Host: $HOST\r\n\
User-Agent: $USERAGENT\r\n\
\r\n\
",
response: b"\
HTTP/1.1 302 Found\r\n\
Server: test\r\n\
Location: /loop
Content-Length: 0\r\n\
\r\n\
"
};

let err = reqwest::get(&format!("http://{}/loop", server.addr())).unwrap_err();
match err {
reqwest::Error::RedirectLoop => (),
e => panic!("wrong error received: {:?}", e),
}
}

#[test]
fn test_redirect_policy_can_stop_redirects_without_an_error() {
let server = server! {
request: b"\
GET /no-redirect HTTP/1.1\r\n\
Host: $HOST\r\n\
User-Agent: $USERAGENT\r\n\
\r\n\
",
response: b"\
HTTP/1.1 302 Found\r\n\
Server: test-dont\r\n\
Location: /dont
Content-Length: 0\r\n\
\r\n\
"
};
let mut client = reqwest::Client::new().unwrap();
client.redirect(reqwest::RedirectPolicy::none());

let res = client.get(&format!("http://{}/no-redirect", server.addr()))
.send()
.unwrap();

assert_eq!(res.status(), &reqwest::StatusCode::Found);
assert_eq!(res.headers().get(), Some(&reqwest::header::Server("test-dont".to_string())));
}