Skip to content

Commit cc43498

Browse files
committed
Use yarl and add generic C-S discovery
Fixes #20 Fixes #25
1 parent 351cce7 commit cc43498

File tree

10 files changed

+234
-138
lines changed

10 files changed

+234
-138
lines changed

mautrix/api.py

+52-32
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,25 @@
1212
import logging
1313
import asyncio
1414

15+
from yarl import URL
1516
from aiohttp import ClientSession
1617
from aiohttp.client_exceptions import ContentTypeError, ClientError
1718

1819
from mautrix.errors import make_request_error, MatrixConnectionError
20+
from mautrix.util.logging import TraceLogger
1921

2022
if TYPE_CHECKING:
2123
from mautrix.types import JSON
2224

2325

2426
class APIPath(Enum):
25-
"""The known Matrix API path prefixes."""
26-
CLIENT = "/_matrix/client/r0"
27-
CLIENT_UNSTABLE = "/_matrix/client/unstable"
28-
MEDIA = "/_matrix/media/r0"
29-
IDENTITY = "/_matrix/identity/r0"
27+
"""
28+
The known Matrix API path prefixes.
29+
These don't start with a slash so they can be used nicely with yarl.
30+
"""
31+
CLIENT = "_matrix/client/r0"
32+
CLIENT_UNSTABLE = "_matrix/client/unstable"
33+
MEDIA = "_matrix/media/r0"
3034

