diff --git a/src/streamdiffusion/controlnet/preprocessors/constants.py b/src/streamdiffusion/controlnet/preprocessors/constants.py new file mode 100644 index 00000000..9137f6f1 --- /dev/null +++ b/src/streamdiffusion/controlnet/preprocessors/constants.py @@ -0,0 +1,107 @@ +"""Shared constant mappings and colour tables for MediaPipe-based preprocessors. + +Keeping these heavy structures in a separate module avoids re-parsing them for + every process import of the main preprocessing classes and keeps those files + more readable. +""" +from __future__ import annotations + +# MediaPipe pose (33-kp) → OpenPose body (25-kp) mapping +MEDIAPIPE_TO_OPENPOSE_MAP = { + # 0 = Nose is identical; 1 = Neck is derived later, etc. + 1: None, # Neck (calculated from shoulders) + 2: 12, # RShoulder → RightShoulder + 3: 14, # RElbow → RightElbow + 4: 16, # RWrist → RightWrist + 5: 11, # LShoulder → LeftShoulder + 6: 13, # LElbow → LeftElbow + 7: 15, # LWrist → LeftWrist + 8: None, # MidHip (calculated from hips) + 9: 24, # RHip → RightHip + 10: 26, # RKnee → RightKnee + 11: 28, # RAnkle → RightAnkle + 12: 23, # LHip → LeftHip + 13: 25, # LKnee → LeftKnee + 14: 27, # LAnkle → LeftAnkle + 19: 31, # LBigToe → LeftFootIndex + 20: 31, # LSmallToe → LeftFootIndex (approx.) + 21: 29, # LHeel → LeftHeel + 22: 32, # RBigToe → RightFootIndex + 23: 32, # RSmallToe → RightFootIndex (approx.) + 24: 30, # RHeel → RightHeel +} + +# OpenPose limb pairs used for body skeleton rendering +OPENPOSE_LIMB_SEQUENCE = [ + [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], + [1, 8], [8, 9], [9, 10], [10, 11], [8, 12], [12, 13], + [13, 14], [14, 19], [19, 20], [14, 21], [11, 22], [22, 23], [11, 24], +] + +# Standard OpenPose colours (BGR) +OPENPOSE_COLORS = [ + [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], + [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], + [0, 170, 255], [0, 85, 255], [0, 0, 255], [255, 0, 0], [255, 85, 0], + [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], +] + +# OpenPose face (70-kp) connection pairs +OPENPOSE_FACE_CONNECTIONS = [ + # Jawline + (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), + (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), + # Left & Right eyebrows + (17, 18), (18, 19), (19, 20), (20, 21), + (22, 23), (23, 24), (24, 25), (25, 26), + # Nose bridge / lower + (27, 28), (28, 29), (29, 30), + (31, 32), (32, 33), (33, 34), (34, 35), + # Eyes + (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 36), + (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42), + # Lips outer + (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), + (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 48), + # Lips inner + (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60), + # Pupils (68-69) + (68, 68), (69, 69) +] + +# Colour per face-connection group (BGR) +FACE_COLORS = list( + [(255, 255, 255)] * 16 + # jaw + [(0, 255, 0)] * 4 + # right brow + [(0, 255, 0)] * 4 + # left brow + [(255, 0, 255)] * 3 + # nose bridge + [(255, 0, 255)] * 4 + # nose lower + [(0, 0, 255)] * 6 + # right eye + [(0, 0, 255)] * 6 + # left eye + [(255, 0, 0)] * 12 + # outer lips + [(255, 0, 0)] * 8 + # inner lips + [(255, 0, 0)] * 2 # pupils +) + +# MediaPipe 468-kp → OpenPose 70-kp face mapping +MEDIAPIPE_TO_OPENPOSE_FACE_MAP = { + # Jawline + 0: 127, 1: 234, 2: 93, 3: 132, 4: 58, 5: 172, 6: 136, 7: 150, + 8: 152, 9: 400, 10: 365, 11: 397, 12: 435, 13: 401, 14: 323, 15: 454, 16: 356, + # Eyebrows + 17: 55, 18: 65, 19: 52, 20: 53, 21: 46, + 22: 285, 23: 295, 24: 282, 25: 283, 26: 276, + # Nose + 27: 168, 28: 197, 29: 5, 30: 4, + 31: 166, 32: 44, 33: 19, 34: 457, 35: 455, + # Eyes + 36: 33, 37: 160, 38: 158, 39: 155, 40: 145, 41: 163, + 42: 463, 43: 385, 44: 388, 45: 263, 46: 373, 47: 381, + # Lips outer + 48: 185, 49: 39, 50: 37, 51: 0, 52: 267, 53: 270, 54: 409, + 55: 321, 56: 314, 57: 17, 58: 181, 59: 146, + # Lips inner + 60: 78, 61: 81, 62: 13, 63: 311, 64: 409, 65: 402, 66: 14, 67: 178, + # Pupils + 68: 468, 69: 473, +} diff --git a/src/streamdiffusion/controlnet/preprocessors/mediapipe_landmarkers.py b/src/streamdiffusion/controlnet/preprocessors/mediapipe_landmarkers.py new file mode 100644 index 00000000..9f67acb0 --- /dev/null +++ b/src/streamdiffusion/controlnet/preprocessors/mediapipe_landmarkers.py @@ -0,0 +1,138 @@ +import os +from typing import Optional, Any +import logging + +import cv2 + +logger = logging.getLogger(__name__) +import mediapipe as mp +import numpy as np +from mediapipe.tasks.python import vision +from mediapipe.tasks.python.core.base_options import BaseOptions +from mediapipe.tasks.python.core.base_options import BaseOptions as _BaseOptions + +# Enum for delegate selection (CPU/GPU) +Delegate = _BaseOptions.Delegate + +# Global cache to reuse MediaPipe detector instances across wrappers +_DETECTOR_CACHE: dict[tuple[str, str, str], object] = {} # key: (wrapper, model_path, delegate) -> detector + +# Assume models are downloaded to a specific path +MODELS_PATH = os.path.join(os.path.dirname(__file__), "mediapipe_models") +FACE_LANDMARKER_MODEL = os.path.join(MODELS_PATH, "face_landmarker.task") +HAND_LANDMARKER_MODEL = os.path.join(MODELS_PATH, "hand_landmarker.task") +POSE_LANDMARKER_MODEL = os.path.join(MODELS_PATH, "pose_landmarker_full.task") + + +class _OptionBuilderMixin: + """Mixin that builds task options with overridable DEFAULT_PARAMS and OPTIONS_CLS.""" + + OPTIONS_CLS: type | None = None # to be set by subclass + DEFAULT_PARAMS: dict = {} + + @classmethod + def build_options(cls, base_options: BaseOptions, running_mode: vision.RunningMode, **overrides): + params = {**cls.DEFAULT_PARAMS, **overrides} + if cls.OPTIONS_CLS is None: + raise NotImplementedError("Subclasses must define OPTIONS_CLS") + return cls.OPTIONS_CLS(base_options=base_options, running_mode=running_mode, **params) + + +class BaseLandmarker(_OptionBuilderMixin): + OPTIONS_CLS = None # subclasses define + DEFAULT_PARAMS = {} + + def __init__( + self, + model_path: str, + running_mode: vision.RunningMode = vision.RunningMode.IMAGE, + delegate: str = "cpu", + **kwargs, + ): + if not os.path.exists(model_path): + raise FileNotFoundError(f"MediaPipe model file not found at {model_path}") + + # Select CPU/GPU delegate + delegate_enum = Delegate.GPU if delegate.lower() == "gpu" else Delegate.CPU + base_options = BaseOptions(model_asset_path=model_path, delegate=delegate_enum) + + self.options = self.build_options(base_options, running_mode, **kwargs) + self.detector = self._get_detector(model_path, delegate_enum) + + def _create_options(self, base_options: BaseOptions, running_mode: vision.RunningMode, **kwargs): + raise NotImplementedError + + def _create_detector(self, options): + raise NotImplementedError + + def _get_detector(self, model_path, delegate_enum): + cache_key = (self.__class__.__name__, model_path, delegate_enum) + if cache_key in _DETECTOR_CACHE: + return _DETECTOR_CACHE[cache_key] + detector = self._create_detector(self.options) + _DETECTOR_CACHE[cache_key] = detector + return detector + + def detect(self, image: np.ndarray) -> Any: + """Run landmark detection and return MediaPipe result. + All errors are caught and logged as warnings to avoid crashing pipelines. + """ + try: + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + return self.detector.detect(mp_image) + except Exception as e: # pylint: disable=broad-except + logger.warning("%s.detect failed: %s", self.__class__.__name__, e) + return None + + def close(self): + """Close the underlying detector and remove from cache if no longer referenced.""" + # Find cache key(s) pointing to this detector + keys_to_remove = [k for k, v in _DETECTOR_CACHE.items() if v is self.detector] + for k in keys_to_remove: + _DETECTOR_CACHE.pop(k, None) + self.detector.close() + + +class FaceLandmarkerWrapper(BaseLandmarker): + OPTIONS_CLS = vision.FaceLandmarkerOptions + DEFAULT_PARAMS = { + "output_face_blendshapes": False, + "output_facial_transformation_matrixes": False, + "num_faces": 1, + "min_face_detection_confidence": 0.5, + "min_face_presence_confidence": 0.5, + "min_tracking_confidence": 0.5, + } + + def _create_detector(self, options): + return vision.FaceLandmarker.create_from_options(options) + + +class HandLandmarkerWrapper(BaseLandmarker): + OPTIONS_CLS = vision.HandLandmarkerOptions + DEFAULT_PARAMS = { + "num_hands": 2, + "min_hand_detection_confidence": 0.5, + "min_hand_presence_confidence": 0.5, + "min_tracking_confidence": 0.5, + } + + def _create_detector(self, options): + return vision.HandLandmarker.create_from_options(options) + + +class PoseLandmarkerWrapper(BaseLandmarker): + OPTIONS_CLS = vision.PoseLandmarkerOptions + DEFAULT_PARAMS = { + "output_segmentation_masks": False, + "num_poses": 1, + "min_pose_detection_confidence": 0.5, + "min_pose_presence_confidence": 0.5, + "min_tracking_confidence": 0.5, + } + + def _create_detector(self, options): + logger.debug("PoseLandmarkerWrapper: Creating detector with options: %s", options) + detector = vision.PoseLandmarker.create_from_options(options) + logger.debug("PoseLandmarkerWrapper: Detector created successfully.") + return detector diff --git a/src/streamdiffusion/controlnet/preprocessors/mediapipe_models/face_landmarker.task b/src/streamdiffusion/controlnet/preprocessors/mediapipe_models/face_landmarker.task new file mode 100644 index 00000000..c50c845d Binary files /dev/null and b/src/streamdiffusion/controlnet/preprocessors/mediapipe_models/face_landmarker.task differ diff --git a/src/streamdiffusion/controlnet/preprocessors/mediapipe_models/hand_landmarker.task b/src/streamdiffusion/controlnet/preprocessors/mediapipe_models/hand_landmarker.task new file mode 100644 index 00000000..0d53faf3 Binary files /dev/null and b/src/streamdiffusion/controlnet/preprocessors/mediapipe_models/hand_landmarker.task differ diff --git a/src/streamdiffusion/controlnet/preprocessors/mediapipe_models/pose_landmarker_full.task b/src/streamdiffusion/controlnet/preprocessors/mediapipe_models/pose_landmarker_full.task new file mode 100644 index 00000000..300a181e Binary files /dev/null and b/src/streamdiffusion/controlnet/preprocessors/mediapipe_models/pose_landmarker_full.task differ diff --git a/src/streamdiffusion/controlnet/preprocessors/mediapipe_pose.py b/src/streamdiffusion/controlnet/preprocessors/mediapipe_pose.py index fe4f3834..58845ede 100644 --- a/src/streamdiffusion/controlnet/preprocessors/mediapipe_pose.py +++ b/src/streamdiffusion/controlnet/preprocessors/mediapipe_pose.py @@ -3,174 +3,213 @@ import cv2 from PIL import Image, ImageDraw from typing import Union, Optional, List, Tuple, Dict +import logging +import time +logger = logging.getLogger(__name__) + +# Heavy constant maps moved to separate module for memory efficiency +from .constants import ( + MEDIAPIPE_TO_OPENPOSE_MAP, + OPENPOSE_LIMB_SEQUENCE, + OPENPOSE_COLORS, + OPENPOSE_FACE_CONNECTIONS, + FACE_COLORS, + MEDIAPIPE_TO_OPENPOSE_FACE_MAP, +) from .base import BasePreprocessor try: import mediapipe as mp + + from .mediapipe_landmarkers import ( + FaceLandmarkerWrapper, + HandLandmarkerWrapper, + PoseLandmarkerWrapper, + FACE_LANDMARKER_MODEL, + HAND_LANDMARKER_MODEL, + POSE_LANDMARKER_MODEL, + ) MEDIAPIPE_AVAILABLE = True except ImportError: MEDIAPIPE_AVAILABLE = False -# MediaPipe to OpenPose keypoint mapping -# MediaPipe has 33 keypoints, OpenPose has 25 keypoints -# Reference: https://github.com/Atif-Anwer/Mediapipe-to-OpenPose-JSON -MEDIAPIPE_TO_OPENPOSE_MAP = { - # OpenPose format (25 keypoints): - # 0: Nose, 1: Neck, 2: RShoulder, 3: RElbow, 4: RWrist, - # 5: LShoulder, 6: LElbow, 7: LWrist, 8: MidHip, 9: RHip, - # 10: RKnee, 11: RAnkle, 12: LHip, 13: LKnee, 14: LAnkle, - # 15: REye, 16: LEye, 17: REar, 18: LEar, 19: LBigToe, - # 20: LSmallToe, 21: LHeel, 22: RBigToe, 23: RSmallToe, 24: RHeel - - 0: 0, # Nose -> Nose - 1: None, # Neck (calculated from shoulders) - 2: 12, # RShoulder -> RightShoulder - 3: 14, # RElbow -> RightElbow - 4: 16, # RWrist -> RightWrist - 5: 11, # LShoulder -> LeftShoulder - 6: 13, # LElbow -> LeftElbow - 7: 15, # LWrist -> LeftWrist - 8: None, # MidHip (calculated from hips) - 9: 24, # RHip -> RightHip - 10: 26, # RKnee -> RightKnee - 11: 28, # RAnkle -> RightAnkle - 12: 23, # LHip -> LeftHip - 13: 25, # LKnee -> LeftKnee - 14: 27, # LAnkle -> LeftAnkle - 15: 5, # REye -> RightEye - 16: 2, # LEye -> LeftEye - 17: 8, # REar -> RightEar - 18: 7, # LEar -> LeftEar - 19: 31, # LBigToe -> LeftFootIndex - 20: 31, # LSmallToe -> LeftFootIndex (approximation) - 21: 29, # LHeel -> LeftHeel - 22: 32, # RBigToe -> RightFootIndex - 23: 32, # RSmallToe -> RightFootIndex (approximation) - 24: 30 # RHeel -> RightHeel -} - -# OpenPose connections for proper skeleton rendering -OPENPOSE_LIMB_SEQUENCE = [ - [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], - [1, 8], [8, 9], [9, 10], [10, 11], [8, 12], [12, 13], - [13, 14], [1, 0], [0, 15], [15, 17], [0, 16], [16, 18], - [14, 19], [19, 20], [14, 21], [11, 22], [22, 23], [11, 24] -] - -# Standard OpenPose colors (BGR format) - matching actual OpenPose output -OPENPOSE_COLORS = [ - [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], - [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], - [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], - [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0], [255, 85, 0], - [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0] -] - - class MediaPipePosePreprocessor(BasePreprocessor): """ - MediaPipe-based pose preprocessor for ControlNet that outputs OpenPose-style annotations - - Converts MediaPipe's 33 keypoints to OpenPose's 25 keypoints format and renders - them in the standard OpenPose style for ControlNet compatibility. - - Improvements inspired by TouchDesigner MediaPipe plugin: - - Better confidence filtering - - Temporal smoothing for jitter reduction - - Improved multi-pose support preparation + MediaPipe-based pose preprocessor for ControlNet that outputs OpenPose-style annotations. + + This preprocessor uses the latest MediaPipe Solutions API to perform modular detection of + pose, face, and hand landmarks. It converts the detected keypoints into an OpenPose-compatible + format for use with ControlNet. + + Features: + - Modular detection: Enable or disable pose, face, and hand detection independently. + - OpenPose compatibility: Converts MediaPipe landmarks to a 25-keypoint OpenPose skeleton. + - Temporal smoothing: Reduces jitter in video streams for more stable animations. """ - - def __init__(self, - detect_resolution: int = 512, - image_resolution: int = 512, - min_detection_confidence: float = 0.5, - min_tracking_confidence: float = 0.5, - model_complexity: int = 1, - static_image_mode: bool = True, - draw_hands: bool = True, - draw_face: bool = False, # Simplified - disable face by default - line_thickness: int = 2, - circle_radius: int = 4, - confidence_threshold: float = 0.3, # TouchDesigner-style confidence filtering - enable_smoothing: bool = True, # TouchDesigner-inspired smoothing - smoothing_factor: float = 0.7, # Smoothing strength - **kwargs): + + def __init__( + self, + detect_resolution: int = 512, + image_resolution: int = 512, + enable_pose: bool = True, + enable_face: bool = True, + enable_hands: bool = True, + line_thickness: int = 2, + circle_radius: int = 4, + confidence_threshold: float = 0.3, + enable_smoothing: bool = True, + smoothing_factor: float = 0.7, + pose_options: Optional[Dict] = None, + face_options: Optional[Dict] = None, + hand_options: Optional[Dict] = None, + **kwargs, + ): """ - Initialize MediaPipe pose preprocessor with TouchDesigner-inspired improvements - + Initializes the MediaPipePosePreprocessor. + Args: - detect_resolution: Resolution for pose detection - image_resolution: Output image resolution - min_detection_confidence: Minimum confidence for detection - min_tracking_confidence: Minimum confidence for tracking - model_complexity: MediaPipe model complexity (0, 1, or 2) - static_image_mode: Treat each image independently - draw_hands: Whether to draw hand poses - draw_face: Whether to draw face landmarks - line_thickness: Thickness of skeleton lines - circle_radius: Radius of joint circles - confidence_threshold: Minimum confidence for rendering keypoints - enable_smoothing: Enable temporal smoothing - smoothing_factor: Smoothing strength (0-1, higher = more smoothing) - **kwargs: Additional parameters + detect_resolution: The resolution for landmark detection. + image_resolution: The output image resolution. + enable_pose: Whether to enable pose detection. + enable_face: Whether to enable face landmark detection. + enable_hands: Whether to enable hand landmark detection. + line_thickness: The thickness of the drawn skeleton lines. + circle_radius: The radius of the drawn keypoint circles. + confidence_threshold: The minimum confidence score for a keypoint to be rendered. + enable_smoothing: Whether to apply temporal smoothing to the keypoints. + smoothing_factor: The strength of the temporal smoothing (0-1). + pose_options: Custom options for the PoseLandmarker. + face_options: Custom options for the FaceLandmarker. + hand_options: Custom options for the HandLandmarker. """ if not MEDIAPIPE_AVAILABLE: raise ImportError( "MediaPipe is required for MediaPipe pose preprocessing. " "Install it with: pip install mediapipe" ) - + super().__init__( detect_resolution=detect_resolution, image_resolution=image_resolution, - min_detection_confidence=min_detection_confidence, - min_tracking_confidence=min_tracking_confidence, - model_complexity=model_complexity, - static_image_mode=static_image_mode, - draw_hands=draw_hands, - draw_face=draw_face, + enable_pose=enable_pose, + enable_face=enable_face, + enable_hands=enable_hands, line_thickness=line_thickness, circle_radius=circle_radius, confidence_threshold=confidence_threshold, enable_smoothing=enable_smoothing, smoothing_factor=smoothing_factor, - **kwargs + pose_options=pose_options, + face_options=face_options, + hand_options=hand_options, + **kwargs, ) - - self._detector = None - self._current_options = None - # TouchDesigner-style smoothing buffers - self._smoothing_buffers = {} - - @property - def detector(self): - """Lazy loading of the MediaPipe Holistic detector""" - new_options = { - 'min_detection_confidence': self.params.get('min_detection_confidence', 0.5), - 'min_tracking_confidence': self.params.get('min_tracking_confidence', 0.5), - 'model_complexity': self.params.get('model_complexity', 1), - 'static_image_mode': self.params.get('static_image_mode', True), - } - - # Initialize or update detector if needed - if self._detector is None or self._current_options != new_options: - if self._detector is not None: - self._detector.close() - - print(f"MediaPipePosePreprocessor.detector: Initializing MediaPipe Holistic detector") - self._detector = mp.solutions.holistic.Holistic( - static_image_mode=new_options['static_image_mode'], - model_complexity=new_options['model_complexity'], - enable_segmentation=False, - refine_face_landmarks=False, # Keep simple - min_detection_confidence=new_options['min_detection_confidence'], - min_tracking_confidence=new_options['min_tracking_confidence'], + + self.enable_pose = enable_pose + self.enable_face = enable_face + self.enable_hands = enable_hands + + self.pose_detector = None + self.face_detector = None + self.hand_detector = None + + logger.debug("Initializing Pose Landmarker...") + if self.enable_pose: + self.pose_detector = PoseLandmarkerWrapper( + model_path=POSE_LANDMARKER_MODEL, **(pose_options or {}) ) - self._current_options = new_options - - return self._detector + logger.debug("Pose Landmarker initialized. Type: %s", type(self.pose_detector)) + else: + logger.debug("Pose Landmarker disabled.") + + logger.debug("Initializing Face Landmarker...") + if self.enable_face: + self.face_detector = FaceLandmarkerWrapper( + model_path=FACE_LANDMARKER_MODEL, **(face_options or {}) + ) + logger.debug("Face Landmarker initialized.") + else: + logger.debug("Face Landmarker disabled.") + + logger.debug("Initializing Hand Landmarker...") + if self.enable_hands: + self.hand_detector = HandLandmarkerWrapper( + model_path=HAND_LANDMARKER_MODEL, **(hand_options or {}) + ) + + # Buffer storing previous smoothed keypoints per unique pose id + self._smoothing_buffers: Dict[str, List[List[float]]] = {} + + # Copy ctor args to explicit attributes so helpers avoid hidden `self.params` + self.enable_smoothing = enable_smoothing + self.smoothing_factor = smoothing_factor + self._face_idx = np.fromiter( + [MEDIAPIPE_TO_OPENPOSE_FACE_MAP[i] for i in range(70)], dtype=np.int32 + ) + + def __call__(self, input_image: Union[Image.Image, np.ndarray], **kwargs) -> Image.Image: + """ + Process an input image to detect and draw pose, face, and hand landmarks. + + Args: + input_image: The input image in PIL or NumPy format. + **kwargs: Additional keyword arguments. + + Returns: + A PIL Image with the detected landmarks drawn. + """ + if not MEDIAPIPE_AVAILABLE: + raise ImportError("MediaPipe is not installed") + + # Convert incoming image to BGR once and keep that space for the whole pipeline. + if isinstance(input_image, Image.Image): + # PIL images are RGB; convert directly to BGR numpy array + input_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) + elif isinstance(input_image, np.ndarray): + if input_image.shape[2] == 4: + # RGBA → BGR + input_image = cv2.cvtColor(input_image, cv2.COLOR_RGBA2BGR) + # else assume already BGR (OpenCV default) + + detect_resolution = self.detect_resolution + image_resolution = self.image_resolution + + image_resized = cv2.resize(input_image, (detect_resolution, detect_resolution)) + + canvas = np.zeros_like(image_resized) + + if self.enable_pose and self.pose_detector: + pose_results = self.pose_detector.detect(image_resized) + if pose_results and pose_results.pose_landmarks: + for landmarks in pose_results.pose_landmarks: + openpose_keypoints = self._mediapipe_to_openpose( + landmarks, detect_resolution, detect_resolution + ) + if self.enable_smoothing: + openpose_keypoints = self._apply_smoothing(openpose_keypoints) + canvas = self._draw_openpose_skeleton(canvas, openpose_keypoints) + + if self.enable_face and self.face_detector: + face_results = self.face_detector.detect(image_resized) + if face_results and face_results.face_landmarks: + for landmarks in face_results.face_landmarks: + canvas = self._draw_face_keypoints(canvas, landmarks) + + if self.enable_hands and self.hand_detector: + hand_results = self.hand_detector.detect(image_resized) + if hand_results and hand_results.hand_landmarks: + for i, landmarks in enumerate(hand_results.hand_landmarks): + is_left = hand_results.handedness[i][0].category_name == 'Left' + canvas = self._draw_hand_keypoints(canvas, landmarks, is_left) + + if image_resolution != detect_resolution: + canvas = cv2.resize(canvas, (image_resolution, image_resolution), interpolation=cv2.INTER_AREA) + + # Convert BGR canvas back to RGB for PIL output + canvas_rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB) + return Image.fromarray(canvas_rgb) - def _apply_smoothing(self, keypoints: List[List[float]], pose_id: str = "default") -> List[List[float]]: + def _apply_smoothing(self, keypoints: List[List[float]], pose_id: str | None = None) -> List[List[float]]: """ Apply TouchDesigner-inspired temporal smoothing @@ -181,15 +220,25 @@ def _apply_smoothing(self, keypoints: List[List[float]], pose_id: str = "default Returns: Smoothed keypoints """ - if not self.params.get('enable_smoothing', True) or not keypoints: + # Fast-exit if smoothing disabled or missing keypoints + if not self.enable_smoothing or not keypoints: return keypoints - - smoothing_factor = self.params.get('smoothing_factor', 0.7) - - # Initialize buffer for this pose if needed + + # Derive a stable pose_id if none supplied – use a simple hash of the first visible + # landmark positions so multi-person frames don’t overwrite each other. + if pose_id is None: + try: + first_kp = next(pt for pt in keypoints if pt[2] > 0.1) + pose_id = f"{hash((round(first_kp[0],2), round(first_kp[1],2)))}" + except StopIteration: + pose_id = "default" + + # Initialise history buffer lazily if pose_id not in self._smoothing_buffers: self._smoothing_buffers[pose_id] = keypoints.copy() return keypoints + + smoothing_factor = self.smoothing_factor # Apply exponential smoothing (simplified 1-euro filter style) smoothed = [] @@ -208,17 +257,19 @@ def _apply_smoothing(self, keypoints: List[List[float]], pose_id: str = "default self._smoothing_buffers[pose_id] = smoothed return smoothed - def _mediapipe_to_openpose(self, mediapipe_landmarks: List, image_width: int, image_height: int) -> List[List[float]]: + def _mediapipe_to_openpose( + self, mediapipe_landmarks: List, image_width: int, image_height: int + ) -> List[List[float]]: """ - Convert MediaPipe landmarks to OpenPose format - + Convert MediaPipe landmarks to OpenPose format. + Args: - mediapipe_landmarks: MediaPipe pose landmarks - image_width: Image width - image_height: Image height - + mediapipe_landmarks: A list of MediaPipe pose landmarks. + image_width: The width of the image. + image_height: The height of the image. + Returns: - OpenPose keypoints in [x, y, confidence] format + A list of OpenPose keypoints in [x, y, confidence] format. """ if not mediapipe_landmarks: return [] @@ -260,16 +311,18 @@ def _mediapipe_to_openpose(self, mediapipe_landmarks: List, image_width: int, im return openpose_keypoints - def _draw_openpose_skeleton(self, image: np.ndarray, keypoints: List[List[float]]) -> np.ndarray: + def _draw_openpose_skeleton( + self, image: np.ndarray, keypoints: List[List[float]] + ) -> np.ndarray: """ - Draw OpenPose-style skeleton on image - + Draw an OpenPose-style skeleton on an image. + Args: - image: Input image - keypoints: OpenPose keypoints - + image: The input image as a NumPy array. + keypoints: A list of OpenPose keypoints. + Returns: - Image with skeleton drawn + The image with the skeleton drawn on it. """ if not keypoints or len(keypoints) != 25: return image @@ -356,7 +409,55 @@ def _draw_hand_keypoints(self, image: np.ndarray, hand_landmarks: List, is_left_ return image - def process(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: + def _draw_face_keypoints(self, image: np.ndarray, face_landmarks: List) -> np.ndarray: + """ + Draw face landmarks in OpenPose style. + + Args: + image: Input image canvas. + face_landmarks: MediaPipe face landmarks (468 points). + + Returns: + Image with face skeleton drawn. + """ + if not face_landmarks: + return image + + h, w = image.shape[:2] + line_thickness = self.params.get('line_thickness', 2) + confidence_threshold = self.params.get('confidence_threshold', 0.3) + + # Vectorised conversion of 468 face landmarks (x,y) and mapping to 70-point OpenPose order + pts = np.stack([(lm.x * w, lm.y * h) for lm in face_landmarks], axis=0).astype(np.float32) + + # Map to 70-point OpenPose order in a single take + openpose_pts = pts[self._face_idx] # (70,2) + + # Draw connections + # Draw connections using the mapped 70-point array + for i, (start_idx, end_idx) in enumerate(OPENPOSE_FACE_CONNECTIONS): + if start_idx < openpose_pts.shape[0] and end_idx < openpose_pts.shape[0]: + start_point = tuple(openpose_pts[start_idx].astype(int)) + end_point = tuple(openpose_pts[end_idx].astype(int)) + color = FACE_COLORS[i % len(FACE_COLORS)] + cv2.line(image, start_point, end_point, color, line_thickness) + + return image + + # DEPRECATED - old method, keep for reference + def process(self, image: Union[Image.Image, np.ndarray]): + """ + Apply MediaPipe pose detection and create OpenPose-style annotation + + Args: + image: Input image + + Returns: + PIL Image with OpenPose-style pose skeleton on black background + """ + return self(image) + + def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: """ Apply MediaPipe pose detection and create OpenPose-style annotation @@ -374,19 +475,29 @@ def process(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS) # Convert to RGB numpy array for MediaPipe - rgb_image = cv2.cvtColor(np.array(image_resized), cv2.COLOR_BGR2RGB) + rgb_image = np.asarray(image_resized) # Already RGB, avoid extra conversion + + pose_results = None + hand_results = None + face_results = None + + if self.enable_pose and self.pose_detector: + pose_results = self.pose_detector.detect(rgb_image) - # Run MediaPipe detection - results = self.detector.process(rgb_image) + if self.enable_hands and self.hand_detector: + hand_results = self.hand_detector.detect(rgb_image) + + if self.enable_face and self.face_detector: + face_results = self.face_detector.detect(rgb_image) # Create black background for pose annotation pose_image = np.zeros((detect_resolution, detect_resolution, 3), dtype=np.uint8) # Draw pose skeleton if detected - if results.pose_landmarks: + if pose_results and pose_results.pose_landmarks: # Convert MediaPipe to OpenPose format openpose_keypoints = self._mediapipe_to_openpose( - results.pose_landmarks.landmark, + pose_results.pose_landmarks[0], # Assuming single person detection for this path detect_resolution, detect_resolution ) @@ -399,19 +510,25 @@ def process(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: # Draw hands if enabled draw_hands = self.params.get('draw_hands', True) - if draw_hands: - if results.left_hand_landmarks: - pose_image = self._draw_hand_keypoints( - pose_image, results.left_hand_landmarks.landmark, is_left_hand=True - ) - - if results.right_hand_landmarks: - pose_image = self._draw_hand_keypoints( - pose_image, results.right_hand_landmarks.landmark, is_left_hand=False + if draw_hands and self.enable_hands and hand_results and hand_results.hand_landmarks: + for i, landmarks_list in enumerate(hand_results.hand_landmarks): + if hand_results.handedness and i < len(hand_results.handedness): + is_left = hand_results.handedness[i][0].category_name == 'Left' + pose_image = self._draw_hand_keypoints( + pose_image, landmarks_list, is_left_hand=is_left + ) + + + # Draw face if enabled + draw_face = self.params.get('draw_face', True) + if draw_face and self.enable_face and face_results and face_results.face_landmarks: + for landmarks_list in face_results.face_landmarks: + pose_image = self._draw_face_keypoints( + pose_image, landmarks_list ) # Convert back to PIL - pose_pil = Image.fromarray(cv2.cvtColor(pose_image, cv2.COLOR_BGR2RGB)) + pose_pil = Image.fromarray(pose_image[:, :, ::-1]) # BGR -> RGB with channel flip # Resize to target resolution image_resolution = self.params.get('image_resolution', 512) @@ -437,7 +554,7 @@ def process_tensor(self, image_tensor: torch.Tensor) -> torch.Tensor: def reset_smoothing_buffers(self): """Reset smoothing buffers (useful for new sequences)""" - print("MediaPipePosePreprocessor.reset_smoothing_buffers: Clearing smoothing buffers") + logger.info("MediaPipePosePreprocessor.reset_smoothing_buffers: Clearing smoothing buffers") self._smoothing_buffers.clear() def __del__(self):