Skip to content

FLOPS Calculation #5

@ToddMorrill

Description

@ToddMorrill

Hey would you be willing to provide a sketch for how you arrived at these numbers?

I imagine this is somewhat specific to the T5 11b architecture but I'm trying to adapt this for GPT-2 and Llama 2 so any pointers or "gotchas" to be aware of will be greatly appreciated. I'm also fine with a rough estimate.

I just want to know that I've got reasonably decent utilization of my hardware. For example, if I adopt your code I'm measuring about 16 TFLOPS on a 4090 when running GPT-2 (base model) in fp32 and about 27 TFLOPS in fp16. I think I have a bug in my bf16 code but nevertheless, that's showing about 34 TFLOPS. So these seem like decent starting points against the 82 TFLOPS advertised for the 4090.

To be sure, here's how I'm measuring:

def calc_flop(train_config, model):
    """
    Borrowed from: https://github.com/lessw2020/t5_11/blob/8ae276d3d91fc9f9ee2a865d111d263842c2970d/utils/calculations_utils.py#L1
    for T5

    TODO: revise these estimates for GPT2 and Llama
    TODO: confirm that this accounts for the forward and backward pass
    """
    B = train_config.micro_batch_size
    s = train_config.input_length
    # may need to unwrap the model here
    model = getattr(model, '_fsdp_wrapped_module', model)
    if isinstance(model, LlamaForCausalLM):
        l = len(model.layers)
    elif isinstance(model, GPT2LMHeadModel):
        l = len(model.transformer.h)
    else:
        raise ValueError("Unsupported model type.")
    h = model.config.n_embd
    V = model.config.vocab_size
    return 96 * B * s * l * h * h * (1 + s/6/h + V/16/l/h)

flops_per_batch = calc_flop(train_config, model)
tflops = flops_per_batch * len(train_dataloader) / epoch_end_time / 1e12

I'm assuming your code measures the flops for a forward and backward pass for a batch of data. Please correct me if I have that wrong or if I'm approaching this the wrong way.

6 x the number of parameters x the number of tokens (based on this analysis) gets me into the ballpark too.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions