Skip to content

Commit 99ac466

Browse files
Fabio SalernoFabio Salerno
authored andcommitted
add the training/ directory
1 parent 922652a commit 99ac466

File tree

17 files changed

+589
-0
lines changed

17 files changed

+589
-0
lines changed

training/README.md

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Fine-tuning
2+
3+
- This folder contains the training scripts used for fine-tuning StarCoder2. Additionally we disclose the training stats.
4+
5+
- The fine-tuning dataset can be retrieved at this link: https://huggingface.co/datasets/AISE-TUDelft/memtune-tuning_data
6+
7+
- The fine-tuned models can be retrieved at this link: https://huggingface.co/collections/AISE-TUDelft/llm4code-memtune-678a2838766dd16037a8bfe0
8+
9+
## Fine-tuning Setup
10+
11+
### Hardware Configuration
12+
- 32 CPU cores
13+
- 32GB RAM
14+
- Multiple NVIDIA A100 GPUs (80GB memory each)
15+
- StarCoder2-3B: 2 GPUs
16+
- StarCoder2-7B: 4 GPUs
17+
- StarCoder2-15B: 6 GPUs
18+
19+
### Software Stack
20+
- NVIDIA Driver: 555.42.02
21+
- CUDA Version: 12.5
22+
- Transformer Version: 4.41.1
23+
- Torch Version: 2.3.0+cu121
24+
25+
### Training Configuration
26+
- Context Window: 1024 tokens
27+
- Learning Rate: 3e-5
28+
- Optimizer: Adafactor with linear scheduler
29+
- Batch Sizes (effective, including gradient accumulation):
30+
- 3B model: 24
31+
- 7B model: 24
32+
- 15B model: 25
33+
34+
### Training Duration
35+
Approximate training times per model:
36+
- StarCoder2-3B: 25 hours
37+
- StarCoder2-7B: 55 hours
38+
- StarCoder2-15B: 110 hours
39+
40+
## Training Process
41+
- Training duration: 3 epochs
42+
- Checkpoints saved after each epoch
43+
- GPU memory and training time were key factors in determining:
44+
- Optimizer selection
45+
- Training file configuration
46+
- Batch size parameters
47+
48+
Training was conducted using resources provided by the [Delft High-Performance Computing Centre](https://doc.dhpc.tudelft.nl/delftblue/).
49+
50+
## Training stats
51+
52+
### StarCoder2-3B
53+
54+
**Evaluation loss**:
55+
![](/training/train-stats/StarCoder2-3B/eval-loss.png)
56+
57+
**Training loss**:
58+
![](/training/train-stats/StarCoder2-3B/train-loss.png)
59+
60+
**Learning rate**:
61+
![](/training/train-stats/StarCoder2-3B/train-learning_rate.png)
62+
63+
### StarCoder2-7B
64+
**Evaluation loss**:
65+
![](/training/train-stats/StarCoder2-7B/eval-loss.png)
66+
67+
**Training loss**:
68+
![](/training/train-stats/StarCoder2-7B/train-loss.png)
69+
70+
**Learning rate**:
71+
![](/training/train-stats/StarCoder2-7B/train-learning_rate.png)
72+
73+
### StarCoder2-15B
74+
75+
**Evaluation loss**:
76+
![](/training/train-stats/StarCoder2-15B/eval-loss.png)
77+
78+
**Training loss**:
79+
![](/training/train-stats/StarCoder2-15B/train-loss.png)
80+
81+
**Learning rate**:
82+
![](/training/train-stats/StarCoder2-15B/train-learning_rate.png)

training/scoder15b/train.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
This script is used to fine-tune StarCoder2 family models on a java dataset, for code completion task.
3+
"""
4+
import torch
5+
from datasets import load_dataset, disable_caching
6+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, set_seed
7+
8+
# parallel processing
9+
from pandarallel import pandarallel
10+
pandarallel.initialize(progress_bar=True, nb_workers=16)
11+
from tqdm import tqdm
12+
tqdm.pandas()
13+
14+
# utility
15+
import pandas as pd
16+
import numpy as np
17+
import matplotlib.pyplot as plt
18+
import re
19+
import os
20+
21+
"""
22+
Setting the variables.
23+
"""
24+
25+
disable_caching()
26+
27+
set_seed(42)
28+
29+
wproject = "name" # wb project name
30+
run_name = "run-name" # name of the W&B run (optional)
31+
# training batches
32+
batch = 5
33+
# Load base-model and tokenizer from HF-hub
34+
checkpoint = "bigcode/starcoder2-15b"
35+
# Select the column of interest from the dataset
36+
text_column = 'content'
37+
38+
# training
39+
max_length = 1024
40+
# model parallel
41+
device_map = 'auto'
42+
43+
#wandb setup
44+
import wandb
45+
wandb.login()
46+
os.environ["WANDB_PROJECT"] = wproject # wandb project name
47+
48+
"""
49+
Loading the model and tokenizer
50+
"""
51+
# tokenizer
52+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
53+
tokenizer.pad_token = tokenizer.eos_token # setting the pad token to the end of sequence token
54+
55+
# model
56+
model = AutoModelForCausalLM.from_pretrained(
57+
checkpoint,
58+
device_map= device_map)
59+
60+
61+
""""
62+
Loading and preprocessing the data
63+
"""
64+
# LINK FOR THE DATASET: https://huggingface.co/datasets/AISE-TUDelft/memtune-tuning_data
65+
# Load the data
66+
dataset_train_20 = load_dataset("AISE-TUDelft/memtune-tuning_data", name = "20k", split = 'train' )
67+
dataset_valid_20 = load_dataset("AISE-TUDelft/memtune-tuning_data", name = "20k", split = 'valid' )
68+
69+
# Pick the columns of interest
70+
train_20 = dataset_train_20['train'].select_columns(text_column)
71+
validation_20 = dataset_valid_20['valid'].select_columns(text_column)
72+
73+
# Tokenize the sequences
74+
# Note: StarCoder2 has a context lenght of 8,000 tokens,
75+
def tokenize_input(batch):
76+
return tokenizer(batch[text_column], padding="max_length", truncation=True, max_length=max_length, return_tensors='pt')
77+
78+
training_20 = train_20.map(tokenize_input, batched=True, num_proc=64, remove_columns=text_column)
79+
validating_20 = validation_20.map(tokenize_input, batched=True, num_proc=64,remove_columns=text_column)
80+
81+
"""
82+
Training initialization
83+
"""
84+
# Data collator
85+
data_collator = DataCollatorForLanguageModeling(
86+
tokenizer=tokenizer,
87+
mlm=False,
88+
return_tensors='pt'
89+
)
90+
91+
92+
# Args
93+
output_dir = "./epochs"
94+
overwrite_output_dir= False
95+
96+
per_device_train_batch_size = batch
97+
per_device_eval_batch_size = batch
98+
gradient_accumulation_steps = 5
99+
100+
optim = "adafactor"
101+
adam_beta1 = 0.9
102+
weight_decay = 0.1
103+
104+
learning_rate = 3e-5
105+
lr_scheduler_type = "linear"
106+
warmup_steps = 50
107+
108+
num_train_epochs = 3
109+
eval_steps = 0.08 #200 # each epoch two evaluations
110+
eval_strategy = "steps" # default is "no"
111+
save_strategy = "epoch" # default is "steps"
112+
113+
logging_steps = 1
114+
report_to = "wandb"
115+
116+
# Training arguments
117+
training_args = TrainingArguments(
118+
output_dir=output_dir,
119+
overwrite_output_dir=overwrite_output_dir,
120+
save_strategy = save_strategy,
121+
eval_strategy = eval_strategy,
122+
123+
num_train_epochs=num_train_epochs,
124+
per_device_train_batch_size=per_device_train_batch_size,
125+
gradient_accumulation_steps = gradient_accumulation_steps,
126+
127+
per_device_eval_batch_size=per_device_eval_batch_size,
128+
eval_steps = eval_steps,
129+
130+
optim = optim,
131+
adam_beta1 = adam_beta1,
132+
weight_decay = weight_decay,
133+
134+
learning_rate = learning_rate,
135+
lr_scheduler_type = lr_scheduler_type,
136+
warmup_steps = warmup_steps,
137+
138+
logging_steps = logging_steps,
139+
report_to=report_to,
140+
run_name=run_name,
141+
seed = 42)
142+
143+
trainer = Trainer(
144+
model = model,
145+
args = training_args,
146+
data_collator = data_collator,
147+
train_dataset = training_20,
148+
eval_dataset = validating_20
149+
)
150+
151+
# Training
152+
trainer.train()

training/scoder15b/tune.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=15b6g25b
3+
#SBATCH --partition=gpu-a100
4+
#SBATCH --time=110:00:00
5+
#SBATCH --ntasks=1
6+
#SBATCH --cpus-per-task=32
7+
#SBATCH --mem=80G
8+
#SBATCH --gpus=8
9+
10+
# Deployment purposes
11+
# This script is used to deploy run .py files on the cluster
12+
13+
# Set conda env:
14+
unset CONDA_SHLVL
15+
source "$(conda info --base)/etc/profile.d/conda.sh"
16+
17+
conda activate memenv
18+
python3 train.py
19+
conda deactivate

0 commit comments

Comments
 (0)