Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to user admin API #9521

Merged
merged 4 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9521.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to user admin API.
84 changes: 55 additions & 29 deletions synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import hmac
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple

from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
Expand Down Expand Up @@ -47,13 +47,15 @@
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()

async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
await assert_requester_is_admin(self.auth, request)

Expand Down Expand Up @@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet):
otherwise an error.
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
Expand All @@ -165,7 +167,7 @@ def __init__(self, hs):
self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool()

async def on_GET(self, request, user_id):
async def on_GET(self, request: SynapseRequest, user_id) -> Tuple[int, JsonDict]:
dklimpel marked this conversation as resolved.
Show resolved Hide resolved
await assert_requester_is_admin(self.auth, request)

target_user = UserID.from_string(user_id)
Expand All @@ -179,7 +181,9 @@ async def on_GET(self, request, user_id):

return 200, ret

async def on_PUT(self, request, user_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)

Expand Down Expand Up @@ -272,7 +276,9 @@ async def on_PUT(self, request, user_id):
target_user.to_string()
)

user = await self.admin_handler.get_user(target_user)
user = await self.admin_handler.get_user(
target_user
) # type: JsonDict # type: ignore
clokep marked this conversation as resolved.
Show resolved Hide resolved
return 200, user

else: # create user
Expand Down Expand Up @@ -330,7 +336,9 @@ async def on_PUT(self, request, user_id):
target_user, requester, body["avatar_url"], True
)

ret = await self.admin_handler.get_user(target_user)
ret = await self.admin_handler.get_user(
target_user
) # type: JsonDict # type: ignore

return 201, ret

Expand All @@ -346,10 +354,10 @@ class UserRegisterServlet(RestServlet):
PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor()
self.nonces = {}
self.nonces = {} # type: Dict
clokep marked this conversation as resolved.
Show resolved Hide resolved
self.hs = hs

def _clear_old_nonces(self):
Expand All @@ -362,7 +370,7 @@ def _clear_old_nonces(self):
if now - v > self.NONCE_TIMEOUT:
del self.nonces[k]

def on_GET(self, request):
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""
Generate a new nonce.
"""
Expand All @@ -372,7 +380,7 @@ def on_GET(self, request):
self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce}

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces()

if not self.hs.config.registration_shared_secret:
Expand Down Expand Up @@ -478,12 +486,14 @@ class WhoisRestServlet(RestServlet):
client_patterns("/admin" + path_regex, v1=True)
)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()

async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
auth_user = requester.user
Expand All @@ -508,7 +518,9 @@ def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()

async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)

Expand Down Expand Up @@ -550,7 +562,7 @@ def __init__(self, hs):
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

body = parse_json_object_from_request(request)
Expand Down Expand Up @@ -584,14 +596,16 @@ class ResetPasswordRestServlet(RestServlet):

PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()

async def on_POST(self, request, target_user_id):
async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
"""
Expand Down Expand Up @@ -626,12 +640,14 @@ class SearchUsersRestServlet(RestServlet):

PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()

async def on_GET(self, request, target_user_id):
async def on_GET(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, Optional[JsonDict]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one optional? I think search_users always returns a dict?

Copy link
Contributor Author

@dklimpel dklimpel Mar 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it can be None if search returns an empty result.

async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
"""Function to search users list for one or more users with
the matched term.
Args:
term: search term
Returns:
A list of dictionaries or None.
"""
return await self.db_pool.simple_search_list(
table="users",
term=term,
col="name",
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
desc="search_users",
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! I was looking at the handler code. 👍 This is why adding the type hints is useful!

"""Get request to search user table for specific users according to
search term.
This needs user to have a administrator access in Synapse.
Expand Down Expand Up @@ -682,12 +698,14 @@ class UserAdminServlet(RestServlet):

PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()

async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

target_user = UserID.from_string(user_id)
Expand All @@ -699,7 +717,9 @@ async def on_GET(self, request, user_id):

return 200, {"admin": is_admin}

async def on_PUT(self, request, user_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
Expand Down Expand Up @@ -730,12 +750,14 @@ class UserMembershipRestServlet(RestServlet):

PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()

async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

room_ids = await self.store.get_rooms_for_user(user_id)
Expand All @@ -758,7 +780,7 @@ class PushersRestServlet(RestServlet):

PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
self.auth = hs.get_auth()
Expand Down Expand Up @@ -799,7 +821,7 @@ class UserMediaRestServlet(RestServlet):

PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
Expand Down Expand Up @@ -891,7 +913,9 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()

async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
Expand Down Expand Up @@ -943,7 +967,9 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()

async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

if not self.hs.is_mine_id(user_id):
Expand Down
10 changes: 5 additions & 5 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Optional, Tuple

from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
Expand All @@ -27,7 +27,7 @@
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import get_domain_from_id
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache

from .account_data import AccountDataStore
Expand Down Expand Up @@ -264,7 +264,7 @@ def _get_active_presence(self, db_conn):

return [UserPresenceState(**row) for row in rows]

async def get_users(self) -> List[Dict[str, Any]]:
async def get_users(self) -> JsonDict:
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""Function to retrieve a list of users in users table.

Returns:
Expand Down Expand Up @@ -292,7 +292,7 @@ async def get_users_paginate(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
) -> Tuple[List[Dict[str, Any]], int]:
) -> Tuple[JsonDict, int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Expand Down Expand Up @@ -353,7 +353,7 @@ def get_users_paginate_txn(txn):
"get_users_paginate_txn", get_users_paginate_txn
)

async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
async def search_users(self, term: str) -> Optional[JsonDict]:
"""Function to search users list for one or more users with
the matched term.

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def get_local_media_by_user_paginate(
start: int,
limit: int,
user_id: str,
order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
order_by: str = MediaSortOrder.CREATED_TS.value,
direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media
Expand Down