Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion benchmarks/recipes/user_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class UserConfig:
delete: bool = False
max_restarts: int = 0
temp_key: str = None
base_output_directory: str = None

def __post_init__(self):
"""Automatically generate derived attributes after the object is created."""
Expand All @@ -97,7 +98,7 @@ def __post_init__(self):
self.worker_flags,
)
self.headless_workload_name = f"{self.user[:3]}-headless"
self.base_output_directory = f"gs://{self.user}-{self.region}/{self.user}-"
self.base_output_directory = self.base_output_directory or f"gs://{self.user}-{self.region}/{self.user}-"

device_base_type = self.device_type.split("-", maxsplit=1)[0]
self.models = build_user_models(
Expand All @@ -124,4 +125,7 @@ def __post_init__(self):
selected_model_framework=["pathways"],
selected_model_names=["llama3_1_8b_8192"],
priority="medium",
base_output_directory=None, # GCS Bucket path
# Optional parameters, useful for single controller data loading optimizations
# proxy_flags="--sidecar_name=external",
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ tensorflow
tiktoken
tokamax
transformers
uvloop
qwix
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ opt-einsum>=3.4.0
optax>=0.2.6
optree>=0.18.0
optype>=0.14.0
orbax-checkpoint>=0.11.28
orbax-checkpoint>=0.11.33
packaging>=25.0
pandas>=2.3.3
parameterized>=0.9.0
Expand Down Expand Up @@ -245,6 +245,7 @@ tzdata>=2025.2
uritemplate>=4.2.0
urllib3>=2.5.0
uvicorn>=0.38.0
uvloop>=0.19.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi we usually don't recommend directly editing dependencies under generated_requirements folder. These two txt files are generated from base_requirements as in this guide. You need to edit base requirements, run seed-env to generated a new set of generated requirements.

Your current patch can work temporarily, but if someone else generated new requirement files, your current change will be lost without notice.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created b/486268025 to try to figure out the issue, but it looks like when I follow that procedure, we get into a severe dependency hell with cloud-tpu-diagnostics and some other subsequent libraries.

I think for the purposes of my checkin, the only thing strictly needed actually is just that orbax be upgraded to version 0.11.33 or greater. Uvloop comes from Orbax.

virtualenv>=20.35.4
wadler-lindig>=0.1.7
websockets>=15.0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ opt-einsum>=3.4.0
optax>=0.2.6
optree>=0.18.0
optype>=0.14.0
orbax-checkpoint>=0.11.28
orbax-checkpoint>=0.11.33
packaging>=25.0
pandas>=2.3.3
parameterized>=0.9.0
Expand Down Expand Up @@ -237,6 +237,7 @@ tzdata>=2025.2
uritemplate>=4.2.0
urllib3>=2.5.0
uvicorn>=0.38.0
uvloop>=0.19.0
virtualenv>=20.35.4
wadler-lindig>=0.1.7
websockets>=15.0.1
Expand Down
14 changes: 14 additions & 0 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def create_orbax_checkpoint_manager(
enable_continuous_checkpointing: bool = False,
max_num_checkpoints_to_keep: int = 10,
checkpoint_storage_concurrent_gb: int = 96,
enable_single_controller: bool = False,
colocated_python_checkpointing: bool = False,
enable_single_replica_ckpt_restoring: bool = False,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
Expand Down Expand Up @@ -269,6 +272,17 @@ def create_orbax_checkpoint_manager(
logger=orbax_logger,
)

# Use Colocated Python checkpointing optimization (Single Controller only).
if enable_single_controller and colocated_python_checkpointing:
max_logging.log("Registering colocated python array handler")
checkpointing_impl = ocp.pathways.CheckpointingImpl.from_options(
use_colocated_python=True,
)
ocp.pathways.register_type_handlers(
use_single_replica_array_handler=enable_single_replica_ckpt_restoring,
checkpointing_impl=checkpointing_impl,
)

max_logging.log("Checkpoint manager created!")
return manager

Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ enable_orbax_v1: False
checkpoint_conversion_fn: none
# optional checkpoint context to use for loading. options: "orbax", "safetensors"
source_checkpoint_layout: "orbax"

# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
colocated_python_checkpointing: False
############################### end checkpointing ##################################


Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ class Checkpointing(BaseModel):
True, description="If True, saves a final checkpoint upon training completion."
)
enable_continuous_checkpointing: bool = Field(False, description="If True, enables continuous checkpointing.")
colocated_python_checkpointing: bool = Field(False, description="If True, enables colocated python checkpointing.")


class OrbaxStorage(BaseModel):
Expand Down
5 changes: 4 additions & 1 deletion src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def create_training_tools(config, model, mesh):
# TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend
use_ocdbt = config.checkpoint_storage_use_ocdbt
use_zarr3 = config.checkpoint_storage_use_zarr3
if config.enable_single_controller:
if config.enable_single_controller and not config.colocated_python_checkpointing:
use_ocdbt, use_zarr3 = False, False

checkpoint_dir = ""
Expand All @@ -79,6 +79,9 @@ def create_training_tools(config, model, mesh):
config.enable_continuous_checkpointing,
config.max_num_checkpoints_to_keep,
config.checkpoint_storage_concurrent_gb,
config.enable_single_controller,
config.colocated_python_checkpointing,
config.enable_single_replica_ckpt_restoring,
)

return init_rng, checkpoint_manager, learning_rate_schedule, tx
Expand Down
Loading