|
6 | 6 | from typing import Tuple, Union, Optional, Dict, TYPE_CHECKING
|
7 | 7 | import logging
|
8 | 8 | import asyncio
|
9 |
| -import hashlib |
10 |
| -import hmac |
11 | 9 |
|
12 | 10 | from mautrix.types import (Filter, RoomFilter, EventFilter, RoomEventFilter, StateFilter, EventType,
|
13 | 11 | RoomID, Serializable, JSON, MessageEvent, EncryptedEvent, StateEvent,
|
14 |
| - EncryptedMegolmEventContent, RequestedKeyInfo, RoomKeyWithheldCode) |
| 12 | + EncryptedMegolmEventContent, RequestedKeyInfo, RoomKeyWithheldCode, |
| 13 | + LoginType) |
15 | 14 | from mautrix.appservice import AppService
|
16 | 15 | from mautrix.errors import EncryptionError
|
17 | 16 | from mautrix.client import Client, SyncStore
|
18 | 17 | from mautrix.crypto import (OlmMachine, CryptoStore, StateStore, PgCryptoStore, PickleCryptoStore,
|
19 | 18 | DeviceIdentity, RejectKeyShare, TrustState)
|
20 | 19 | from mautrix.util.logging import TraceLogger
|
21 | 20 |
|
22 |
| -from .crypto_state_store import GetPortalFunc, PgCryptoStateStore, SQLCryptoStateStore |
| 21 | +from .crypto_state_store import PgCryptoStateStore, SQLCryptoStateStore |
23 | 22 |
|
24 | 23 | try:
|
25 | 24 | from mautrix.client.state_store.sqlalchemy import UserProfile
|
@@ -47,23 +46,21 @@ class EncryptionManager:
|
47 | 46 |
|
48 | 47 | bridge: 'Bridge'
|
49 | 48 | az: AppService
|
50 |
| - login_shared_secret: bytes |
51 | 49 | _id_prefix: str
|
52 | 50 | _id_suffix: str
|
53 | 51 |
|
54 | 52 | sync_task: asyncio.Future
|
55 | 53 | _share_session_events: Dict[RoomID, asyncio.Event]
|
56 | 54 |
|
57 |
| - def __init__(self, bridge: 'Bridge', login_shared_secret: str, homeserver_address: str, |
58 |
| - user_id_prefix: str, user_id_suffix: str, db_url: str, |
59 |
| - key_sharing_config: Dict[str, bool] = None) -> None: |
| 55 | + def __init__(self, bridge: 'Bridge', homeserver_address: str, user_id_prefix: str, |
| 56 | + user_id_suffix: str, db_url: str, key_sharing_config: Dict[str, bool] = None |
| 57 | + ) -> None: |
60 | 58 | self.loop = bridge.loop or asyncio.get_event_loop()
|
61 | 59 | self.bridge = bridge
|
62 | 60 | self.az = bridge.az
|
63 | 61 | self.device_name = bridge.name
|
64 | 62 | self._id_prefix = user_id_prefix
|
65 | 63 | self._id_suffix = user_id_suffix
|
66 |
| - self.login_shared_secret = login_shared_secret.encode("utf-8") |
67 | 64 | self._share_session_events = {}
|
68 | 65 | self.key_sharing_config = key_sharing_config or {}
|
69 | 66 | pickle_key = "mautrix.bridge.e2ee"
|
@@ -161,17 +158,22 @@ async def decrypt(self, evt: EncryptedEvent) -> MessageEvent:
|
161 | 158 | self.log.trace("Decrypted event %s: %s", evt.event_id, decrypted)
|
162 | 159 | return decrypted
|
163 | 160 |
|
| 161 | + async def check_server_support(self) -> bool: |
| 162 | + flows = await self.client.get_login_flows() |
| 163 | + return flows.supports_type(LoginType.APPSERVICE) |
| 164 | + |
164 | 165 | async def start(self) -> None:
|
165 | 166 | self.log.debug("Logging in with bridge bot user")
|
166 |
| - password = hmac.new(self.login_shared_secret, self.az.bot_mxid.encode("utf-8"), |
167 |
| - hashlib.sha512).hexdigest() |
168 | 167 | if self.crypto_db:
|
169 | 168 | await self.crypto_db.start()
|
170 | 169 | await self.crypto_store.open()
|
171 | 170 | device_id = await self.crypto_store.get_device_id()
|
172 | 171 | if device_id:
|
173 | 172 | self.log.debug(f"Found device ID in database: {device_id}")
|
174 |
| - await self.client.login(password=password, device_name=self.device_name, |
| 173 | + # We set the API token to the AS token here to authenticate the appservice login |
| 174 | + # It'll get overridden after the login |
| 175 | + self.client.api.token = self.az.as_token |
| 176 | + await self.client.login(login_type=LoginType.APPSERVICE, device_name=self.device_name, |
175 | 177 | device_id=device_id, store_access_token=True, update_hs_url=False)
|
176 | 178 | await self.crypto.load()
|
177 | 179 | if not device_id:
|
|
0 commit comments