From e0262d7cbd37787f7cbce9d06b85fc8ad6dd34fd Mon Sep 17 00:00:00 2001 From: bernard karaba Date: Mon, 2 Jun 2025 17:35:12 +0300 Subject: [PATCH] align train_step with new object dict structure and loss input format --- hackathon/objectdetection.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/hackathon/objectdetection.py b/hackathon/objectdetection.py index 76af579..da3d3f5 100644 --- a/hackathon/objectdetection.py +++ b/hackathon/objectdetection.py @@ -133,12 +133,19 @@ def loss_fn(params, static, images, objects, key): @eqx.filter_jit def train_step(params, static, optimizer, opt_state, batch, key): - images = batch["global_crops"][:, 0, :, :, :] - objects = batch["objects"][:, 0, :, :, :] - classes = batch["classes"] + images = batch["global_crops"][:, 0] # shape: (B, H, W, C) + images = nhwc_to_nchw(images) # → (B, C, H, W) + + object_data = batch["objects"] # dict of lists/tensors + + targets = { + "bbox_target": object_data["bboxes"], # shape: (B, H, W, 4) + "class_target": object_data["labels"], # shape: (B, H, W, C) + "object_mask": object_data["object_mask"], # shape: (B, H, W, 1) + } def loss_for_step(p): - return loss_fn(p, static, images, objects, classes, key) + return loss_fn(p, static, images, targets, key) loss, grads = jax.value_and_grad(loss_for_step)(params)