Skip to content

Conversation

@bernardev254
Copy link

Summary

This PR updates the loss_fn in hackathon/objectdetection.py to use the actual model prediction head by combining Equinox parameters and applying the model via jax.vmap.

Changes

  • Uses eqx.combine(params, static) to reconstruct the full model.
  • Applies the model to a batch of inputs using jax.vmap.
  • Computes the composite object detection loss from the model’s outputs.
  • Ensures all loss components (bbox, objectness, classification) are derived from the model’s actual predictions.

Motivation

This change integrates the prediction head into the training loop, enabling accurate loss computation based on live model outputs — a key requirement for end-to-end object detection training.

closes #6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Integrate prediction head into model's forward pass for loss_fn

1 participant