Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/build_and_test_maxtext.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ jobs:
fail-fast: false
matrix:
image_type: ["py312"]
worker_group: [1, 2]
with:
device_type: tpu
device_name: v6e-4
Expand All @@ -160,6 +161,8 @@ jobs:
tf_force_gpu_allow_growth: false
container_resource_option: "--privileged"
is_scheduled_run: ${{ github.event_name == 'schedule' }}
worker_group: ${{ matrix.worker_group }}
total_workers: 2
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

maxtext_tpu_integration_tests:
Expand All @@ -170,6 +173,7 @@ jobs:
fail-fast: false
matrix:
image_type: ["py312"]
worker_group: [1, 2]
with:
device_type: tpu
device_name: v6e-4
Expand All @@ -180,6 +184,8 @@ jobs:
tf_force_gpu_allow_growth: false
container_resource_option: "--privileged"
is_scheduled_run: ${{ github.event_name == 'schedule' }}
worker_group: ${{ matrix.worker_group }}
total_workers: 2
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

maxtext_tpu_pathways_unit_tests:
Expand Down Expand Up @@ -231,6 +237,7 @@ jobs:
matrix:
image_type: ["py312"]
cuda: ["cuda12"]
worker_group: [1, 2]
with:
device_type: ${{ matrix.cuda }}
device_name: a100-40gb-4
Expand All @@ -241,6 +248,8 @@ jobs:
tf_force_gpu_allow_growth: true
container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged"
is_scheduled_run: ${{ github.event_name == 'schedule' }}
worker_group: ${{ matrix.worker_group }}
total_workers: 2
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

maxtext_gpu_integration_tests:
Expand All @@ -252,6 +261,7 @@ jobs:
matrix:
image_type: ["py312"]
cuda: ["cuda12"]
worker_group: [1, 2]
with:
device_type: ${{ matrix.cuda }}
device_name: a100-40gb-4
Expand All @@ -262,6 +272,8 @@ jobs:
tf_force_gpu_allow_growth: true
container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged"
is_scheduled_run: ${{ github.event_name == 'schedule' }}
worker_group: ${{ matrix.worker_group }}
total_workers: 2
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

all_tests_passed:
Expand Down
14 changes: 12 additions & 2 deletions docs/tutorials/posttraining/rl.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,16 @@ export HF_TOKEN=<Hugging Face access token>
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory

export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)

export CHIPS_PER_VM=<the number of chips per VM> # depends on hardware, for v5p this is 4, for v6e this is 8
```

For the value of `CHIPS_PER_VM` on different TPU hardware, refer the official document

- [TPU v5e](https://docs.cloud.google.com/tpu/docs/v5e) (single host, chips_per_vm=8)
- [TPU v5p](https://docs.cloud.google.com/tpu/docs/v5p) (single host, chips_per_vm=4)
- [TPU v6e](https://docs.cloud.google.com/tpu/docs/v6e) (single host, chips_per_vm=8)

## Get your model checkpoint

### Option 1: Using an existing MaxText checkpoint
Expand Down Expand Up @@ -159,7 +167,8 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
load_parameters_path=${MAXTEXT_CKPT_PATH} \
run_name=${RUN_NAME} \
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
hf_access_token=${HF_TOKEN}
hf_access_token=${HF_TOKEN} \
chips_per_vm=${CHIPS_PER_VM}
```

The overview of what this run will do is as follows:
Expand All @@ -183,7 +192,8 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
run_name=${RUN_NAME} \
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
hf_access_token=${HF_TOKEN} \
loss_algo=gspo-token
loss_algo=gspo-token \
chips_per_vm=${CHIPS_PER_VM}
```

The overview of what this run will do is as follows:
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,7 @@ class MaxTextConfig(
# Reinforcement Learning
RLHardware,
VLLM,
RL,
RLDataset,
RLEvaluation,
Reward,
Expand Down
Loading