13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
from enum import Enum
16
- from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Tuple
16
+ from typing import (
17
+ TYPE_CHECKING ,
18
+ Any ,
19
+ Collection ,
20
+ Dict ,
21
+ Iterable ,
22
+ List ,
23
+ Optional ,
24
+ Tuple ,
25
+ Union ,
26
+ )
17
27
18
28
from synapse .storage ._base import SQLBaseStore
19
- from synapse .storage .database import DatabasePool
29
+ from synapse .storage .database import (
30
+ DatabasePool ,
31
+ LoggingDatabaseConnection ,
32
+ LoggingTransaction ,
33
+ )
34
+ from synapse .types import JsonDict , UserID
20
35
21
36
if TYPE_CHECKING :
22
37
from synapse .server import HomeServer
@@ -46,7 +61,12 @@ class MediaSortOrder(Enum):
46
61
47
62
48
63
class MediaRepositoryBackgroundUpdateStore (SQLBaseStore ):
49
- def __init__ (self , database : DatabasePool , db_conn , hs : "HomeServer" ):
64
+ def __init__ (
65
+ self ,
66
+ database : DatabasePool ,
67
+ db_conn : LoggingDatabaseConnection ,
68
+ hs : "HomeServer" ,
69
+ ):
50
70
super ().__init__ (database , db_conn , hs )
51
71
52
72
self .db_pool .updates .register_background_index_update (
@@ -102,13 +122,15 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
102
122
self ._drop_media_index_without_method ,
103
123
)
104
124
105
- async def _drop_media_index_without_method (self , progress , batch_size ):
125
+ async def _drop_media_index_without_method (
126
+ self , progress : JsonDict , batch_size : int
127
+ ) -> int :
106
128
"""background update handler which removes the old constraints.
107
129
108
130
Note that this is only run on postgres.
109
131
"""
110
132
111
- def f (txn ) :
133
+ def f (txn : LoggingTransaction ) -> None :
112
134
txn .execute (
113
135
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
114
136
)
@@ -126,7 +148,12 @@ def f(txn):
126
148
class MediaRepositoryStore (MediaRepositoryBackgroundUpdateStore ):
127
149
"""Persistence for attachments and avatars"""
128
150
129
- def __init__ (self , database : DatabasePool , db_conn , hs : "HomeServer" ):
151
+ def __init__ (
152
+ self ,
153
+ database : DatabasePool ,
154
+ db_conn : LoggingDatabaseConnection ,
155
+ hs : "HomeServer" ,
156
+ ):
130
157
super ().__init__ (database , db_conn , hs )
131
158
self .server_name = hs .hostname
132
159
@@ -174,7 +201,9 @@ async def get_local_media_by_user_paginate(
174
201
plus the total count of all the user's media
175
202
"""
176
203
177
- def get_local_media_by_user_paginate_txn (txn ):
204
+ def get_local_media_by_user_paginate_txn (
205
+ txn : LoggingTransaction ,
206
+ ) -> Tuple [List [Dict [str , Any ]], int ]:
178
207
179
208
# Set ordering
180
209
order_by_column = MediaSortOrder (order_by ).value
@@ -184,14 +213,14 @@ def get_local_media_by_user_paginate_txn(txn):
184
213
else :
185
214
order = "ASC"
186
215
187
- args = [user_id ]
216
+ args : List [ Union [ str , int ]] = [user_id ]
188
217
sql = """
189
218
SELECT COUNT(*) as total_media
190
219
FROM local_media_repository
191
220
WHERE user_id = ?
192
221
"""
193
222
txn .execute (sql , args )
194
- count = txn .fetchone ()[0 ]
223
+ count = txn .fetchone ()[0 ] # type: ignore[index]
195
224
196
225
sql = """
197
226
SELECT
@@ -268,7 +297,7 @@ async def get_local_media_before(
268
297
)
269
298
sql += sql_keep
270
299
271
- def _get_local_media_before_txn (txn ) :
300
+ def _get_local_media_before_txn (txn : LoggingTransaction ) -> List [ str ] :
272
301
txn .execute (sql , (before_ts , before_ts , size_gt ))
273
302
return [row [0 ] for row in txn ]
274
303
@@ -278,13 +307,13 @@ def _get_local_media_before_txn(txn):
278
307
279
308
async def store_local_media (
280
309
self ,
281
- media_id ,
282
- media_type ,
283
- time_now_ms ,
284
- upload_name ,
285
- media_length ,
286
- user_id ,
287
- url_cache = None ,
310
+ media_id : str ,
311
+ media_type : str ,
312
+ time_now_ms : int ,
313
+ upload_name : Optional [ str ] ,
314
+ media_length : int ,
315
+ user_id : UserID ,
316
+ url_cache : Optional [ str ] = None ,
288
317
) -> None :
289
318
await self .db_pool .simple_insert (
290
319
"local_media_repository" ,
@@ -315,7 +344,7 @@ async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
315
344
None if the URL isn't cached.
316
345
"""
317
346
318
- def get_url_cache_txn (txn ) :
347
+ def get_url_cache_txn (txn : LoggingTransaction ) -> Optional [ Dict [ str , Any ]] :
319
348
# get the most recently cached result (relative to the given ts)
320
349
sql = (
321
350
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
@@ -359,7 +388,7 @@ def get_url_cache_txn(txn):
359
388
360
389
async def store_url_cache (
361
390
self , url , response_code , etag , expires_ts , og , media_id , download_ts
362
- ):
391
+ ) -> None :
363
392
await self .db_pool .simple_insert (
364
393
"local_media_repository_url_cache" ,
365
394
{
@@ -390,13 +419,13 @@ async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]
390
419
391
420
async def store_local_thumbnail (
392
421
self ,
393
- media_id ,
394
- thumbnail_width ,
395
- thumbnail_height ,
396
- thumbnail_type ,
397
- thumbnail_method ,
398
- thumbnail_length ,
399
- ):
422
+ media_id : str ,
423
+ thumbnail_width : int ,
424
+ thumbnail_height : int ,
425
+ thumbnail_type : str ,
426
+ thumbnail_method : str ,
427
+ thumbnail_length : int ,
428
+ ) -> None :
400
429
await self .db_pool .simple_upsert (
401
430
table = "local_media_repository_thumbnails" ,
402
431
keyvalues = {
@@ -430,14 +459,14 @@ async def get_cached_remote_media(
430
459
431
460
async def store_cached_remote_media (
432
461
self ,
433
- origin ,
434
- media_id ,
435
- media_type ,
436
- media_length ,
437
- time_now_ms ,
438
- upload_name ,
439
- filesystem_id ,
440
- ):
462
+ origin : str ,
463
+ media_id : str ,
464
+ media_type : str ,
465
+ media_length : int ,
466
+ time_now_ms : int ,
467
+ upload_name : Optional [ str ] ,
468
+ filesystem_id : str ,
469
+ ) -> None :
441
470
await self .db_pool .simple_insert (
442
471
"remote_media_cache" ,
443
472
{
@@ -458,7 +487,7 @@ async def update_cached_last_access_time(
458
487
local_media : Iterable [str ],
459
488
remote_media : Iterable [Tuple [str , str ]],
460
489
time_ms : int ,
461
- ):
490
+ ) -> None :
462
491
"""Updates the last access time of the given media
463
492
464
493
Args:
@@ -467,7 +496,7 @@ async def update_cached_last_access_time(
467
496
time_ms: Current time in milliseconds
468
497
"""
469
498
470
- def update_cache_txn (txn ) :
499
+ def update_cache_txn (txn : LoggingTransaction ) -> None :
471
500
sql = (
472
501
"UPDATE remote_media_cache SET last_access_ts = ?"
473
502
" WHERE media_origin = ? AND media_id = ?"
@@ -488,7 +517,7 @@ def update_cache_txn(txn):
488
517
489
518
txn .execute_batch (sql , ((time_ms , media_id ) for media_id in local_media ))
490
519
491
- return await self .db_pool .runInteraction (
520
+ await self .db_pool .runInteraction (
492
521
"update_cached_last_access_time" , update_cache_txn
493
522
)
494
523
@@ -542,15 +571,15 @@ async def get_remote_media_thumbnail(
542
571
543
572
async def store_remote_media_thumbnail (
544
573
self ,
545
- origin ,
546
- media_id ,
547
- filesystem_id ,
548
- thumbnail_width ,
549
- thumbnail_height ,
550
- thumbnail_type ,
551
- thumbnail_method ,
552
- thumbnail_length ,
553
- ):
574
+ origin : str ,
575
+ media_id : str ,
576
+ filesystem_id : str ,
577
+ thumbnail_width : int ,
578
+ thumbnail_height : int ,
579
+ thumbnail_type : str ,
580
+ thumbnail_method : str ,
581
+ thumbnail_length : int ,
582
+ ) -> None :
554
583
await self .db_pool .simple_upsert (
555
584
table = "remote_media_cache_thumbnails" ,
556
585
keyvalues = {
@@ -566,7 +595,7 @@ async def store_remote_media_thumbnail(
566
595
desc = "store_remote_media_thumbnail" ,
567
596
)
568
597
569
- async def get_remote_media_before (self , before_ts ) :
598
+ async def get_remote_media_before (self , before_ts : int ) -> List [ Dict [ str , str ]] :
570
599
sql = (
571
600
"SELECT media_origin, media_id, filesystem_id"
572
601
" FROM remote_media_cache"
@@ -602,26 +631,24 @@ async def get_expired_url_cache(self, now_ts: int) -> List[str]:
602
631
" LIMIT 500"
603
632
)
604
633
605
- def _get_expired_url_cache_txn (txn ) :
634
+ def _get_expired_url_cache_txn (txn : LoggingTransaction ) -> List [ str ] :
606
635
txn .execute (sql , (now_ts ,))
607
636
return [row [0 ] for row in txn ]
608
637
609
638
return await self .db_pool .runInteraction (
610
639
"get_expired_url_cache" , _get_expired_url_cache_txn
611
640
)
612
641
613
- async def delete_url_cache (self , media_ids ) :
642
+ async def delete_url_cache (self , media_ids : Collection [ str ]) -> None :
614
643
if len (media_ids ) == 0 :
615
644
return
616
645
617
646
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
618
647
619
- def _delete_url_cache_txn (txn ) :
648
+ def _delete_url_cache_txn (txn : LoggingTransaction ) -> None :
620
649
txn .execute_batch (sql , [(media_id ,) for media_id in media_ids ])
621
650
622
- return await self .db_pool .runInteraction (
623
- "delete_url_cache" , _delete_url_cache_txn
624
- )
651
+ await self .db_pool .runInteraction ("delete_url_cache" , _delete_url_cache_txn )
625
652
626
653
async def get_url_cache_media_before (self , before_ts : int ) -> List [str ]:
627
654
sql = (
@@ -631,19 +658,19 @@ async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
631
658
" LIMIT 500"
632
659
)
633
660
634
- def _get_url_cache_media_before_txn (txn ) :
661
+ def _get_url_cache_media_before_txn (txn : LoggingTransaction ) -> List [ str ] :
635
662
txn .execute (sql , (before_ts ,))
636
663
return [row [0 ] for row in txn ]
637
664
638
665
return await self .db_pool .runInteraction (
639
666
"get_url_cache_media_before" , _get_url_cache_media_before_txn
640
667
)
641
668
642
- async def delete_url_cache_media (self , media_ids ) :
669
+ async def delete_url_cache_media (self , media_ids : Collection [ str ]) -> None :
643
670
if len (media_ids ) == 0 :
644
671
return
645
672
646
- def _delete_url_cache_media_txn (txn ) :
673
+ def _delete_url_cache_media_txn (txn : LoggingTransaction ) -> None :
647
674
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
648
675
649
676
txn .execute_batch (sql , [(media_id ,) for media_id in media_ids ])
@@ -652,6 +679,6 @@ def _delete_url_cache_media_txn(txn):
652
679
653
680
txn .execute_batch (sql , [(media_id ,) for media_id in media_ids ])
654
681
655
- return await self .db_pool .runInteraction (
682
+ await self .db_pool .runInteraction (
656
683
"delete_url_cache_media" , _delete_url_cache_media_txn
657
684
)
0 commit comments