diff --git a/hackathon/objectdetection.py b/hackathon/objectdetection.py index 76af579..5c23159 100644 --- a/hackathon/objectdetection.py +++ b/hackathon/objectdetection.py @@ -238,3 +238,43 @@ def evaluate(dataset, params, static, key, config, seed): logger.info("Evaluation completed!") return params + +def loss_fn(params, static, inputs, targets, num_classes, lambda_box=5.0, lambda_obj=1.0, lambda_cls=1.0): + """ + params, static: from eqx.partition + inputs: batch of images [B, C, H, W] + targets: dict of ground truth values + """ + model = eqx.combine(params, static) + + + preds = jax.vmap(model)(inputs) # output: [B, H, W, A * (5 + C)] + + obj_mask = targets["object_mask"] + bbox_target = targets["bbox_target"] + class_target = targets["class_target"] + + pred_bbox = preds[..., :4] + pred_obj = preds[..., 4:5] + pred_cls = preds[..., 5:] + + + bbox_loss = jnp.abs(pred_bbox - bbox_target) * obj_mask + bbox_loss = jnp.sum(bbox_loss) / (jnp.sum(obj_mask) + 1e-6) + + + obj_loss = optax.sigmoid_binary_cross_entropy(pred_obj, obj_mask) + obj_loss = jnp.mean(obj_loss) + + + cls_loss = optax.sigmoid_binary_cross_entropy(pred_cls, class_target) + cls_loss = jnp.sum(cls_loss * obj_mask) / (jnp.sum(obj_mask) + 1e-6) + + total_loss = lambda_box * bbox_loss + lambda_obj * obj_loss + lambda_cls * cls_loss + + return total_loss, { + "total_loss": total_loss, + "bbox_loss": bbox_loss, + "obj_loss": obj_loss, + "cls_loss": cls_loss + }