Skip to content

Commit

Permalink
Tweak type annotations for mongo storage
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed May 12, 2024
1 parent a08f43a commit 171b412
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
39 changes: 23 additions & 16 deletions limits/storage/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
# mypy: disable-error-code="no-untyped-def, misc, type-arg"

from __future__ import annotations

import calendar
import datetime
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from typing import Any, cast

from deprecated.sphinx import versionadded

from limits.typing import Dict, Optional, Tuple, Type, Union
from limits.typing import (
Dict,
MongoClient,
MongoCollection,
MongoDatabase,
Optional,
Tuple,
Type,
Union,
)

from ..util import get_dependency
from .base import MovingWindowSupport, Storage

if TYPE_CHECKING:
import pymongo


class MongoDBStorageBase(Storage, MovingWindowSupport, ABC):
"""
Expand Down Expand Up @@ -53,10 +57,10 @@ def __init__(
self.lib_errors, _ = get_dependency("pymongo.errors")
self._storage_uri = uri
self._storage_options = options
self._storage: Optional[Any] = None
self._storage: Optional[MongoClient] = None

@property
def storage(self):
def storage(self) -> MongoClient:
if self._storage is None:
self._storage = self._init_mongo_client(
self._storage_uri, **self._storage_options
Expand All @@ -65,19 +69,21 @@ def storage(self):
return self._storage

@property
def _database(self):
def _database(self) -> MongoDatabase:
return self.storage[self._database_name]

@property
def counters(self):
def counters(self) -> MongoCollection:
return self._database["counters"]

@property
def windows(self):
def windows(self) -> MongoCollection:
return self._database["windows"]

@abstractmethod
def _init_mongo_client(self, uri: Optional[str], **options: Union[int, str, bool]):
def _init_mongo_client(
self, uri: Optional[str], **options: Union[int, str, bool]
) -> MongoClient:
raise NotImplementedError()

@property
Expand Down Expand Up @@ -275,6 +281,7 @@ def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> b
class MongoDBStorage(MongoDBStorageBase):
STORAGE_SCHEME = ["mongodb", "mongodb+srv"]

def _init_mongo_client(self, uri: Optional[str], **options: Union[int, str, bool]):
client: "pymongo.MongoClient" = self.lib.MongoClient(uri, **options)
return client
def _init_mongo_client(
self, uri: Optional[str], **options: Union[int, str, bool]
) -> MongoClient:
return cast(MongoClient, self.lib.MongoClient(uri, **options))
10 changes: 10 additions & 0 deletions limits/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Expand All @@ -8,6 +9,7 @@
Optional,
Tuple,
Type,
TypeAlias,
TypeVar,
Union,
)
Expand All @@ -24,6 +26,7 @@
if TYPE_CHECKING:
import coredis
import coredis.commands.script
import pymongo
import redis


Expand Down Expand Up @@ -107,6 +110,10 @@ class ScriptP(Protocol[R_co]):
def __call__(self, keys: List[Serializable], args: List[Serializable]) -> R_co: ...


MongoClient: TypeAlias = "pymongo.MongoClient[Dict[str, Any]]" # type:ignore[misc]
MongoDatabase: TypeAlias = "pymongo.database.Database[Dict[str, Any]]" # type:ignore[misc]
MongoCollection: TypeAlias = "pymongo.collection.Collection[Dict[str, Any]]" # type:ignore[misc]

__all__ = [
"AsyncRedisClient",
"Awaitable",
Expand All @@ -118,6 +125,9 @@ def __call__(self, keys: List[Serializable], args: List[Serializable]) -> R_co:
"ItemP",
"List",
"MemcachedClientP",
"MongoClient",
"MongoCollection",
"MongoDatabase",
"NamedTuple",
"Optional",
"P",
Expand Down

0 comments on commit 171b412

Please sign in to comment.