diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 04825d273..cdce01b40 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -167,7 +167,8 @@ def train( "`--padding_free` argument was called with `packing=True`, " "Trainer should not perform packing when using `--padding_free`" ) - + if fast_moe_config is not None and fast_moe_config.fast_moe is None: + fast_moe_config = None if fast_moe_config is not None: # Checking for unsupported modules with Scatter MoE for LoRA # Only raise an error for `all-linear`