Переработана система Middleware
This commit is contained in:
parent
e0e7bb53b6
commit
c8dd896691
@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Optional
|
||||
import functools
|
||||
from typing import Any, Awaitable, Callable, Dict, List, TYPE_CHECKING, Optional
|
||||
from asyncio.exceptions import TimeoutError as AsyncioTimeoutError
|
||||
|
||||
from aiohttp import ClientConnectorError
|
||||
@ -120,6 +121,17 @@ class Dispatcher:
|
||||
self.bot._me = me
|
||||
|
||||
logger_dp.info(f'Бот: @{me.username} first_name={me.first_name} id={me.user_id}')
|
||||
|
||||
def build_middleware_chain(
|
||||
self,
|
||||
middlewares: list[BaseMiddleware],
|
||||
handler: Callable[[Any, dict[str, Any]], Awaitable[Any]]
|
||||
) -> Callable[[Any, dict[str, Any]], Awaitable[Any]]:
|
||||
|
||||
for mw in reversed(middlewares):
|
||||
handler = functools.partial(mw, handler)
|
||||
|
||||
return handler
|
||||
|
||||
def include_routers(self, *routers: 'Router'):
|
||||
|
||||
@ -169,38 +181,18 @@ class Dispatcher:
|
||||
new_ctx = MemoryContext(chat_id, user_id)
|
||||
self.contexts.append(new_ctx)
|
||||
return new_ctx
|
||||
|
||||
async def process_middlewares(
|
||||
self,
|
||||
middlewares: List[BaseMiddleware],
|
||||
event_object: UpdateUnion,
|
||||
result_data_kwargs: Dict[str, Any]
|
||||
):
|
||||
|
||||
async def call_handler(self, handler, event_object, data):
|
||||
|
||||
"""
|
||||
Последовательно обрабатывает middleware цепочку.
|
||||
|
||||
:param middlewares: Список middleware.
|
||||
:param event_object: Объект события.
|
||||
:param result_data_kwargs: Аргументы, передаваемые обработчику.
|
||||
:return: Изменённые аргументы или None.
|
||||
Правка аргументов конечной функции хендлера и ее вызов
|
||||
"""
|
||||
|
||||
for middleware in middlewares:
|
||||
result = await middleware.process_middleware(
|
||||
event_object=event_object,
|
||||
result_data_kwargs=result_data_kwargs
|
||||
)
|
||||
|
||||
if result is None or result is False:
|
||||
return
|
||||
|
||||
elif result is True:
|
||||
continue
|
||||
|
||||
result_data_kwargs.update(result)
|
||||
func_args = handler.func_event.__annotations__.keys()
|
||||
kwargs_filtered = {k: v for k, v in data.items() if k in func_args}
|
||||
|
||||
await handler.func_event(event_object, **kwargs_filtered)
|
||||
|
||||
return result_data_kwargs
|
||||
|
||||
async def handle(self, event_object: UpdateUnion):
|
||||
|
||||
@ -232,12 +224,6 @@ class Dispatcher:
|
||||
if not filter_attrs(event_object, *router.filters):
|
||||
continue
|
||||
|
||||
kwargs = await self.process_middlewares(
|
||||
middlewares=router.middlewares,
|
||||
event_object=event_object,
|
||||
result_data_kwargs=kwargs
|
||||
)
|
||||
|
||||
for handler in router.event_handlers:
|
||||
|
||||
if not handler.update_type == event_object.update_type:
|
||||
@ -252,20 +238,19 @@ class Dispatcher:
|
||||
|
||||
func_args = handler.func_event.__annotations__.keys()
|
||||
|
||||
kwargs = await self.process_middlewares(
|
||||
middlewares=handler.middlewares,
|
||||
event_object=event_object,
|
||||
result_data_kwargs=kwargs
|
||||
)
|
||||
if isinstance(router, Router):
|
||||
full_middlewares = self.middlewares + router.middlewares + handler.middlewares
|
||||
elif isinstance(router, Dispatcher):
|
||||
full_middlewares = self.middlewares + handler.middlewares
|
||||
|
||||
if not kwargs:
|
||||
continue
|
||||
handler_chain = self.build_middleware_chain(
|
||||
full_middlewares,
|
||||
functools.partial(self.call_handler, handler)
|
||||
)
|
||||
|
||||
kwargs_filtered = {k: v for k, v in kwargs.items() if k in func_args}
|
||||
|
||||
for key in kwargs.copy().keys():
|
||||
if key not in func_args:
|
||||
del kwargs[key]
|
||||
|
||||
await handler.func_event(event_object, **kwargs)
|
||||
await handler_chain(event_object, kwargs_filtered)
|
||||
|
||||
logger_dp.info(f'Обработано: {router_id} | {process_info}')
|
||||
|
||||
|
@ -1,27 +1,10 @@
|
||||
from typing import Any, Dict
|
||||
from ..types.updates import UpdateUnion
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
class BaseMiddleware:
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
async def process_middleware(
|
||||
self,
|
||||
result_data_kwargs: Dict[str, Any],
|
||||
event_object: UpdateUnion
|
||||
):
|
||||
|
||||
# пока что заглушка
|
||||
if result_data_kwargs is None:
|
||||
return {}
|
||||
|
||||
kwargs_temp = {'data': result_data_kwargs.copy()}
|
||||
|
||||
for key in kwargs_temp.copy().keys():
|
||||
if key not in self.__call__.__annotations__.keys(): # type: ignore
|
||||
del kwargs_temp[key]
|
||||
|
||||
result: Dict[str, Any] = await self(event_object, **kwargs_temp) # type: ignore
|
||||
|
||||
return result
|
||||
async def __call__(
|
||||
self,
|
||||
handler: Callable[[Any, dict[str, Any]], Awaitable[Any]],
|
||||
event_object: Any,
|
||||
data: dict[str, Any]
|
||||
) -> Any:
|
||||
return await handler(event_object, data)
|
Loading…
x
Reference in New Issue
Block a user