3135
def __repr__(self):
3236
return self.value
@@ -60,7 +64,7 @@ class PathBuilder:
6064
>>> room_id = "!foo:example.com"
6165
>>> event_id = "$bar:example.com"
6266
>>> str(Path.rooms[room_id].event[event_id])
63-
"/_matrix/client/r0/rooms/%21foo%3Aexample.com/event/%24bar%3Aexample.com"
67+
"_matrix/client/r0/rooms/%21foo%3Aexample.com/event/%24bar%3Aexample.com"
6468
"""
6569

6670
def __init__(self, path: Union[str, APIPath] = "") -> None:
@@ -105,14 +109,21 @@ def __getitem__(self, append: Union[str, int]) -> 'PathBuilder':
105109
ClientPath = Path
106110
UnstableClientPath = PathBuilder(APIPath.CLIENT_UNSTABLE)
107111
MediaPath = PathBuilder(APIPath.MEDIA)
108-
IdentityPath = PathBuilder(APIPath.IDENTITY)
109112

110113

111114
class HTTPAPI:
112115
"""HTTPAPI is a simple asyncio Matrix API request sender."""
113116

114-
def __init__(self, base_url: str, token: str = "", *, client_session: ClientSession = None,
115-
txn_id: int = 0, log: Optional[logging.Logger] = None,
117+
base_url: URL
118+
token: str
119+
log: TraceLogger
120+
loop: asyncio.AbstractEventLoop
121+
session: ClientSession
122+
txn_id: Optional[int]
123+
124+
def __init__(self, base_url: Union[URL, str], token: str = "", *,
125+
client_session: ClientSession = None,
126+
txn_id: int = 0, log: Optional[TraceLogger] = None,
116127
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
117128
"""
118129
Args:
@@ -122,18 +133,18 @@ def __init__(self, base_url: str, token: str = "", *, client_session: ClientSess
122133
txn_id: The outgoing transaction ID to start with.
123134
log: The logging.Logger instance to log requests with.
124135
"""
125-
self.base_url: str = base_url
126-
self.token: str = token
127-
self.log: Optional[logging.Logger] = log or logging.getLogger("mau.http")
136+
self.base_url = URL(base_url)
137+
self.token = token
138+
self.log = log or logging.getLogger("mau.http")
128139
self.loop = loop or asyncio.get_event_loop()
129-
self.session: ClientSession = client_session or ClientSession(loop=self.loop)
140+
self.session = client_session or ClientSession(loop=self.loop)
130141
if txn_id is not None:
131-
self.txn_id: int = txn_id
142+
self.txn_id = txn_id
132143

133-
async def _send(self, method: Method, endpoint: str, content: Union[bytes, str],
144+
async def _send(self, method: Method, url: URL, content: Union[bytes, str],
134145
query_params: Dict[str, str], headers: Dict[str, str]) -> 'JSON':
135146
while True:
136-
request = self.session.request(str(method), endpoint, data=content,
147+
request = self.session.request(str(method), url, data=content,
137148
params=query_params, headers=headers)
138149
async with request as response:
139150
if response.status < 200 or response.status >= 300:
@@ -150,7 +161,10 @@ async def _send(self, method: Method, endpoint: str, content: Union[bytes, str],
150161

151162
if response.status == 429:
152163
resp = await response.json()
153-
await asyncio.sleep(resp["retry_after_ms"] / 1000, loop=self.loop)
164+
seconds = resp["retry_after_ms"] / 1000
165+
self.log.debug(f"Request to {url} returned 429, "
166+
f"waiting {seconds} seconds and retrying")
167+
await asyncio.sleep(seconds, loop=self.loop)
154168
else:
155169
return await response.json()
156170

@@ -161,7 +175,7 @@ def _log_request(self, method: Method, path: PathBuilder, content: Union[str, by
161175
log_content = content if not isinstance(content, bytes) else f"<{len(content)} bytes>"
162176
as_user = query_params.get("user_id", None)
163177
level = 1 if path == Path.sync else 5
164-
self.log.log(level, f"{method} {path} {log_content}".strip(" "),
178+
self.log.log(level, f"{method} /{path} {log_content}".strip(" "),
165179
extra={"matrix_http_request": {
166180
"method": str(method),
167181
"path": str(path),
@@ -170,23 +184,25 @@ def _log_request(self, method: Method, path: PathBuilder, content: Union[str, by
170184
"user": as_user,
171185
}})
172186

173-
async def request(self, method: Method, path: PathBuilder,
174-
content: Optional[Union['JSON', bytes, str]] = None,
187+
async def request(self, method: Method, path: Union[PathBuilder, str],
188+
content: Optional[Union[dict, list, bytes, str]] = None,
175189
headers: Optional[Dict[str, str]] = None,
176190
query_params: Optional[Dict[str, str]] = None) -> 'JSON':
177191
"""
178-
Make a raw HTTP request.
192+
Make a raw Matrix API request.
179193
180194
Args:
181195
method: The HTTP method to use.
182-
path: The API endpoint to call.
183-
Does not include the base path (e.g. /_matrix/client/r0).
184-
content: The content to post as a dict (json) or bytes/str (raw).
185-
headers: The dict of HTTP headers to send.
186-
query_params: The dict of query parameters to send.
196+
path: The full API endpoint to call (including the _matrix/... prefix)
197+
content: The content to post as a dict/list (will be serialized as JSON)
198+
or bytes/str (will be sent as-is).
199+
headers: A dict of HTTP headers to send.
200+
If the headers don't contain ``Content-Type``, it'll be set to ``application/json``.
201+
The ``Authorization`` header is always overridden if :attr:`token` is set.
202+
query_params: A dict of query parameters to send.
187203
188204
Returns:
189-
The response as a dict.
205+
The parsed response JSON.
190206
"""
191207
content = content or {}
192208
headers = headers or {}
@@ -203,18 +219,22 @@ async def request(self, method: Method, path: PathBuilder,
203219

204220
self._log_request(method, path, content, orig_content, query_params)
205221

206-
endpoint = self.base_url + str(path)
222+
path = str(path)
223+
if path and path[0] == "/":
224+
path = path[1:]
225+
207226
try:
208-
return await self._send(method, endpoint, content, query_params, headers or {})
227+
return await self._send(method, self.base_url / path,
228+
content, query_params, headers or {})
209229
except ClientError as e:
210230
raise MatrixConnectionError(str(e)) from e
211231

212232
def get_txn_id(self) -> str:
213233
"""Get a new unique transaction ID."""
214234
self.txn_id += 1
215-
return str(self.txn_id) + str(int(time() * 1000))
235+
return f"mautrix-python_R{self.txn_id}@T{int(time() * 1000)}"
216236

217-
def get_download_url(self, mxc_uri: str, download_type: str = "download") -> str:
237+
def get_download_url(self, mxc_uri: str, download_type: str = "download") -> URL:
218238
"""
219239
Get the full HTTP URL to download a mxc:// URI.
220240
@@ -234,6 +254,6 @@ def get_download_url(self, mxc_uri: str, download_type: str = "download") -> str
234254
"https://matrix.org/_matrix/media/r0/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6"
235255
"""
236256
if mxc_uri.startswith("mxc://"):
237-
return f"{self.base_url}{APIPath.MEDIA}/{download_type}/{mxc_uri[6:]}"
257+
return self.base_url / str(APIPath.MEDIA) / download_type / mxc_uri[6:]
238258
else:
239259
raise ValueError("MXC URI did not begin with `mxc://`")

mautrix/appservice/api/appservice.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99

1010
from aiohttp import ClientSession
11+
from yarl import URL
1112

1213
from mautrix.types import UserID
1314
from mautrix.api import HTTPAPI, Method, PathBuilder
@@ -39,7 +40,7 @@ class AppServiceAPI(HTTPAPI):
3940

4041
_bot_intent: Optional[IntentAPI]
4142

42-
def __init__(self, base_url: str, bot_mxid: UserID = None, token: str = None,
43+
def __init__(self, base_url: Union[URL, str], bot_mxid: UserID = None, token: str = None,
4344
identity: Optional[UserID] = None, log: TraceLogger = None,
4445
state_store: 'ASStateStore' = None, client_session: ClientSession = None,
4546
child: bool = False, real_user: bool = False,
@@ -96,7 +97,8 @@ def user(self, user: UserID) -> 'ChildAppServiceAPI':
9697
self.children[user] = child
9798
return child
9899

99-
def real_user(self, mxid: UserID, token: str, base_url: Optional[str] = None) -> 'AppServiceAPI':
100+
def real_user(self, mxid: UserID, token: str, base_url: Optional[URL] = None
101+
) -> 'AppServiceAPI':
100102
"""
101103
Get the AppServiceAPI for a real (non-appservice-managed) Matrix user.
102104

0 commit comments

Comments
 (0)