Skip to content

Commit

Permalink
go back to HashMap<String,String> session data
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr committed Jul 24, 2020
1 parent 02f5456 commit a4bb48d
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 93 deletions.
30 changes: 16 additions & 14 deletions examples/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,30 @@ async fn main() -> Result<(), std::io::Error> {
tide::log::start();
let mut app = tide::new();

#[derive(Clone, Debug, Default, PartialEq)]
struct SessionData {
visits: usize,
}

app.middleware(tide::sessions::SessionMiddleware::new(
tide::sessions::MemoryStore::<SessionData>::new(),
std::env::var("TIDE_SECRET").unwrap().as_bytes(),
tide::sessions::MemoryStore::new(),
b"use std::env::var(\"TIDE_SECRET\").unwrap().as_bytes() instead of a fixed value",
));

let visit_counter = tide::utils::Before(|mut request: tide::Request<()>| async move {
request.session_data_mut::<SessionData>().visits += 1;
request
});

app.middleware(visit_counter);
app.middleware(tide::utils::Before(
|mut request: tide::Request<()>| async move {
let visits: usize = request.session_get("visits").unwrap_or_default();
request.session_insert("visits", visits + 1).unwrap();
request
},
));

app.at("/").get(|req: tide::Request<()>| async move {
let SessionData { visits } = *req.session_data();
let visits: usize = req.session_get("visits").unwrap();
Ok(format!("you have visited this website {} times", visits))
});

app.at("/reset")
.get(|mut req: tide::Request<()>| async move {
req.session_mut().destroy();
Ok(tide::Redirect::new("/"))
});

app.listen("127.0.0.1:8080").await?;

Ok(())
Expand Down
37 changes: 13 additions & 24 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ impl<State> Request<State> {
self.req.ext().get()
}

#[must_use]
pub fn ext_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
self.req.ext_mut().get_mut()
}

/// Set a request extension value.
pub fn set_ext<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
self.req.ext_mut().insert(val)
Expand Down Expand Up @@ -507,43 +512,27 @@ impl<State> Request<State> {
}

