Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions dlclive/pose_estimation_pytorch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def update(self, pose: torch.Tensor, w: int, h: int) -> None:
self._detections = dict(boxes=bboxes, scores=torch.ones(num_det))
self._age += 1

def next_frame(self) -> None:
"""Increment the frame counter and set detections to None (to handle no detections)"""
self._detections = None
self._age += 1


@dataclass
class TopDownConfig:
Expand Down Expand Up @@ -118,28 +123,42 @@ class PyTorchRunner(BaseRunner):
path: The path to the model to run inference with.
device: The device on which to run inference, e.g. "cpu", "cuda", "cuda:0"
precision: The precision of the model. One of "FP16" or "FP32".
single_animal: This option is only available for single-animal pose estimation
single_animal: bool | None, default=True
Set to True if the model is a single-animal model, False if it is a multi-animal model.
If set to None, single_animal mode will be inferred from the model configuration.
This option is introduced for single-animal pose estimation
models. It makes the code behave in exactly the same way as DeepLabCut-Live
with version < 3.0.0. This ensures backwards compatibility with any
Processors that were implemented.
dynamic: Whether to use dynamic cropping.
top_down_config: Only for top-down models running with a detector.

returns:
pose: The pose of the animal(s) in the frame.
shape:
(n_bodyparts, 3) if single_animal is True
(n_individuals, n_bodyparts, 3) if single_animal is False.
If no detections are found, the pose consists of zeros.

Raises:
ValueError: If the model is not loaded. Call load_model() or init_inference() before calling get_pose().
"""

def __init__(
self,
path: str | Path,
device: str = "auto",
precision: Literal["FP16", "FP32"] = "FP32",
single_animal: bool = True,
single_animal: bool | None = None,
dynamic: dict | dynamic_cropping.DynamicCropper | None = None,
top_down_config: dict | TopDownConfig | None = None,
) -> None:
super().__init__(path)
self.device = _parse_device(device)
self.precision = precision
self.single_animal = single_animal

self.n_individuals = None
self.n_bodyparts = None
self.cfg = None
self.detector = None
self.model = None
Expand Down Expand Up @@ -173,6 +192,10 @@ def close(self) -> None:

@torch.inference_mode()
def get_pose(self, frame: np.ndarray) -> np.ndarray:
if self.model is None:
raise ValueError(
"Model not loaded. Call load_model() or init_inference() before calling get_pose()."
)
c, h, w = frame.shape
tensor = torch.from_numpy(frame).permute(2, 0, 1) # CHW, still on CPU

Expand All @@ -191,7 +214,16 @@ def get_pose(self, frame: np.ndarray) -> np.ndarray:

frame_batch, offsets_and_scales = self._prepare_top_down(tensor, detections)
if len(frame_batch) == 0:
offsets_and_scales = [(0, 0), 1]
# Determine output shape based on single_animal parameter and n_individuals
if self.single_animal or self.n_individuals == 1:
zero_pose = np.zeros((self.n_bodyparts, 3))
else:
zero_pose = np.zeros((self.n_individuals, self.n_bodyparts, 3))
# Update skip_frames even when returning early to maintain frame counter
if self.top_down_config.skip_frames is not None:
self.top_down_config.skip_frames.next_frame()
return zero_pose

tensor = frame_batch # still CHW, batched

if self.dynamic is not None:
Expand Down Expand Up @@ -259,6 +291,16 @@ def load_model(self) -> None:
raw_data = torch.load(self.path, map_location="cpu", weights_only=True)

self.cfg = raw_data["config"]

# Infer n_bodyparts and n_individuals from model configuration
individuals = self.cfg.get("metadata", {}).get("individuals", ['idv1'])
bodyparts = self.cfg.get("metadata", {}).get("bodyparts", [])
self.n_individuals = len(individuals)
self.n_bodyparts = len(bodyparts)
# If single_animal is not set, infer it from n_individuals in model configuration
if self.single_animal is None:
self.single_animal = self.n_individuals == 1

self.model = models.PoseModel.build(self.cfg["model"])
self.model.load_state_dict(raw_data["pose"])
self.model = self.model.to(self.device)
Expand Down