Skip to content

Conversation

@romitjain
Copy link
Collaborator

@romitjain romitjain commented Dec 1, 2025

This PR adds support for ODM dataset for the padding-free use case. It is slightly involved so I will mention all the possible cases below to get the complete picture.

We delegate the handling of dataloader to accelerate which in turns takes different paths based on the knobs we select. (ref)

A thing to note - We are dealing with IterableDataset and padding free collator. Padding free collator takes a list of samples eg: (s0, s1) if the batch size is 2, and converts it into a tensor of size (1, |s0| + |s1|). Effectively, the batch information is lost.

For accelerate, we have two knobs to control the dataloader - split_batches and dispatch_batches. Toy scenario, where per_device_batch_size is 2 and num_processes is 8.

Case 1: split_batches = True and dispatch_batches = True (Current ODM default)

Rank 0 accumulates the complete batch (where per_device_batch_size is set to 16), applies collator and then sends to each process. The issue is that after the collator, the batch becomes (1, total tokens across 16 samples) which is not possible to shard for other ranks. Instead, each rank will see something like (0, seq_len). Hence, it fails.

Case 2: split_batches = True and dispatch_batches = False

Works very similarly to Case 3, just that it needs to update per_device_batch_size to 16.

Case 3: split_batches = False and dispatch_batches = False (Works for non iterable datasets, proposed solution for ODM dataset)

Every rank accumulates the num_processes batch (8 "super" batch of 2 samples, so total of 16 samples), indexes its own samples. Each rank has (s0, s1, ... s15) and each rank slices its relevant samples eg: Rank 0 would pick (0, 1), rank 1 would pick (2, 3) and so on. After this, it applies the collator. This works for non-iterable datasets even with padding free since each rank can convert its sequence to (1, |si| + |sj|).

For iterable datasets, it would require us to make sure that the dataset is deterministic for each rank. So each rank should receive the same order of samples (s0, s1, ... s15) otherwise the splitting would be incorrect across ranks. Each rank might see a different order or completely different 16 samples, and we will not be able to make sure that the datatset is correctly consumed.

Case 4: split_batches = False and dispatch_batches = True

Rank 0 accumulates the complete batch (so 2 * 8 samples), applies collator, and concatenates them which enforces that the entire batch to have the same size which is also noted here.

Unfortunately, this will also not work for our use case because the collator does not pad to a fixed length. No reason for s0, s1, s2.. to have same lengths.

Signed-off-by: romitjain <romit@ibm.com>
@romitjain romitjain marked this pull request as ready for review December 1, 2025 10:35
@kmehant
Copy link
Collaborator

kmehant commented Dec 1, 2025

Few complexities

(1) padding free uses lists for all of these instead of just labels - https://github.com/huggingface/trl/blob/c7d172bfc471788841038ae5a1f382f57f91b8ac/trl/trainer/sft_trainer.py#L186-L188

(2) setting padding_free would set the collator for the individual datasets however our wrapper odm dataset would be unaware of this. Being unaware is problematic because the structure of batch is totally changed since its flattened on dim 0 in to list. So, the solution would be to preparation of padding_Free batch should be handled by our odm dataset and we should not let trl do anything.

Alternative to (2)
I have another alternative which could be simpler. We can let individual datasets collate for padding_free like here https://github.com/huggingface/trl/blob/c7d172bfc471788841038ae5a1f382f57f91b8ac/trl/trainer/sft_trainer.py#L186-L188. But in our ODM dataset when preparing batch we prepare like the same and should be aware that the individual batch structure is as well flattened. This would require less changes

@romitjain
Copy link
Collaborator Author

