Skip to content

Commit

Permalink
Set up request headers
Browse files Browse the repository at this point in the history
  • Loading branch information
dvolodin7 committed Feb 27, 2024
1 parent 99af319 commit f948c55
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 17 deletions.
49 changes: 41 additions & 8 deletions src/async_client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ use pyo3::{
prelude::*,
};
use pyo3_asyncio::tokio::future_into_py;
use reqwest::{redirect::Policy, RequestBuilder};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
redirect::Policy,
RequestBuilder,
};
use std::collections::HashMap;

#[pyclass(module = "gufo.http.async_client")]
pub struct AsyncClient {
Expand All @@ -20,19 +25,47 @@ pub struct AsyncClient {
#[pymethods]
impl AsyncClient {
#[new]
fn new(max_redirect: Option<usize>) -> PyResult<Self> {
fn new(max_redirect: Option<usize>, headers: Option<HashMap<&str, &[u8]>>) -> PyResult<Self> {
let builder = reqwest::Client::builder();
let mut builder = builder.redirect(match max_redirect {
Some(x) => Policy::limited(x),
None => Policy::none(),
});
// Set headers
if let Some(h) = headers {
let mut map = HeaderMap::with_capacity(h.len());
for (k, v) in h {
map.insert(
HeaderName::from_bytes(k.as_ref())
.map_err(|e| PyValueError::new_err(e.to_string()))?,
HeaderValue::from_bytes(v).map_err(|e| PyValueError::new_err(e.to_string()))?,
);
}
builder = builder.default_headers(map);
}
//
let client = builder
.redirect(match max_redirect {
Some(x) => Policy::limited(x),
None => Policy::none(),
})
.build()
.map_err(|x| PyValueError::new_err(x.to_string()))?;
Ok(AsyncClient { client })
}
fn get<'a>(&self, py: Python<'a>, url: String) -> PyResult<&'a PyAny> {
self.process_request(py, self.client.get(url))
fn get<'a>(
&self,
py: Python<'a>,
url: String,
headers: Option<HashMap<&str, &[u8]>>,
) -> PyResult<&'a PyAny> {
let mut builder = self.client.get(url);
if let Some(h) = headers {
for (k, v) in h {
builder = builder.header(
HeaderName::from_bytes(k.as_ref())
.map_err(|e| PyValueError::new_err(e.to_string()))?,
HeaderValue::from_bytes(v).map_err(|e| PyValueError::new_err(e.to_string()))?,
)
}
}
self.process_request(py, builder)
}
}

Expand Down
10 changes: 7 additions & 3 deletions src/gufo/http/_fast.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# ---------------------------------------------------------------------

# Python modules
from typing import Iterable, Optional, Tuple
from typing import Dict, Iterable, Optional, Tuple

class HttpError(Exception): ...

Expand All @@ -29,6 +29,10 @@ class AsyncResponse(object):

class AsyncClient(object):
def __init__(
self: "AsyncClient", max_redirects: Optional[int]
self: "AsyncClient",
max_redirects: Optional[int],
headers: Optional[Dict[str, bytes]],
) -> None: ...
async def get(self: "AsyncClient", url: str) -> AsyncResponse: ...
async def get(
self: "AsyncClient", url: str, headers: Optional[Dict[str, bytes]]
) -> AsyncResponse: ...
33 changes: 27 additions & 6 deletions src/gufo/http/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

# Python modules
from types import TracebackType
from typing import AsyncIterator, Dict, Iterable, Optional, Tuple, Type, Union

from typing import Dict, Optional, Type

# Gufo HTTP modules
from ._fast import AsyncClient, AsyncResponse
from .util import merge_dict

MAX_REDIRECTS = 10

Expand All @@ -22,15 +22,25 @@ class HttpClient(object):
"""
Asynchronous HTTP client.
Attributes:
headers: Headers to be added to every request.
Used in subclasses.
Args:
max_redirects: Maximal amount of redirects. Use `None`
to disable redirect processing.
"""

headers: Optional[Dict[str, bytes]] = None

def __init__(
self: "HttpClient", max_redirects: Optional[int] = MAX_REDIRECTS
self: "HttpClient",
max_redirects: Optional[int] = MAX_REDIRECTS,
headers: Optional[Dict[str, bytes]] = None,
) -> None:
self._client = AsyncClient(max_redirects)
self._client = AsyncClient(
max_redirects, merge_dict(self.headers, headers)
)

async def __aenter__(self: "HttpClient") -> "HttpClient":
"""Asynchronous context manager entry."""
Expand All @@ -44,11 +54,22 @@ async def __aexit__(
) -> None:
"""Asynchronous context manager exit."""

async def get(self: "HttpClient", url: str) -> AsyncResponse:
async def get(
self: "HttpClient",
url: str,
headers: Optional[Dict[str, bytes]] = None,
) -> AsyncResponse:
"""
Send HTTP GET request and receive a response.
Args:
url: Request url
headers: Optional request headers
Returns:
AsyncResponse instance.
"""
return await self._client.get(url)
return await self._client.get(url, headers)


__all__ = ["HttpClient", "AsyncResponse"]

0 comments on commit f948c55

Please sign in to comment.