From c2c98a0adc7f80cc90a23affa1348385a52ad852 Mon Sep 17 00:00:00 2001 From: HwVanICI Date: Tue, 23 Dec 2025 19:23:06 -0500 Subject: [PATCH] Implement get_device_stats for train controller --- areal/controller/train_controller.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index d2045bdee..d07f6637e 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -487,6 +487,9 @@ def connect_engine(self, rollout: RolloutController, meta: WeightUpdateMeta): self._init_weight_update_from_distributed(meta) self.weight_update_group_initialized = True + def get_device_stats(self): + return self._custom_function_call("get_device_stats") + def prepare_batch( self, dataloader: StatefulDataLoader,