From 13702d9c3c4d3f555c9827ac4f61cda61ba64c38 Mon Sep 17 00:00:00 2001 From: bernard karaba Date: Mon, 2 Jun 2025 16:41:06 +0300 Subject: [PATCH] use eqx.combine and jax.vmap in loss_fn to compute detection outputs from model --- hackathon/objectdetection.py | 40 ++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/hackathon/objectdetection.py b/hackathon/objectdetection.py index 76af579..5c23159 100644 --- a/hackathon/objectdetection.py +++ b/hackathon/objectdetection.py @@ -238,3 +238,43 @@ def evaluate(dataset, params, static, key, config, seed): logger.info("Evaluation completed!") return params + +def loss_fn(params, static, inputs, targets, num_classes, lambda_box=5.0, lambda_obj=1.0, lambda_cls=1.0): + """ + params, static: from eqx.partition + inputs: batch of images [B, C, H, W] + targets: dict of ground truth values + """ + model = eqx.combine(params, static) + + + preds = jax.vmap(model)(inputs) # output: [B, H, W, A * (5 + C)] + + obj_mask = targets["object_mask"] + bbox_target = targets["bbox_target"] + class_target = targets["class_target"] + + pred_bbox = preds[..., :4] + pred_obj = preds[..., 4:5] + pred_cls = preds[..., 5:] + + + bbox_loss = jnp.abs(pred_bbox - bbox_target) * obj_mask + bbox_loss = jnp.sum(bbox_loss) / (jnp.sum(obj_mask) + 1e-6) + + + obj_loss = optax.sigmoid_binary_cross_entropy(pred_obj, obj_mask) + obj_loss = jnp.mean(obj_loss) + + + cls_loss = optax.sigmoid_binary_cross_entropy(pred_cls, class_target) + cls_loss = jnp.sum(cls_loss * obj_mask) / (jnp.sum(obj_mask) + 1e-6) + + total_loss = lambda_box * bbox_loss + lambda_obj * obj_loss + lambda_cls * cls_loss + + return total_loss, { + "total_loss": total_loss, + "bbox_loss": bbox_loss, + "obj_loss": obj_loss, + "cls_loss": cls_loss + }