diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index d2045bdee..d07f6637e 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -487,6 +487,9 @@ def connect_engine(self, rollout: RolloutController, meta: WeightUpdateMeta): self._init_weight_update_from_distributed(meta) self.weight_update_group_initialized = True + def get_device_stats(self): + return self._custom_function_call("get_device_stats") + def prepare_batch( self, dataloader: StatefulDataLoader,