-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Hey would you be willing to provide a sketch for how you arrived at these numbers?
I imagine this is somewhat specific to the T5 11b architecture but I'm trying to adapt this for GPT-2 and Llama 2 so any pointers or "gotchas" to be aware of will be greatly appreciated. I'm also fine with a rough estimate.
I just want to know that I've got reasonably decent utilization of my hardware. For example, if I adopt your code I'm measuring about 16 TFLOPS on a 4090 when running GPT-2 (base model) in fp32 and about 27 TFLOPS in fp16. I think I have a bug in my bf16 code but nevertheless, that's showing about 34 TFLOPS. So these seem like decent starting points against the 82 TFLOPS advertised for the 4090.
To be sure, here's how I'm measuring:
def calc_flop(train_config, model):
"""
Borrowed from: https://github.com/lessw2020/t5_11/blob/8ae276d3d91fc9f9ee2a865d111d263842c2970d/utils/calculations_utils.py#L1
for T5
TODO: revise these estimates for GPT2 and Llama
TODO: confirm that this accounts for the forward and backward pass
"""
B = train_config.micro_batch_size
s = train_config.input_length
# may need to unwrap the model here
model = getattr(model, '_fsdp_wrapped_module', model)
if isinstance(model, LlamaForCausalLM):
l = len(model.layers)
elif isinstance(model, GPT2LMHeadModel):
l = len(model.transformer.h)
else:
raise ValueError("Unsupported model type.")
h = model.config.n_embd
V = model.config.vocab_size
return 96 * B * s * l * h * h * (1 + s/6/h + V/16/l/h)
flops_per_batch = calc_flop(train_config, model)
tflops = flops_per_batch * len(train_dataloader) / epoch_end_time / 1e12
I'm assuming your code measures the flops for a forward and backward pass for a batch of data. Please correct me if I have that wrong or if I'm approaching this the wrong way.
6 x the number of parameters x the number of tokens (based on this analysis) gets me into the ballpark too.