diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index 9a2d778bb6..683e1e568e 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -150,6 +150,7 @@ jobs: fail-fast: false matrix: image_type: ["py312"] + worker_group: [1, 2] with: device_type: tpu device_name: v6e-4 @@ -160,6 +161,8 @@ jobs: tf_force_gpu_allow_growth: false container_resource_option: "--privileged" is_scheduled_run: ${{ github.event_name == 'schedule' }} + worker_group: ${{ matrix.worker_group }} + total_workers: 2 maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} maxtext_tpu_integration_tests: @@ -170,6 +173,7 @@ jobs: fail-fast: false matrix: image_type: ["py312"] + worker_group: [1, 2] with: device_type: tpu device_name: v6e-4 @@ -180,6 +184,8 @@ jobs: tf_force_gpu_allow_growth: false container_resource_option: "--privileged" is_scheduled_run: ${{ github.event_name == 'schedule' }} + worker_group: ${{ matrix.worker_group }} + total_workers: 2 maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} maxtext_tpu_pathways_unit_tests: @@ -231,6 +237,7 @@ jobs: matrix: image_type: ["py312"] cuda: ["cuda12"] + worker_group: [1, 2] with: device_type: ${{ matrix.cuda }} device_name: a100-40gb-4 @@ -241,6 +248,8 @@ jobs: tf_force_gpu_allow_growth: true container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged" is_scheduled_run: ${{ github.event_name == 'schedule' }} + worker_group: ${{ matrix.worker_group }} + total_workers: 2 maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} maxtext_gpu_integration_tests: @@ -252,6 +261,7 @@ jobs: matrix: image_type: ["py312"] cuda: ["cuda12"] + worker_group: [1, 2] with: device_type: ${{ matrix.cuda }} device_name: a100-40gb-4 @@ -262,6 +272,8 @@ jobs: tf_force_gpu_allow_growth: true container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged" is_scheduled_run: ${{ github.event_name == 'schedule' }} + worker_group: ${{ matrix.worker_group }} + total_workers: 2 maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} all_tests_passed: diff --git a/docs/tutorials/posttraining/rl.md b/docs/tutorials/posttraining/rl.md index f77af73a80..c13e199478 100644 --- a/docs/tutorials/posttraining/rl.md +++ b/docs/tutorials/posttraining/rl.md @@ -127,8 +127,16 @@ export HF_TOKEN= export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory export RUN_NAME= # e.g., $(date +%Y-%m-%d-%H-%M-%S) + +export CHIPS_PER_VM= # depends on hardware, for v5p this is 4, for v6e this is 8 ``` +For the value of `CHIPS_PER_VM` on different TPU hardware, refer the official document + +- [TPU v5e](https://docs.cloud.google.com/tpu/docs/v5e) (single host, chips_per_vm=8) +- [TPU v5p](https://docs.cloud.google.com/tpu/docs/v5p) (single host, chips_per_vm=4) +- [TPU v6e](https://docs.cloud.google.com/tpu/docs/v6e) (single host, chips_per_vm=8) + ## Get your model checkpoint ### Option 1: Using an existing MaxText checkpoint @@ -159,7 +167,8 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ load_parameters_path=${MAXTEXT_CKPT_PATH} \ run_name=${RUN_NAME} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ - hf_access_token=${HF_TOKEN} + hf_access_token=${HF_TOKEN} \ + chips_per_vm=${CHIPS_PER_VM} ``` The overview of what this run will do is as follows: @@ -183,7 +192,8 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ run_name=${RUN_NAME} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ hf_access_token=${HF_TOKEN} \ - loss_algo=gspo-token + loss_algo=gspo-token \ + chips_per_vm=${CHIPS_PER_VM} ``` The overview of what this run will do is as follows: diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index dc8f0013df..ef899b205c 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -1735,6 +1735,7 @@ class MaxTextConfig( # Reinforcement Learning RLHardware, VLLM, + RL, RLDataset, RLEvaluation, Reward,