Skip to content
6 changes: 5 additions & 1 deletion fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,11 @@ def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, ta
@contextlib.contextmanager
def set_generator(generator: torch.Generator) -> typing.Generator[None, None, None]:
"""Use the generator as default, for ops that don't support a generator argument."""
default_generator: torch.Generator = torch.cuda.default_generators[torch.cuda.current_device()]
default_generator: torch.Generator = (
torch.cuda.default_generators[generator.device.index]
if generator.device.type == "cuda"
else torch.default_generator
)
assert generator is not default_generator
old_state = default_generator.get_state()
default_generator.set_state(generator.get_state())
Expand Down
80 changes: 52 additions & 28 deletions fast_llm/core/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,35 @@
from amp_C import multi_tensor_scale as _multi_tensor_scale # noqa
from apex.multi_tensor_apply import multi_tensor_applier as _multi_tensor_applier # noqa

_apex_available = True
_apex_available = torch.cuda.is_available()
except ImportError:
_apex_available = False


def l2_norm(tensors: list[torch.Tensor], noop_flag: torch.Tensor) -> torch.Tensor:
assert _apex_available
norm, _ = _multi_tensor_applier(
_multi_tensor_l2norm,
noop_flag,
[tensors],
False, # no per-parameter norm
)
if _apex_available:
norm, _ = _multi_tensor_applier(
_multi_tensor_l2norm,
noop_flag,
[tensors],
False, # no per-parameter norm
)
else:
norm = sum(torch.norm(tensor) ** 2 for tensor in tensors) ** 0.5
return norm


def scale_(tensors: list[torch.Tensor], noop_flag: torch.Tensor, scale: torch.Tensor | float) -> None:
assert _apex_available
_multi_tensor_applier(
_multi_tensor_scale,
noop_flag,
[tensors, tensors],
scale,
)
if _apex_available:
_multi_tensor_applier(
_multi_tensor_scale,
noop_flag,
[tensors, tensors],
scale,
)
else:
for tensor in tensors:
tensor.mul_(scale)


# TODO: Same as torch._fused_adam_?
Expand All @@ -52,16 +57,35 @@ def fused_adam(
eps: float,
step: int,
) -> None:
_multi_tensor_applier(
_multi_tensor_adam,
noop_flag,
[grads, params, exp_avgs, exp_avg_sqs],
lr,
beta1,
beta2,
eps,
step,
1, # adamw
1, # bias correction
wd,
)
if _apex_available:
_multi_tensor_applier(
_multi_tensor_adam,
noop_flag,
[grads, params, exp_avgs, exp_avg_sqs],
lr,
beta1,
beta2,
eps,
step,
1, # adamw
1, # bias correction
wd,
)
else:
import torch.optim.adamw as adamw

