diff --git a/tower-sessions-core/src/session.rs b/tower-sessions-core/src/session.rs index ad0cd93..986e034 100644 --- a/tower-sessions-core/src/session.rs +++ b/tower-sessions-core/src/session.rs @@ -829,7 +829,8 @@ impl Session { let old_session_id = record_guard.id; record_guard.id = Id::default(); - *self.session_id.lock() = Some(record_guard.id); + *self.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's + // `create` method. self.store .delete(&old_session_id) @@ -950,3 +951,84 @@ pub enum Expiry { /// [`set_expiry`](Session::set_expiry). AtDateTime(OffsetDateTime), } + +#[cfg(test)] +mod tests { + use async_trait::async_trait; + use mockall::{ + mock, + predicate::{self, always}, + }; + + use super::*; + + mock! { + #[derive(Debug)] + pub Store {} + + #[async_trait] + impl SessionStore for Store { + async fn create(&self, record: &mut Record) -> session_store::Result<()>; + async fn save(&self, record: &Record) -> session_store::Result<()>; + async fn load(&self, session_id: &Id) -> session_store::Result>; + async fn delete(&self, session_id: &Id) -> session_store::Result<()>; + } + } + + #[tokio::test] + async fn test_cycle_id() { + let mut mock_store = MockStore::new(); + + let initial_id = Id::default(); + let new_id = Id::default(); + + // Set up expectations for the mock store + mock_store + .expect_save() + .with(always()) + .times(1) + .returning(|_| Ok(())); + mock_store + .expect_load() + .with(predicate::eq(initial_id)) + .times(1) + .returning(move |_| { + Ok(Some(Record { + id: initial_id, + data: Data::default(), + expiry_date: OffsetDateTime::now_utc(), + })) + }); + mock_store + .expect_delete() + .with(predicate::eq(initial_id)) + .times(1) + .returning(|_| Ok(())); + mock_store + .expect_create() + .times(1) + .returning(move |record| { + record.id = new_id; + Ok(()) + }); + + let store = Arc::new(mock_store); + let session = Session::new(Some(initial_id), store.clone(), None); + + // Insert some data and save the session + session.insert("foo", 42).await.unwrap(); + session.save().await.unwrap(); + + // Cycle the session ID + session.cycle_id().await.unwrap(); + + // Verify that the session ID has changed and the data is still present + assert_ne!(session.id(), Some(initial_id)); + assert!(session.id().is_none()); // The session ID should be None + assert_eq!(session.get::("foo").await.unwrap(), Some(42)); + + // Save the session to update the ID in the session object + session.save().await.unwrap(); + assert_eq!(session.id(), Some(new_id)); + } +}