Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: 新增 Lifespan.on_ready() 供适配器使用 #2483

Merged
merged 10 commits into from
Dec 10, 2023
12 changes: 0 additions & 12 deletions nonebot/drivers/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup

from ._lifespan import LIFESPAN_FUNC, Lifespan

try:
import uvicorn
from fastapi.responses import Response
Expand Down Expand Up @@ -97,8 +95,6 @@ def __init__(self, env: Env, config: NoneBotConfig):

self.fastapi_config: Config = Config(**config.dict())

self._lifespan = Lifespan()

self._server_app = FastAPI(
lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url,
Expand Down Expand Up @@ -155,14 +151,6 @@ async def _handle(websocket: WebSocket) -> None:
name=setup.name,
)

@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)

@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)

@contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup()
Expand Down
14 changes: 0 additions & 14 deletions nonebot/drivers/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver

from ._lifespan import LIFESPAN_FUNC, Lifespan

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
Expand All @@ -35,8 +33,6 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)

self._lifespan = Lifespan()

self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False

Expand All @@ -52,16 +48,6 @@ def logger(self):
"""none driver 使用的 logger"""
return logger

@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)

@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)

@override
def run(self, *args, **kwargs):
"""启动 none driver"""
Expand Down
27 changes: 3 additions & 24 deletions nonebot/drivers/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,7 @@
import asyncio
from functools import wraps
from typing_extensions import override
from typing import (
Any,
Dict,
List,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Coroutine,
cast,
)
from typing import Any, Dict, List, Tuple, Union, Optional, cast

from pydantic import BaseSettings

Expand Down Expand Up @@ -57,8 +46,6 @@
"Install with pip: `pip install nonebot2[quart]`"
) from e

_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])


def catch_closed(func):
@wraps(func)
Expand Down Expand Up @@ -102,6 +89,8 @@ def __init__(self, env: Env, config: NoneBotConfig):
self._server_app = Quart(
self.__class__.__qualname__, **self.quart_config.quart_extra
)
self._server_app.before_serving(self._lifespan.startup)
self._server_app.after_serving(self._lifespan.shutdown)

@property
@override
Expand Down Expand Up @@ -150,16 +139,6 @@ async def _handle() -> None:
view_func=_handle,
)

@override
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.before_serving(func) # type: ignore

@override
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.after_serving(func) # type: ignore

@override
def run(
self,
Expand Down
4 changes: 4 additions & 0 deletions nonebot/internal/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, AsyncGenerator

from nonebot.config import Config
from nonebot.internal.driver._lifespan import LIFESPAN_FUNC
from nonebot.internal.driver import (
Driver,
Request,
Expand Down Expand Up @@ -97,6 +98,9 @@ async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
async with self.driver.websocket(setup) as ws:
yield ws

def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self.driver._lifespan.on_ready(func)

@abc.abstractmethod
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class Lifespan:
def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = []
self._ready_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = []

def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
Expand All @@ -21,6 +22,10 @@ def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._shutdown_funcs.append(func)
return func

def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._ready_funcs.append(func)
return func

@staticmethod
async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC],
Expand All @@ -35,6 +40,9 @@ async def startup(self) -> None:
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)

if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)

async def shutdown(self) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
Expand Down
18 changes: 9 additions & 9 deletions nonebot/internal/driver/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
from typing_extensions import TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from typing import TYPE_CHECKING, Any, Set, Dict, Type, AsyncGenerator

from nonebot.log import logger
from nonebot.config import Env, Config
Expand All @@ -16,6 +16,7 @@
T_BotDisconnectionHook,
)

from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, env: Env, config: Config):
"""全局配置对象"""
self._bots: Dict[str, "Bot"] = {}
self._bot_tasks: Set[asyncio.Task] = set()
self._lifespan = Lifespan()

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -100,15 +102,13 @@ def run(self, *args, **kwargs):

self.on_shutdown(self._cleanup)

@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
raise NotImplementedError
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)

@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动器停止时执行的函数"""
raise NotImplementedError
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)

@classmethod
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:
Expand Down
32 changes: 24 additions & 8 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import pytest
from nonebug import App

from utils import FakeAdapter
from nonebot.adapters import Bot
from nonebot.params import Depends
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
from nonebot.drivers import (
URL,
Driver,
Expand All @@ -25,34 +25,50 @@


@pytest.mark.asyncio
async def test_lifespan():
lifespan = Lifespan()
@pytest.mark.parametrize(
"driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
)
async def test_lifespan(driver: Driver):
adapter = FakeAdapter(driver)

start_log = []
ready_log = []
shutdown_log = []

@lifespan.on_startup
@driver.on_startup
async def _startup1():
assert start_log == []
start_log.append(1)

@lifespan.on_startup
@driver.on_startup
async def _startup2():
assert start_log == [1]
start_log.append(2)

@lifespan.on_shutdown
@adapter.on_ready
def _ready1():
assert start_log == [1, 2]
assert ready_log == []
ready_log.append(1)

@adapter.on_ready
def _ready2():
assert ready_log == [1]
ready_log.append(2)

@driver.on_shutdown
async def _shutdown1():
assert shutdown_log == []
shutdown_log.append(1)

@lifespan.on_shutdown
@driver.on_shutdown
async def _shutdown2():
assert shutdown_log == [1]
shutdown_log.append(2)

async with lifespan:
async with driver._lifespan:
assert start_log == [1, 2]
assert ready_log == [1, 2]

assert shutdown_log == [1, 2]

Expand Down
Loading