diff --git a/twitchio/ext/commands/bot.py b/twitchio/ext/commands/bot.py index fdfada4a..d5aa4b71 100644 --- a/twitchio/ext/commands/bot.py +++ b/twitchio/ext/commands/bot.py @@ -51,7 +51,7 @@ from twitchio.user import User from .components import Component - from .types_ import AutoBotOptions, BotOptions + from .types_ import AutoBotOptions, BotOptions, BotT PrefixT: TypeAlias = str | Iterable[str] | Callable[["Bot", "ChatMessage"], Coroutine[Any, Any, str | Iterable[str]]] @@ -433,7 +433,7 @@ async def test(ctx: CustomContext) -> None: async def _process_commands( self, payload: ChatMessage | ChannelPointsRedemptionAdd | ChannelPointsRedemptionUpdate ) -> None: - ctx: Context = self.get_context(payload) + ctx = self.get_context(payload) await self.invoke(ctx) async def process_commands( @@ -441,7 +441,7 @@ async def process_commands( ) -> None: await self._process_commands(payload) - async def invoke(self, ctx: Context) -> None: + async def invoke(self, ctx: Context[BotT]) -> None: try: await ctx.invoke() except CommandError as e: @@ -482,7 +482,7 @@ async def event_command_error(self, payload: CommandErrorPayload) -> None: msg = f'Ignoring exception in command "{payload.context.command}":\n' logger.error(msg, exc_info=payload.exception) - async def before_invoke(self, ctx: Context) -> None: + async def before_invoke(self, ctx: Context[BotT]) -> None: """A pre invoke hook for all commands that have been added to the bot. Commands from :class:`~.commands.Component`'s are included, however if you wish to control them separately, @@ -513,7 +513,7 @@ async def before_invoke(self, ctx: Context) -> None: The context associated with command invocation, before being passed to the command. """ - async def after_invoke(self, ctx: Context) -> None: + async def after_invoke(self, ctx: Context[BotT]) -> None: """A post invoke hook for all commands that have been added to the bot. Commands from :class:`~.commands.Component`'s are included, however if you wish to control them separately, @@ -544,7 +544,7 @@ async def after_invoke(self, ctx: Context) -> None: The context associated with command invocation, after being passed through the command. """ - async def global_guard(self, ctx: Context, /) -> bool: + async def global_guard(self, ctx: Context[BotT], /) -> bool: """|coro| A global guard applied to all commmands added to the bot. diff --git a/twitchio/ext/commands/components.py b/twitchio/ext/commands/components.py index 0690c344..c8a6afbd 100644 --- a/twitchio/ext/commands/components.py +++ b/twitchio/ext/commands/components.py @@ -37,7 +37,7 @@ from collections.abc import Callable from .context import Context - from .types_ import ComponentOptions + from .types_ import BotT, ComponentOptions __all__ = ("Component",) @@ -262,13 +262,13 @@ async def component_teardown(self) -> None: This method is intended to be overwritten, by default it does nothing. """ - async def component_before_invoke(self, ctx: Context) -> None: + async def component_before_invoke(self, ctx: Context[BotT]) -> None: """Hook called before a :class:`~.commands.Command` in this Component is invoked. Similar to :meth:`~.commands.Bot.before_invoke` but only applies to commands in this Component. """ - async def component_after_invoke(self, ctx: Context) -> None: + async def component_after_invoke(self, ctx: Context[BotT]) -> None: """Hook called after a :class:`~.commands.Command` has successfully invoked in this Component. Similar to :meth:`~.commands.Bot.after_invoke` but only applies to commands in this Component. diff --git a/twitchio/ext/commands/context.py b/twitchio/ext/commands/context.py index e21e0a6a..6ed4d390 100644 --- a/twitchio/ext/commands/context.py +++ b/twitchio/ext/commands/context.py @@ -25,12 +25,13 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Literal, TypeAlias +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias from twitchio.models.eventsub_ import ChannelPointsRedemptionAdd, ChannelPointsRedemptionUpdate, ChatMessage from .core import CommandErrorPayload, ContextType, RewardCommand, RewardStatus from .exceptions import * +from .types_ import BotT from .view import StringView @@ -50,7 +51,7 @@ PrefixT: TypeAlias = str | Iterable[str] | Callable[[Bot, ChatMessage], Coroutine[Any, Any, str | Iterable[str]]] -class Context: +class Context(Generic[BotT]): """The Context class constructed when a message or reward redemption in the respective events is received and processed in a :class:`~.commands.Bot`. @@ -77,10 +78,10 @@ def __init__( self, payload: ChatMessage | ChannelPointsRedemptionAdd | ChannelPointsRedemptionUpdate, *, - bot: Bot, + bot: BotT, ) -> None: self._payload: ChatMessage | ChannelPointsRedemptionAdd | ChannelPointsRedemptionUpdate = payload - self._bot: Bot = bot + self._bot = bot self._component: Component | None = None self._prefix: str | None = None @@ -220,7 +221,7 @@ def channel(self) -> PartialUser: return self.broadcaster @property - def bot(self) -> Bot: + def bot(self) -> BotT: """Property returning the :class:`~.commands.Bot` object.""" return self._bot diff --git a/twitchio/ext/commands/converters.py b/twitchio/ext/commands/converters.py index 9889ca6c..cd176f36 100644 --- a/twitchio/ext/commands/converters.py +++ b/twitchio/ext/commands/converters.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from .bot import Bot from .context import Context + from .types_ import BotT __all__ = ("_BaseConverter",) @@ -67,7 +68,7 @@ def _bool(self, arg: str) -> bool: return result - async def _user(self, context: Context, arg: str) -> User: + async def _user(self, context: Context[BotT], arg: str) -> User: arg = arg.lower() users: list[User] msg: str = 'Failed to convert "{}" to User. A User with the ID or login could not be found.' diff --git a/twitchio/ext/commands/cooldowns.py b/twitchio/ext/commands/cooldowns.py index c937c9bc..8ae4c3ac 100644 --- a/twitchio/ext/commands/cooldowns.py +++ b/twitchio/ext/commands/cooldowns.py @@ -36,6 +36,7 @@ import twitchio from .context import Context + from .types_ import BotT __all__ = ("BaseCooldown", "Bucket", "BucketType", "Cooldown", "GCRACooldown") @@ -65,7 +66,7 @@ class BucketType(enum.Enum): channel = 2 chatter = 3 - def get_key(self, payload: twitchio.ChatMessage | Context) -> Any: + def get_key(self, payload: twitchio.ChatMessage | Context[BotT]) -> Any: if self is BucketType.user: return payload.chatter.id @@ -75,7 +76,7 @@ def get_key(self, payload: twitchio.ChatMessage | Context) -> Any: elif self is BucketType.chatter: return (payload.broadcaster.id, payload.chatter.id) - def __call__(self, payload: twitchio.ChatMessage | Context) -> Any: + def __call__(self, payload: twitchio.ChatMessage | Context[BotT]) -> Any: return self.get_key(payload) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index 6af5fc94..0cd2c0ce 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -67,6 +67,7 @@ from twitchio.user import Chatter from .context import Context + from .types_ import BotT P = ParamSpec("P") else: @@ -180,8 +181,8 @@ class CommandErrorPayload: __slots__ = ("context", "exception") - def __init__(self, *, context: Context, exception: CommandError) -> None: - self.context: Context = context + def __init__(self, *, context: Context[BotT], exception: CommandError) -> None: + self.context = context self.exception: CommandError = exception @@ -197,7 +198,7 @@ class Command(Generic[Component_T, P]): def __init__( self, - callback: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro], + callback: Callable[Concatenate[Component_T, Context[Any], P], Coro] | Callable[Concatenate[Context[Any], P], Coro], *, name: str, **kwargs: Unpack[CommandOptions], @@ -206,7 +207,7 @@ def __init__( self.callback = callback self._aliases: list[str] = kwargs.get("aliases", []) self._guards: list[Callable[..., bool] | Callable[..., CoroC]] = getattr(callback, "__command_guards__", []) - self._buckets: list[Bucket[Context]] = getattr(callback, "__command_cooldowns__", []) + self._buckets: list[Bucket[Context[Any]]] = getattr(callback, "__command_cooldowns__", []) self._guards_after_parsing = kwargs.get("guards_after_parsing", False) self._cooldowns_first = kwargs.get("cooldowns_before_guards", False) @@ -216,8 +217,8 @@ def __init__( self._parent: Group[Component_T, P] | None = kwargs.get("parent") self._bypass_global_guards: bool = kwargs.get("bypass_global_guards", False) - self._before_hook: Callable[[Component_T, Context], Coro] | Callable[[Context], Coro] | None = None - self._after_hook: Callable[[Component_T, Context], Coro] | Callable[[Context], Coro] | None = None + self._before_hook: Callable[[Component_T, Context[Any]], Coro] | Callable[[Context[Any]], Coro] | None = None + self._after_hook: Callable[[Component_T, Context[Any]], Coro] | Callable[[Context[Any]], Coro] | None = None self._help: str = callback.__doc__ or "" self.__doc__ = self._help @@ -228,7 +229,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return self._name - async def __call__(self, context: Context) -> Any: + async def __call__(self, context: Context[BotT]) -> Any: callback = self._callback(self._injected, context) if self._injected else self._callback(context) # type: ignore return await callback # type: ignore will fix later @@ -327,7 +328,9 @@ def guards(self) -> list[Callable[..., bool] | Callable[..., CoroC]]: return self._guards @property - def callback(self) -> Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro]: + def callback( + self, + ) -> Callable[Concatenate[Component_T, Context[Any], P], Coro] | Callable[Concatenate[Context[Any], P], Coro]: """Property returning the coroutine callback used in invocation. E.g. the function you wrap with :func:`.command`. """ @@ -335,7 +338,7 @@ def callback(self) -> Callable[Concatenate[Component_T, Context, P], Coro] | Cal @callback.setter def callback( - self, func: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro] + self, func: Callable[Concatenate[Component_T, Context[Any], P], Coro] | Callable[Concatenate[Context[Any], P], Coro] ) -> None: self._callback = func unwrap = unwrap_function(func) @@ -349,7 +352,7 @@ def callback( self._params: dict[str, inspect.Parameter] = get_signature_parameters(func, globalns) def _convert_literal_type( - self, context: Context, param: inspect.Parameter, args: tuple[Any, ...], *, raw: str | None + self, context: Context[BotT], param: inspect.Parameter, args: tuple[Any, ...], *, raw: str | None ) -> Any: name: str = param.name result: Any = MISSING @@ -370,7 +373,9 @@ def _convert_literal_type( return result - async def _do_conversion(self, context: Context, param: inspect.Parameter, *, annotation: Any, raw: str | None) -> Any: + async def _do_conversion( + self, context: Context[BotT], param: inspect.Parameter, *, annotation: Any, raw: str | None + ) -> Any: name: str = param.name if isinstance(annotation, UnionType) or getattr(annotation, "__origin__", None) is Union: @@ -435,7 +440,7 @@ async def _do_conversion(self, context: Context, param: inspect.Parameter, *, an return result - async def _parse_arguments(self, context: Context) -> ...: + async def _parse_arguments(self, context: Context[BotT]) -> ...: context._view.skip_ws() params: list[inspect.Parameter] = list(self._params.values()) @@ -501,7 +506,7 @@ async def _guard_runner(self, guards: list[Callable[..., bool] | Callable[..., C if result is not True: raise GuardFailure(exc_msg, guard=guard) - async def run_guards(self, ctx: Context, *, with_cooldowns: bool = False) -> None: + async def run_guards(self, ctx: Context[BotT], *, with_cooldowns: bool = False) -> None: """|coro| Method which allows a :class:`~twitchio.ext.commands.Command` to run and check all associated Guards, including @@ -532,7 +537,7 @@ async def run_guards(self, ctx: Context, *, with_cooldowns: bool = False) -> Non """ await self._run_guards(ctx, with_cooldowns=with_cooldowns) - async def _run_guards(self, context: Context, *, with_cooldowns: bool = True) -> None: + async def _run_guards(self, context: Context[BotT], *, with_cooldowns: bool = True) -> None: if with_cooldowns and self._cooldowns_first: await self._run_cooldowns(context) @@ -551,7 +556,7 @@ async def _run_guards(self, context: Context, *, with_cooldowns: bool = True) -> if with_cooldowns and not self._cooldowns_first: await self._run_cooldowns(context) - async def _run_cooldowns(self, context: Context) -> None: + async def _run_cooldowns(self, context: Context[BotT]) -> None: type_ = "group" if isinstance(self, Group) else "command" for bucket in self._buckets: @@ -569,7 +574,7 @@ async def _run_cooldowns(self, context: Context) -> None: cooldown=cooldown, ) - async def _invoke(self, context: Context) -> Any: + async def _invoke(self, context: Context[BotT]) -> Any: context._component = self._injected if not self._guards_after_parsing: @@ -616,7 +621,7 @@ async def _invoke(self, context: Context) -> Any: except Exception as e: raise CommandInvokeError(msg=str(e), original=e) from e - async def invoke(self, context: Context) -> Any: + async def invoke(self, context: Context[BotT]) -> Any: try: return await self._invoke(context) except CommandError as e: @@ -625,7 +630,7 @@ async def invoke(self, context: Context) -> Any: error = CommandInvokeError(str(e), original=e) await self._dispatch_error(context, error) - async def _dispatch_error(self, context: Context, exception: CommandError) -> None: + async def _dispatch_error(self, context: Context[BotT], exception: CommandError) -> None: payload = CommandErrorPayload(context=context, exception=exception) if self._error is not None: @@ -761,7 +766,7 @@ class RewardCommand(Command[Component_T, P]): def __init__( self, - callback: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro], + callback: Callable[Concatenate[Component_T, Context[Any], P], Coro] | Callable[Concatenate[Context[Any], P], Coro], *, reward_id: str, invoke_when: RewardStatus, @@ -999,7 +1004,7 @@ async def hi_command(ctx: commands.Context) -> None: """ def wrapper( - func: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro], + func: Callable[Concatenate[Component_T, Context[BotT], P], Coro] | Callable[Concatenate[Context[BotT], P], Coro], ) -> Command[Any, ...]: if isinstance(func, Command): raise ValueError(f'Callback "{func._callback}" is already a Command.') # type: ignore @@ -1082,7 +1087,7 @@ async def reward_test(self, ctx: commands.Context, *, user_input: str) -> None: """ def wrapper( - func: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro], + func: Callable[Concatenate[Component_T, Context[BotT], P], Coro] | Callable[Concatenate[Context[BotT], P], Coro], ) -> RewardCommand[Any, ...]: if isinstance(func, (Command, RewardCommand)): raise ValueError(f'Callback "{func._callback}" is already a Command.') # type: ignore @@ -1163,7 +1168,7 @@ async def socials_twitch(ctx: commands.Context) -> None: """ def wrapper( - func: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro], + func: Callable[Concatenate[Component_T, Context[BotT], P], Coro] | Callable[Concatenate[Context[BotT], P], Coro], ) -> Group[Any, ...]: if isinstance(func, Command): raise ValueError(f'Callback "{func._callback.__name__}" is already a Command.') # type: ignore @@ -1203,7 +1208,7 @@ def walk_commands(self) -> Generator[Command[Component_T, P] | Group[Component_T if isinstance(command, Group): yield from command.walk_commands() - async def _invoke(self, context: Context) -> None: + async def _invoke(self, context: Context[BotT]) -> None: view = context._view view.skip_ws() trigger = view.get_word() @@ -1229,7 +1234,7 @@ async def _invoke(self, context: Context) -> None: raise CommandNotFound(f'The sub-command "{trigger}" for group "{self._name}" was not found.') - async def invoke(self, context: Context) -> None: + async def invoke(self, context: Context[BotT]) -> None: try: return await self._invoke(context) except CommandError as e: @@ -1292,7 +1297,7 @@ async def socials_twitch(ctx: commands.Context) -> None: """ def wrapper( - func: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro], + func: Callable[Concatenate[Component_T, Context[BotT], P], Coro] | Callable[Concatenate[Context[BotT], P], Coro], ) -> Command[Any, ...]: new = command(name=name, aliases=aliases, extras=extras, parent=self, **kwargs)(func) @@ -1362,7 +1367,7 @@ async def socials_discord_two(ctx: commands.Context) -> None: """ def wrapper( - func: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro], + func: Callable[Concatenate[Component_T, Context[BotT], P], Coro] | Callable[Concatenate[Context[BotT], P], Coro], ) -> Command[Any, ...]: new = group(name=name, aliases=aliases, extras=extras, parent=self, **kwargs)(func) @@ -1466,7 +1471,7 @@ def is_owner() -> Any: The guard predicate returned ``False`` and prevented the chatter from using the command. """ - def predicate(context: Context) -> bool: + def predicate(context: Context[BotT]) -> bool: return context.chatter.id == context.bot.owner_id return guard(predicate) @@ -1491,7 +1496,7 @@ def is_staff() -> Any: The guard predicate returned ``False`` and prevented the chatter from using the command. """ - def predicate(context: Context) -> bool: + def predicate(context: Context[BotT]) -> bool: if context.type is ContextType.REWARD: raise TypeError("This Guard can not be used on a RewardCommand instance.") @@ -1517,7 +1522,7 @@ def is_broadcaster() -> Any: The guard predicate returned ``False`` and prevented the chatter from using the command. """ - def predicate(context: Context) -> bool: + def predicate(context: Context[BotT]) -> bool: return context.chatter.id == context.broadcaster.id return guard(predicate) @@ -1545,7 +1550,7 @@ def is_moderator() -> Any: The guard predicate returned ``False`` and prevented the chatter from using the command. """ - def predicate(context: Context) -> bool: + def predicate(context: Context[BotT]) -> bool: if context.type is ContextType.REWARD: raise TypeError("This Guard can not be used on a RewardCommand instance.") @@ -1577,7 +1582,7 @@ def is_vip() -> Any: The guard predicate returned ``False`` and prevented the chatter from using the command. """ - def predicate(context: Context) -> bool: + def predicate(context: Context[BotT]) -> bool: if context.type is ContextType.REWARD: raise TypeError("This Guard can not be used on a RewardCommand instance.") @@ -1621,7 +1626,7 @@ async def test(self, ctx: commands.Context) -> None: The guard predicate returned ``False`` and prevented the chatter from using the command. """ - def predicate(context: Context) -> bool: + def predicate(context: Context[BotT]) -> bool: if context.type is ContextType.REWARD: raise TypeError("This Guard can not be used on a RewardCommand instance.") @@ -1724,7 +1729,7 @@ async def custom_key(ctx: commands.Context) -> typing.Hashable | None: async def hello(ctx: commands.Context) -> None: ... """ - bucket_: Bucket[Context] = Bucket.from_cooldown(base=base, key=key, **kwargs) # type: ignore + bucket_: Bucket[Context[BotT]] = Bucket.from_cooldown(base=base, key=key, **kwargs) # type: ignore def wrapper(func: Any) -> Any: nonlocal bucket_ diff --git a/twitchio/ext/commands/types_.py b/twitchio/ext/commands/types_.py index 9f6a92da..8dbf460f 100644 --- a/twitchio/ext/commands/types_.py +++ b/twitchio/ext/commands/types_.py @@ -22,17 +22,27 @@ SOFTWARE. """ -from typing import TYPE_CHECKING, Any, TypedDict, TypeVar +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, Union from twitchio.types_.options import AutoClientOptions, ClientOptions if TYPE_CHECKING: + from .bot import AutoBot, Bot from .components import Component + from .context import Context Component_T = TypeVar("Component_T", bound="Component | None") +ContextT = TypeVar("ContextT", bound="Context[Any]") +ContextT_co = TypeVar("ContextT_co", bound="Context[Any]", covariant=True) + +_Bot = Union["Bot", "AutoBot"] +BotT = TypeVar("BotT", bound=_Bot, covariant=True) + class CommandOptions(TypedDict, total=False): aliases: list[str]