-
Notifications
You must be signed in to change notification settings - Fork 19
fix: ODM support for padding free collator #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: ODM support for padding free collator #165
Conversation
Signed-off-by: romitjain <romit@ibm.com>
|
Few complexities (1) padding free uses lists for all of these instead of just labels - https://github.com/huggingface/trl/blob/c7d172bfc471788841038ae5a1f382f57f91b8ac/trl/trainer/sft_trainer.py#L186-L188 (2) setting padding_free would set the collator for the individual datasets however our wrapper odm dataset would be unaware of this. Being unaware is problematic because the structure of batch is totally changed since its flattened on dim 0 in to list. So, the solution would be to preparation of padding_Free batch should be handled by our odm dataset and we should not let trl do anything. Alternative to (2) |
|
|
Given (1) (2), you can disregard my (2) and agree on (4). |
|
@romitjain Thanks for the details!
You can reconstruct individual batches using position ids. From the details, would the solution to introduce a knob at DataLoaderDispatcher in accelerate to apply collator post dispatch or pre dispatch. That should fix this WDYT? |
Signed-off-by: romitjain <romit@ibm.com>
Yes, but that would require too much plumbing. Accelerate would need to know that the collator used was for padding free or guess if there are keys in the batch for position ids. And then, deconstruct the batch, send to individual ranks and let each rank flatten it again.
I am not 100% sure. As I see it, |
Signed-off-by: romitjain <romit@ibm.com>
|
Verified that each rank produces the same sequence (the easiest way to verify is to add accelerator process index in In a sample run that I did, at the end of the run (100 steps, so 800+ samples produced), for every rank, the sampling weights are exactly the same at the end of the run. For resume functionality, I did a smaller run, 5 eval steps, save every 5 steps, total 10 steps. The samples still match till end, but I was not able to reproduce similar results from #155 even on main branch. Likely, that I am missing something in runtime params. Will have to double check - but I could confirm that the dataset rng seeds are getting loaded. |

This PR adds support for ODM dataset for the padding-free use case. It is slightly involved so I will mention all the possible cases below to get the complete picture.
We delegate the handling of dataloader to accelerate which in turns takes different paths based on the knobs we select. (ref)
A thing to note - We are dealing with
IterableDatasetand padding free collator. Padding free collator takes a list of samples eg: (s0, s1) if the batch size is 2, and converts it into a tensor of size (1, |s0| + |s1|). Effectively, the batch information is lost.For accelerate, we have two knobs to control the dataloader -
split_batchesanddispatch_batches. Toy scenario, whereper_device_batch_sizeis 2 andnum_processesis 8.Case 1:
split_batches = Trueanddispatch_batches = True(Current ODM default)Rank 0 accumulates the complete batch (where
per_device_batch_sizeis set to 16), applies collator and then sends to each process. The issue is that after the collator, the batch becomes (1, total tokens across 16 samples) which is not possible to shard for other ranks. Instead, each rank will see something like (0, seq_len). Hence, it fails.Case 2:
split_batches = Trueanddispatch_batches = FalseWorks very similarly to Case 3, just that it needs to update
per_device_batch_sizeto 16.Case 3:
split_batches = Falseanddispatch_batches = False(Works for non iterable datasets, proposed solution for ODM dataset)Every rank accumulates the
num_processesbatch (8 "super" batch of 2 samples, so total of 16 samples), indexes its own samples. Each rank has (s0, s1, ... s15) and each rank slices its relevant samples eg: Rank 0 would pick (0, 1), rank 1 would pick (2, 3) and so on. After this, it applies the collator. This works for non-iterable datasets even with padding free since each rank can convert its sequence to (1, |si| + |sj|).For iterable datasets, it would require us to make sure that the dataset is deterministic for each rank. So each rank should receive the same order of samples (s0, s1, ... s15) otherwise the splitting would be incorrect across ranks. Each rank might see a different order or completely different 16 samples, and we will not be able to make sure that the datatset is correctly consumed.
Case 4:
split_batches = Falseanddispatch_batches = TrueRank 0 accumulates the complete batch (so 2 * 8 samples), applies collator, and concatenates them which enforces that the entire batch to have the same size which is also noted here.
Unfortunately, this will also not work for our use case because the collator does not pad to a fixed length. No reason for s0, s1, s2.. to have same lengths.