Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions .isort.cfg

This file was deleted.

8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- --unsafe
- id: check-added-large-files
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
rev: v3.21.2
hooks:
- id: pyupgrade
args:
Expand All @@ -31,7 +31,7 @@ repos:
- app/scripts/utility/shell.py
- --remove-duplicate-keys
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 7.0.0
hooks:
- id: isort
name: isort (python)
Expand All @@ -42,14 +42,14 @@ repos:
name: isort (pyi)
types: [pyi]
- repo: https://github.com/psf/black
rev: 24.10.0
rev: 26.1.0
hooks:
- id: black
args:
- "--config"
- "./pyproject.toml"
- repo: https://github.com/DavidAnson/markdownlint-cli2
rev: v0.16.0
rev: v0.20.0
hooks:
- id: markdownlint-cli2
name: markdownlint
Expand Down
2 changes: 1 addition & 1 deletion docs/recipes/generate.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Fast-LLM models support `generate` and `forward` operations through Hugging Face

---

### 🔧 Generating Text from a Fast-LLM Model
## 🔧 Generating Text from a Fast-LLM Model

Below is a step-by-step example of how to generate text using a Fast-LLM model checkpoint from Hugging Face Hub.

Expand Down
26 changes: 11 additions & 15 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ def _process_config_class(cls: type["Config"]):
return cls


