Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bilibili调度新增回避策略 #573

Merged
merged 15 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions nonebot_bison/platform/bilibili/fsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import sys
import asyncio
import inspect
from enum import Enum
from dataclasses import dataclass
from collections.abc import Set as AbstractSet
from collections.abc import Callable, Sequence, Awaitable, AsyncGenerator
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Protocol, TypeAlias, TypedDict, NamedTuple, runtime_checkable

from nonebot import logger


class StrEnum(str, Enum): ...


TAddon = TypeVar("TAddon", contravariant=True)
TState = TypeVar("TState", contravariant=True)
TEvent = TypeVar("TEvent", contravariant=True)
TFSM = TypeVar("TFSM", bound="FSM", contravariant=True)


class StateError(Exception): ...


ActionReturn: TypeAlias = Any


@runtime_checkable
class SupportStateOnExit(Generic[TAddon], Protocol):
async def on_exit(self, addon: TAddon) -> None: ...


@runtime_checkable
class SupportStateOnEnter(Generic[TAddon], Protocol):
async def on_enter(self, addon: TAddon) -> None: ...


class Action(Generic[TState, TEvent, TAddon], Protocol):
async def __call__(self, from_: TState, event: TEvent, to: TState, addon: TAddon) -> ActionReturn: ...


ConditionFunc = Callable[[TAddon], Awaitable[bool]]


@dataclass(frozen=True)
class Condition(Generic[TAddon]):
call: ConditionFunc[TAddon]
not_: bool = False

def __repr__(self):
if inspect.isfunction(self.call) or inspect.isclass(self.call):
call_str = self.call.__name__
else:
call_str = repr(self.call)
return f"Condition(call={call_str})"

async def __call__(self, addon: TAddon) -> bool:
return (await self.call(addon)) ^ self.not_


# FIXME: Python 3.11+ 才支持 NamedTuple和TypedDict使用多继承添加泛型
# 所以什么时候 drop 3.10(?
if sys.version_info >= (3, 11) or TYPE_CHECKING:

class Transition(Generic[TState, TEvent, TAddon], NamedTuple):
action: Action[TState, TEvent, TAddon]
to: TState
conditions: AbstractSet[Condition[TAddon]] | None = None

class StateGraph(Generic[TState, TEvent, TAddon], TypedDict):
transitions: dict[
TState,
dict[
TEvent,
Transition[TState, TEvent, TAddon] | Sequence[Transition[TState, TEvent, TAddon]],
],
]
initial: TState

else:

class Transition(NamedTuple):
action: Action
to: Any
conditions: AbstractSet[Condition] | None = None

class StateGraph(TypedDict):
transitions: dict[Any, dict[Any, Transition]]
initial: Any


class FSM(Generic[TState, TEvent, TAddon]):
def __init__(self, graph: StateGraph[TState, TEvent, TAddon], addon: TAddon):
self.started = False
self.graph = graph
self.current_state = graph["initial"]
self.machine = self._core()
self.addon = addon

async def _core(self) -> AsyncGenerator[ActionReturn, TEvent]:
self.current_state = self.graph["initial"]
res = None
while True:
event = yield res

if not self.started:
raise StateError("FSM not started, please call start() first")

selected_transition = await self.cherry_pick(event)

logger.trace(f"exit state: {self.current_state}")
if isinstance(self.current_state, SupportStateOnExit):
logger.trace(f"do {self.current_state}.on_exit")
await self.current_state.on_exit(self.addon)

logger.trace(f"do action: {selected_transition.action}")
res = await selected_transition.action(self.current_state, event, selected_transition.to, self.addon)

logger.trace(f"enter state: {selected_transition.to}")
self.current_state = selected_transition.to

if isinstance(self.current_state, SupportStateOnEnter):
logger.trace(f"do {self.current_state}.on_enter")
await self.current_state.on_enter(self.addon)

async def start(self):
await anext(self.machine)
self.started = True
logger.trace(f"FSM started, initial state: {self.current_state}")

async def cherry_pick(self, event: TEvent) -> Transition[TState, TEvent, TAddon]:
transitions = self.graph["transitions"][self.current_state].get(event)
if transitions is None:
raise StateError(f"Invalid event {event} in state {self.current_state}")

if isinstance(transitions, Transition):
return transitions
elif isinstance(transitions, Sequence):
no_conds: list[Transition[TState, TEvent, TAddon]] = []
for transition in transitions:
if not transition.conditions:
no_conds.append(transition)
continue

values = await asyncio.gather(*(condition(self.addon) for condition in transition.conditions))

if all(values):
logger.trace(f"conditions {transition.conditions} passed")
return transition
else:
if no_conds:
return no_conds.pop()
else:
raise StateError(f"Invalid event {event} in state {self.current_state}")
else:
raise TypeError("Invalid transition type: {transitions}, expected Transition or Sequence[Transition]")

async def emit(self, event: TEvent):
return await self.machine.asend(event)

async def reset(self):
await self.machine.aclose()
self.started = False

del self.machine
self.machine = self._core()

logger.trace("FSM closed")
38 changes: 2 additions & 36 deletions nonebot_bison/platform/bilibili/platforms.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import re
import json
from copy import deepcopy
from functools import wraps
from enum import Enum, unique
from typing import NamedTuple
from typing_extensions import Self
from typing import TypeVar, NamedTuple
from collections.abc import Callable, Awaitable

from yarl import URL
from nonebot import logger
from httpx import AsyncClient
from httpx import URL as HttpxURL
from pydantic import Field, BaseModel, ValidationError
from nonebot.compat import type_validate_json, type_validate_python

Expand All @@ -19,6 +16,7 @@
from nonebot_bison.utils import text_similarity, decode_unicode_escapes
from nonebot_bison.types import Tag, Target, RawPost, ApiError, Category

from .retry import ApiCode352Error, retry_for_352
from .scheduler import BilibiliSite, BililiveSite, BiliBangumiSite
from ..platform import NewMessage, StatusChange, CategoryNotSupport, CategoryNotRecognize
from .models import (
Expand All @@ -38,38 +36,6 @@
LiveRecommendMajor,
)

B = TypeVar("B", bound="Bilibili")
MAX_352_RETRY_COUNT = 3


class ApiCode352Error(Exception):
def __init__(self, url: HttpxURL) -> None:
msg = f"api {url} error"
super().__init__(msg)


def retry_for_352(func: Callable[[B, Target], Awaitable[list[DynRawPost]]]):
retried_times = 0

@wraps(func)
async def wrapper(bls: B, *args, **kwargs):
nonlocal retried_times
try:
res = await func(bls, *args, **kwargs)
except ApiCode352Error as e:
if retried_times < MAX_352_RETRY_COUNT:
retried_times += 1
logger.warning(f"获取动态列表失败,尝试刷新cookie: {retried_times}/{MAX_352_RETRY_COUNT}")
await bls.ctx.refresh_client()
return [] # 返回空列表
else:
raise ApiError(e.args[0])
else:
retried_times = 0
return res

return wrapper


class _ProcessedText(NamedTuple):
title: str
Expand Down
Loading
Loading