diff --git a/hackathon/objectdetection.py b/hackathon/objectdetection.py index 76af579..da3d3f5 100644 --- a/hackathon/objectdetection.py +++ b/hackathon/objectdetection.py @@ -133,12 +133,19 @@ def loss_fn(params, static, images, objects, key): @eqx.filter_jit def train_step(params, static, optimizer, opt_state, batch, key): - images = batch["global_crops"][:, 0, :, :, :] - objects = batch["objects"][:, 0, :, :, :] - classes = batch["classes"] + images = batch["global_crops"][:, 0] # shape: (B, H, W, C) + images = nhwc_to_nchw(images) # → (B, C, H, W) + + object_data = batch["objects"] # dict of lists/tensors + + targets = { + "bbox_target": object_data["bboxes"], # shape: (B, H, W, 4) + "class_target": object_data["labels"], # shape: (B, H, W, C) + "object_mask": object_data["object_mask"], # shape: (B, H, W, 1) + } def loss_for_step(p): - return loss_fn(p, static, images, objects, classes, key) + return loss_fn(p, static, images, targets, key) loss, grads = jax.value_and_grad(loss_for_step)(params)