Skip to content

Commit 60d38da

Browse files
authored
Feat(VL Training)Support Qwen2.5-VL training and freeze MLLM modules (#3153)
1 parent 17b0de7 commit 60d38da

File tree

11 files changed

+410
-50
lines changed

11 files changed

+410
-50
lines changed

docs/zh/datasets_format_zh.md

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
### 1.1. 在线数据流
1010

11-
#### 1.1.1. erniekit格式
11+
#### 1.1.1. erniekit 格式
1212

1313
使用 `erniekit` 格式需要在 `train(/eval)_dataset_type` 处指定为 `erniekit`
1414

15-
erniekit格式:每条数据都是一个字典,包含以下字段:
15+
erniekit 格式:每条数据都是一个字典,包含以下字段:
1616

1717
- `text` : `str, List(str)`
1818

@@ -30,11 +30,11 @@ wget https://paddleformers.bj.bcebos.com/datasets/pt_data.tar.gz
3030
mkdir -p data/pt && tar -xf pt_data.tar.gz -C data/pt/
3131
```
3232

33-
#### 1.1.2. messages格式
33+
#### 1.1.2. messages 格式
3434

3535
使用 `messages` 格式需要在 `train(/eval)_dataset_type` 处指定为 `messages`
3636

37-
messages格式:每条数据都是一个字典,包含以下字段:
37+
messages 格式:每条数据都是一个字典,包含以下字段:
3838

3939
- `messages` : `List(Dict)`
4040

@@ -49,7 +49,7 @@ messages格式:每条数据都是一个字典,包含以下字段:
4949

5050
我们也可以选择使用离线的比特预训练数据流,更节省内存。
5151

52-
为了方便测试,我们也提供了[离线预训练demo数据集](https://paddleformers.bj.bcebos.com/datasets/pretrain_offline_data.tar.gz)可以直接使用:
52+
为了方便测试,我们也提供了[离线预训练 demo 数据集](https://paddleformers.bj.bcebos.com/datasets/pretrain_offline_data.tar.gz)可以直接使用:
5353

5454
```shell
5555
wget https://paddleformers.bj.bcebos.com/datasets/pretrain_offline_data.tar.gz
@@ -60,7 +60,7 @@ tar -xf pretrain_offline_data.tar.gz -C data/pre-training/
6060

6161
下载一个文本数据集,例如 https://modelscope.cn/datasets/BazingaLyn/mini_pretrain_dataset
6262

63-
格式需为jsonl,每行格式例如BazingaLyn/mini_pretrain_dataset/pretrain_hq_v7.jsonl:
63+
格式需为 jsonl,每行格式例如 BazingaLyn/mini_pretrain_dataset/pretrain_hq_v7.jsonl:
6464
```text
6565
{"text": "番茄炒蛋\n材料:\n鸡蛋3个、番茄1个、油、盐、糖、水淀粉\n做法:..."}
6666
{"text": "请描述一下如何正确规划个人理财。正确规划个人理财需要以下几个步骤..."}
@@ -82,35 +82,35 @@ python -u examples/tools/create_pretraining_data.py \
8282
```
8383

8484
- 参数说明
85-
85+
8686
| 参数名 | 类型 | 说明 |
8787
|--------------------|----------- |-----------------|
8888
| `--model_name_or_path` | string | 模型路径 |
8989
| `--data_format` | string | 支持的文件格式,当前只支持 JSON |
90-
| `--input_path` | string | 输入的json文件的路径 |
91-
| `--append_eos` | store_true | 是否在document的结尾添加eos token |
90+
| `--input_path` | string | 输入的 json 文件的路径 |
91+
| `--append_eos` | store_true | 是否在 document 的结尾添加 eos token |
9292
| `--output_prefix` | str | 输出文件的前缀 |
9393
| `--workers` | int | 运行的进程数 |
9494
| `--log_interval` | int | 打印日志间隔 |
95-
| `--data_impl` | str | 制作的数据集类型,默认为mmap,也可以选择lazy |
95+
| `--data_impl` | str | 制作的数据集类型,默认为 mmap,也可以选择 lazy |
9696

97-
## 2. SFT数据流
97+
## 2. SFT 数据流
9898

99-
### erniekit格式
99+
### erniekit 格式
100100

101101
使用 `erniekit` 格式需要在 `train(/eval)_dataset_type` 处指定为 `erniekit`
102102

103-
SFT数据流中,每条数据都是一个字典,包含以下字段:
103+
SFT 数据流中,每条数据都是一个字典,包含以下字段:
104104

105105
- `src` : `str, List(str)`, 模型的输入指令(instruction)、提示(prompt),模型应该执行的任务。
106106
- `tgt` : `str, List(str)`, 模型的输出。
107107
- `system(optional)` : 系统配置
108108
- `label(optional)`: Training flag (1=参与训练, 0=不参与训练)
109-
- `is_system(optional)` : 标志src的第一条数据是否是system
109+
- `is_system(optional)` : 标志 src 的第一条数据是否是 system
110110

111111
Notes:
112112
* `src``tgt` 为支持多轮对话的列表(List)对象
113-
* 每个训练样本均为JSON格式,多个样本以换行符分隔
113+
* 每个训练样本均为 JSON 格式,多个样本以换行符分隔
114114

115115
样例数据:
116116
```json
@@ -136,21 +136,21 @@ mkdir -p data/sft && tar -xf alpaca_demo.gz -C data/sft/ --strip-components=1
136136
```
137137

138138

139-
### messages格式
139+
### messages 格式
140140

141141
使用 `messages` 格式需要在 `train(/eval)_dataset_type` 处指定为 `messages`
142142

143-
SFT数据流中,每条数据都是一个字典,包含以下字段:
143+
SFT 数据流中,每条数据都是一个字典,包含以下字段:
144144

145-
- `messages` : `List(Dict)`, 每个字典包含 `role``content``tool_calls(optional)` 三种key
145+
- `messages` : `List(Dict)`, 每个字典包含 `role``content``tool_calls(optional)` 三种 key
146146
- `role` 的值可以选择 `system`, `user`, `assistant``tool(optional)`
147147
- `content`为具体的对话内容。
148148
- `tool_calls(optional)` 为申请工具调用。
149149
- `tools(optional)` : `List(Dict)`, 表示工具信息。
150150
- `label(optional)`: Training flag (1=参与训练, 0=不参与训练)
151151

152152
Notes:
153-
* 每个训练样本均为JSON格式,多个样本以换行符分隔
153+
* 每个训练样本均为 JSON 格式,多个样本以换行符分隔
154154

155155
样例数据:
156156

@@ -166,9 +166,9 @@ Notes:
166166
]
167167
```
168168

169-
- 注意:在 `examples/data/sft_think-train.jsonl``examples/data/sft_think-eval.jsonl` 中提供的demo数据集来自由nvidia发布的 [OpenCodeReasoning数据集](https://huggingface.co/datasets/nvidia/OpenCodeReasoning)。该数据集需要遵循 Creative Commons Attribution 4.0 International License (CC BY 4.0) 协议。
169+
- 注意:在 `examples/data/sft_think-train.jsonl``examples/data/sft_think-eval.jsonl` 中提供的 demo 数据集来自由 nvidia 发布的 [OpenCodeReasoning 数据集](https://huggingface.co/datasets/nvidia/OpenCodeReasoning)。该数据集需要遵循 Creative Commons Attribution 4.0 International License (CC BY 4.0) 协议。
170170

171-
用于function call训练的demo数据
171+
用于 function call 训练的 demo 数据
172172

173173
```json
174174
[
@@ -194,23 +194,23 @@ wget https://paddleformers.bj.bcebos.com/datasets/sft_function_call_demo.tar.gz
194194
mkdir -p data/sft && tar -zxf sft_function_call_demo.tar.gz -C data/sft/
195195
```
196196

197-
## 3. DPO数据流
197+
## 3. DPO 数据流
198198

199-
### erniekit格式
199+
### erniekit 格式
200200

201201
使用 `erniekit` 格式需要在 `train(/eval)_dataset_type` 处指定为 `erniekit`
202202

203-
DPO数据流中,每条数据都是一个字典,包含以下字段:
203+
DPO 数据流中,每条数据都是一个字典,包含以下字段:
204204

205205
- `system(optional)`: 系统配置
206206
- `src` : `str, List(str)`, 用户对话内容
207-
- `tgt` : `str, List(str)`, 系统回复内容(比src少一个
207+
- `tgt` : `str, List(str)`, 系统回复内容(比 src 少一个
208208
- `response` : `str, List(str)`, 包含 chosen 和 rejected 回复。
209209
- `sort` : `List(int)`, sort 值用于区分 response 中 chosen 和 rejected(sort 值小的是 rejected,sort 值大的是 chosen)。
210-
- `is_system(optional)` : 标志src的第一条数据是否是system
210+
- `is_system(optional)` : 标志 src 的第一条数据是否是 system
211211

212212
Notes:
213-
* 每个训练样本均为JSON格式,多个样本以换行符分隔
213+
* 每个训练样本均为 JSON 格式,多个样本以换行符分隔
214214

215215
样例数据:
216216

@@ -249,7 +249,7 @@ mkdir -p data/dpo && tar -zxf ultrafeedback_binarized.tar.gz -C data/dpo/ --stri
249249

250250
使用 `messages` 格式需要在 `train(/eval)_dataset_type` 处指定为 `messages`
251251

252-
DPO数据流中,每条数据都是一个字典,包含以下字段:
252+
DPO 数据流中,每条数据都是一个字典,包含以下字段:
253253
- `messages` : `List(dict)`, 对话历史列表。
254254
- 普通轮次:包含 `role` (`"user"``"assistant"`) 和 `content` (`str`) 字段。
255255
- 偏好/非偏好轮次(用于偏好学习):包含以下两个关键字段,用于表示对同一用户查询的不同系统回复的偏好排序。
@@ -258,7 +258,7 @@ DPO数据流中,每条数据都是一个字典,包含以下字段:
258258
- `tools` : `List(dict)`, 对话中可能用到的工具(函数)的定义列表。
259259
- `label` : `List(int)`, 用于区分 `preferred_output``non_preferred_output` 的排序标签。其中 0 对应 `non_preferred_output` (rejected), 1 对应 `preferred_output` (chosen)。
260260

261-
详细的数据格式可见[function call说明](https://github.com/PaddlePaddle/PaddleFormers/blob/develop/examples/best_practices/function_call.md)
261+
详细的数据格式可见[function call 说明](https://github.com/PaddlePaddle/PaddleFormers/blob/develop/examples/best_practices/function_call.md)
262262

263263
样例数据
264264
```json
@@ -329,13 +329,13 @@ wget https://paddleformers.bj.bcebos.com/datasets/dpo_function_call_1k.tar.gz
329329
mkdir -p data/dpo_fc && tar -zxf dpo_function_call_1k.tar.gz -C data/dpo_fc/
330330
```
331331

332-
## 4. 多模 SFT数据流
332+
## 4. 多模 SFT 数据流
333333

334-
### erniekit格式
334+
### erniekit 格式
335335

336336
使用 `erniekit` 格式需要在 `train(/eval)_dataset_type` 处指定为 `erniekit`
337337

338-
SFT数据流中,每条数据都是一个字典,包含以下字段:
338+
SFT 数据流中,每条数据都是一个字典,包含以下字段:
339339

340340
* `text_info`: 纯文本的列表,每个元素包含一个 `text` 和一个 `tag`
341341
* `text`: 来自使用者的问题或系统回复的文字内容
@@ -397,18 +397,18 @@ SFT数据流中,每条数据都是一个字典,包含以下字段:
397397
}
398398
```
399399

400-
为了方便测试,我们也提供了用于快速训练的demo数据,请根据您的需要下载[数据](https://paddleformers.bj.bcebos.com/datasets/DoclingMatix.tar.gz),并将其解压缩到`tests/fixtures/dummy/sft-vl/`
400+
为了方便测试,我们也提供了用于快速训练的 demo 数据,请根据您的需要下载[数据](https://paddleformers.bj.bcebos.com/datasets/DoclingMatix.tar.gz),并将其解压缩到`tests/fixtures/dummy/sft-vl/`
401401

402402
```shell
403403
wget https://paddleformers.bj.bcebos.com/datasets/DoclingMatix.tar.gz
404-
tar -xf DoclingMatix.tar.gz -C tests/fixtures/dummy/sft-vl/ --strip-components=1
404+
tar -xf DoclingMatix.tar.gz -C tests/fixtures/dummy/sft-vl/
405405
```
406406

407-
### messages格式
407+
### messages 格式
408408

409409
使用 `messages` 格式需要在 `train(/eval)_dataset_type` 处指定为 `messages`
410410

411-
多模messages格式需要在纯文messages格式的基础上加上`images``videos``audios`几个key,用于传入多模态资源的`url`或者`path`,同时在`messages`中插入`<image>``<video>``<audio>`标签来表述插入多模态数据的位置:
411+
多模 messages 格式需要在纯文 messages 格式的基础上加上`images``videos``audios`几个 key,用于传入多模态资源的`url`或者`path`,同时在`messages`中插入`<image>``<video>``<audio>`标签来表述插入多模态数据的位置:
412412

413413
纯文:
414414
```json

examples/config/sft-vl/full.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ save_strategy: steps
3333
logging_steps: 1
3434
gradient_accumulation_steps: 4
3535
logging_dir: ./vdl_log
36-
output_dir: ./checkpoints/Qwen2.5-VL-sft-full
36+
output_dir: ./checkpoints/qwen2.5-vl-sft-full
3737
disable_tqdm: true
3838
eval_accumulation_steps: 16
3939

@@ -48,4 +48,7 @@ sharding: stage2
4848
recompute: true
4949
bf16: true
5050
fp16_opt_level: O2
51-
unified_checkpoint: true
51+
unified_checkpoint: false
52+
save_checkpoint_format: "flex_checkpoint"
53+
load_checkpoint_format: "flex_checkpoint"
54+
freeze_config: freeze_vision freeze_aligner
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ train_dataset_path: ./tests/fixtures/dummy/sft-vl/train.jsonl
55
train_dataset_prob: "1.0"
66
eval_dataset_path: ./tests/fixtures/dummy/sft-vl/train.jsonl
77
eval_dataset_prob: "1.0"
8-
max_seq_len: 8192
8+
max_seq_len: 32768
99
packing: true
1010
mix_strategy: concat
1111
template_backend: custom
1212
template: qwen2_vl
1313

1414
### model
15-
model_name_or_path: Qwen2.5-VL-3B-Instruct
15+
model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
1616
attn_impl: flashmask
1717

1818
### finetuning
@@ -33,7 +33,7 @@ save_strategy: steps
3333
logging_steps: 1
3434
gradient_accumulation_steps: 4
3535
logging_dir: ./vdl_log
36-
output_dir: ./checkpoints/Qwen2.5-VL-sft-full-tp-pp
36+
output_dir: ./checkpoints/qwen2.5-vl-sft-full-tp
3737
disable_tqdm: true
3838
eval_accumulation_steps: 16
3939

@@ -43,12 +43,13 @@ learning_rate: 1.0e-5
4343

4444
# performance
4545
tensor_parallel_degree: 2
46-
pipeline_parallel_degree: 2
46+
pipeline_parallel_degree: 1
4747
sequence_parallel: true
4848
sharding: stage1
4949
recompute: true
5050
bf16: true
5151
fp16_opt_level: O2
5252
unified_checkpoint: false
5353
save_checkpoint_format: "flex_checkpoint"
54-
load_checkpoint_format: "flex_checkpoint"
54+
load_checkpoint_format: "flex_checkpoint"
55+
freeze_config: freeze_vision freeze_aligner

examples/config/sft-vl/lora.yaml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
### data
2+
train_dataset_type: erniekit
3+
eval_dataset_type: erniekit
4+
train_dataset_path: ./tests/fixtures/dummy/sft-vl/train.jsonl
5+
train_dataset_prob: "1.0"
6+
eval_dataset_path: ./tests/fixtures/dummy/sft-vl/train.jsonl
7+
eval_dataset_prob: "1.0"
8+
max_seq_len: 8192
9+
packing: false
10+
mix_strategy: concat
11+
template_backend: custom
12+
template: qwen2_vl
13+
14+
### model
15+
model_name_or_path: Qwen2.5-VL-3B-Instruct
16+
attn_impl: flashmask
17+
lora: true
18+
lora_rank: 8
19+
20+
### finetuning
21+
# base
22+
stage: VL-SFT
23+
fine_tuning: lora
24+
seed: 23
25+
do_train: true
26+
do_eval: true
27+
per_device_eval_batch_size: 1
28+
per_device_train_batch_size: 1
29+
num_train_epochs: 1
30+
max_steps: -1
31+
eval_steps: 100
32+
evaluation_strategy: steps
33+
save_steps: 100
34+
save_strategy: steps
35+
logging_steps: 1
36+
gradient_accumulation_steps: 4
37+
logging_dir: ./vdl_log
38+
output_dir: ./checkpoints/qwen2.5-vl-sft-lora
39+
disable_tqdm: true
40+
eval_accumulation_steps: 16
41+
42+
# train
43+
warmup_steps: 20
44+
learning_rate: 1.0e-4
45+
46+
# performance
47+
tensor_parallel_degree: 1
48+
pipeline_parallel_degree: 1
49+
sharding: stage2
50+
recompute: true
51+
bf16: true
52+
fp16_opt_level: O2
53+
unified_checkpoint: false
54+
save_checkpoint_format: "flex_checkpoint"
55+
load_checkpoint_format: "flex_checkpoint"
56+
freeze_config: freeze_vision freeze_aligner
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
### data
2+
train_dataset_type: erniekit
3+
eval_dataset_type: erniekit
4+
train_dataset_path: ./tests/fixtures/dummy/sft-vl/train.jsonl
5+
train_dataset_prob: "1.0"
6+
eval_dataset_path: ./tests/fixtures/dummy/sft-vl/train.jsonl
7+
eval_dataset_prob: "1.0"
8+
max_seq_len: 32768
9+
packing: true
10+
mix_strategy: concat
11+
template_backend: custom
12+
template: qwen2_vl
13+
14+
### model
15+
model_name_or_path: Qwen2.5-VL-3B-Instruct
16+
attn_impl: flashmask
17+
lora: true
18+
lora_rank: 8
19+
20+
### finetuning
21+
# base
22+
stage: VL-SFT
23+
fine_tuning: lora
24+
seed: 23
25+
do_train: true
26+
do_eval: true
27+
per_device_eval_batch_size: 1
28+
per_device_train_batch_size: 1
29+
num_train_epochs: 1
30+
max_steps: -1
31+
eval_steps: 100
32+
evaluation_strategy: steps
33+
save_steps: 100
34+
save_strategy: steps
35+
logging_steps: 1
36+
gradient_accumulation_steps: 4
37+
logging_dir: ./vdl_log
38+
output_dir: ./checkpoints/qwem2.5-vl-sft-lora-tp
39+
disable_tqdm: true
40+
eval_accumulation_steps: 16
41+
42+
# train
43+
warmup_steps: 20
44+
learning_rate: 1.0e-4
45+
46+
# performance
47+
tensor_parallel_degree: 2
48+
pipeline_parallel_degree: 1
49+
sequence_parallel: true
50+
sharding: stage2
51+
recompute: true
52+
bf16: true
53+
fp16_opt_level: O2
54+
unified_checkpoint: false
55+
save_checkpoint_format: "flex_checkpoint"
56+
load_checkpoint_format: "flex_checkpoint"
57+
freeze_config: freeze_vision freeze_aligner

0 commit comments

Comments
 (0)