diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 6e5807fa4..000000000 --- a/.isort.cfg +++ /dev/null @@ -1,10 +0,0 @@ -[settings] -default_section = THIRDPARTY -multi_line_output=3 -include_trailing_comma=True -force_grid_wrap=0 -use_parentheses=True -line_length=119 -float_to_top=True -known_first_party = - fast_llm diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 548a4edcb..94e9e4e33 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - --unsafe - id: check-added-large-files - repo: https://github.com/asottile/pyupgrade - rev: v3.19.1 + rev: v3.21.2 hooks: - id: pyupgrade args: @@ -31,7 +31,7 @@ repos: - app/scripts/utility/shell.py - --remove-duplicate-keys - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 7.0.0 hooks: - id: isort name: isort (python) @@ -42,14 +42,14 @@ repos: name: isort (pyi) types: [pyi] - repo: https://github.com/psf/black - rev: 24.10.0 + rev: 26.1.0 hooks: - id: black args: - "--config" - "./pyproject.toml" - repo: https://github.com/DavidAnson/markdownlint-cli2 - rev: v0.16.0 + rev: v0.20.0 hooks: - id: markdownlint-cli2 name: markdownlint diff --git a/docs/recipes/generate.md b/docs/recipes/generate.md index 655fa29c0..d6d2333e1 100644 --- a/docs/recipes/generate.md +++ b/docs/recipes/generate.md @@ -12,7 +12,7 @@ Fast-LLM models support `generate` and `forward` operations through Hugging Face --- -### 🔧 Generating Text from a Fast-LLM Model +## 🔧 Generating Text from a Fast-LLM Model Below is a step-by-step example of how to generate text using a Fast-LLM model checkpoint from Hugging Face Hub. diff --git a/fast_llm/config.py b/fast_llm/config.py index 4339fa401..5411a2078 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -243,9 +243,9 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class[ - T: Config -](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]: +def config_class[T: Config]( + registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None +) -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -715,9 +715,7 @@ def to_copy[ def __repr__(self): return self.to_logs(log_fn=str) - def to_logs[ - T - ]( + def to_logs[T]( self, verbose: int | None = FieldVerboseLevel.core, log_fn: typing.Callable[[str], T] = logger.info, @@ -1048,9 +1046,7 @@ def config(self) -> ConfigType: return self._config -def set_nested_dict_value[ - KeyType, ValueType -]( +def set_nested_dict_value[KeyType, ValueType]( d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...], value: ValueType, @@ -1094,9 +1090,9 @@ def set_nested_dict_value[ raise NotImplementedError(update_type) -def get_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: +def get_nested_dict_value[KeyType, ValueType]( + d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...] +) -> ValueType: if isinstance(keys, tuple): for key in keys: d = d[key] @@ -1105,9 +1101,9 @@ def get_nested_dict_value[ return d[keys] -def pop_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: +def pop_nested_dict_value[KeyType, ValueType]( + d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...] +) -> ValueType: if isinstance(keys, tuple): for key in keys[:-1]: d = d[key] diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 77507afa8..2c6c8105f 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -240,9 +240,9 @@ def is_main_rank() -> bool: return DistributedConfig.default_rank == _MAIN_RANK -def log_main_rank[ - T -](*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", ") -> T: +def log_main_rank[T]( + *message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", " +) -> T: if is_main_rank(): return log(*message, log_fn=log_fn, join=join) @@ -251,9 +251,9 @@ def is_model_parallel_main_rank() -> bool: return is_main_rank() if _run is None else _run._is_model_parallel_main_rank # Noqa -def log_model_parallel_main_rank[ - T -](*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info) -> T: +def log_model_parallel_main_rank[T]( + *message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info +) -> T: if is_model_parallel_main_rank(): return log(*message, log_fn=log_fn) @@ -262,8 +262,8 @@ def is_pipeline_parallel_main_rank() -> bool: return is_main_rank() if _run is None else _run._is_pipeline_parallel_main_rank # Noqa -def log_pipeline_parallel_main_rank[ - T -](*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info) -> T: +def log_pipeline_parallel_main_rank[T]( + *message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info +) -> T: if is_pipeline_parallel_main_rank(): return log(*message, log_fn=log_fn) diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 163a9459c..58c490cb9 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -108,9 +108,7 @@ def _get_runnable(self) -> typing.Callable[[], None]: def run(self) -> None: self._get_runnable()() - def _show[ - T - ]( + def _show[T]( self, verbose: int = FieldVerboseLevel.core, *, diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 5976a477f..c7ab610b2 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -410,9 +410,9 @@ def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: def get_distributed_dim(self, name: str) -> DistributedDim: return self.distributed_dims[name] - def _log_on_rank[ - T - ](self, *message, rank: int | None = None, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info): + def _log_on_rank[T]( + self, *message, rank: int | None = None, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info + ): if rank is None or self.rank == rank: return log(*message, log_fn=log_fn) diff --git a/fast_llm/functional/autograd.py b/fast_llm/functional/autograd.py index cea5f6ee2..3e8e31cea 100644 --- a/fast_llm/functional/autograd.py +++ b/fast_llm/functional/autograd.py @@ -10,9 +10,7 @@ # TODO: Improve type hint (use protocol?) -def wrap_forward_backward[ - OutputType, ContextType -]( +def wrap_forward_backward[OutputType, ContextType]( forward: typing.Callable[..., tuple[OutputType, ContextType]], backward: typing.Callable[[OutputType, ContextType], typing.Any], ) -> typing.Callable[..., OutputType]: diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 5a2ff2dac..84b945a67 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -123,9 +123,7 @@ def format_metrics( @torch._dynamo.disable # noqa -def log_tensor[ - T -]( +def log_tensor[T]( name: str, tensor: torch.Tensor, *, @@ -133,7 +131,7 @@ def log_tensor[ level: int = 2, storage: bool = False, log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, -) -> (T | None): +) -> T | None: if level < 1: return tensor = tensor.detach() @@ -219,9 +217,7 @@ def log_tensor[ @torch._dynamo.disable # noqa -def log_grad[ - T -]( +def log_grad[T]( name: str, tensor: torch.Tensor, *, @@ -244,9 +240,7 @@ def log_grad[ @torch._dynamo.disable # noqa -def log_distributed_tensor[ - T -]( +def log_distributed_tensor[T]( name: str, tensor: torch.Tensor, *, @@ -257,7 +251,7 @@ def log_distributed_tensor[ global_: bool = True, log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, meta: TensorMeta, -) -> (T | None): +) -> T | None: if level <= 0: return if global_: @@ -278,9 +272,7 @@ def log_distributed_tensor[ @torch._dynamo.disable # noqa -def log_distributed_grad[ - T -]( +def log_distributed_grad[T]( name: str, tensor: torch.Tensor, *, @@ -292,7 +284,7 @@ def log_distributed_grad[ global_: bool = True, log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, meta: TensorMeta, -) -> (T | None): +) -> T | None: if level <= 0: return tensor.register_hook( @@ -311,9 +303,7 @@ def log_distributed_grad[ @torch._dynamo.disable # noqa -def log_generator[ - T -]( +def log_generator[T]( name, generator: torch.Tensor | torch.Generator | None = None, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, @@ -328,9 +318,7 @@ def log_generator[ _global_max_reserved = 0 -def log_memory_usage[ - T -]( +def log_memory_usage[T]( header: str | None = None, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, reset_stats: bool = True, diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 2518785dd..29fd5a155 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -261,9 +261,9 @@ def __getitem__(self, key: KeyType) -> ValueType: return super().__getitem__(key)() -def log[ - T -](*message: typing.Any, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", ") -> T: +def log[T]( + *message: typing.Any, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", " +) -> T: message = join.join([str(m() if callable(m) else m) for m in message]) logged = log_fn(message) if isinstance(logged, BaseException): diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py index 66c419dfd..dbae2963c 100644 --- a/fast_llm_external_models/apriel2/convert.py +++ b/fast_llm_external_models/apriel2/convert.py @@ -41,9 +41,13 @@ compose_configs, ) from fast_llm_external_models.apriel2.conversion import llava as llava_converter -from fast_llm_external_models.apriel2.conversion import plan_surgery +from fast_llm_external_models.apriel2.conversion import ( + plan_surgery, +) from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter -from fast_llm_external_models.apriel2.conversion import strip_init_fields +from fast_llm_external_models.apriel2.conversion import ( + strip_init_fields, +) # Allow running as script or module if __name__ == "__main__": diff --git a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py index bb4fe8bc6..942e35520 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py @@ -795,10 +795,7 @@ def test_vs_qwen3next( always uses initial_state=None, ignoring cached recurrent state. """ from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig - from transformers.models.qwen3_next.modeling_qwen3_next import ( - Qwen3NextDynamicCache, - Qwen3NextGatedDeltaNet, - ) + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextDynamicCache, Qwen3NextGatedDeltaNet from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Cache, Apriel2GatedDeltaNet @@ -826,7 +823,9 @@ def test_vs_qwen3next( # Create models with same weights torch.manual_seed(seed) qwen_gdn = Qwen3NextGatedDeltaNet(qwen3_config, layer_idx=0).to(device="cuda", dtype=test_dtype) - apriel_gdn = Apriel2GatedDeltaNet(hidden_size, gdn_mixer_config, layer_idx=0).to(device="cuda", dtype=test_dtype) + apriel_gdn = Apriel2GatedDeltaNet(hidden_size, gdn_mixer_config, layer_idx=0).to( + device="cuda", dtype=test_dtype + ) # Transfer weights using conversion plan plan = plan_qwen3next_gdn_to_apriel2( @@ -1013,6 +1012,7 @@ def test_chunked_vs_recurrent( msg=f"GDN chunked vs recurrent mode (prefill={prefill_len}, decode={decode_steps})", ) + # ============================================================================= # SECTION 2: EQUIVALENCE TESTS - KimiDeltaAttention # ============================================================================= @@ -1046,10 +1046,8 @@ def test_vs_fla( from fla.layers.kda import KimiDeltaAttention as FLA_KDA from fla.models.utils import Cache as FLACache - from fast_llm_external_models.apriel2.modeling_apriel2 import ( - Apriel2Cache, - KimiDeltaAttention as Apriel2_KDA, - ) + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Cache + from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention as Apriel2_KDA num_heads, head_dim = kda_config seq_len = prefill_len + decode_steps + prefill2_len @@ -1272,4 +1270,3 @@ def test_chunked_vs_recurrent( atol=atol * 5, msg=f"KDA chunked vs recurrent mode (prefill={prefill_len}, decode={decode_steps})", ) - diff --git a/pyproject.toml b/pyproject.toml index c7d3ffd23..8488623d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,3 +8,16 @@ testpaths = [ "fast_llm_external_models/tests" # External models tests ] norecursedirs = ["Megatron-LM"] + +[tool.isort] +profile = "black" +default_section = "THIRDPARTY" +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +line_length = 119 +float_to_top = "True" +known_first_party = [ + "fast_llm", +] diff --git a/setup.cfg b/setup.cfg index 005ae5a8a..867e1da29 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,9 +8,9 @@ packages = include_package_data = True python_requires = >=3.12 install_requires = - requests>=2.32.4 - PyYAML>=6.0.2 - pybind11>=2.13.6 + requests>=2.32.5 + PyYAML>=6.0.3 + pybind11>=3.0.1 packaging>=25.0 [options.extras_require] @@ -19,11 +19,10 @@ install_requires = # FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE]" --no-build-isolation CORE = # Available through the nvidia base image - torch>=2.7.0 - # Numpy major needs to match torch - numpy>=1.26.4,<2.0.0 + torch>=2.9.0 + numpy>=2.1.0 # Used for checkpoints - safetensors>=0.5.3 + safetensors>=0.6.2 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.4.post1 # Dropless MoE kernel is broken with triton >= 3.2.0 and needs a rewrite (also limited to 32 experts). @@ -34,7 +33,7 @@ CORE = # Small packages required for some optional features and tools. OPTIONAL = # Weights and biases - wandb>=0.20.1 + wandb>=0.24.0 # Hydra hydra-core>=1.3.2 omegaconf>=2.3.0 @@ -43,10 +42,10 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.53.2 + transformers>=4.57.3,<5.0.0 hf-transfer>=0.1.9 - datasets>=3.6.0 - huggingface-hub>=0.32.6 + datasets>=4.4.1 + huggingface-hub>=0.36.0 # Required to run SSMs # To install on cpu environment (ex. for IDE support): @@ -60,20 +59,19 @@ SSM = GENERATION = lm_eval>=0.4.9 - # Required for supporting vision inputs VISION = # Vision Tools webp>=0.4.0 pillow-simd>=9.5.0 - torchvision>=0.20.0 + torchvision>=0.24.0 DEV = # Pre-commit git hook - pre-commit>=4.2.0 + pre-commit>=4.5.1 # Required for testing - pytest>=8.4.0 - pytest-xdist>=3.7.0 + pytest>=9.0.2 + pytest-xdist>=3.8.0 # Somehow needed for Megatron to work with base image 24.11 setuptools>=80.9.0 # Dependency manager needs colorama to show colors. diff --git a/tests/conftest.py b/tests/conftest.py index 64b0df6e2..f93eec215 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,7 +39,6 @@ # Import all dynamic classes. import fast_llm.cli # isort: skip - logger = logging.getLogger(__name__) manager: DependencyManager | None = None diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index cf66bed10..777214aae 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -11,7 +11,7 @@ from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.layers.ssm.kda import _kda_available from fast_llm.utils import Assert -from tests.utils.utils import get_stage, requires_cuda +from tests.utils.utils import get_stage try: from fast_llm_external_models.apriel2.modeling_apriel2 import ( diff --git a/tools/push_model.py b/tools/push_model.py index 39a3b9141..cf66e70a2 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -29,7 +29,6 @@ from fast_llm.engine.checkpoint.convert import ConvertConfig # isort:skip - logger = logging.getLogger(__name__)