Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PNI #245

Merged
merged 12 commits into from
May 30, 2024
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ resolver = "2"
[patch.crates-io]
curve25519-dalek = { git = 'https://github.com/signalapp/curve25519-dalek', tag = 'signal-curve25519-4.1.1' }

# [patch."https://github.com/whisperfish/libsignal-service-rs.git"]
# libsignal-service = { path = "../libsignal-service-rs/libsignal-service" }
# libsignal-service-hyper = { path = "../libsignal-service-rs/libsignal-service-hyper" }
[patch."https://github.com/whisperfish/libsignal-service-rs.git"]
libsignal-service = { path = "../libsignal-service-rs/libsignal-service" }
libsignal-service-hyper = { path = "../libsignal-service-rs/libsignal-service-hyper" }
1 change: 1 addition & 0 deletions presage-store-sled/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ thiserror = "1.0"
prost = "> 0.10, <= 0.12"
sha2 = "0.10"
quickcheck_macros = "1.0.0"
chrono = "0.4.35"

[dev-dependencies]
anyhow = "1.0"
Expand Down
7 changes: 4 additions & 3 deletions presage-store-sled/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ pub enum SledStoreError {

impl StoreError for SledStoreError {}

impl SledStoreError {
pub(crate) fn into_signal_error(self) -> SignalProtocolError {
SignalProtocolError::InvalidState("presage error", self.to_string())
impl From<SledStoreError> for SignalProtocolError {
fn from(error: SledStoreError) -> Self {
log::error!("presage store error: {error}");
Self::InvalidState("presage store error", error.to_string())
}
}
115 changes: 88 additions & 27 deletions presage-store-sled/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use presage::libsignal_service::{
content::Content,
groups_v2::Group,
models::Contact,
pre_keys::PreKeysStore,
pre_keys::{PreKeysStore, ServiceKyberPreKeyStore},
prelude::{ProfileKey, Uuid},
protocol::{
Direction, GenericSignedPreKey, IdentityKey, IdentityKeyPair, IdentityKeyStore,
Expand Down Expand Up @@ -49,6 +49,7 @@ const SLED_TREE_SENDER_KEYS: &str = "sender_keys";
const SLED_TREE_SESSIONS: &str = "sessions";
const SLED_TREE_SIGNED_PRE_KEYS: &str = "signed_pre_keys";
const SLED_TREE_KYBER_PRE_KEYS: &str = "kyber_pre_keys";
const SLED_TREE_KYBER_PRE_KEYS_LAST_RESORT: &str = "kyber_pre_keys_last_resort";
const SLED_TREE_STATE: &str = "state";
const SLED_TREE_THREADS_PREFIX: &str = "threads";
const SLED_TREE_PROFILES: &str = "profiles";
Expand Down Expand Up @@ -717,6 +718,78 @@ impl PreKeysStore for SledStore {
}
}

#[async_trait(?Send)]
impl ServiceKyberPreKeyStore for SledStore {
async fn store_last_resort_kyber_pre_key(
&mut self,
kyber_prekey_id: KyberPreKeyId,
record: &KyberPreKeyRecord,
) -> Result<(), SignalProtocolError> {
self.insert(
SLED_TREE_KYBER_PRE_KEYS_LAST_RESORT,
kyber_prekey_id.to_string(),
record.serialize()?,
)
.map_err(|e| {
log::error!("sled error: {}", e);
SignalProtocolError::InvalidState(
"store_last_resort_kyber_pre_key",
"sled error".into(),
)
})?;
Ok(())
}

async fn load_last_resort_kyber_pre_keys(
&self,
) -> Result<Vec<KyberPreKeyRecord>, SignalProtocolError> {
self
.db
.read()
.expect("poisoned mutex")
.open_tree(SLED_TREE_KYBER_PRE_KEYS_LAST_RESORT).map_err(|e| {
log::error!("sled error: {}", e);
SignalProtocolError::InvalidState(
"load_last_resort_kyber_pre_keys",
"sled error".into(),
)
})?
.iter()
.values()
.filter_map(Result::ok)
.map(|data| {
KyberPreKeyRecord::deserialize(&data)
})
.collect()
}

async fn remove_kyber_pre_key(
&mut self,
kyber_prekey_id: KyberPreKeyId,
) -> Result<(), SignalProtocolError> {
self.remove(SLED_TREE_KYBER_PRE_KEYS_LAST_RESORT, kyber_prekey_id.to_string())?;
self.remove(SLED_TREE_KYBER_PRE_KEYS, kyber_prekey_id.to_string())?;
Ok(())
}

/// Analogous to markAllOneTimeKyberPreKeysStaleIfNecessary
async fn mark_all_one_time_kyber_pre_keys_stale_if_necessary(
&mut self,
_stale_time: chrono::DateTime<chrono::Utc>,
) -> Result<(), SignalProtocolError> {
unimplemented!("should not be used yet")
}

/// Analogue of deleteAllStaleOneTimeKyberPreKeys
async fn delete_all_stale_one_time_kyber_pre_keys(
&mut self,
_threshold: chrono::DateTime<chrono::Utc>,
_min_count: usize,
) -> Result<(), SignalProtocolError> {
unimplemented!("should not be used yet")
}
}

impl Store for SledStore {
type Error = SledStoreError;

Expand Down Expand Up @@ -972,8 +1045,7 @@ impl SessionStore for SledStore {
address: &ProtocolAddress,
) -> Result<Option<SessionRecord>, SignalProtocolError> {
let session = self
.get(SLED_TREE_SESSIONS, address.to_string())
.map_err(SledStoreError::into_signal_error)?;
.get(SLED_TREE_SESSIONS, address.to_string())?;
trace!("loading session {} / exists={}", address, session.is_some());
session
.map(|b: Vec<u8>| SessionRecord::deserialize(&b))
Expand All @@ -986,8 +1058,7 @@ impl SessionStore for SledStore {
record: &SessionRecord,
) -> Result<(), SignalProtocolError> {
trace!("storing session {}", address);
self.insert(SLED_TREE_SESSIONS, address.to_string(), record.serialize()?)
.map_err(SledStoreError::into_signal_error)?;
self.insert(SLED_TREE_SESSIONS, address.to_string(), record.serialize()?)?;
Ok(())
}
}
Expand All @@ -1003,8 +1074,7 @@ impl SessionStoreExt for SledStore {
let session_ids: Vec<u32> = self
.read()
.open_tree(SLED_TREE_SESSIONS)
.map_err(Into::into)
.map_err(SledStoreError::into_signal_error)?
.map_err(SledStoreError::from)?
.scan_prefix(&session_prefix)
.filter_map(|r| {
let (key, _) = r.ok()?;
Expand All @@ -1021,8 +1091,7 @@ impl SessionStoreExt for SledStore {
trace!("deleting session {}", address);
self.write()
.open_tree(SLED_TREE_SESSIONS)
.map_err(Into::into)
.map_err(SledStoreError::into_signal_error)?
.map_err(SledStoreError::from)?
.remove(address.to_string())
.map_err(|_e| SignalProtocolError::SessionNotFound(address.clone()))?;
Ok(())
Expand All @@ -1035,8 +1104,7 @@ impl SessionStoreExt for SledStore {
let db = self.write();
let sessions_tree = db
.open_tree(SLED_TREE_SESSIONS)
.map_err(Into::into)
.map_err(SledStoreError::into_signal_error)?;
.map_err(SledStoreError::from)?;

let mut batch = Batch::default();
sessions_tree
Expand All @@ -1048,8 +1116,7 @@ impl SessionStoreExt for SledStore {
.for_each(|k| batch.remove(k));

db.apply_batch(batch)
.map_err(SledStoreError::Db)
.map_err(SledStoreError::into_signal_error)?;
.map_err(SledStoreError::Db)?;

let len = sessions_tree.len();
sessions_tree.clear().map_err(|_e| {
Expand All @@ -1064,23 +1131,21 @@ impl IdentityKeyStore for SledStore {
async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair, SignalProtocolError> {
trace!("getting identity_key_pair");
let data = self
.load_registration_data()
.map_err(SledStoreError::into_signal_error)?
.load_registration_data()?
.ok_or(SignalProtocolError::InvalidState(
"failed to load identity key pair",
"no registration data".into(),
))?;

Ok(IdentityKeyPair::new(
IdentityKey::new(data.aci_public_key()),
data.aci_identity_key(),
data.aci_private_key(),
))
}

async fn get_local_registration_id(&self) -> Result<u32, SignalProtocolError> {
let data = self
.load_registration_data()
.map_err(SledStoreError::into_signal_error)?
.load_registration_data()?
.ok_or(SignalProtocolError::InvalidState(
"failed to load registration ID",
"no registration data".into(),
Expand All @@ -1102,7 +1167,7 @@ impl IdentityKeyStore for SledStore {
)
.map_err(|e| {
error!("error saving identity for {:?}: {}", address, e);
e.into_signal_error()
e
})?;

self.save_trusted_identity_message(
Expand All @@ -1125,8 +1190,7 @@ impl IdentityKeyStore for SledStore {
_direction: Direction,
) -> Result<bool, SignalProtocolError> {
match self
.get(SLED_TREE_IDENTITIES, address.to_string())
.map_err(SledStoreError::into_signal_error)?
.get(SLED_TREE_IDENTITIES, address.to_string())?
.map(|b: Vec<u8>| IdentityKey::decode(&b))
.transpose()?
{
Expand All @@ -1153,8 +1217,7 @@ impl IdentityKeyStore for SledStore {
&self,
address: &ProtocolAddress,
) -> Result<Option<IdentityKey>, SignalProtocolError> {
self.get(SLED_TREE_IDENTITIES, address.to_string())
.map_err(SledStoreError::into_signal_error)?
self.get(SLED_TREE_IDENTITIES, address.to_string())?
.map(|b: Vec<u8>| IdentityKey::decode(&b))
.transpose()
}
Expand All @@ -1174,8 +1237,7 @@ impl SenderKeyStore for SledStore {
sender.device_id(),
distribution_id
);
self.insert(SLED_TREE_SENDER_KEYS, key, record.serialize()?)
.map_err(SledStoreError::into_signal_error)?;
self.insert(SLED_TREE_SENDER_KEYS, key, record.serialize()?)?;
Ok(())
}

Expand All @@ -1190,8 +1252,7 @@ impl SenderKeyStore for SledStore {
sender.device_id(),
distribution_id
);
self.get(SLED_TREE_SENDER_KEYS, key)
.map_err(SledStoreError::into_signal_error)?
self.get(SLED_TREE_SENDER_KEYS, key)?
.map(|b: Vec<u8>| SenderKeyRecord::deserialize(&b))
.transpose()
}
Expand Down
68 changes: 56 additions & 12 deletions presage/src/manager/confirmation.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use libsignal_service::configuration::{ServiceConfiguration, SignalServers};
use libsignal_service::messagepipe::ServiceCredentials;
use libsignal_service::prelude::phonenumber::PhoneNumber;
use libsignal_service::protocol::KeyPair;
use libsignal_service::protocol::IdentityKeyPair;
use libsignal_service::provisioning::generate_registration_id;
use libsignal_service::push_service::{
AccountAttributes, DeviceCapabilities, PushService, RegistrationMethod, ServiceIds,
AccountAttributes, DeviceActivationRequest, DeviceCapabilities, PushService,
RegistrationMethod, ServiceIds,
};
use libsignal_service::zkgroup::profiles::ProfileKey;
use libsignal_service_hyper::push_service::HyperPushService;
Expand Down Expand Up @@ -35,7 +36,7 @@ impl<S: Store> Manager<S, Confirmation> {
/// Returns a [registered manager](Manager::load_registered) that you can use
/// to send and receive messages.
pub async fn confirm_verification_code(
self,
mut self,
confirmation_code: impl AsRef<str>,
) -> Result<Manager<S, Registered>, Error<S::Error>> {
trace!("confirming verification code");
Expand Down Expand Up @@ -87,7 +88,50 @@ impl<S: Store> Manager<S, Confirmation> {

let profile_key = ProfileKey::generate(profile_key);

let aci_identity_key_pair = IdentityKeyPair::generate(&mut rng);
let pni_identity_key_pair = IdentityKeyPair::generate(&mut rng);

let (_aci_pre_keys, aci_signed_pre_key, _aci_pq_pre_keys, aci_pq_last_resort_pre_key) =
libsignal_service::pre_keys::replenish_pre_keys(
&mut self.store,
&aci_identity_key_pair,
&mut rng,
true,
0,
0,
)
.await?;
gferon marked this conversation as resolved.
Show resolved Hide resolved

let aci_pq_last_resort_pre_key =
aci_pq_last_resort_pre_key.expect("requested last resort key");
assert!(_aci_pre_keys.is_empty());
assert!(_aci_pq_pre_keys.is_empty());

let (_pni_pre_keys, pni_signed_pre_key, _pni_pq_pre_keys, pni_pq_last_resort_pre_key) =
libsignal_service::pre_keys::replenish_pre_keys(
&mut self.store,
&pni_identity_key_pair,
&mut rng,
true,
0,
0,
)
.await?;

let pni_pq_last_resort_pre_key =
pni_pq_last_resort_pre_key.expect("requested last resort key");
assert!(_pni_pre_keys.is_empty());
assert!(_pni_pq_pre_keys.is_empty());

let skip_device_transfer = false;

let device_activation_request = DeviceActivationRequest {
aci_signed_pre_key: aci_signed_pre_key.try_into()?,
pni_signed_pre_key: pni_signed_pre_key.try_into()?,
aci_pq_last_resort_pre_key: aci_pq_last_resort_pre_key.try_into()?,
pni_pq_last_resort_pre_key: pni_pq_last_resort_pre_key.try_into()?,
};

let registered = push_service
.submit_registration_request(
RegistrationMethod::SessionId(&session_id),
Expand All @@ -105,18 +149,18 @@ impl<S: Store> Manager<S, Confirmation> {
unrestricted_unidentified_access: false, // TODO: make this configurable?
discoverable_by_phone_number: true,
capabilities: DeviceCapabilities {
gv2: true,
gv1_migration: true,
pni: true,
sender_key: true,
..Default::default()
},
},
skip_device_transfer,
aci_identity_key_pair.identity_key(),
pni_identity_key_pair.identity_key(),
device_activation_request,
)
.await?;

let aci_identity_key_pair = KeyPair::generate(&mut rng);
let pni_identity_key_pair = KeyPair::generate(&mut rng);

trace!("confirmed! (and registered)");

let mut manager = Manager {
Expand All @@ -135,10 +179,10 @@ impl<S: Store> Manager<S, Confirmation> {
device_id: None,
registration_id,
pni_registration_id: Some(pni_registration_id),
aci_private_key: aci_identity_key_pair.private_key,
aci_public_key: aci_identity_key_pair.public_key,
pni_private_key: Some(pni_identity_key_pair.private_key),
pni_public_key: Some(pni_identity_key_pair.public_key),
aci_private_key: *aci_identity_key_pair.private_key(),
aci_identity_key: *aci_identity_key_pair.identity_key(),
pni_private_key: Some(*pni_identity_key_pair.private_key()),
pni_identity_key: Some(*pni_identity_key_pair.identity_key()),
profile_key,
}),
};
Expand Down
4 changes: 2 additions & 2 deletions presage/src/manager/linking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ impl<S: Store> Manager<S, Linking> {
device_id: Some(d.device_id.into()),
registration_id: d.registration_id,
pni_registration_id: Some(d.pni_registration_id),
aci_public_key: d.aci_public_key,
aci_identity_key: d.aci_public_key,
aci_private_key: d.aci_private_key,
pni_public_key: Some(d.pni_public_key),
pni_identity_key: Some(d.pni_public_key),
pni_private_key: Some(d.pni_private_key),
profile_key: d.profile_key,
};
Expand Down
4 changes: 2 additions & 2 deletions presage/src/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ mod tests {

let data: RegistrationData =
serde_json::from_value(previous_state).expect("should deserialize");
assert_eq!(data.aci_public_key, key_pair.public_key);
assert_eq!(data.aci_identity_key, key_pair.public_key);
assert!(data.aci_private_key == key_pair.private_key);
assert!(data.pni_public_key.is_none());
assert!(data.pni_identity_key.is_none());
}
}
Loading
Loading