diff --git a/benchmarks/recipes/user_configs.py b/benchmarks/recipes/user_configs.py index 15255cc46b..79c33d3c05 100644 --- a/benchmarks/recipes/user_configs.py +++ b/benchmarks/recipes/user_configs.py @@ -53,6 +53,7 @@ class UserConfig: zone: str = "us-east5-b" device_type: str = "v6e-256" priority: str = "medium" + base_output_directory: str = None # Images for env server_image: str = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server" @@ -97,7 +98,10 @@ def __post_init__(self): self.worker_flags, ) self.headless_workload_name = f"{self.user[:3]}-headless" - self.base_output_directory = f"gs://{self.user}-{self.region}/{self.user}-" + self.base_output_directory = ( + self.base_output_directory + or f"gs://{self.user}-{self.region}/{self.user}-" + ) device_base_type = self.device_type.split("-", maxsplit=1)[0] self.models = build_user_models( @@ -124,4 +128,7 @@ def __post_init__(self): selected_model_framework=["pathways"], selected_model_names=["llama3_1_8b_8192"], priority="medium", + base_output_directory=None, # GCS Bucket path + # Optional parameters, useful for single controller data loading optimizations + # proxy_flags="--sidecar_name=external", ) diff --git a/dependencies/requirements/base_requirements/requirements.txt b/dependencies/requirements/base_requirements/requirements.txt index c40252cfc1..66a6c3521d 100644 --- a/dependencies/requirements/base_requirements/requirements.txt +++ b/dependencies/requirements/base_requirements/requirements.txt @@ -41,6 +41,7 @@ tensorflow tiktoken tokamax transformers +uvloop qwix google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/dependencies/requirements/generated_requirements/cuda12-requirements.txt b/dependencies/requirements/generated_requirements/cuda12-requirements.txt index 09483af80c..364465d935 100644 --- a/dependencies/requirements/generated_requirements/cuda12-requirements.txt +++ b/dependencies/requirements/generated_requirements/cuda12-requirements.txt @@ -155,7 +155,7 @@ opt-einsum>=3.4.0 optax>=0.2.6 optree>=0.18.0 optype>=0.14.0 -orbax-checkpoint>=0.11.28 +orbax-checkpoint>=0.11.33 packaging>=25.0 pandas>=2.3.3 parameterized>=0.9.0 @@ -245,6 +245,7 @@ tzdata>=2025.2 uritemplate>=4.2.0 urllib3>=2.5.0 uvicorn>=0.38.0 +uvloop>=0.19.0 virtualenv>=20.35.4 wadler-lindig>=0.1.7 websockets>=15.0.1 diff --git a/dependencies/requirements/generated_requirements/tpu-requirements.txt b/dependencies/requirements/generated_requirements/tpu-requirements.txt index 2c59290d60..58ebd5555f 100644 --- a/dependencies/requirements/generated_requirements/tpu-requirements.txt +++ b/dependencies/requirements/generated_requirements/tpu-requirements.txt @@ -149,7 +149,7 @@ opt-einsum>=3.4.0 optax>=0.2.6 optree>=0.18.0 optype>=0.14.0 -orbax-checkpoint>=0.11.28 +orbax-checkpoint>=0.11.33 packaging>=25.0 pandas>=2.3.3 parameterized>=0.9.0 @@ -237,6 +237,7 @@ tzdata>=2025.2 uritemplate>=4.2.0 urllib3>=2.5.0 uvicorn>=0.38.0 +uvloop>=0.19.0 virtualenv>=20.35.4 wadler-lindig>=0.1.7 websockets>=15.0.1 diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index ba1d325d20..558df97d4e 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -217,6 +217,9 @@ def create_orbax_checkpoint_manager( enable_continuous_checkpointing: bool = False, max_num_checkpoints_to_keep: int = 10, checkpoint_storage_concurrent_gb: int = 96, + enable_single_controller: bool = False, + colocated_python_checkpointing: bool = False, + enable_single_replica_ckpt_restoring: bool = False, ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -269,6 +272,17 @@ def create_orbax_checkpoint_manager( logger=orbax_logger, ) + # Use Colocated Python checkpointing optimization (Single Controller only). + if enable_single_controller and colocated_python_checkpointing: + max_logging.log("Registering colocated python array handler") + checkpointing_impl = ocp.pathways.CheckpointingImpl.from_options( + use_colocated_python=True, + ) + ocp.pathways.register_type_handlers( + use_single_replica_array_handler=enable_single_replica_ckpt_restoring, + checkpointing_impl=checkpointing_impl, + ) + max_logging.log("Checkpoint manager created!") return manager diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 19a3d1d4d0..ff49b9ede2 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -76,6 +76,9 @@ enable_orbax_v1: False checkpoint_conversion_fn: none # optional checkpoint context to use for loading. options: "orbax", "safetensors" source_checkpoint_layout: "orbax" + +# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing +colocated_python_checkpointing: False ############################### end checkpointing ################################## diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 799a125736..f70265b800 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -314,6 +314,7 @@ class Checkpointing(BaseModel): True, description="If True, saves a final checkpoint upon training completion." ) enable_continuous_checkpointing: bool = Field(False, description="If True, enables continuous checkpointing.") + colocated_python_checkpointing: bool = Field(False, description="If True, enables colocated python checkpointing.") class OrbaxStorage(BaseModel): diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index eb5656a9fc..cb550fdf43 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -61,7 +61,7 @@ def create_training_tools(config, model, mesh): # TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend use_ocdbt = config.checkpoint_storage_use_ocdbt use_zarr3 = config.checkpoint_storage_use_zarr3 - if config.enable_single_controller: + if config.enable_single_controller and not config.colocated_python_checkpointing: use_ocdbt, use_zarr3 = False, False checkpoint_dir = "" @@ -79,6 +79,9 @@ def create_training_tools(config, model, mesh): config.enable_continuous_checkpointing, config.max_num_checkpoints_to_keep, config.checkpoint_storage_concurrent_gb, + config.enable_single_controller, + config.colocated_python_checkpointing, + config.enable_single_replica_ckpt_restoring, ) return init_rng, checkpoint_manager, learning_rate_schedule, tx