Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use inline type hints in http/federation/, storage/ and util/ #10381

Merged
merged 4 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10381.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.
13 changes: 6 additions & 7 deletions synapse/http/federation/well_known_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,8 @@
logger = logging.getLogger(__name__)


_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
_had_valid_well_known_cache = TTLCache(
"had-valid-well-known"
) # type: TTLCache[bytes, bool]
_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")


@attr.s(slots=True, frozen=True)
Expand Down Expand Up @@ -130,9 +128,10 @@ async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
# requests for the same server in parallel?
try:
with Measure(self._clock, "get_well_known"):
result, cache_period = await self._fetch_well_known(
server_name
) # type: Optional[bytes], float
result: Optional[bytes]
cache_period: float

result, cache_period = await self._fetch_well_known(server_name)
Comment on lines +131 to +134
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this fairly unreadable and would prefer the old way, but 🤷

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some SO answers say this is the way to do it, i think this is the most pythonic you're gonna get, an alternative would be

res: Tuple[Optional[bytes], float] = await self._fetch_well_known(server_name)

result: Optional[bytes] = res[0]
cache_period: float = res[1]

Copy link
Contributor

@reivilibre reivilibre Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, neither before nor after is particularly pretty, but I guess there's something to be said about adhering to the 'Pythonic'/standard way of doing things, and if this is it, then why not?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can't say the alternatives are very pretty either. I don't think we'd even need them except we have to mark result as optional. 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that explicitly annotating unpacked variables wont be needed if we have enough function signatures in place, after which we can just drop the inline type arguments (unless its needed for some explicit reason) and let mypy infer from the methods.

So i think this is okay coupled with the reason to please mypy for now, with the intention to remove it later, does that sound good?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that explicitly annotating unpacked variables wont be needed if we have enough function signatures in place

_fetch_well_known does have proper types, it just that it is Tuple[bytes, float], but the caller might have the result value be Optional[bytes] so it gets a bit awkward. 😄

Anyway, I think this is fine.


except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
Expand Down
16 changes: 7 additions & 9 deletions synapse/storage/background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,12 @@ def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self.db_pool = database

# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]

self._background_update_performance = (
{}
) # type: Dict[str, BackgroundUpdatePerformance]
self._background_update_handlers = (
{}
) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
self._current_background_update: Optional[str] = None

self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
self._background_update_handlers: Dict[
str, Callable[[JsonDict, int], Awaitable[int]]
] = {}
self._all_done = False

def start_doing_background_updates(self) -> None:
Expand Down Expand Up @@ -411,7 +409,7 @@ def create_index_sqlite(conn: Connection) -> None:
c.execute(sql)

if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner = create_index_psql # type: Optional[Callable[[Connection], None]]
runner: Optional[Callable[[Connection], None]] = create_index_psql
elif psql_only:
runner = None
else:
Expand Down
14 changes: 7 additions & 7 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,8 +670,8 @@ async def runInteraction(
Returns:
The result of func
"""
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []

if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
Expand Down Expand Up @@ -1090,7 +1090,7 @@ def _getwhere(key):
return False

# We didn't find any existing rows, so insert a new one
allvalues = {} # type: Dict[str, Any]
allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
Expand Down Expand Up @@ -1121,7 +1121,7 @@ def simple_upsert_txn_native_upsert(
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
"""
allvalues = {} # type: Dict[str, Any]
allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues)
allvalues.update(insertion_values or {})

