From 22cb2700954cd1d5e23dbe4108f8a9f41e4a9de4 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 12 Nov 2025 23:03:19 +0000 Subject: [PATCH] fix: Add TypeVar("C") for decorators to maintain typing --- .../src/pyearthtools/data/indexes/decorators.py | 12 +++++++----- .../pyearthtools/data/modifications/decorator.py | 8 +++++--- .../utils/src/pyearthtools/utils/decorators.py | 14 +++++++------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/packages/data/src/pyearthtools/data/indexes/decorators.py b/packages/data/src/pyearthtools/data/indexes/decorators.py index 80c08ab9..365f0843 100644 --- a/packages/data/src/pyearthtools/data/indexes/decorators.py +++ b/packages/data/src/pyearthtools/data/indexes/decorators.py @@ -24,7 +24,7 @@ import inspect from pathlib import Path -from typing import Any +from typing import Any, TypeVar, Callable from pyearthtools.data.indexes.utilities import spellcheck, open_static from pyearthtools.utils.decorators import alias_arguments @@ -38,6 +38,8 @@ "variable_modifications", ] +C = TypeVar("C", bound=Callable[..., Any]) + def _check_required_arguments(default: dict[str, inspect.Parameter], kwargs: dict, function_object: object): """ @@ -184,7 +186,7 @@ def _check_structure(structure: str | Path | dict[str, Any], arguments: dict[str def check_arguments( struc: str | Path | dict[str, Any] | None = None, **valid_arguments: list[Any] | tuple[Any, ...] | str, -): +) -> Callable[[C], C]: """ Check Arguments before passing to function, @@ -221,7 +223,7 @@ def func(variable, argument = 'default'): """ - def internal_function(func): + def internal_function(func: C) -> C: ## Get function signature signature = inspect.signature(func) ## Get all params @@ -302,7 +304,7 @@ def deprecated_arguments( deprecation: dict[str, str | None] | str | None = None, *arg_deprecations: str, **extra_deprecations: str | None, -): +) -> Callable[[C], C]: """ Warn a user if they attempt to use a deprecated argument, and remove it from the call. @@ -317,7 +319,7 @@ def deprecated_arguments( deprecation.update({k: None for k in arg_deprecations}) deprecation.update(extra_deprecations) - def internal_func(func): + def internal_func(func: C) -> C: @functools.wraps(func) def warn_on_deprecated(*args, **kwargs): for depr in deprecation.keys(): diff --git a/packages/data/src/pyearthtools/data/modifications/decorator.py b/packages/data/src/pyearthtools/data/modifications/decorator.py index 89734149..a6f5a078 100644 --- a/packages/data/src/pyearthtools/data/modifications/decorator.py +++ b/packages/data/src/pyearthtools/data/modifications/decorator.py @@ -24,7 +24,7 @@ import json import functools -from typing import Callable, Any, Optional, Type, Union +from typing import Callable, Any, Optional, Type, Union, TypeVar import xarray as xr from pyearthtools.data.indexes import TimeDataIndex @@ -37,6 +37,8 @@ __all__ = ["variable_modifications", "Modification"] +C = TypeVar("C", bound=Callable[..., Any]) + def _args_to_kwargs(func, args: tuple[Any, ...]) -> dict[str, Any]: """ @@ -226,7 +228,7 @@ def variable_modifications( *, remove_variables: bool = False, skip_if_invalid_class: bool = False, -): +) -> Callable[[C], C]: """ Allow modifications of variables dynamically, @@ -286,7 +288,7 @@ def variable_modifications( If using this decorator with `check_arguments` put this one above it, and with `alias_arguments` put it below. """ - def internal(func: Callable): + def internal(func: C) -> C: @functools.wraps(func) def wrapper(*args, **kwargs): # Get variables to parse modifications from diff --git a/packages/utils/src/pyearthtools/utils/decorators.py b/packages/utils/src/pyearthtools/utils/decorators.py index fd8635ab..5b8679d9 100644 --- a/packages/utils/src/pyearthtools/utils/decorators.py +++ b/packages/utils/src/pyearthtools/utils/decorators.py @@ -17,7 +17,9 @@ import functools import warnings -from typing import Any, Callable +from typing import Any, Callable, TypeVar + +C = TypeVar("C", bound=Callable[..., Any]) class classproperty(property): @@ -35,7 +37,7 @@ def invert_dictionary_list(dictionary: dict) -> dict: return return_dict -def alias_arguments(**aliases: str | list[str]) -> Callable: +def alias_arguments(**aliases: str | list[str]) -> Callable[[C], C]: """ Setup aliases for parameters @@ -58,7 +60,7 @@ def function(response): """ - def internal_function(func: Callable) -> Callable: + def internal_function(func: C) -> C: # Force all aliases to be list for k, v in aliases.items(): if isinstance(v, (list, tuple)): @@ -92,7 +94,7 @@ def wrapper(*args, **kwargs): return internal_function -def BackwardsCompatibility(new_func: Callable[[Any], Any]): +def BackwardsCompatibility(new_func: C) -> Callable[[C], C]: """ Allows for the renaming of a functionality, and subsequent backwards compatilbility. @@ -111,9 +113,7 @@ def BackwardsCompatibility(new_func: Callable[[Any], Any]): """ - @functools.wraps(new_func) - def decorator(func): - @functools.wraps(new_func) + def decorator(func: C) -> C: def wrapped(*args, **kwargs): warnings.warn( f"{func.__name__} has been removed in favour of {new_func.__name__}, please switch over.",