@kmehant

  1. In fms-hf-tuning, the collator selected it DataCollatorForSeq2Seq (ref)
  2. fms-accel patches it to us DataCollatorWithFlattening (ref)
  3. The issue is with DataCollatorWithFlattening that treats labels as lists. I have raised this in transformers (DataCollatorWithFlattening only accepts labels as list huggingface/transformers#42502)
  4. ODM dataset should not worry about padding_free IMO since it's a dataset. The collator should be responsible to handle all the logic for manipulating the batch

@kmehant
Copy link
Collaborator

kmehant commented Dec 1, 2025

@romitjain

Given (1) (2), you can disregard my (2) and agree on (4).
(3) Can you fill me with what was the issue with labels being a list?

@kmehant
Copy link
Collaborator

kmehant commented Dec 2, 2025

@romitjain Thanks for the details!

Effectively, the batch information is lost.

You can reconstruct individual batches using position ids.

From the details, would the solution to introduce a knob at DataLoaderDispatcher in accelerate to apply collator post dispatch or pre dispatch. That should fix this WDYT?

Signed-off-by: romitjain <romit@ibm.com>
@romitjain romitjain changed the title fix: Converting sample labels to list fix: ODM support for padding free collator Dec 2, 2025
@romitjain
Copy link
Collaborator Author

@kmehant

You can reconstruct individual batches using position ids.

Yes, but that would require too much plumbing. Accelerate would need to know that the collator used was for padding free or guess if there are keys in the batch for position ids. And then, deconstruct the batch, send to individual ranks and let each rank flatten it again.

From the details, would the solution to introduce a knob at DataLoaderDispatcher in accelerate to apply collator post dispatch or pre dispatch.

I am not 100% sure. As I see it, DataLoaderDispatcher fetches samples from the base dataloader. Base dataloader already collates the samples, so it won't be possible to skip that.

Signed-off-by: romitjain <romit@ibm.com>
@romitjain
Copy link
Collaborator Author

Verified that each rank produces the same sequence (the easiest way to verify is to add accelerator process index in log_to_file calls)

In a sample run that I did, at the end of the run (100 steps, so 800+ samples produced), for every rank, the sampling weights are exactly the same at the end of the run.

      samples_produced_so_far  sampling_interval  total_categories                           current_sampling_weights  ...                                            rewards                           count  action rank
6510                      809                  1                 7  [1.096585325986057, 1.0810652187708683, 1.0912...  ...  [12.644085884094238, 2.8629062175750732, 20.25...  [208, 56, 352, 24, 64, 16, 56]  update    0
6511                      809                  1                 7  [1.096585325986057, 1.0810652187708683, 1.0912...  ...  [12.644085884094238, 2.8629062175750732, 20.25...  [208, 56, 352, 24, 64, 16, 56]  update    1
6507                      809                  1                 7  [1.096585325986057, 1.0810652187708683, 1.0912...  ...  [12.644085884094238, 2.8629062175750732, 20.25...  [208, 56, 352, 24, 64, 16, 56]  update    2
6508                      809                  1                 7  [1.096585325986057, 1.0810652187708683, 1.0912...  ...  [12.644085884094238, 2.8629062175750732, 20.25...  [208, 56, 352, 24, 64, 16, 56]  update    3
6505                      809                  1                 7  [1.096585325986057, 1.0810652187708683, 1.0912...  ...  [12.644085884094238, 2.8629062175750732, 20.25...  [208, 56, 352, 24, 64, 16, 56]  update    4
6504                      809                  1                 7  [1.096585325986057, 1.0810652187708683, 1.0912...  ...  [12.644085884094238, 2.8629062175750732, 20.25...  [208, 56, 352, 24, 64, 16, 56]  update    5
6506                      809                  1                 7  [1.096585325986057, 1.0810652187708683, 1.0912...  ...  [12.644085884094238, 2.8629062175750732, 20.25...  [208, 56, 352, 24, 64, 16, 56]  update    6
6509                      809                  1                 7  [1.096585325986057, 1.0810652187708683, 1.0912...  ...  [12.644085884094238, 2.8629062175750732, 20.25...  [208, 56, 352, 24, 64, 16, 56]  update    7

For resume functionality, I did a smaller run, 5 eval steps, save every 5 steps, total 10 steps. The samples still match till end, but I was not able to reproduce similar results from #155 even on main branch. Likely, that I am missing something in runtime params. Will have to double check - but I could confirm that the dataset rng seeds are getting loaded.

@romitjain
Copy link
Collaborator Author

Loss comparison with main branch

image

Loss curve in the same direction, but don't match exactly, likely due to small batch size (per device batch size = 1).

@kmehant kmehant self-requested a review December 2, 2025 16:03
@kmehant kmehant merged commit ecf4918 into foundation-model-stack:main Dec 2, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants