diff --git a/bindings/matrix-sdk-ffi/src/room_list.rs b/bindings/matrix-sdk-ffi/src/room_list.rs index 8437df88b49..362bb55267c 100644 --- a/bindings/matrix-sdk-ffi/src/room_list.rs +++ b/bindings/matrix-sdk-ffi/src/room_list.rs @@ -23,6 +23,7 @@ use matrix_sdk_ui::{ BoxedFilterFn, }, timeline::default_event_filter, + unable_to_decrypt_hook::UtdHookManager, }; use tokio::sync::RwLock; @@ -107,6 +108,7 @@ impl From for matrix_sdk_ui::room_list_service::Input { #[derive(uniffi::Object)] pub struct RoomListService { pub(crate) inner: Arc, + pub(crate) utd_hook: Option>, } #[uniffi::export(async_runtime = "tokio")] @@ -128,6 +130,7 @@ impl RoomListService { Ok(Arc::new(RoomListItem { inner: Arc::new(RUNTIME.block_on(async { self.inner.room(room_id).await })?), + utd_hook: self.utd_hook.clone(), })) } @@ -478,6 +481,7 @@ impl FilterWrapper { #[derive(uniffi::Object)] pub struct RoomListItem { inner: Arc, + utd_hook: Option>, } #[uniffi::export(async_runtime = "tokio")] @@ -549,6 +553,11 @@ impl RoomListItem { default_event_filter(event, room_version_id) && event_type_filter.filter(event) }); } + + if let Some(utd_hook) = self.utd_hook.clone() { + timeline_builder = timeline_builder.with_unable_to_decrypt_hook(utd_hook); + } + self.inner.init_timeline_with_builder(timeline_builder).map_err(RoomListError::from).await } diff --git a/bindings/matrix-sdk-ffi/src/sync_service.rs b/bindings/matrix-sdk-ffi/src/sync_service.rs index 069966801de..4b0c05489a5 100644 --- a/bindings/matrix-sdk-ffi/src/sync_service.rs +++ b/bindings/matrix-sdk-ffi/src/sync_service.rs @@ -12,13 +12,18 @@ // See the License for that specific language governing permissions and // limitations under the License. -use std::{fmt::Debug, sync::Arc}; +use std::{fmt::Debug, sync::Arc, time::Duration}; use futures_util::pin_mut; use matrix_sdk::Client; -use matrix_sdk_ui::sync_service::{ - State as MatrixSyncServiceState, SyncService as MatrixSyncService, - SyncServiceBuilder as MatrixSyncServiceBuilder, +use matrix_sdk_ui::{ + sync_service::{ + State as MatrixSyncServiceState, SyncService as MatrixSyncService, + SyncServiceBuilder as MatrixSyncServiceBuilder, + }, + unable_to_decrypt_hook::{ + UnableToDecryptHook, UnableToDecryptInfo as SdkUnableToDecryptInfo, UtdHookManager, + }, }; use crate::{ @@ -53,12 +58,16 @@ pub trait SyncServiceStateObserver: Send + Sync + Debug { #[derive(uniffi::Object)] pub struct SyncService { pub(crate) inner: Arc, + utd_hook: Option>, } #[uniffi::export(async_runtime = "tokio")] impl SyncService { pub fn room_list_service(&self) -> Arc { - Arc::new(RoomListService { inner: self.inner.room_list_service() }) + Arc::new(RoomListService { + inner: self.inner.room_list_service(), + utd_hook: self.utd_hook.clone(), + }) } pub async fn start(&self) { @@ -85,11 +94,13 @@ impl SyncService { #[derive(Clone, uniffi::Object)] pub struct SyncServiceBuilder { builder: MatrixSyncServiceBuilder, + + utd_hook: Option>, } impl SyncServiceBuilder { pub(crate) fn new(client: Client) -> Arc { - Arc::new(Self { builder: MatrixSyncService::builder(client) }) + Arc::new(Self { builder: MatrixSyncService::builder(client), utd_hook: None }) } } @@ -101,17 +112,88 @@ impl SyncServiceBuilder { ) -> Arc { let this = unwrap_or_clone_arc(self); let builder = this.builder.with_unified_invites_in_room_list(with_unified_invites); - Arc::new(Self { builder }) + Arc::new(Self { builder, utd_hook: this.utd_hook }) } pub fn with_cross_process_lock(self: Arc, app_identifier: Option) -> Arc { let this = unwrap_or_clone_arc(self); let builder = this.builder.with_cross_process_lock(app_identifier); - Arc::new(Self { builder }) + Arc::new(Self { builder, utd_hook: this.utd_hook }) + } + + pub fn with_utd_hook(self: Arc, delegate: Box) -> Arc { + // UTDs detected before this duration may be reclassified as "late decryption" + // events (or discarded, if they get decrypted fast enough). + const UTD_HOOK_GRACE_PERIOD: Duration = Duration::from_secs(60); + + let this = unwrap_or_clone_arc(self); + let utd_hook = Some(Arc::new( + UtdHookManager::new(Arc::new(UtdHook { delegate })) + .with_max_delay(UTD_HOOK_GRACE_PERIOD), + )); + Arc::new(Self { builder: this.builder, utd_hook }) } pub async fn finish(self: Arc) -> Result, ClientError> { let this = unwrap_or_clone_arc(self); - Ok(Arc::new(SyncService { inner: Arc::new(this.builder.build().await?) })) + Ok(Arc::new(SyncService { + inner: Arc::new(this.builder.build().await?), + utd_hook: this.utd_hook, + })) + } +} + +#[uniffi::export(callback_interface)] +pub trait UnableToDecryptDelegate: Sync + Send { + fn on_utd(&self, info: UnableToDecryptInfo); +} + +struct UtdHook { + delegate: Box, +} + +impl std::fmt::Debug for UtdHook { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("UtdHook").finish_non_exhaustive() + } +} + +impl UnableToDecryptHook for UtdHook { + fn on_utd(&self, info: SdkUnableToDecryptInfo) { + const IGNORE_UTD_PERIOD: Duration = Duration::from_secs(4); + + // UTDs that have been decrypted in the `IGNORE_UTD_PERIOD` are just ignored and + // not considered UTDs. + if let Some(duration) = &info.time_to_decrypt { + if *duration < IGNORE_UTD_PERIOD { + return; + } + } + + // Report the UTD to the client. + self.delegate.on_utd(info.into()); + } +} + +#[derive(uniffi::Record)] +pub struct UnableToDecryptInfo { + /// The identifier of the event that couldn't get decrypted. + event_id: String, + + /// If the event could be decrypted late (that is, the event was encrypted + /// at first, but could be decrypted later on), then this indicates the + /// time it took to decrypt the event. If it is not set, this is + /// considered a definite UTD. + /// + /// If set, this is in milliseconds. + pub time_to_decrypt_ms: Option, +} + +impl From for UnableToDecryptInfo { + fn from(value: SdkUnableToDecryptInfo) -> Self { + Self { + event_id: value.event_id.to_string(), + time_to_decrypt_ms: value.time_to_decrypt.map(|ttd| ttd.as_millis() as u64), + } } } diff --git a/crates/matrix-sdk-ui/src/lib.rs b/crates/matrix-sdk-ui/src/lib.rs index 0e33ac0e7d0..3537a39e281 100644 --- a/crates/matrix-sdk-ui/src/lib.rs +++ b/crates/matrix-sdk-ui/src/lib.rs @@ -21,6 +21,7 @@ pub mod notification_client; pub mod room_list_service; pub mod sync_service; pub mod timeline; +pub mod unable_to_decrypt_hook; pub use self::{room_list_service::RoomListService, timeline::Timeline}; diff --git a/crates/matrix-sdk-ui/src/timeline/builder.rs b/crates/matrix-sdk-ui/src/timeline/builder.rs index 54edde8d3bb..cbbe52d432f 100644 --- a/crates/matrix-sdk-ui/src/timeline/builder.rs +++ b/crates/matrix-sdk-ui/src/timeline/builder.rs @@ -36,6 +36,7 @@ use super::{ queue::send_queued_messages, BackPaginationStatus, Timeline, TimelineDropHandle, }; +use crate::unable_to_decrypt_hook::UtdHookManager; /// Builder that allows creating and configuring various parts of a /// [`Timeline`]. @@ -45,11 +46,29 @@ pub struct TimelineBuilder { room: Room, prev_token: Option, settings: TimelineInnerSettings, + + /// An optional hook to call whenever we run into an unable-to-decrypt or a + /// late-decryption event. + unable_to_decrypt_hook: Option>, } impl TimelineBuilder { pub(super) fn new(room: &Room) -> Self { - Self { room: room.clone(), prev_token: None, settings: TimelineInnerSettings::default() } + Self { + room: room.clone(), + prev_token: None, + settings: TimelineInnerSettings::default(), + unable_to_decrypt_hook: None, + } + } + + /// Sets up a hook to catch unable-to-decrypt (UTD) events for the timeline + /// we're building. + /// + /// If it was previously set before, will overwrite the previous one. + pub fn with_unable_to_decrypt_hook(mut self, hook: Arc) -> Self { + self.unable_to_decrypt_hook = Some(hook); + self } /// Add initial events to the timeline. @@ -119,7 +138,7 @@ impl TimelineBuilder { ) )] pub async fn build(self) -> event_cache::Result { - let Self { room, prev_token, settings } = self; + let Self { room, prev_token, settings, unable_to_decrypt_hook } = self; let client = room.client(); let event_cache = client.event_cache(); @@ -133,7 +152,7 @@ impl TimelineBuilder { let has_events = !events.is_empty(); let track_read_marker_and_receipts = settings.track_read_receipts; - let mut inner = TimelineInner::new(room).with_settings(settings); + let mut inner = TimelineInner::new(room, unable_to_decrypt_hook).with_settings(settings); if track_read_marker_and_receipts { inner.populate_initial_user_receipt(ReceiptType::Read).await; diff --git a/crates/matrix-sdk-ui/src/timeline/event_handler.rs b/crates/matrix-sdk-ui/src/timeline/event_handler.rs index c49cdc3206d..f2cccb5f7e8 100644 --- a/crates/matrix-sdk-ui/src/timeline/event_handler.rs +++ b/crates/matrix-sdk-ui/src/timeline/event_handler.rs @@ -297,6 +297,14 @@ impl<'a, 'o> TimelineEventHandler<'a, 'o> { AnyMessageLikeEventContent::RoomEncrypted(c) => { // TODO: Handle replacements if the replaced event is also UTD self.add(true, TimelineItemContent::unable_to_decrypt(c)); + + // Let the hook know that we ran into an unable-to-decrypt that is added to the + // timeline. + if let Some(hook) = self.meta.unable_to_decrypt_hook.as_ref() { + if let Flow::Remote { event_id, .. } = &self.ctx.flow { + hook.on_utd(event_id); + } + } } AnyMessageLikeEventContent::Sticker(content) => { self.add(should_add, TimelineItemContent::Sticker(Sticker { content })); diff --git a/crates/matrix-sdk-ui/src/timeline/inner/mod.rs b/crates/matrix-sdk-ui/src/timeline/inner/mod.rs index 8d42e9c8742..a3df82cb9e2 100644 --- a/crates/matrix-sdk-ui/src/timeline/inner/mod.rs +++ b/crates/matrix-sdk-ui/src/timeline/inner/mod.rs @@ -66,7 +66,7 @@ use super::{ AnnotationKey, EventSendState, EventTimelineItem, InReplyToDetails, Message, Profile, RepliedToEvent, TimelineDetails, TimelineItem, TimelineItemContent, TimelineItemKind, }; -use crate::timeline::TimelineEventFilterFn; +use crate::{timeline::TimelineEventFilterFn, unable_to_decrypt_hook::UtdHookManager}; mod state; @@ -210,8 +210,12 @@ pub fn default_event_filter(event: &AnySyncTimelineEvent, room_version: &RoomVer } impl TimelineInner

{ - pub(super) fn new(room_data_provider: P) -> Self { - let state = TimelineInnerState::new(room_data_provider.room_version()); + pub(super) fn new( + room_data_provider: P, + unable_to_decrypt_hook: Option>, + ) -> Self { + let state = + TimelineInnerState::new(room_data_provider.room_version(), unable_to_decrypt_hook); Self { state: Arc::new(RwLock::new(state)), room_data_provider, @@ -786,11 +790,13 @@ impl TimelineInner

{ let settings = self.settings.clone(); let room_data_provider = self.room_data_provider.clone(); let push_rules_context = room_data_provider.push_rules_and_context().await; + let unable_to_decrypt_hook = state.unable_to_decrypt_hook.clone(); matrix_sdk::executor::spawn(async move { let retry_one = |item: Arc| { let decryptor = decryptor.clone(); let should_retry = &should_retry; + let unable_to_decrypt_hook = unable_to_decrypt_hook.clone(); async move { let event_item = item.as_event()?; @@ -824,6 +830,12 @@ impl TimelineInner

{ trace!( "Successfully decrypted event that previously failed to decrypt" ); + + // Notify observers that we managed to eventually decrypt an event. + if let Some(hook) = unable_to_decrypt_hook { + hook.on_late_decrypt(&remote_event.event_id); + } + Some(event) } Err(e) => { diff --git a/crates/matrix-sdk-ui/src/timeline/inner/state.rs b/crates/matrix-sdk-ui/src/timeline/inner/state.rs index 2e315208b7e..9b627c63e06 100644 --- a/crates/matrix-sdk-ui/src/timeline/inner/state.rs +++ b/crates/matrix-sdk-ui/src/timeline/inner/state.rs @@ -54,6 +54,7 @@ use crate::{ AnnotationKey, Error as TimelineError, Profile, ReactionSenderData, TimelineItem, TimelineItemKind, VirtualTimelineItem, }, + unable_to_decrypt_hook::UtdHookManager, }; #[derive(Debug)] @@ -63,13 +64,16 @@ pub(in crate::timeline) struct TimelineInnerState { } impl TimelineInnerState { - pub(super) fn new(room_version: RoomVersionId) -> Self { + pub(super) fn new( + room_version: RoomVersionId, + unable_to_decrypt_hook: Option>, + ) -> Self { Self { // Upstream default capacity is currently 16, which is making // sliding-sync tests with 20 events lag. This should still be // small enough. items: ObservableVector::with_capacity(32), - meta: TimelineInnerMetadata::new(room_version), + meta: TimelineInnerMetadata::new(room_version, unable_to_decrypt_hook), } } @@ -806,10 +810,16 @@ pub(in crate::timeline) struct TimelineInnerMetadata { /// /// Private because it's not needed by `TimelineEventHandler`. back_pagination_tokens: VecDeque<(OwnedEventId, String)>, + + /// The hook to call whenever we run into a unable-to-decrypt event. + pub(crate) unable_to_decrypt_hook: Option>, } impl TimelineInnerMetadata { - fn new(room_version: RoomVersionId) -> TimelineInnerMetadata { + fn new( + room_version: RoomVersionId, + unable_to_decrypt_hook: Option>, + ) -> Self { Self { all_events: Default::default(), next_internal_id: Default::default(), @@ -824,6 +834,7 @@ impl TimelineInnerMetadata { in_flight_reaction: Default::default(), room_version, back_pagination_tokens: VecDeque::new(), + unable_to_decrypt_hook, } } diff --git a/crates/matrix-sdk-ui/src/timeline/tests/encryption.rs b/crates/matrix-sdk-ui/src/timeline/tests/encryption.rs index be2290e69a8..dd24b9c1e3a 100644 --- a/crates/matrix-sdk-ui/src/timeline/tests/encryption.rs +++ b/crates/matrix-sdk-ui/src/timeline/tests/encryption.rs @@ -14,7 +14,11 @@ #![cfg(not(target_arch = "wasm32"))] -use std::{io::Cursor, iter}; +use std::{ + io::Cursor, + iter, + sync::{Arc, Mutex}, +}; use assert_matches::assert_matches; use assert_matches2::assert_let; @@ -32,10 +36,13 @@ use ruma::{ use stream_assert::assert_next_matches; use super::TestTimeline; -use crate::timeline::{EncryptedMessage, TimelineItemContent}; +use crate::{ + timeline::{EncryptedMessage, TimelineItemContent}, + unable_to_decrypt_hook::{UnableToDecryptHook, UnableToDecryptInfo, UtdHookManager}, +}; #[async_test] -async fn retry_message_decryption() { +async fn test_retry_message_decryption() { const SESSION_ID: &str = "gM8i47Xhu0q52xLfgUXzanCMpLinoyVyH7R58cBuVBU"; const SESSION_KEY: &[u8] = b"\ -----BEGIN MEGOLM SESSION DATA-----\n\ @@ -51,7 +58,21 @@ async fn retry_message_decryption() { HztoSJUr/2Y\n\ -----END MEGOLM SESSION DATA-----"; - let timeline = TestTimeline::new(); + #[derive(Debug, Default)] + struct DummyUtdHook { + utds: Mutex>, + } + + impl UnableToDecryptHook for DummyUtdHook { + fn on_utd(&self, info: UnableToDecryptInfo) { + self.utds.lock().unwrap().push(info); + } + } + + let hook = Arc::new(DummyUtdHook::default()); + let utd_hook = Arc::new(UtdHookManager::new(hook.clone())); + + let timeline = TestTimeline::with_unable_to_decrypt_hook(utd_hook.clone()); let mut stream = timeline.subscribe().await; timeline @@ -92,6 +113,13 @@ async fn retry_message_decryption() { ); assert_eq!(session_id, SESSION_ID); + { + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 1); + assert_eq!(utds[0].event_id, event.event_id().unwrap()); + assert!(utds[0].time_to_decrypt.is_none()); + } + let own_user_id = user_id!("@example:morheus.localhost"); let exported_keys = decrypt_room_key_export(Cursor::new(SESSION_KEY), "1234").unwrap(); @@ -115,10 +143,23 @@ async fn retry_message_decryption() { assert_let!(TimelineItemContent::Message(message) = event.content()); assert_eq!(message.body(), "It's a secret to everybody"); assert!(!event.is_highlighted()); + + { + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 2); + + // The previous UTD report is still there. + assert_eq!(utds[0].event_id, event.event_id().unwrap()); + assert!(utds[0].time_to_decrypt.is_none()); + + // The UTD is now *also* reported as a late-decryption event. + assert_eq!(utds[1].event_id, event.event_id().unwrap()); + assert!(utds[1].time_to_decrypt.is_some()); + } } #[async_test] -async fn retry_edit_decryption() { +async fn test_retry_edit_decryption() { const SESSION1_KEY: &[u8] = b"\ -----BEGIN MEGOLM SESSION DATA-----\n\ AXou7bY+PWm0GrxTioyoKTkxAgfrQ5lGIla62WoBMrqWAAAACgXidLIt0gaK5NT3mGigzFAPjh/M0ibXjSvo\ @@ -224,7 +265,7 @@ async fn retry_edit_decryption() { } #[async_test] -async fn retry_edit_and_more() { +async fn test_retry_edit_and_more() { const DEVICE_ID: &str = "MTEGRRVPEN"; const SENDER_KEY: &str = "NFPM2+ucU3n3sEdbDdwwv48Bsj4AiQ185lGuRFjy+gs"; const SESSION_ID: &str = "SMNh04luorH5E8J3b4XYuOBFp8dldO5njacq0OFO70o"; @@ -329,7 +370,7 @@ async fn retry_edit_and_more() { } #[async_test] -async fn retry_message_decryption_highlighted() { +async fn test_retry_message_decryption_highlighted() { const SESSION_ID: &str = "C25PoE+4MlNidQD0YU5ibZqHawV0zZ/up7R8vYJBYTY"; const SESSION_KEY: &[u8] = b"\ -----BEGIN MEGOLM SESSION DATA-----\n\ diff --git a/crates/matrix-sdk-ui/src/timeline/tests/mod.rs b/crates/matrix-sdk-ui/src/timeline/tests/mod.rs index ba1a4a6d8bc..82750649f73 100644 --- a/crates/matrix-sdk-ui/src/timeline/tests/mod.rs +++ b/crates/matrix-sdk-ui/src/timeline/tests/mod.rs @@ -54,6 +54,7 @@ use super::{ traits::RoomDataProvider, EventTimelineItem, Profile, TimelineInner, TimelineItem, }; +use crate::unable_to_decrypt_hook::UtdHookManager; mod basic; mod echo; @@ -81,7 +82,17 @@ impl TestTimeline { } fn with_room_data_provider(room_data_provider: TestRoomDataProvider) -> Self { - Self { inner: TimelineInner::new(room_data_provider), event_builder: EventBuilder::new() } + Self { + inner: TimelineInner::new(room_data_provider, None), + event_builder: EventBuilder::new(), + } + } + + fn with_unable_to_decrypt_hook(hook: Arc) -> Self { + Self { + inner: TimelineInner::new(TestRoomDataProvider::default(), Some(hook)), + event_builder: EventBuilder::new(), + } } fn with_settings(mut self, settings: TimelineInnerSettings) -> Self { diff --git a/crates/matrix-sdk-ui/src/unable_to_decrypt_hook.rs b/crates/matrix-sdk-ui/src/unable_to_decrypt_hook.rs new file mode 100644 index 00000000000..98adbf89a66 --- /dev/null +++ b/crates/matrix-sdk-ui/src/unable_to_decrypt_hook.rs @@ -0,0 +1,373 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! This module provides a generic interface to subscribe to unable-to-decrypt +//! events, and notable updates to such events. +//! +//! This provides a general trait that a consumer may implement, as well as +//! utilities to simplify usage of this trait. + +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; + +use ruma::{EventId, OwnedEventId}; +use tokio::{spawn, task::JoinHandle, time::sleep}; + +/// A generic interface which methods get called whenever we observe a +/// unable-to-decrypt (UTD) event. +pub trait UnableToDecryptHook: std::fmt::Debug + Send + Sync { + /// Called every time the hook observes an encrypted event that couldn't be + /// decrypted. + /// + /// If the hook manager was configured with a max delay, this could also + /// contain extra information for late-decrypted events. See details in + /// [`UnableToDecryptInfo::time_to_decrypt`]. + fn on_utd(&self, info: UnableToDecryptInfo); +} + +/// Information about an event we were unable to decrypt (UTD). +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct UnableToDecryptInfo { + /// The identifier of the event that couldn't get decrypted. + pub event_id: OwnedEventId, + + /// If the event could be decrypted late (that is, the event was encrypted + /// at first, but could be decrypted later on), then this indicates the + /// time it took to decrypt the event. If it is not set, this is + /// considered a definite UTD. + pub time_to_decrypt: Option, +} + +type PendingUtdReports = Vec<(OwnedEventId, JoinHandle<()>)>; + +/// A manager over an existing [`UnableToDecryptHook`] that deduplicates UTDs +/// on similar events, and adds basic consistency checks. +/// +/// It can also implement a grace period before reporting an event as a UTD, if +/// configured with [`Self::with_max_delay`]. Instead of immediately reporting +/// the UTD, the reporting will be delayed by the max delay at most; if the +/// event could eventually get decrypted, it may be reported before the end of +/// that delay. +#[derive(Debug)] +pub struct UtdHookManager { + /// The parent hook we'll call, when we have found a unique UTD. + parent: Arc, + + /// A mapping of events we've marked as UTDs, and the time at which we + /// observed those UTDs. + /// + /// Note: this is unbounded, because we have absolutely no idea how long it + /// will take for a UTD to resolve, or if it will even resolve at any + /// point. + known_utds: Arc>>, + + /// An optional delay before marking the event as UTD ("grace period"). + max_delay: Option, + + /// The set of outstanding tasks to report deferred UTDs, including the + /// event relating to the task. + /// + /// Note: this is empty if no [`Self::max_delay`] is set. + /// + /// Note: this is theoretically unbounded in size, although this set of + /// tasks will degrow over time, as tasks expire after the max delay. + pending_delayed: Arc>, +} + +impl UtdHookManager { + /// Create a new [`UtdHookManager`] for the given hook. + pub fn new(parent: Arc) -> Self { + Self { + parent, + known_utds: Default::default(), + max_delay: None, + pending_delayed: Default::default(), + } + } + + /// Reports UTDs with the given max delay. + /// + /// Note: late decryptions are always reported, even if there was a grace + /// period set for the reporting of the UTD. + pub fn with_max_delay(mut self, delay: Duration) -> Self { + self.max_delay = Some(delay); + self + } + + /// The function to call whenever a UTD is seen for the first time. + /// + /// Pipe in any information that needs to be included in the final report. + pub(crate) fn on_utd(&self, event_id: &EventId) { + // Only let the parent hook know if the event wasn't already handled. + { + let mut known_utds = self.known_utds.lock().unwrap(); + // Note: we don't want to replace the previous time, so don't look at the result + // of insert to know whether the entry was already present or not. + if known_utds.contains_key(event_id) { + return; + } + known_utds.insert(event_id.to_owned(), Instant::now()); + } + + let info = UnableToDecryptInfo { event_id: event_id.to_owned(), time_to_decrypt: None }; + + let Some(max_delay) = self.max_delay else { + // No delay: immediately report the event to the parent hook. + self.parent.on_utd(info); + return; + }; + + let event_id = info.event_id.clone(); + + // Clone Arc'd pointers shared with the task below. + let known_utds = self.known_utds.clone(); + let pending_delayed = self.pending_delayed.clone(); + let parent = self.parent.clone(); + + // Spawn a task that will wait for the given delay, and maybe call the parent + // hook then. + let handle = spawn(async move { + // Wait for the given delay. + sleep(max_delay).await; + + // In any case, remove the task from the outstanding set. + pending_delayed.lock().unwrap().retain(|(event_id, _)| *event_id != info.event_id); + + // Check if the event is still in the map: if not, it's been decrypted since + // then! + if known_utds.lock().unwrap().contains_key(&info.event_id) { + parent.on_utd(info); + } + }); + + // Add the task to the set of pending tasks. + self.pending_delayed.lock().unwrap().push((event_id, handle)); + } + + /// The function to call whenever an event that was marked as a UTD has + /// eventually been decrypted. + /// + /// Note: if this is called for an event that was never marked as a UTD + /// before, it has no effect. + pub(crate) fn on_late_decrypt(&self, event_id: &EventId) { + // Only let the parent hook know if the event was known to be a UTDs. + let Some(marked_utd_at) = self.known_utds.lock().unwrap().remove(event_id) else { + return; + }; + + let info = UnableToDecryptInfo { + event_id: event_id.to_owned(), + time_to_decrypt: Some(marked_utd_at.elapsed()), + }; + + // Cancel and remove the task from the outstanding set immediately. + self.pending_delayed.lock().unwrap().retain(|(event_id, task)| { + if *event_id == info.event_id { + task.abort(); + false + } else { + true + } + }); + + // Report to the parent hook. + self.parent.on_utd(info); + } +} + +impl Drop for UtdHookManager { + fn drop(&mut self) { + // Cancel all the outstanding delayed tasks to report UTDs. + let mut pending_delayed = self.pending_delayed.lock().unwrap(); + for (_, task) in pending_delayed.drain(..) { + task.abort(); + } + } +} + +#[cfg(test)] +mod tests { + use matrix_sdk_test::async_test; + use ruma::event_id; + + use super::*; + + #[derive(Debug, Default)] + struct Dummy { + utds: Mutex>, + } + + impl UnableToDecryptHook for Dummy { + fn on_utd(&self, info: UnableToDecryptInfo) { + self.utds.lock().unwrap().push(info); + } + } + + #[test] + fn test_deduplicates_utds() { + // If I create a dummy hook, + let hook = Arc::new(Dummy::default()); + + // And I wrap with the UtdHookManager, + let wrapper = UtdHookManager::new(hook.clone()); + + // And I call the `on_utd` method multiple times, sometimes on the same event, + wrapper.on_utd(event_id!("$1")); + wrapper.on_utd(event_id!("$1")); + wrapper.on_utd(event_id!("$2")); + wrapper.on_utd(event_id!("$1")); + wrapper.on_utd(event_id!("$2")); + wrapper.on_utd(event_id!("$3")); + + // Then the event ids have been deduplicated, + { + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 3); + assert_eq!(utds[0].event_id, event_id!("$1")); + assert_eq!(utds[1].event_id, event_id!("$2")); + assert_eq!(utds[2].event_id, event_id!("$3")); + + // No event is a late-decryption event. + assert!(utds[0].time_to_decrypt.is_none()); + assert!(utds[1].time_to_decrypt.is_none()); + assert!(utds[2].time_to_decrypt.is_none()); + } + } + + #[test] + fn test_on_late_decrypted_no_effect() { + // If I create a dummy hook, + let hook = Arc::new(Dummy::default()); + + // And I wrap with the UtdHookManager, + let wrapper = UtdHookManager::new(hook.clone()); + + // And I call the `on_late_decrypt` method before the event had been marked as + // utd, + wrapper.on_late_decrypt(event_id!("$1")); + + // Then nothing is registered in the parent hook. + assert!(hook.utds.lock().unwrap().is_empty()); + } + + #[test] + fn test_on_late_decrypted_after_utd_no_grace_period() { + // If I create a dummy hook, + let hook = Arc::new(Dummy::default()); + + // And I wrap with the UtdHookManager, + let wrapper = UtdHookManager::new(hook.clone()); + + // And I call the `on_utd` method for an event, + wrapper.on_utd(event_id!("$1")); + + // Then the UTD has been notified, but not as late-decrypted event. + { + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 1); + assert_eq!(utds[0].event_id, event_id!("$1")); + assert!(utds[0].time_to_decrypt.is_none()); + } + + // And when I call the `on_late_decrypt` method, + wrapper.on_late_decrypt(event_id!("$1")); + + // Then the event is now reported as a late-decryption too. + { + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 2); + + // The previous report is still there. (There was no grace period.) + assert_eq!(utds[0].event_id, event_id!("$1")); + assert!(utds[0].time_to_decrypt.is_none()); + + // The new report with a late-decryption is there. + assert_eq!(utds[1].event_id, event_id!("$1")); + assert!(utds[1].time_to_decrypt.is_some()); + } + } + + #[cfg(not(target_arch = "wasm32"))] // wasm32 has no time for that + #[async_test] + async fn test_delayed_utd() { + // If I create a dummy hook, + let hook = Arc::new(Dummy::default()); + + // And I wrap with the UtdHookManager, configured to delay reporting after 2 + // seconds. + let wrapper = UtdHookManager::new(hook.clone()).with_max_delay(Duration::from_secs(2)); + + // And I call the `on_utd` method for an event, + wrapper.on_utd(event_id!("$1")); + + // Then the UTD is not being reported immediately. + assert!(hook.utds.lock().unwrap().is_empty()); + assert_eq!(wrapper.pending_delayed.lock().unwrap().len(), 1); + + // If I wait for 1 second, then it's still not been notified yet. + sleep(Duration::from_secs(1)).await; + + assert!(hook.utds.lock().unwrap().is_empty()); + assert_eq!(wrapper.pending_delayed.lock().unwrap().len(), 1); + + // But if I wait just a bit more, then it's getting notified as a definite UTD. + sleep(Duration::from_millis(1500)).await; + + { + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 1); + assert_eq!(utds[0].event_id, event_id!("$1")); + assert!(utds[0].time_to_decrypt.is_none()); + } + + assert!(wrapper.pending_delayed.lock().unwrap().is_empty()); + } + + #[cfg(not(target_arch = "wasm32"))] // wasm32 has no time for that + #[async_test] + async fn test_delayed_late_decryption() { + // If I create a dummy hook, + let hook = Arc::new(Dummy::default()); + + // And I wrap with the UtdHookManager, configured to delay reporting after 2 + // seconds. + let wrapper = UtdHookManager::new(hook.clone()).with_max_delay(Duration::from_secs(2)); + + // And I call the `on_utd` method for an event, + wrapper.on_utd(event_id!("$1")); + + // Then the UTD has not been notified quite yet. + assert!(hook.utds.lock().unwrap().is_empty()); + assert_eq!(wrapper.pending_delayed.lock().unwrap().len(), 1); + + // If I wait for 1 second, and mark the event as late-decrypted, + sleep(Duration::from_secs(1)).await; + + wrapper.on_late_decrypt(event_id!("$1")); + + // Then it's being immediately reported as a late-decryption UTD. + { + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 1); + assert_eq!(utds[0].event_id, event_id!("$1")); + assert!(utds[0].time_to_decrypt.is_some()); + } + + // And there aren't any pending delayed reports anymore. + assert!(wrapper.pending_delayed.lock().unwrap().is_empty()); + } +}