diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 81f661160ce2..4870678f5027 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -35,7 +35,6 @@ Tuple, Type, TypeVar, - Union, cast, overload, ) @@ -1044,43 +1043,20 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: results = [dict(zip(col_headers, row)) for row in cursor] return results - @overload - async def execute( - self, desc: str, decoder: Literal[None], query: str, *args: Any - ) -> List[Tuple[Any, ...]]: - ... - - @overload - async def execute( - self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any - ) -> R: - ... - - async def execute( - self, - desc: str, - decoder: Optional[Callable[[Cursor], R]], - query: str, - *args: Any, - ) -> Union[List[Tuple[Any, ...]], R]: + async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]: """Runs a single query for a result set. Args: desc: description of the transaction, for logging and metrics - decoder - The function which can resolve the cursor results to - something meaningful. query - The query string to execute *args - Query args. Returns: The result of decoder(results) """ - def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]: + def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]: txn.execute(query, args) - if decoder: - return decoder(txn) - else: - return txn.fetchall() + return txn.fetchall() return await self.runInteraction(desc, interaction) diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 58177ecec132..711fdddd4e96 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -93,7 +93,7 @@ async def _censor_redactions(self) -> None: """ rows = await self.db_pool.execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 + "_censor_redactions_fetch", sql, before_ts, 100 ) updates = [] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fc23d18eba53..328cd266c981 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -882,7 +882,6 @@ async def get_all_devices_changed( rows = await self.db_pool.execute( "get_all_devices_changed", - None, sql, from_key, to_key, @@ -966,7 +965,7 @@ async def get_users_whose_signatures_changed( WHERE from_user_id = ? AND stream_id > ? """ rows = await self.db_pool.execute( - "get_users_whose_signatures_changed", None, sql, user_id, from_key + "get_users_whose_signatures_changed", sql, user_id, from_key ) return {user for row in rows for user in db_to_json(row[0])} else: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f13d776b0d1a..f70f95eebaa5 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -155,7 +155,6 @@ async def get_e2e_device_keys_for_federation_query( """ rows = await self.db_pool.execute( "get_e2e_device_keys_for_federation_query_check", - None, sql, now_stream_id, user_id, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index c5fce1c82b12..00618051507e 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1310,12 +1310,9 @@ def process(txn: Cursor) -> None: # ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the # indexes on it. - # We need to pass execute a dummy function to handle the txn's result otherwise - # it tries to call fetchall() on it and fails because there's no result to fetch. - await self.db_pool.execute( + await self.db_pool.runInteraction( "background_analyze_new_stream_ordering_column", - lambda txn: None, - "ANALYZE events(stream_ordering2)", + lambda txn: txn.execute("ANALYZE events(stream_ordering2)"), ) await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index b244651ba6b9..30515a129900 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -683,7 +683,7 @@ async def get_remote_media_ids( return cast( List[Tuple[str, str, str]], - await self.db_pool.execute("get_remote_media_ids", None, sql, before_ts), + await self.db_pool.execute("get_remote_media_ids", sql, before_ts), ) async def delete_remote_media(self, media_origin: str, media_id: str) -> None: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 3a87eba430d8..0003f64849ad 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -940,7 +940,7 @@ async def _check_host_room_membership( like_clause = "%:" + host rows = await self.db_pool.execute( - "is_host_joined", None, sql, membership, room_id, like_clause + "is_host_joined", sql, membership, room_id, like_clause ) if not rows: @@ -1165,7 +1165,7 @@ async def is_locally_forgotten_room(self, room_id: str) -> bool: AND forgotten = 0; """ - rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id) + rows = await self.db_pool.execute("is_forgotten_room", sql, room_id) # `count(*)` returns always an integer # If any rows still exist it means someone has not forgotten this room yet diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index b5acee913237..731ac0eaf600 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -510,7 +510,7 @@ async def search_msgs( # List of tuples of (rank, room_id, event_id). results = cast( List[Tuple[int, str, str]], - await self.db_pool.execute("search_msgs", None, sql, *args), + await self.db_pool.execute("search_msgs", sql, *args), ) results = list(filter(lambda row: row[1] in room_ids, results)) @@ -533,9 +533,7 @@ async def search_msgs( # List of tuples of (room_id, count). count_results = cast( List[Tuple[str, int]], - await self.db_pool.execute( - "search_rooms_count", None, count_sql, *count_args - ), + await self.db_pool.execute("search_rooms_count", count_sql, *count_args), ) count = sum(row[1] for row in count_results if row[0] in room_ids) @@ -675,7 +673,7 @@ async def search_rooms( # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering). results = cast( List[Tuple[int, str, str, int, int]], - await self.db_pool.execute("search_rooms", None, sql, *args), + await self.db_pool.execute("search_rooms", sql, *args), ) results = list(filter(lambda row: row[1] in room_ids, results)) @@ -698,9 +696,7 @@ async def search_rooms( # List of tuples of (room_id, count). count_results = cast( List[Tuple[str, int]], - await self.db_pool.execute( - "search_rooms_count", None, count_sql, *count_args - ), + await self.db_pool.execute("search_rooms_count", count_sql, *count_args), ) count = sum(row[1] for row in count_results if row[0] in room_ids) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 872df6bda12c..2225f8272d93 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1078,7 +1078,7 @@ async def get_current_topological_token(self, room_id: str, stream_key: int) -> """ row = await self.db_pool.execute( - "get_current_topological_token", None, sql, room_id, room_id, stream_key + "get_current_topological_token", sql, room_id, room_id, stream_key ) return row[0][0] if row else 0 @@ -1636,7 +1636,6 @@ async def get_timeline_gaps( rows = await self.db_pool.execute( "get_timeline_gaps", - None, sql, room_id, from_token.stream if from_token else 0, diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 8a4dd75e52ed..a9f5d68b639a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -1146,7 +1146,7 @@ async def search_user_dir( results = cast( List[Tuple[str, Optional[str], Optional[str]]], - await self.db_pool.execute("search_user_dir", None, sql, *args), + await self.db_pool.execute("search_user_dir", sql, *args), ) limited = len(results) > limit diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 6ff533a129bd..0f9c550b27e4 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -359,7 +359,6 @@ async def _background_deduplicate_state( if max_group is None: rows = await self.db_pool.execute( "_background_deduplicate_state", - None, "SELECT coalesce(max(id), 0) FROM state_groups", ) max_group = rows[0][0] diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 75ae740b435d..08214b001316 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -100,7 +100,6 @@ def get_destination_room(self, room: str, destination: str = "host2") -> dict: event_id, stream_ordering = self.get_success( self.hs.get_datastores().main.db_pool.execute( "test:get_destination_rooms", - None, """ SELECT event_id, stream_ordering FROM destination_rooms dr diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index abf7d0564d81..3f23df0b2e48 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -456,17 +456,16 @@ def test_not_null_constraint(self) -> None: ); """ self.get_success( - self.store.db_pool.execute( - "test_not_null_constraint", lambda _: None, table_sql + self.store.db_pool.runInteraction( + "test_not_null_constraint", lambda txn: txn.execute(table_sql) ) ) # We add an index so that we can check that its correctly recreated when # using SQLite. - index_sql = "CREATE INDEX test_index ON test_constraint(a)" self.get_success( - self.store.db_pool.execute( - "test_not_null_constraint", lambda _: None, index_sql + self.store.db_pool.runInteraction( + "test_not_null_constraint", lambda txn: txn.execute(table_sql) ) ) @@ -563,20 +562,14 @@ def test_foreign_constraint(self) -> None: ); """ - table_sql = """ - CREATE TABLE test_constraint( - a INT PRIMARY KEY, - b INT NOT NULL - ); - """ self.get_success( - self.store.db_pool.execute( - "test_foreign_key_constraint", lambda _: None, base_sql + self.store.db_pool.runInteraction( + "test_foreign_key_constraint", lambda txn: txn.execute(base_sql) ) ) self.get_success( - self.store.db_pool.execute( - "test_foreign_key_constraint", lambda _: None, table_sql + self.store.db_pool.runInteraction( + "test_foreign_key_constraint", lambda txn: txn.execute(base_sql) ) ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 95f99f413011..6afb5403bd13 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -120,7 +120,7 @@ def f(txn: LoggingTransaction) -> None: res = self.get_success( self.store.db_pool.execute( - "", None, "SELECT full_user_id from profiles ORDER BY full_user_id" + "", "SELECT full_user_id from profiles ORDER BY full_user_id" ) ) self.assertEqual(len(res), len(expected_values)) diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py index d4637d9d1ebb..2da6a018e8a8 100644 --- a/tests/storage/test_user_filters.py +++ b/tests/storage/test_user_filters.py @@ -87,7 +87,7 @@ def f(txn: LoggingTransaction) -> None: res = self.get_success( self.store.db_pool.execute( - "", None, "SELECT full_user_id from user_filters ORDER BY full_user_id" + "", "SELECT full_user_id from user_filters ORDER BY full_user_id" ) ) self.assertEqual(len(res), len(expected_values))