def config_class[
T: Config
](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]:
def config_class[T: Config](
registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None
) -> typing.Callable[[type[T]], type[T]]:
"""
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
"""
Expand Down Expand Up @@ -715,9 +715,7 @@ def to_copy[
def __repr__(self):
return self.to_logs(log_fn=str)

def to_logs[
T
](
def to_logs[T](
self,
verbose: int | None = FieldVerboseLevel.core,
log_fn: typing.Callable[[str], T] = logger.info,
Expand Down Expand Up @@ -1048,9 +1046,7 @@ def config(self) -> ConfigType:
return self._config


def set_nested_dict_value[
KeyType, ValueType
](
def set_nested_dict_value[KeyType, ValueType](
d: dict[KeyType, ValueType],
keys: KeyType | tuple[KeyType, ...],
value: ValueType,
Expand Down Expand Up @@ -1094,9 +1090,9 @@ def set_nested_dict_value[
raise NotImplementedError(update_type)


def get_nested_dict_value[
KeyType, ValueType
](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType:
def get_nested_dict_value[KeyType, ValueType](
d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]
) -> ValueType:
if isinstance(keys, tuple):
for key in keys:
d = d[key]
Expand All @@ -1105,9 +1101,9 @@ def get_nested_dict_value[
return d[keys]


def pop_nested_dict_value[
KeyType, ValueType
](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType:
def pop_nested_dict_value[KeyType, ValueType](
d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]
) -> ValueType:
if isinstance(keys, tuple):
for key in keys[:-1]:
d = d[key]
Expand Down
18 changes: 9 additions & 9 deletions fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ def is_main_rank() -> bool:
return DistributedConfig.default_rank == _MAIN_RANK


def log_main_rank[
T
](*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", ") -> T:
def log_main_rank[T](
*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", "
) -> T:
if is_main_rank():
return log(*message, log_fn=log_fn, join=join)

Expand All @@ -251,9 +251,9 @@ def is_model_parallel_main_rank() -> bool:
return is_main_rank() if _run is None else _run._is_model_parallel_main_rank # Noqa


def log_model_parallel_main_rank[
T
](*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info) -> T:
def log_model_parallel_main_rank[T](
*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info
) -> T:
if is_model_parallel_main_rank():
return log(*message, log_fn=log_fn)

Expand All @@ -262,8 +262,8 @@ def is_pipeline_parallel_main_rank() -> bool:
return is_main_rank() if _run is None else _run._is_pipeline_parallel_main_rank # Noqa


def log_pipeline_parallel_main_rank[
T
](*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info) -> T:
def log_pipeline_parallel_main_rank[T](
*message, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info
) -> T:
if is_pipeline_parallel_main_rank():
return log(*message, log_fn=log_fn)
4 changes: 1 addition & 3 deletions fast_llm/engine/config_utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ def _get_runnable(self) -> typing.Callable[[], None]:
def run(self) -> None:
self._get_runnable()()

def _show[
T
](
def _show[T](
self,
verbose: int = FieldVerboseLevel.core,
*,
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
def get_distributed_dim(self, name: str) -> DistributedDim:
return self.distributed_dims[name]

def _log_on_rank[
T
](self, *message, rank: int | None = None, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info):
def _log_on_rank[T](
self, *message, rank: int | None = None, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info
):
if rank is None or self.rank == rank:
return log(*message, log_fn=log_fn)

Expand Down
4 changes: 1 addition & 3 deletions fast_llm/functional/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@


# TODO: Improve type hint (use protocol?)
def wrap_forward_backward[
OutputType, ContextType
](
def wrap_forward_backward[OutputType, ContextType](
forward: typing.Callable[..., tuple[OutputType, ContextType]],
backward: typing.Callable[[OutputType, ContextType], typing.Any],
) -> typing.Callable[..., OutputType]:
Expand Down
30 changes: 9 additions & 21 deletions fast_llm/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,15 @@ def format_metrics(


@torch._dynamo.disable # noqa
def log_tensor[
T
](
def log_tensor[T](
name: str,
tensor: torch.Tensor,
*,
scale: float = 1.0,
level: int = 2,
storage: bool = False,
log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info,
) -> (T | None):
) -> T | None:
if level < 1:
return
tensor = tensor.detach()
Expand Down Expand Up @@ -219,9 +217,7 @@ def log_tensor[


@torch._dynamo.disable # noqa
def log_grad[
T
](
def log_grad[T](
name: str,
tensor: torch.Tensor,
*,
Expand All @@ -244,9 +240,7 @@ def log_grad[


@torch._dynamo.disable # noqa
def log_distributed_tensor[
T
](
def log_distributed_tensor[T](
name: str,
tensor: torch.Tensor,
*,
Expand All @@ -257,7 +251,7 @@ def log_distributed_tensor[
global_: bool = True,
log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info,
meta: TensorMeta,
) -> (T | None):
) -> T | None:
if level <= 0:
return
if global_:
Expand All @@ -278,9 +272,7 @@ def log_distributed_tensor[


@torch._dynamo.disable # noqa
def log_distributed_grad[
T
](
def log_distributed_grad[T](
name: str,
tensor: torch.Tensor,
*,
Expand All @@ -292,7 +284,7 @@ def log_distributed_grad[
global_: bool = True,
log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info,
meta: TensorMeta,
) -> (T | None):
) -> T | None:
if level <= 0:
return
tensor.register_hook(
Expand All @@ -311,9 +303,7 @@ def log_distributed_grad[


@torch._dynamo.disable # noqa
def log_generator[
T
](
def log_generator[T](
name,
generator: torch.Tensor | torch.Generator | None = None,
log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info,
Expand All @@ -328,9 +318,7 @@ def log_generator[
_global_max_reserved = 0


def log_memory_usage[
T
](
def log_memory_usage[T](
header: str | None = None,
log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info,
reset_stats: bool = True,
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def __getitem__(self, key: KeyType) -> ValueType:
return super().__getitem__(key)()


def log[
T
](*message: typing.Any, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", ") -> T:
def log[T](
*message: typing.Any, log_fn: type[BaseException] | typing.Callable[[str], T] = logger.info, join: str = ", "
) -> T:
message = join.join([str(m() if callable(m) else m) for m in message])
logged = log_fn(message)
if isinstance(logged, BaseException):
Expand Down
8 changes: 6 additions & 2 deletions fast_llm_external_models/apriel2/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@
compose_configs,
)
from fast_llm_external_models.apriel2.conversion import llava as llava_converter
from fast_llm_external_models.apriel2.conversion import plan_surgery
from fast_llm_external_models.apriel2.conversion import (
plan_surgery,
)
from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter
from fast_llm_external_models.apriel2.conversion import strip_init_fields
from fast_llm_external_models.apriel2.conversion import (
strip_init_fields,
)

# Allow running as script or module
if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -795,10 +795,7 @@ def test_vs_qwen3next(
always uses initial_state=None, ignoring cached recurrent state.
"""
from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig
from transformers.models.qwen3_next.modeling_qwen3_next import (
Qwen3NextDynamicCache,
Qwen3NextGatedDeltaNet,
)
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextDynamicCache, Qwen3NextGatedDeltaNet

from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Cache, Apriel2GatedDeltaNet

Expand Down Expand Up @@ -826,7 +823,9 @@ def test_vs_qwen3next(
# Create models with same weights
torch.manual_seed(seed)
qwen_gdn = Qwen3NextGatedDeltaNet(qwen3_config, layer_idx=0).to(device="cuda", dtype=test_dtype)
apriel_gdn = Apriel2GatedDeltaNet(hidden_size, gdn_mixer_config, layer_idx=0).to(device="cuda", dtype=test_dtype)
apriel_gdn = Apriel2GatedDeltaNet(hidden_size, gdn_mixer_config, layer_idx=0).to(
device="cuda", dtype=test_dtype
)

# Transfer weights using conversion plan
plan = plan_qwen3next_gdn_to_apriel2(
Expand Down Expand Up @@ -1013,6 +1012,7 @@ def test_chunked_vs_recurrent(
msg=f"GDN chunked vs recurrent mode (prefill={prefill_len}, decode={decode_steps})",
)


# =============================================================================
# SECTION 2: EQUIVALENCE TESTS - KimiDeltaAttention
# =============================================================================
Expand Down Expand Up @@ -1046,10 +1046,8 @@ def test_vs_fla(
from fla.layers.kda import KimiDeltaAttention as FLA_KDA
from fla.models.utils import Cache as FLACache

from fast_llm_external_models.apriel2.modeling_apriel2 import (
Apriel2Cache,
KimiDeltaAttention as Apriel2_KDA,
)
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Cache
from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention as Apriel2_KDA

num_heads, head_dim = kda_config
seq_len = prefill_len + decode_steps + prefill2_len
Expand Down Expand Up @@ -1272,4 +1270,3 @@ def test_chunked_vs_recurrent(
atol=atol * 5,
msg=f"KDA chunked vs recurrent mode (prefill={prefill_len}, decode={decode_steps})",
)

13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,16 @@ testpaths = [
"fast_llm_external_models/tests" # External models tests
]
norecursedirs = ["Megatron-LM"]

[tool.isort]
profile = "black"
default_section = "THIRDPARTY"
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
line_length = 119
float_to_top = "True"
known_first_party = [
"fast_llm",
]
Loading
Loading