Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
boxdot committed Jul 15, 2024
1 parent f4a8e18 commit 3e17255
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 43 deletions.
33 changes: 22 additions & 11 deletions presage-store-sled/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,13 +512,29 @@ mod tests {
ServiceAddress,
};
use presage::store::ContentsStore;
use protocol::SledPreKeyId;
use quickcheck::{Arbitrary, Gen};
use quickcheck_macros::quickcheck;

use crate::SledPreKeyId;
use crate::SchemaVersion;

use super::SledStore;
use super::*;

#[test]
fn test_migration_steps() {
let steps: Vec<_> = SchemaVersion::steps(SchemaVersion::V0).collect();
assert_eq!(
steps,
[
SchemaVersion::V1,
SchemaVersion::V2,
SchemaVersion::V3,
SchemaVersion::V4,
SchemaVersion::V5,
SchemaVersion::V6,
]
)
}
#[derive(Debug, Clone)]
struct Thread(presage::store::Thread);

Expand Down Expand Up @@ -577,16 +593,11 @@ mod tests {
}

#[quickcheck]
fn compare_pre_keys(pre_key_id: u32, next_pre_key_id: u32) {
if pre_key_id < next_pre_key_id {
assert!(
PreKeyId::from(pre_key_id).sled_key() < PreKeyId::from(next_pre_key_id).sled_key()
)
} else {
assert!(
PreKeyId::from(pre_key_id).sled_key() > PreKeyId::from(next_pre_key_id).sled_key()
)
fn compare_pre_keys(mut pre_key_id: u32, mut next_pre_key_id: u32) {
if pre_key_id > next_pre_key_id {
std::mem::swap(&mut pre_key_id, &mut next_pre_key_id);
}
assert!(PreKeyId::from(pre_key_id).sled_key() <= PreKeyId::from(next_pre_key_id).sled_key())
}

#[quickcheck_async::tokio]
Expand Down
125 changes: 93 additions & 32 deletions presage-store-sled/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,6 @@ impl<T: SledTrees> SledProtocolStore<T> {
.and_then(|data| Some(u32::from_be_bytes(data.as_ref().try_into().ok()?)))
.map_or(0, |id| id + 1))
}

/// Whether to force a pre key refresh.
///
/// Check whether we have:
/// - 1 signed EC pre key
/// - 1 Kyber last resort key
async fn needs_pre_key_refresh(&self) -> Result<bool, SignalProtocolError> {
let has_signed_pre_keys = self.signed_pre_keys_count().await? > 0;
let has_last_resort_kyber_pre_keys = self.kyber_pre_keys_count(true).await? > 0;
Ok(has_signed_pre_keys && has_last_resort_kyber_pre_keys)
}
}

pub trait SledTrees: Clone {
Expand Down Expand Up @@ -174,27 +163,16 @@ impl SledTrees for PniSledStore {
}
}

trait SledPreKeyId {
fn sled_key(self) -> [u8; 4];
}

impl SledPreKeyId for PreKeyId {
pub(crate) trait SledPreKeyId: Into<u32> {
fn sled_key(self) -> [u8; 4] {
u32::from(self).to_be_bytes()
let idx: u32 = self.into();
idx.to_be_bytes()
}
}

impl SledPreKeyId for SignedPreKeyId {
fn sled_key(self) -> [u8; 4] {
u32::from(self).to_be_bytes()
}
}

impl SledPreKeyId for KyberPreKeyId {
fn sled_key(self) -> [u8; 4] {
u32::from(self).to_be_bytes()
}
}
impl SledPreKeyId for PreKeyId {}
impl SledPreKeyId for SignedPreKeyId {}
impl SledPreKeyId for KyberPreKeyId {}

impl<T: SledTrees> SledProtocolStore<T> {
pub(crate) fn clear(&self, clear_sessions: bool) -> Result<(), SledStoreError> {
Expand Down Expand Up @@ -679,13 +657,17 @@ mod tests {

use base64::prelude::*;
use presage::{
libsignal_service::protocol::{
self, Direction, GenericSignedPreKey, IdentityKeyStore, PreKeyRecord, PreKeyStore,
SessionRecord, SessionStore, SignedPreKeyRecord, SignedPreKeyStore, Timestamp,
libsignal_service::{
pre_keys::PreKeysStore,
protocol::{
self, Direction, GenericSignedPreKey, IdentityKeyStore, PreKeyId, PreKeyRecord,
PreKeyStore, SessionRecord, SessionStore, SignedPreKeyId, SignedPreKeyRecord,
SignedPreKeyStore, Timestamp,
},
},
store::Store,
};
use quickcheck::{Arbitrary, Gen};
use quickcheck::{Arbitrary, Gen, TestResult};

use super::SledStore;

Expand Down Expand Up @@ -789,4 +771,83 @@ mod tests {
.unwrap()
== signed_pre_key_record.serialize().unwrap()
}

#[derive(Debug, Clone)]
struct ArbPreKeyRecord(protocol::PreKeyRecord);

impl Arbitrary for ArbPreKeyRecord {
fn arbitrary(g: &mut Gen) -> Self {
let id = u32::arbitrary(g);
let key_pair = KeyPair::arbitrary(g);
Self(protocol::PreKeyRecord::new(id.into(), &key_pair.0))
}
}

#[derive(Debug, Clone)]
struct ArbSignedPreKeyRecord(protocol::SignedPreKeyRecord);

impl Arbitrary for ArbSignedPreKeyRecord {
fn arbitrary(g: &mut Gen) -> Self {
let id = u32::arbitrary(g);
let timestamp = Arbitrary::arbitrary(g);
let key_pair = KeyPair::arbitrary(g);
let signature: Vec<u8> = Arbitrary::arbitrary(g);
Self(protocol::SignedPreKeyRecord::new(
id.into(),
protocol::Timestamp::from_epoch_millis(timestamp),
&key_pair.0,
&signature,
))
}
}

#[quickcheck_async::tokio]
async fn get_next_pre_key_ids(
key1: ArbPreKeyRecord,
key2: ArbPreKeyRecord,
signed_key: ArbSignedPreKeyRecord,
) {
let db = SledStore::temporary().unwrap();
let mut store = db.aci_protocol_store();

assert_eq!(store.next_pre_key_id().await.unwrap(), 0);
assert_eq!(store.next_pq_pre_key_id().await.unwrap(), 0);
assert_eq!(store.next_signed_pre_key_id().await.unwrap(), 0);

store
.save_pre_key(PreKeyId::from(0), &key1.0)
.await
.unwrap();
store
.save_pre_key(PreKeyId::from(1), &key2.0)
.await
.unwrap();
store
.save_signed_pre_key(SignedPreKeyId::from(0), &signed_key.0)
.await
.unwrap();

assert_eq!(store.next_pre_key_id().await.unwrap(), 2);
assert_eq!(store.next_pq_pre_key_id().await.unwrap(), 0);
assert_eq!(store.next_signed_pre_key_id().await.unwrap(), 1);
}

#[quickcheck_async::tokio]
async fn test_next_key_id_is_max(keys: Vec<u32>, record: ArbPreKeyRecord) -> TestResult {
let db = SledStore::temporary().unwrap();
let mut store = db.aci_protocol_store();

for &key in &keys {
store.save_pre_key(key.into(), &record.0).await.unwrap();
if key == u32::MAX {
return TestResult::discard();
}
}
if keys.iter().copied().max().map(|id| id + 1).unwrap_or(0)
!= store.next_pre_key_id().await.unwrap()
{
return TestResult::failed();
}
TestResult::passed()
}
}

0 comments on commit 3e17255

Please sign in to comment.