diff --git a/hackathon/objectdetection.py b/hackathon/objectdetection.py index 76af579..f24620f 100644 --- a/hackathon/objectdetection.py +++ b/hackathon/objectdetection.py @@ -238,3 +238,42 @@ def evaluate(dataset, params, static, key, config, seed): logger.info("Evaluation completed!") return params + +def loss_fn(preds, targets, num_classes, lambda_box=5.0, lambda_obj=1.0, lambda_cls=1.0): + """ + preds: (B, H, W, A * (5 + C)) raw outputs + targets: dict with keys: + - "object_mask": (B, H, W, 1) → 1 if object exists + - "bbox_target": (B, H, W, 4) + - "class_target": (B, H, W, C) + """ + + obj_mask = targets["object_mask"] + bbox_target = targets["bbox_target"] + class_target = targets["class_target"] + + pred_bbox = preds[..., :4] + pred_obj = preds[..., 4:5] + pred_cls = preds[..., 5:] + + + bbox_loss = jnp.abs(pred_bbox - bbox_target) * obj_mask + bbox_loss = jnp.sum(bbox_loss) / jnp.sum(obj_mask + 1e-6) + + + obj_loss = optax.sigmoid_binary_cross_entropy(pred_obj, obj_mask) + obj_loss = jnp.mean(obj_loss) + + + cls_loss = optax.sigmoid_binary_cross_entropy(pred_cls, class_target) + cls_loss = jnp.sum(cls_loss * obj_mask) / jnp.sum(obj_mask + 1e-6) + + total_loss = lambda_box * bbox_loss + lambda_obj * obj_loss + lambda_cls * cls_loss + + return total_loss, { + "total_loss": total_loss, + "bbox_loss": bbox_loss, + "obj_loss": obj_loss, + "cls_loss": cls_loss + } +