From 0e7e873f69a589edb1516c486ca3ec5a63be6472 Mon Sep 17 00:00:00 2001 From: bernard karaba Date: Mon, 2 Jun 2025 12:20:10 +0300 Subject: [PATCH] add YOLOv10-style Equinox detection head for object detection outputs --- hackathon/objectdetection.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/hackathon/objectdetection.py b/hackathon/objectdetection.py index 76af579..feedff2 100644 --- a/hackathon/objectdetection.py +++ b/hackathon/objectdetection.py @@ -9,6 +9,7 @@ from einops import reduce from loguru import logger from tqdm import tqdm +from typing import Callable from hackathon.augmentations import ( create_global_crops, @@ -238,3 +239,33 @@ def evaluate(dataset, params, static, key, config, seed): logger.info("Evaluation completed!") return params + + + +class YOLOv10Head(eqx.Module): + conv: eqx.nn.Conv2d + output_channels: int + activation: Callable = eqx.static_field() + + def __init__(self, in_channels: int, num_classes: int, num_anchors: int = 1, key=None): + self.output_channels = num_anchors * (num_classes + 5) # 4 bbox + 1 obj + C classes + conv_key = key or jax.random.PRNGKey(0) + + self.conv = eqx.nn.Conv2d( + in_channels=in_channels, + out_channels=self.output_channels, + kernel_size=1, + stride=1, + padding=0, + key=conv_key + ) + self.activation = lambda x: x # raw output; optional sigmoid/softmax comes later + + def __call__(self, x): + """ + x: (B, C_in, H, W) + return: (B, H, W, A * (C + 5)) + """ + x = self.conv(x) # (B, output_channels, H, W) + x = jnp.transpose(x, (0, 2, 3, 1)) # -> (B, H, W, output_channels) + return self.activation(x)