adamw.adamw(
params,
grads,
exp_avgs,
exp_avg_sqs,
None,
lr=lr,
beta1=beta1,
beta2=beta2,
eps=eps,
state_steps=torch.full([len(params)], step, dtype=torch.int64, device=params[0].device).unbind(),
weight_decay=wd,
amsgrad=False,
maximize=False,
)
8 changes: 2 additions & 6 deletions fast_llm/engine/checkpoint/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig
from fast_llm.engine.config_utils.runnable import RunnableConfig
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode
from fast_llm.functional.config import TritonConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -64,8 +63,8 @@ def _convert_model_partial(
logger.info(f"Loading {self.input.format} checkpoint from {self.input.path}...")
model = model_class.from_pretrained(
self.input,
{("distributed", "use_cuda"): not self.use_cpu},
mode=StageMode.weights,
use_cpu=self.use_cpu,
stage_filter=stage_filter,
)
logger.info(f"Saving {output.format} checkpoint to {output.path}...")
Expand All @@ -78,9 +77,6 @@ def run(self):
# TODO: Set logging in tests
logging.getLogger().setLevel(logging.INFO)
self.to_logs()
# Disable Triton to convert model on CPU
if self.use_cpu:
TritonConfig.TRITON_ENABLED = False
# Skip on exist_ok=False if the model has already been processed
if not self.exist_ok and (self.output.path / "ok").exists():
logger.info(
Expand All @@ -100,8 +96,8 @@ def run(self):
# Create a dummy version to determine the stage split.
model = model_class.from_pretrained(
self.input.to_copy({"model_weights": False}),
{("distributed", "use_cuda"): not self.use_cpu},
mode=StageMode.off_device,
use_cpu=self.use_cpu,
)
stages_per_step = math.ceil(self.layers_per_step / model._config.multi_stage.layers_per_stage)
num_stages = len(model.stages)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def configure_logging(
def get_run(self, distributed: "Distributed") -> "Run":
from fast_llm.functional.config import TritonConfig

TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels
TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels # and distributed.config.use_cuda
TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels
run = Run(config=self, distributed=distributed)
set_global_variables(not self.run.torch_dynamo_enable)
Expand Down
5 changes: 5 additions & 0 deletions fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ class DistributedConfig(Config):
hint=FieldHint.optional,
valid=check_field(Assert.gt, 0),
)
use_cuda: bool = Field(
default=True,
desc="Enable CUDA device.",
hint=FieldHint.expert,
)
seed: int = Field(default=1234, desc="A seed for training.", hint=FieldHint.optional)
# TODO: Rename to compute_dtype (not just for training), move elsewhere
compute_dtype: DataType = Field(
Expand Down
23 changes: 12 additions & 11 deletions fast_llm/engine/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
world_size: int | None = None,
local_world_size: int | None = None,
timeout: float = 60,
use_cpu: bool = False,
use_cuda: bool = True,
init_method: str = "env://",
backend: DistributedBackend = DistributedBackend.nccl,
):
Expand All @@ -38,19 +38,20 @@ def __init__(
DistributedConfig.default_local_world_size if local_world_size is None else local_world_size
)
self._timeout = timeout
self._use_cpu = use_cpu
self._use_cuda = use_cuda
self._backend = backend
self._process_groups = {}

if self._use_cpu:
if backend == DistributedBackend.nccl:
Assert.eq(self._world_size, 1)
self._device = torch.device("cpu")
else:
if self._use_cuda:
assert torch.cuda.is_available()
Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count())
torch.cuda.init()
self._device = torch.device(self._rank % self._local_world_size)
torch.cuda.set_device(self._device)
else:
if backend == DistributedBackend.nccl:
Assert.eq(self._world_size, 1)
self._device = torch.device("cpu")

if self._world_size > 1:
if self._rank == 0:
Expand Down Expand Up @@ -153,7 +154,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]):
TODO: Clarify cpu support.
"""

def __init__(self, config: DistributedConfig, use_cpu: bool = False):
def __init__(self, config: DistributedConfig):
super().__init__(config)
assert self._config.reference_config is None

Expand All @@ -164,15 +165,15 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
self._config.world_size,
self._config.local_world_size,
self._config.timeout,
use_cpu,
self._config.use_cuda,
backend=self._config.backend,
)
else:
self._pool = _default_pool
Assert.geq(self._pool.world_size, self._config.world_size)
Assert.eq(self._pool.rank, self._config.rank)
Assert.geq(self._pool.local_world_size, self._config.local_world_size)
Assert.eq(self._pool.device.type, "cpu" if use_cpu else "cuda")
Assert.eq(self._pool.device.type, "cuda" if self._config.use_cuda else "cpu")
Assert.eq(self._pool.backend, self._config.backend)

self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world])
Expand Down Expand Up @@ -259,5 +260,5 @@ def set_step(self, step: int, phase: PhaseType) -> None:
self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED)

def __del__(self):
if self._local_pool:
if getattr(self, "_local_pool", False) and hasattr(self, "_pool"):
self._pool.shutdown()
2 changes: 0 additions & 2 deletions fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def from_pretrained(
optimizer_state_names: tuple[str, ...] | None = None,
# setup: bool = True,
mode: StageMode = StageMode.training,
use_cpu: bool = False,
stage_filter: set | None = None,
**kwargs,
) -> typing.Self:
Expand All @@ -104,7 +103,6 @@ def from_pretrained(
optimizer_state_names=optimizer_state_names,
setup=True,
mode=mode,
use_cpu=use_cpu,
stage_filter=stage_filter,
)

Expand Down
3 changes: 1 addition & 2 deletions fast_llm/engine/multi_stage/fast_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def from_pretrained(
optimizer_state_names: tuple[str, ...] | None = None,
setup: bool = True,
mode: StageMode = StageMode.training,
use_cpu: bool = False,
stage_filter: set | None = None,
) -> typing.Self:
metadata = cls.config_class.load_metadata(pretrained_config)
Expand All @@ -69,7 +68,7 @@ def from_pretrained(
)

if setup:
model.setup(Distributed(config.distributed, use_cpu=use_cpu), mode=mode)
model.setup(Distributed(config.distributed), mode=mode)

if mode.on_device:
if pretrained_config.model_weights:
Expand Down
18 changes: 18 additions & 0 deletions fast_llm/engine/schedule/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,21 @@ class EventType(str, enum.Enum):
send = "send"
recv = "recv"
pipe_wait_compute = "pipe_wait_compute"


class MockStream:
stream_id: int = 0

def wait_stream(self, stream):
pass

def __eq__(self, other):
return isinstance(other, MockStream)


class MockEvent:
def record(self, stream=None):
pass

def wait(self):
pass
Loading
Loading