diff --git a/changelog.d/3520.bugfix b/changelog.d/3520.bugfix new file mode 100644 index 000000000000..9278cb3708db --- /dev/null +++ b/changelog.d/3520.bugfix @@ -0,0 +1 @@ +Correctly announce deleted devices over federation diff --git a/synapse/notifier.py b/synapse/notifier.py index 51cbd66f0628..e650c3e4941a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -274,7 +274,7 @@ def _notify_app_services(self, room_stream_id): logger.exception("Error notifying application services of event") def on_new_event(self, stream_key, new_token, users=[], rooms=[]): - """ Used to inform listeners that something has happend event wise. + """ Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. """ diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index ec68e39f1e87..cc3cdf2ebccc 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -248,17 +248,31 @@ def update_remote_device_list_cache_entry(self, user_id, device_id, content, def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, content, stream_id): - self._simple_upsert_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "content": json.dumps(content), - } - ) + if content.get("deleted"): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + ) + + txn.call_after( + self.device_id_exists_cache.invalidate, (user_id, device_id,) + ) + else: + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + } + ) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) @@ -366,7 +380,7 @@ def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, now_stream_id = max(stream_id for stream_id in itervalues(query_map)) devices = self._get_e2e_device_keys_txn( - txn, query_map.keys(), include_all_devices=True + txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True ) prev_sent_id_sql = """ @@ -393,12 +407,15 @@ def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, prev_id = stream_id - key_json = device.get("key_json", None) - if key_json: - result["keys"] = json.loads(key_json) - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name + if device is not None: + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + else: + result["deleted"] = True results.append(result) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 7ae5c6548275..523b4360c3cb 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -64,12 +64,18 @@ def _set_e2e_device_keys_txn(txn): ) @defer.inlineCallbacks - def get_e2e_device_keys(self, query_list, include_all_devices=False): + def get_e2e_device_keys( + self, query_list, include_all_devices=False, + include_deleted_devices=False, + ): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. include_all_devices (bool): whether to include entries for devices that don't have device keys + include_deleted_devices (bool): whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to dict containing "key_json", "device_display_name". @@ -79,7 +85,7 @@ def get_e2e_device_keys(self, query_list, include_all_devices=False): results = yield self.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, - query_list, include_all_devices, + query_list, include_all_devices, include_deleted_devices, ) for user_id, device_keys in iteritems(results): @@ -88,10 +94,19 @@ def get_e2e_device_keys(self, query_list, include_all_devices=False): defer.returnValue(results) - def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): + def _get_e2e_device_keys_txn( + self, txn, query_list, include_all_devices=False, + include_deleted_devices=False, + ): query_clauses = [] query_params = [] + if include_all_devices is False: + include_deleted_devices = False + + if include_deleted_devices: + deleted_devices = set(query_list) + for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) @@ -119,8 +134,14 @@ def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): result = {} for row in rows: + if include_deleted_devices: + deleted_devices.remove((row["user_id"], row["device_id"])) result.setdefault(row["user_id"], {})[row["device_id"]] = row + if include_deleted_devices: + for user_id, device_id in deleted_devices: + result.setdefault(user_id, {})[device_id] = None + return result @defer.inlineCallbacks