Expand Down Expand Up @@ -1257,7 +1257,7 @@ def simple_upsert_many_txn_native_upsert(
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
allnames = [] # type: List[str]
allnames: List[str] = []
allnames.extend(key_names)
allnames.extend(value_names)

Expand Down Expand Up @@ -1566,7 +1566,7 @@ async def simple_select_many_batch(
"""
keyvalues = keyvalues or {}

results = [] # type: List[Dict[str, Any]]
results: List[Dict[str, Any]] = []

if not iterable:
return results
Expand Down Expand Up @@ -1978,7 +1978,7 @@ def simple_select_list_paginate_txn(
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")

where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
arg_list = [] # type: List[Any]
arg_list: List[Any] = []
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())
Expand Down
4 changes: 1 addition & 3 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def _make_exclusive_regex(
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
exclusive_user_pattern = re.compile(
exclusive_user_regex
) # type: Optional[Pattern]
exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex)
else:
# We handle this case specially otherwise the constructed regex
# will always match
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _get_e2e_device_keys_txn(

txn.execute(sql, query_params)

result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
Expand Down
26 changes: 13 additions & 13 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def __init__(self, database: DatabasePool, db_conn, hs):
)

# Cache of event ID to list of auth event IDs and their depths.
self._event_auth_cache = LruCache(
self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache(
500000, "_event_auth_cache", size_callback=len
) # type: LruCache[str, List[Tuple[str, int]]]
)

self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)

Expand Down Expand Up @@ -137,10 +137,10 @@ def _get_auth_chain_ids_using_cover_index_txn(
initial_events = set(event_ids)

# All the events that we've found that are reachable from the events.
seen_events = set() # type: Set[str]
seen_events: Set[str] = set()

# A map from chain ID to max sequence number of the given events.
event_chains = {} # type: Dict[int, int]
event_chains: Dict[int, int] = {}

sql = """
SELECT event_id, chain_id, sequence_number
Expand Down Expand Up @@ -182,7 +182,7 @@ def _get_auth_chain_ids_using_cover_index_txn(
"""

# A map from chain ID to max sequence number *reachable* from any event ID.
chains = {} # type: Dict[int, int]
chains: Dict[int, int] = {}

# Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000):
Expand Down Expand Up @@ -353,14 +353,14 @@ def _get_auth_chain_difference_using_cover_index_txn(
initial_events = set(state_sets[0]).union(*state_sets[1:])

# Map from event_id -> (chain ID, seq no)
chain_info = {} # type: Dict[str, Tuple[int, int]]
chain_info: Dict[str, Tuple[int, int]] = {}

# Map from chain ID -> seq no -> event Id
chain_to_event = {} # type: Dict[int, Dict[int, str]]
chain_to_event: Dict[int, Dict[int, str]] = {}

# All the chains that we've found that are reachable from the state
# sets.
seen_chains = set() # type: Set[int]
seen_chains: Set[int] = set()

sql = """
SELECT event_id, chain_id, sequence_number
Expand Down Expand Up @@ -392,9 +392,9 @@ def _get_auth_chain_difference_using_cover_index_txn(

# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
set_to_chain = [] # type: List[Dict[int, int]]
set_to_chain: List[Dict[int, int]] = []
for state_set in state_sets:
chains = {} # type: Dict[int, int]
chains: Dict[int, int] = {}
set_to_chain.append(chains)

for event_id in state_set:
Expand Down Expand Up @@ -446,7 +446,7 @@ def _get_auth_chain_difference_using_cover_index_txn(

# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
chain_to_gap: Dict[int, Tuple[int, int]] = {}

for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
Expand Down Expand Up @@ -555,7 +555,7 @@ def _get_auth_chain_difference_txn(
}

# The sorted list of events whose auth chains we should walk.
search = [] # type: List[Tuple[int, str]]
search: List[Tuple[int, str]] = []

# We need to get the depth of the initial events for sorting purposes.
sql = """
Expand All @@ -578,7 +578,7 @@ def _get_auth_chain_difference_txn(
search.sort()

# Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]]
event_to_auth_events: Dict[str, Set[str]] = {}

base_sql = """
SELECT a.event_id, auth_id, depth
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
Expand Down
38 changes: 18 additions & 20 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,8 @@ def __init__(

# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen = (
self.store._backfill_id_gen
) # type: MultiWriterIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen

# This should only exist on instances that are configured to write
assert (
Expand Down Expand Up @@ -221,7 +219,7 @@ async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[st
Returns:
Filtered event ids
"""
results = [] # type: List[str]
results: List[str] = []

def _get_events_which_are_prevs_txn(txn, batch):
sql = """
Expand Down Expand Up @@ -508,7 +506,7 @@ def _add_chain_cover_index(
"""

# Map from event ID to chain ID/sequence number.
chain_map = {} # type: Dict[str, Tuple[int, int]]
chain_map: Dict[str, Tuple[int, int]] = {}

# Set of event IDs to calculate chain ID/seq numbers for.
events_to_calc_chain_id_for = set(event_to_room_id)
Expand Down Expand Up @@ -817,8 +815,8 @@ def _allocate_chain_ids(
# new chain if the sequence number has already been allocated.
#

existing_chains = set() # type: Set[int]
tree = [] # type: List[Tuple[str, Optional[str]]]
existing_chains: Set[int] = set()
tree: List[Tuple[str, Optional[str]]] = []

# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events before
Expand Down Expand Up @@ -848,7 +846,7 @@ def _allocate_chain_ids(
)
txn.execute(sql % (clause,), args)

chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn}

# Allocate the new events chain ID/sequence numbers.
#
Expand All @@ -858,8 +856,8 @@ def _allocate_chain_ids(
# number of new chain IDs in one call, replacing all temporary
# objects with real allocated chain IDs.

unallocated_chain_ids = set() # type: Set[object]
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
unallocated_chain_ids: Set[object] = set()
new_chain_tuples: Dict[str, Tuple[Any, int]] = {}
for event_id, auth_event_id in tree:
# If we reference an auth_event_id we fetch the allocated chain ID,
# either from the existing `chain_map` or the newly generated
Expand All @@ -870,7 +868,7 @@ def _allocate_chain_ids(
if not existing_chain_id:
existing_chain_id = chain_map[auth_event_id]

new_chain_tuple = None # type: Optional[Tuple[Any, int]]
new_chain_tuple: Optional[Tuple[Any, int]] = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
Expand All @@ -897,9 +895,9 @@ def _allocate_chain_ids(
)

# Map from potentially temporary chain ID to real chain ID
chain_id_to_allocated_map = dict(
chain_id_to_allocated_map: Dict[Any, int] = dict(
zip(unallocated_chain_ids, newly_allocated_chain_ids)
) # type: Dict[Any, int]
)
chain_id_to_allocated_map.update((c, c) for c in existing_chains)

return {
Expand Down Expand Up @@ -1175,9 +1173,9 @@ def _filter_events_and_contexts_for_duplicates(
Returns:
list[(EventBase, EventContext)]: filtered list
"""
new_events_and_contexts = (
OrderedDict()
) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
new_events_and_contexts: OrderedDict[
str, Tuple[EventBase, EventContext]
] = OrderedDict()
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
Expand Down Expand Up @@ -1205,7 +1203,7 @@ def _update_room_depths_txn(
we are persisting
backfilled (bool): True if the events were backfilled
"""
depth_updates = {} # type: Dict[str, int]
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
Expand Down Expand Up @@ -1885,7 +1883,7 @@ def _set_push_actions_for_event_and_users_txn(
),
)

room_to_event_ids = {} # type: Dict[str, List[str]]
room_to_event_ids: Dict[str, List[str]] = {}
for e, _ in events_and_contexts:
room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)

Expand Down Expand Up @@ -2012,7 +2010,7 @@ def _update_backward_extremeties(self, txn, events):

Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {} # type: Dict[str, List[EventBase]]
events_by_room: Dict[str, List[EventBase]] = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)

Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/events_bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,9 @@ def _calculate_chain_cover_txn(
event_to_types = {row[0]: (row[1], row[2]) for row in rows}

# Calculate the new last position we've processed up to.
new_last_depth = rows[-1][3] if rows else last_depth # type: int
new_last_stream = rows[-1][4] if rows else last_stream # type: int
new_last_room_id = rows[-1][5] if rows else "" # type: str
new_last_depth: int = rows[-1][3] if rows else last_depth
new_last_stream: int = rows[-1][4] if rows else last_stream
new_last_room_id: str = rows[-1][5] if rows else ""

# Map from room_id to last depth/stream_ordering processed for the room,
# excluding the last room (which we're likely still processing). We also
Expand All @@ -989,7 +989,7 @@ def _calculate_chain_cover_txn(
retcols=("event_id", "auth_id"),
)

event_to_auth_chain = {} # type: Dict[str, List[str]]
event_to_auth_chain: Dict[str, List[str]] = {}
for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])

Expand Down
Loading