#[cfg(feature = "sessions")]
pub fn session<SessionData>(&self) -> &crate::sessions::Session<SessionData>
where
SessionData: Clone + Default + Send + Sync + std::fmt::Debug + PartialEq + 'static,
{
self.ext::<crate::sessions::Session<SessionData>>().expect(
pub fn session(&self) -> &crate::sessions::Session {
self.ext::<crate::sessions::Session>().expect(
"request session not initialized, did you enable tide::sessions::SessionMiddleware?",
)
}

#[cfg(feature = "sessions")]
pub fn session_mut<SessionData>(&mut self) -> &mut crate::sessions::Session<SessionData>
where
SessionData: Clone + Default + Send + Sync + std::fmt::Debug + PartialEq + 'static,
{
self.req
.ext_mut().get_mut::<crate::sessions::Session<SessionData>>()
.expect(
pub fn session_mut(&mut self) -> &mut crate::sessions::Session {
self.ext_mut().expect(
"request session not initialized, did you enable tide::sessions::SessionMiddleware?",
)
}

#[cfg(feature = "sessions")]
pub fn session_data<SessionData>(&self) -> crate::sessions::SessionReader<'_, SessionData>
where
SessionData: Clone + Default + Send + Sync + std::fmt::Debug + PartialEq + 'static,
{
self.session::<SessionData>().data()
pub fn session_get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
self.session().get(key)
}

#[cfg(feature = "sessions")]
pub fn session_data_mut<SessionData>(
&mut self,
) -> crate::sessions::SessionWriter<'_, SessionData>
where
SessionData: Clone + Default + Send + Sync + std::fmt::Debug + PartialEq + 'static,
{
self.session_mut::<SessionData>().data_mut()
pub fn session_insert(&mut self, key: &str, value: impl serde::Serialize) -> crate::Result<()> {
Ok(self.session_mut().insert(key, value)?)
}

/// Get the length of the body stream, if it has been set.
Expand Down
39 changes: 16 additions & 23 deletions src/sessions/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,43 @@ const BASE64_DIGEST_LEN: usize = 44;
/// ## example:
/// ```rust
/// # async_std::task::block_on(async {
/// #[derive(Clone, Debug, Default, PartialEq)]
/// struct SessionData {
/// visits: usize,
/// }
/// let mut app = tide::new();
///
/// app.middleware(tide::sessions::SessionMiddleware::new(
/// tide::sessions::MemoryStore::<SessionData>::new(),
/// tide::sessions::MemoryStore::new(),
/// b"use std::env::var(\"TIDE_SECRET\").unwrap().as_bytes() instead of a fixed value"
/// ));
///
/// app.middleware(tide::utils::Before(|mut request: tide::Request<()>| async move {
/// request.session_data_mut::<SessionData>().visits += 1;
/// let visits: usize = request.session_get("visits").unwrap_or_default();
/// request.session_insert("visits", visits + 1).unwrap();
/// request
/// }));
///
/// app.at("/").get(|req: tide::Request<()>| async move {
/// let SessionData { visits } = *req.session_data();
/// let visits: usize = req.session_get("visits").unwrap();
/// Ok(format!("you have visited this website {} times", visits))
/// });
///
/// app.at("/reset")
/// .get(|mut req: tide::Request<()>| async move {
/// req.session_mut().destroy();
/// Ok(tide::Redirect::new("/"))
/// });
/// # })
/// ```
pub struct SessionMiddleware<SessionData, Store> {
pub struct SessionMiddleware<Store> {
store: Store,
cookie_path: String,
cookie_name: String,
session_ttl: Option<Duration>,
save_unchanged: bool,
same_site_policy: SameSite,
key: Key,
session_data: std::marker::PhantomData<SessionData>,
}

impl<SessionData, Store: SessionStore<SessionData>> std::fmt::Debug
for SessionMiddleware<SessionData, Store>
where
SessionData: Clone + Default + Send + Sync + std::fmt::Debug + PartialEq + 'static,
{
impl<Store: SessionStore> std::fmt::Debug for SessionMiddleware<Store> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionMiddleware")
.field("store", &self.store)
Expand All @@ -68,10 +66,9 @@ where
}

#[async_trait]
impl<SessionData, Store, State> Middleware<State> for SessionMiddleware<SessionData, Store>
impl<Store, State> Middleware<State> for SessionMiddleware<Store>
where
SessionData: Clone + Default + Send + Sync + std::fmt::Debug + PartialEq + 'static,
Store: SessionStore<SessionData>,
Store: SessionStore,
State: Clone + Send + Sync + 'static,
{
async fn handle(&self, mut request: Request<State>, next: Next<'_, State>) -> crate::Result {
Expand Down Expand Up @@ -111,10 +108,7 @@ where
}
}

impl<SessionData, Store: SessionStore<SessionData>> SessionMiddleware<SessionData, Store>
where
SessionData: Clone + Default + Send + Sync + std::fmt::Debug + PartialEq + 'static,
{
impl<Store: SessionStore> SessionMiddleware<Store> {
/// Creates a new SessionMiddleware with a mandatory cookie
/// signing secret. The `secret` MUST be at least 32 bytes long,
/// and should be cryptographically random. It is recommended to
Expand All @@ -137,7 +131,6 @@ where
same_site_policy: SameSite::Strict,
session_ttl: Some(Duration::from_secs(24 * 60 * 60)),
key: Key::derive_from(secret),
session_data: std::marker::PhantomData,
}
}

Expand Down Expand Up @@ -191,7 +184,7 @@ where

//--- methods below here are private ---

async fn load_or_create(&self, cookie_value: Option<String>) -> Session<SessionData> {
async fn load_or_create(&self, cookie_value: Option<String>) -> Session {
let session = match cookie_value {
Some(cookie_value) => self.store.load_session(cookie_value).await,
None => None,
Expand Down
4 changes: 1 addition & 3 deletions src/sessions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,4 @@ pub use middleware::SessionMiddleware;

mod middleware;

pub use async_session::{
CookieStore, MemoryStore, Session, SessionReader, SessionStore, SessionWriter,
};
pub use async_session::{CookieStore, MemoryStore, Session, SessionStore};
52 changes: 23 additions & 29 deletions tests/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use tide::{
Method::{Get, Post},
Request, Response, Url,
},
sessions::{MemoryStore, SessionMiddleware, SessionWriter},
sessions::{MemoryStore, SessionMiddleware},
utils::Before,
};
#[derive(Clone, Debug, Default, PartialEq)]
Expand All @@ -22,17 +22,18 @@ struct SessionData {
async fn test_basic_sessions() {
let mut app = tide::new();
app.middleware(SessionMiddleware::new(
MemoryStore::<SessionData>::new(),
MemoryStore::new(),
b"12345678901234567890123456789012345",
));

app.middleware(Before(|mut request: tide::Request<()>| async move {
request.session_data_mut::<SessionData>().visits += 1;
let visits: usize = request.session_get("visits").unwrap_or_default();
request.session_insert("visits", visits + 1).unwrap();
request
}));

app.at("/").get(|req: tide::Request<()>| async move {
let SessionData { visits } = *req.session_data();
let visits: usize = req.session_get("visits").unwrap();
Ok(format!("you have visited this website {} times", visits))
});

Expand Down Expand Up @@ -62,29 +63,27 @@ async fn test_basic_sessions() {
async fn test_customized_sessions() {
let mut app = tide::new();
app.middleware(
SessionMiddleware::new(
MemoryStore::<SessionData>::new(),
b"12345678901234567890123456789012345",
)
.with_cookie_name("custom.cookie.name")
.with_cookie_path("/nested")
.with_same_site_policy(SameSite::Lax)
.with_session_ttl(Some(Duration::from_secs(1)))
.without_save_unchanged(),
SessionMiddleware::new(MemoryStore::new(), b"12345678901234567890123456789012345")
.with_cookie_name("custom.cookie.name")
.with_cookie_path("/nested")
.with_same_site_policy(SameSite::Lax)
.with_session_ttl(Some(Duration::from_secs(1)))
.without_save_unchanged(),
);

app.at("/").get(|_| async { Ok("/") });
app.at("/nested").get(|req: tide::Request<()>| async move {
Ok(format!(
"/nested {}",
req.session_data::<SessionData>().visits
req.session_get::<usize>("visits").unwrap_or_default()
))
});
app.at("/nested/incr")
.get(|mut req: tide::Request<()>| async move {
let mut data: SessionWriter<SessionData> = req.session_data_mut();
data.visits += 1;
Ok(format!("/nested/incr {}", data.visits))
let mut visits: usize = req.session_get("visits").unwrap_or_default();
visits += 1;
req.session_insert("visits", visits).unwrap();
Ok(format!("/nested/incr {}", visits))
});

let response = app.get("/").await;
Expand Down Expand Up @@ -131,23 +130,24 @@ async fn test_customized_sessions() {
async fn test_session_destruction() {
let mut app = tide::new();
app.middleware(SessionMiddleware::new(
MemoryStore::<SessionData>::new(),
MemoryStore::new(),
b"12345678901234567890123456789012345",
));

app.middleware(Before(|mut request: tide::Request<()>| async move {
request.session_data_mut::<SessionData>().visits += 1;
let visits: usize = request.session_get("visits").unwrap_or_default();
request.session_insert("visits", visits + 1).unwrap();
request
}));

app.at("/").get(|req: tide::Request<()>| async move {
let SessionData { visits } = *req.session_data();
let visits: usize = req.session_get("visits").unwrap();
Ok(format!("you have visited this website {} times", visits))
});

app.at("/logout")
.post(|mut req: tide::Request<()>| async move {
req.session_mut::<SessionData>().destroy();
req.session_mut().destroy();
Ok(Response::new(200))
});

Expand All @@ -156,16 +156,10 @@ async fn test_session_destruction() {

let mut second_request = Request::new(Post, Url::parse("https://whatever/logout").unwrap());
second_request.insert_header("Cookie", &cookies);
let mut second_response: Response = app.respond(second_request).await.unwrap();
let body = second_response.body_string().await.unwrap();
let second_response: Response = app.respond(second_request).await.unwrap();
let cookies = Cookies::from_response(&second_response);
assert_eq!(cookies["tide.sid"].value(), "");
assert_eq!(cookies.len(), 1);
dbg!(&cookies);

// let response = app.get("https://secure/").await;
// let cookies = Cookies::from_response(&response);
// let cookie = &cookies["tide.sid"];
// assert_eq!(cookie.secure(), Some(true));
}

#[derive(Debug, Clone)]
Expand Down

0 comments on commit a4bb48d

Please sign in to comment.