Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse/storage/databases/main/stats.py #11653

Merged
merged 1 commit into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11653.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ exclude = (?x)
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
|synapse/storage/databases/main/stats.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/

Expand Down Expand Up @@ -214,6 +213,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.profile]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.stats]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True

Expand Down
94 changes: 53 additions & 41 deletions synapse/storage/databases/main/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
import logging
from enum import Enum
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast

from typing_extensions import Counter

from twisted.internet.defer import DeferredLock

from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
Expand Down Expand Up @@ -122,7 +126,9 @@ def __init__(
self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
self.db_pool.updates.register_noop_background_update("populate_stats_prepare")

async def _populate_stats_process_users(self, progress, batch_size):
async def _populate_stats_process_users(
self, progress: JsonDict, batch_size: int
) -> int:
"""
This is a background update which regenerates statistics for users.
"""
Expand All @@ -134,7 +140,7 @@ async def _populate_stats_process_users(self, progress, batch_size):

last_user_id = progress.get("last_user_id", "")

def _get_next_batch(txn):
def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT name FROM users
WHERE name > ?
Expand Down Expand Up @@ -168,7 +174,9 @@ def _get_next_batch(txn):

return len(users_to_work_on)

async def _populate_stats_process_rooms(self, progress, batch_size):
async def _populate_stats_process_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""This is a background update which regenerates statistics for rooms."""
if not self.stats_enabled:
await self.db_pool.updates._end_background_update(
Expand All @@ -178,7 +186,7 @@ async def _populate_stats_process_rooms(self, progress, batch_size):

last_room_id = progress.get("last_room_id", "")

def _get_next_batch(txn):
def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ?
Expand Down Expand Up @@ -307,7 +315,7 @@ async def bulk_update_stats_delta(
stream_id: Current position.
"""

def _bulk_update_stats_delta_txn(txn):
def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None:
for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items():
logger.debug(
Expand Down Expand Up @@ -339,7 +347,7 @@ async def update_stats_delta(
stats_type: str,
stats_id: str,
fields: Dict[str, int],
complete_with_stream_id: Optional[int],
complete_with_stream_id: int,
Comment on lines -342 to +350
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Optional[int] is in conflict with

# Keep the delta stream ID field up to date
absolute_field_overrides = absolute_field_overrides.copy()
absolute_field_overrides["completed_delta_stream_id"] = complete_with_stream_id

I could not see where Optional is needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_stats_delta has two callers, one calculates complete_with_stream_id by calling get_room_max_stream_ordering(), the other from _get_max_stream_id_in_current_state_deltas_txn which will always return an int, so I agree. 👍

absolute_field_overrides: Optional[Dict[str, int]] = None,
) -> None:
"""
Expand Down Expand Up @@ -372,14 +380,14 @@ async def update_stats_delta(

def _update_stats_delta_txn(
self,
txn,
ts,
stats_type,
stats_id,
fields,
complete_with_stream_id,
absolute_field_overrides=None,
):
txn: LoggingTransaction,
ts: int,
stats_type: str,
stats_id: str,
fields: Dict[str, int],
complete_with_stream_id: int,
absolute_field_overrides: Optional[Dict[str, int]] = None,
) -> None:
if absolute_field_overrides is None:
absolute_field_overrides = {}

Expand Down Expand Up @@ -422,20 +430,23 @@ def _update_stats_delta_txn(
)

def _upsert_with_additive_relatives_txn(
self, txn, table, keyvalues, absolutes, additive_relatives
):
self,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
absolutes: Dict[str, Any],
additive_relatives: Dict[str, int],
) -> None:
"""Used to update values in the stats tables.

This is basically a slightly convoluted upsert that *adds* to any
existing rows.

Args:
txn
table (str): Table name
keyvalues (dict[str, any]): Row-identifying key values
absolutes (dict[str, any]): Absolute (set) fields
additive_relatives (dict[str, int]): Fields that will be added onto
if existing row present.
table: Table name
keyvalues: Row-identifying key values
absolutes: Absolute (set) fields
additive_relatives: Fields that will be added onto if existing row present.
"""
if self.database_engine.can_native_upsert:
absolute_updates = [
Expand Down Expand Up @@ -491,20 +502,17 @@ def _upsert_with_additive_relatives_txn(
current_row.update(absolutes)
self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)

async def _calculate_and_set_initial_state_for_room(
self, room_id: str
) -> Tuple[dict, dict, int]:
async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
"""Calculate and insert an entry into room_stats_current.

Args:
room_id: The room ID under calculation.

Returns:
A tuple of room state, membership counts and stream position.
"""

def _fetch_current_state_stats(txn):
pos = self.get_room_max_stream_ordering()
def _fetch_current_state_stats(
txn: LoggingTransaction,
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]

rows = self.db_pool.simple_select_many_txn(
txn,
Expand All @@ -524,7 +532,7 @@ def _fetch_current_state_stats(txn):
retcols=["event_id"],
)

event_ids = [row["event_id"] for row in rows]
event_ids = cast(List[str], [row["event_id"] for row in rows])

txn.execute(
"""
Expand All @@ -544,9 +552,9 @@ def _fetch_current_state_stats(txn):
(room_id,),
)

(current_state_events_count,) = txn.fetchone()
current_state_events_count = cast(Tuple[int], txn.fetchone())[0]

users_in_room = self.get_users_in_room_txn(txn, room_id)
users_in_room = self.get_users_in_room_txn(txn, room_id) # type: ignore[attr-defined]

return (
event_ids,
Expand All @@ -566,7 +574,7 @@ def _fetch_current_state_stats(txn):
"get_initial_state_for_room", _fetch_current_state_stats
)

state_event_map = await self.get_events(event_ids, get_prev_content=False)
state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]

room_state = {
"join_rules": None,
Expand Down Expand Up @@ -622,8 +630,10 @@ def _fetch_current_state_stats(txn):
},
)

async def _calculate_and_set_initial_state_for_user(self, user_id):
def _calculate_and_set_initial_state_for_user_txn(txn):
async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None:
def _calculate_and_set_initial_state_for_user_txn(
txn: LoggingTransaction,
) -> Tuple[int, int]:
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)

txn.execute(
Expand All @@ -634,7 +644,7 @@ def _calculate_and_set_initial_state_for_user_txn(txn):
""",
(user_id,),
)
(count,) = txn.fetchone()
count = cast(Tuple[int], txn.fetchone())[0]
return count, pos

joined_rooms, pos = await self.db_pool.runInteraction(
Expand Down Expand Up @@ -678,7 +688,9 @@ async def get_users_media_usage_paginate(
users that exist given this query
"""

def get_users_media_usage_paginate_txn(txn):
def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
filters = []
args = [self.hs.config.server.server_name]

Expand Down Expand Up @@ -733,7 +745,7 @@ def get_users_media_usage_paginate_txn(txn):
sql_base=sql_base,
)
txn.execute(sql, args)
count = txn.fetchone()[0]
count = cast(Tuple[int], txn.fetchone())[0]

sql = """
SELECT
Expand Down