From 19bf387748bd720a5cf7fcfc2ade71864afc320e Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 25 Jan 2026 14:25:25 -0600 Subject: [PATCH 1/4] kmeans cluster implementation --- .../predictionset/cluster/cluster_label.py | 356 ++++++++++++++++++ 1 file changed, 356 insertions(+) create mode 100644 pyhealth/calib/predictionset/cluster/cluster_label.py diff --git a/pyhealth/calib/predictionset/cluster/cluster_label.py b/pyhealth/calib/predictionset/cluster/cluster_label.py new file mode 100644 index 00000000..253965b0 --- /dev/null +++ b/pyhealth/calib/predictionset/cluster/cluster_label.py @@ -0,0 +1,356 @@ +""" +Cluster-Based Conformal Prediction. + +This module implements conformal prediction with cluster-specific calibration +thresholds using K-means clustering on patient embeddings. The method groups +similar patients into clusters and computes separate calibration thresholds +for each cluster, enabling cluster-aware prediction sets. + +This serves as a baseline approach for future personalized/dynamic conformal +prediction methods that use patient similarity for calibration set construction. +""" + +from typing import Dict, Optional, Union + +import numpy as np +import torch +from sklearn.cluster import KMeans +from torch.utils.data import IterableDataset + +from pyhealth.calib.base_classes import SetPredictor +from pyhealth.calib.predictionset.base_conformal import _query_quantile +from pyhealth.calib.utils import extract_embeddings, prepare_numpy_dataset +from pyhealth.models import BaseModel + +__all__ = ["ClusterLabel"] + + +class ClusterLabel(SetPredictor): + """Cluster-based conformal prediction for multiclass classification. + + This method uses K-means clustering on patient embeddings to group similar + patients into clusters. Each cluster gets its own calibration threshold, + computed from the conformity scores of calibration samples in that cluster. + At inference time, test samples are assigned to their nearest cluster and + use the cluster-specific threshold. + + This approach is simpler than KDE-based methods and serves as a baseline + for more advanced personalized conformal prediction approaches. + + Args: + model: A trained base model that supports embedding extraction + (must support `embed=True` in forward pass) + alpha: Target miscoverage rate(s). Can be: + - float: marginal coverage P(Y not in C(X)) <= alpha + - array: class-conditional P(Y not in C(X) | Y=k) <= alpha[k] + n_clusters: Number of K-means clusters. Default is 5. + random_state: Random seed for K-means clustering. Default is 42. + debug: Whether to use debug mode (processes fewer samples for + faster iteration) + + Examples: + >>> from pyhealth.datasets import TUEVDataset, split_by_sample_conformal + >>> from pyhealth.datasets import get_dataloader + >>> from pyhealth.models import ContraWR + >>> from pyhealth.tasks import EEGEventsTUEV + >>> from pyhealth.calib.predictionset.cluster import ClusterLabel + >>> from pyhealth.calib.utils import extract_embeddings + >>> from pyhealth.trainer import Trainer, get_metrics_fn + >>> + >>> # Prepare data + >>> dataset = TUEVDataset(root="path/to/tuev") + >>> sample_dataset = dataset.set_task(EEGEventsTUEV()) + >>> train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( + ... sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15], seed=42 + ... ) + >>> + >>> # Train model + >>> model = ContraWR(dataset=sample_dataset) + >>> # ... training code ... + >>> + >>> # Extract embeddings for clustering + >>> train_embeddings = extract_embeddings(model, train_ds, batch_size=32) + >>> cal_embeddings = extract_embeddings(model, cal_ds, batch_size=32) + >>> + >>> # Create and calibrate cluster-based predictor + >>> cluster_predictor = ClusterLabel(model=model, alpha=0.1, n_clusters=5) + >>> cluster_predictor.calibrate( + ... cal_dataset=cal_ds, + ... train_embeddings=train_embeddings, + ... cal_embeddings=cal_embeddings, + ... ) + >>> + >>> # Evaluate + >>> test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + >>> y_true, y_prob, _, extra = Trainer(model=cluster_predictor).inference( + ... test_loader, additional_outputs=["y_predset"] + ... ) + >>> metrics = get_metrics_fn(cluster_predictor.mode)( + ... y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], + ... y_predset=extra["y_predset"] + ... ) + """ + + def __init__( + self, + model: BaseModel, + alpha: Union[float, np.ndarray], + n_clusters: int = 5, + random_state: int = 42, + debug: bool = False, + **kwargs, + ) -> None: + super().__init__(model, **kwargs) + + if model.mode != "multiclass": + raise NotImplementedError( + "ClusterLabel only supports multiclass classification" + ) + + self.mode = self.model.mode + + # Freeze model parameters + for param in model.parameters(): + param.requires_grad = False + self.model.eval() + + self.device = model.device + self.debug = debug + + # Store alpha + if not isinstance(alpha, float): + alpha = np.asarray(alpha) + self.alpha = alpha + + # Store clustering parameters + self.n_clusters = n_clusters + self.random_state = random_state + + # Will be set during calibration + self.kmeans_model = None + self.cluster_thresholds = None # Dict mapping cluster_id -> threshold(s) + + def calibrate( + self, + cal_dataset: IterableDataset, + train_embeddings: Optional[np.ndarray] = None, + cal_embeddings: Optional[np.ndarray] = None, + ): + """Calibrate cluster-specific thresholds. + + This method: + 1. Combines train and calibration embeddings for clustering + 2. Fits K-means on the combined embeddings + 3. Assigns calibration samples to clusters + 4. Computes cluster-specific calibration thresholds + + Args: + cal_dataset: Calibration set + train_embeddings: Optional pre-computed training embeddings + of shape (n_train, embedding_dim). If not provided, will + be extracted from the model (requires train_dataset parameter). + cal_embeddings: Optional pre-computed calibration embeddings + of shape (n_cal, embedding_dim). If not provided, will be + extracted from cal_dataset. + + Note: + Either provide embeddings directly or ensure the model supports + embedding extraction via `embed=True` flag. + """ + # Get predictions and true labels from calibration set + cal_dataset_dict = prepare_numpy_dataset( + self.model, + cal_dataset, + ["y_prob", "y_true"], + debug=self.debug, + ) + + y_prob = cal_dataset_dict["y_prob"] + y_true = cal_dataset_dict["y_true"] + N, K = y_prob.shape + + # Extract embeddings if not provided + if cal_embeddings is None: + print("Extracting embeddings from calibration set...") + cal_embeddings = extract_embeddings( + self.model, cal_dataset, batch_size=32, device=self.device + ) + else: + cal_embeddings = np.asarray(cal_embeddings) + + if train_embeddings is None: + raise ValueError( + "train_embeddings must be provided. " + "Extract embeddings from training set using extract_embeddings()." + ) + else: + train_embeddings = np.asarray(train_embeddings) + + # Combine embeddings for clustering + print(f"Combining embeddings: train={train_embeddings.shape}, cal={cal_embeddings.shape}") + all_embeddings = np.concatenate([train_embeddings, cal_embeddings], axis=0) + print(f"Total embeddings for clustering: {all_embeddings.shape}") + + # Fit K-means on combined embeddings + print(f"Fitting K-means with {self.n_clusters} clusters...") + self.kmeans_model = KMeans( + n_clusters=self.n_clusters, + random_state=self.random_state, + n_init=10, + ) + self.kmeans_model.fit(all_embeddings) + + # Assign calibration samples to clusters + # Note: cal_embeddings start at index len(train_embeddings) in all_embeddings + cal_start_idx = len(train_embeddings) + cal_cluster_labels = self.kmeans_model.labels_[cal_start_idx:] + + print(f"Cluster assignments: {np.bincount(cal_cluster_labels)}") + + # Compute conformity scores (probabilities of true class) + conformity_scores = y_prob[np.arange(N), y_true] + + # Compute cluster-specific thresholds + self.cluster_thresholds = {} + for cluster_id in range(self.n_clusters): + cluster_mask = cal_cluster_labels == cluster_id + cluster_scores = conformity_scores[cluster_mask] + + if len(cluster_scores) == 0: + print( + f"Warning: No calibration samples in cluster {cluster_id}, " + "using -inf threshold (include all classes)" + ) + if isinstance(self.alpha, float): + self.cluster_thresholds[cluster_id] = -np.inf + else: + self.cluster_thresholds[cluster_id] = np.array( + [-np.inf] * K + ) + else: + if isinstance(self.alpha, float): + # Marginal coverage: single threshold per cluster + t = _query_quantile(cluster_scores, self.alpha) + self.cluster_thresholds[cluster_id] = t + else: + # Class-conditional coverage: one threshold per class per cluster + if len(self.alpha) != K: + raise ValueError( + f"alpha must have length {K} for class-conditional " + f"coverage, got {len(self.alpha)}" + ) + t = [] + for k in range(K): + class_mask = (y_true[cluster_mask] == k) + if np.sum(class_mask) > 0: + class_scores = cluster_scores[class_mask] + t_k = _query_quantile(class_scores, self.alpha[k]) + else: + # If no calibration examples for this class in this cluster + print( + f"Warning: No calibration examples for class {k} " + f"in cluster {cluster_id}, using -inf threshold" + ) + t_k = -np.inf + t.append(t_k) + self.cluster_thresholds[cluster_id] = np.array(t) + + if self.debug: + print(f"Cluster thresholds: {self.cluster_thresholds}") + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation with cluster-specific prediction set construction. + + Returns: + Dictionary with all results from base model, plus: + - y_predset: Boolean tensor indicating which classes + are in the prediction set + """ + if self.kmeans_model is None or self.cluster_thresholds is None: + raise RuntimeError( + "Model must be calibrated before inference. " + "Call calibrate() first." + ) + + # Get base model prediction + pred = self.model(**kwargs) + + # Extract embedding for this sample to assign to cluster + embed_kwargs = {**kwargs, "embed": True} + embed_output = self.model(**embed_kwargs) + if "embed" not in embed_output: + raise ValueError( + f"Model {type(self.model).__name__} does not return " + "embeddings. Make sure the model supports the " + "embed=True flag in its forward() method." + ) + + # Get embedding and assign to cluster + sample_embedding = embed_output["embed"].detach().cpu().numpy() + if sample_embedding.ndim == 1: + sample_embedding = sample_embedding.reshape(1, -1) + + cluster_id = self.kmeans_model.predict(sample_embedding)[0] + + # Get cluster-specific threshold + cluster_threshold = self.cluster_thresholds[cluster_id] + + # Convert to tensor if needed + if isinstance(cluster_threshold, np.ndarray): + cluster_threshold = torch.tensor( + cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype + ) + else: + cluster_threshold = torch.tensor( + cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype + ) + + # Construct prediction set using cluster-specific threshold + pred["y_predset"] = pred["y_prob"] >= cluster_threshold + + return pred + + +if __name__ == "__main__": + """ + Demonstration of cluster-based conformal prediction. + """ + from pyhealth.datasets import TUEVDataset, split_by_sample_conformal, get_dataloader + from pyhealth.models import ContraWR + from pyhealth.tasks import EEGEventsTUEV + from pyhealth.calib.predictionset.cluster import ClusterLabel + from pyhealth.calib.utils import extract_embeddings + from pyhealth.trainer import Trainer, get_metrics_fn + + # Setup data and model + dataset = TUEVDataset(root="downloads/tuev/v2.0.1/edf", subset="both") + sample_dataset = dataset.set_task(EEGEventsTUEV()) + train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( + sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15], seed=42 + ) + + model = ContraWR(dataset=sample_dataset) + # ... Train the model here ... + + # Extract embeddings + train_embeddings = extract_embeddings(model, train_ds, batch_size=32) + cal_embeddings = extract_embeddings(model, cal_ds, batch_size=32) + + # Create and calibrate cluster-based predictor + cluster_predictor = ClusterLabel(model=model, alpha=0.1, n_clusters=5) + cluster_predictor.calibrate( + cal_dataset=cal_ds, + train_embeddings=train_embeddings, + cal_embeddings=cal_embeddings, + ) + + # Evaluate + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + y_true, y_prob, _, extra = Trainer(model=cluster_predictor).inference( + test_loader, additional_outputs=["y_predset"] + ) + metrics = get_metrics_fn(cluster_predictor.mode)( + y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], + y_predset=extra["y_predset"] + ) + print(f"Results: {metrics}") From 21fc76a953299dc85e7174f8377d6d34e6a80945 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 25 Jan 2026 14:26:50 -0600 Subject: [PATCH 2/4] init stuff --- pyhealth/calib/predictionset/__init__.py | 3 ++- pyhealth/calib/predictionset/cluster/__init__.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 pyhealth/calib/predictionset/cluster/__init__.py diff --git a/pyhealth/calib/predictionset/__init__.py b/pyhealth/calib/predictionset/__init__.py index 44264b80..46760945 100644 --- a/pyhealth/calib/predictionset/__init__.py +++ b/pyhealth/calib/predictionset/__init__.py @@ -1,9 +1,10 @@ """Prediction set construction methods""" from pyhealth.calib.predictionset.base_conformal import BaseConformal +from pyhealth.calib.predictionset.cluster import ClusterLabel from pyhealth.calib.predictionset.covariate import CovariateLabel from pyhealth.calib.predictionset.favmac import FavMac from pyhealth.calib.predictionset.label import LABEL from pyhealth.calib.predictionset.scrib import SCRIB -__all__ = ["BaseConformal", "LABEL", "SCRIB", "FavMac", "CovariateLabel"] +__all__ = ["BaseConformal", "LABEL", "SCRIB", "FavMac", "CovariateLabel", "ClusterLabel"] diff --git a/pyhealth/calib/predictionset/cluster/__init__.py b/pyhealth/calib/predictionset/cluster/__init__.py new file mode 100644 index 00000000..eacd8996 --- /dev/null +++ b/pyhealth/calib/predictionset/cluster/__init__.py @@ -0,0 +1,5 @@ +"""Cluster-based prediction set methods""" + +from pyhealth.calib.predictionset.cluster.cluster_label import ClusterLabel + +__all__ = ["ClusterLabel"] From 95682f1b60a6cd46b6100bb0f456aea1108dae9c Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 25 Jan 2026 14:29:18 -0600 Subject: [PATCH 3/4] script for eval --- .../conformal_eeg/tuev_kmeans_conformal.py | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 examples/conformal_eeg/tuev_kmeans_conformal.py diff --git a/examples/conformal_eeg/tuev_kmeans_conformal.py b/examples/conformal_eeg/tuev_kmeans_conformal.py new file mode 100644 index 00000000..82aa27e3 --- /dev/null +++ b/examples/conformal_eeg/tuev_kmeans_conformal.py @@ -0,0 +1,203 @@ +"""K-means Cluster-Based Conformal Prediction (ClusterLabel) on TUEV EEG Events using ContraWR. + +This script: +1) Loads the TUEV dataset and applies the EEGEventsTUEV task. +2) Splits into train/val/cal/test using split conformal protocol. +3) Trains a ContraWR model. +4) Extracts embeddings for training and calibration splits using embed=True. +5) Calibrates a ClusterLabel prediction-set predictor (K-means clustering). +6) Evaluates prediction-set coverage/miscoverage and efficiency on the test split. + +Example (from repo root): + python examples/conformal_eeg/tuev_kmeans_conformal.py --root downloads/tuev/v2.0.1/edf --n-clusters 5 + +Notes: +- ClusterLabel uses K-means clustering on embeddings to compute cluster-specific thresholds. +- Different K values can be tested to find optimal cluster count. +""" + +from __future__ import annotations + +import argparse +import random +from pathlib import Path + +import numpy as np +import torch + +from pyhealth.calib.predictionset.cluster import ClusterLabel +from pyhealth.calib.utils import extract_embeddings +from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal +from pyhealth.models import ContraWR +from pyhealth.tasks import EEGEventsTUEV +from pyhealth.trainer import Trainer, get_metrics_fn + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="K-means cluster-based conformal prediction (ClusterLabel) on TUEV EEG events using ContraWR." + ) + parser.add_argument( + "--root", + type=str, + default="downloads/tuev/v2.0.1/edf", + help="Path to TUEV edf/ folder.", + ) + parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"]) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).") + parser.add_argument( + "--ratios", + type=float, + nargs=4, + default=(0.6, 0.1, 0.15, 0.15), + metavar=("TRAIN", "VAL", "CAL", "TEST"), + help="Split ratios for train/val/cal/test. Must sum to 1.0.", + ) + parser.add_argument( + "--n-clusters", + type=int, + default=5, + help="Number of K-means clusters for cluster-specific thresholds.", + ) + parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size used by ContraWR.") + parser.add_argument( + "--device", + type=str, + default=None, + help="Device string, e.g. 'cuda:0' or 'cpu'. Defaults to auto-detect.", + ) + return parser.parse_args() + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def main() -> None: + args = parse_args() + set_seed(args.seed) + + device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") + root = Path(args.root) + if not root.exists(): + raise FileNotFoundError( + f"TUEV root not found: {root}. " + "Pass --root to point to your downloaded TUEV edf/ directory." + ) + + print("=" * 80) + print("STEP 1: Load TUEV + build task dataset") + print("=" * 80) + dataset = TUEVDataset(root=str(root), subset=args.subset) + sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache") + + print(f"Task samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + if len(sample_dataset) == 0: + raise RuntimeError("No samples produced. Verify TUEV root/subset/task.") + + print("\n" + "=" * 80) + print("STEP 2: Split train/val/cal/test") + print("=" * 80) + train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal( + dataset=sample_dataset, ratios=list(args.ratios), seed=args.seed + ) + print(f"Train: {len(train_ds)}") + print(f"Val: {len(val_ds)}") + print(f"Cal: {len(cal_ds)}") + print(f"Test: {len(test_ds)}") + + train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None + test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False) + + print("\n" + "=" * 80) + print("STEP 3: Train ContraWR") + print("=" * 80) + model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device) + trainer = Trainer(model=model, device=device, enable_logging=False) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.epochs, + monitor="accuracy" if val_loader is not None else None, + ) + + print("\nBase model performance on test set:") + y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader) + base_metrics = get_metrics_fn("multiclass")(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"]) + for metric, value in base_metrics.items(): + print(f" {metric}: {value:.4f}") + + print("\n" + "=" * 80) + print("STEP 4: K-means Cluster-Based Conformal Prediction (ClusterLabel)") + print("=" * 80) + print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})") + print(f"Number of clusters: {args.n_clusters}") + + print("Extracting embeddings for training split...") + train_embeddings = extract_embeddings(model, train_ds, batch_size=args.batch_size, device=device) + print(f" train_embeddings shape: {train_embeddings.shape}") + + print("Extracting embeddings for calibration split...") + cal_embeddings = extract_embeddings(model, cal_ds, batch_size=args.batch_size, device=device) + print(f" cal_embeddings shape: {cal_embeddings.shape}") + + cluster_predictor = ClusterLabel( + model=model, + alpha=float(args.alpha), + n_clusters=args.n_clusters, + random_state=args.seed, + ) + print("Calibrating ClusterLabel predictor (fits K-means and computes cluster-specific thresholds)...") + cluster_predictor.calibrate( + cal_dataset=cal_ds, + train_embeddings=train_embeddings, + cal_embeddings=cal_embeddings, + ) + + print("Evaluating ClusterLabel predictor on test set...") + y_true, y_prob, _loss, extra = Trainer(model=cluster_predictor).inference( + test_loader, additional_outputs=["y_predset"] + ) + + cluster_metrics = get_metrics_fn("multiclass")( + y_true, + y_prob, + metrics=["accuracy", "miscoverage_ps"], + y_predset=extra["y_predset"], + ) + + predset = extra["y_predset"] + if isinstance(predset, np.ndarray): + predset_t = torch.tensor(predset) + else: + predset_t = predset + avg_set_size = predset_t.float().sum(dim=1).mean().item() + + miscoverage = cluster_metrics["miscoverage_ps"] + if isinstance(miscoverage, np.ndarray): + miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean()) + else: + miscoverage = float(miscoverage) + + print("\nClusterLabel Results:") + print(f" Accuracy: {cluster_metrics['accuracy']:.4f}") + print(f" Empirical miscoverage: {miscoverage:.4f}") + print(f" Empirical coverage: {1 - miscoverage:.4f}") + print(f" Average set size: {avg_set_size:.2f}") + print(f" Number of clusters: {args.n_clusters}") + + +if __name__ == "__main__": + main() From db5beed396f855b72d83a65cd05ee53addf21779 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 25 Jan 2026 14:32:54 -0600 Subject: [PATCH 4/4] Add tests --- tests/core/test_cluster_label.py | 408 +++++++++++++++++++++++++++++++ 1 file changed, 408 insertions(+) create mode 100644 tests/core/test_cluster_label.py diff --git a/tests/core/test_cluster_label.py b/tests/core/test_cluster_label.py new file mode 100644 index 00000000..509a024b --- /dev/null +++ b/tests/core/test_cluster_label.py @@ -0,0 +1,408 @@ +import unittest +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MLP +from pyhealth.calib.predictionset.cluster import ClusterLabel +from pyhealth.calib.utils import extract_embeddings + + +class TestClusterLabel(unittest.TestCase): + """Test cases for the ClusterLabel prediction set constructor.""" + + def setUp(self): + """Set up test data and model.""" + # Create samples with 3 classes for multiclass classification + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80", "cond-12"], + "procedures": [1.0, 2.0, 3.5, 4.0], + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-33", "cond-86", "cond-80"], + "procedures": [5.0, 2.0, 3.5, 4.0], + "label": 1, + }, + { + "patient_id": "patient-2", + "visit_id": "visit-2", + "conditions": ["cond-10", "cond-20", "cond-30"], + "procedures": [2.0, 3.0, 4.5, 5.0], + "label": 2, + }, + { + "patient_id": "patient-3", + "visit_id": "visit-3", + "conditions": ["cond-40", "cond-50"], + "procedures": [1.5, 2.5, 3.0, 4.5], + "label": 0, + }, + { + "patient_id": "patient-4", + "visit_id": "visit-4", + "conditions": ["cond-60", "cond-70", "cond-80"], + "procedures": [3.0, 4.0, 5.0, 6.0], + "label": 1, + }, + { + "patient_id": "patient-5", + "visit_id": "visit-5", + "conditions": ["cond-90", "cond-100"], + "procedures": [2.5, 3.5, 4.0, 5.5], + "label": 2, + }, + ] + + # Define input and output schemas + self.input_schema = { + "conditions": "sequence", + "procedures": "tensor", + } + self.output_schema = {"label": "multiclass"} + + # Create dataset + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + # Create and set up base model + self.model = MLP( + dataset=self.dataset, + feature_keys=["conditions", "procedures"], + label_key="label", + mode="multiclass", + ) + self.model.eval() + + def _get_embeddings(self, dataset): + """Helper to extract embeddings from dataset.""" + return extract_embeddings(self.model, dataset, batch_size=32, device="cpu") + + def test_initialization(self): + """Test that ClusterLabel initializes correctly.""" + cluster_model = ClusterLabel( + model=self.model, + alpha=0.1, + n_clusters=3, + random_state=42, + ) + + self.assertIsInstance(cluster_model, ClusterLabel) + self.assertEqual(cluster_model.mode, "multiclass") + self.assertEqual(cluster_model.alpha, 0.1) + self.assertEqual(cluster_model.n_clusters, 3) + self.assertEqual(cluster_model.random_state, 42) + self.assertIsNone(cluster_model.kmeans_model) + self.assertIsNone(cluster_model.cluster_thresholds) + + def test_initialization_with_array_alpha(self): + """Test initialization with class-specific alpha values.""" + alpha_per_class = [0.1, 0.15, 0.2] + cluster_model = ClusterLabel( + model=self.model, + alpha=alpha_per_class, + n_clusters=3, + ) + + self.assertIsInstance(cluster_model.alpha, np.ndarray) + np.testing.assert_array_equal(cluster_model.alpha, alpha_per_class) + + def test_initialization_non_multiclass_raises_error(self): + """Test that non-multiclass models raise an error.""" + # Create a binary classification dataset + binary_samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-1"], + "procedures": [1.0], + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-2"], + "procedures": [2.0], + "label": 1, + }, + ] + binary_dataset = create_sample_dataset( + samples=binary_samples, + input_schema={"conditions": "sequence", "procedures": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="test", + ) + binary_model = MLP( + dataset=binary_dataset, + feature_keys=["conditions"], + label_key="label", + mode="binary", + ) + + with self.assertRaises(NotImplementedError): + ClusterLabel( + model=binary_model, + alpha=0.1, + n_clusters=2, + ) + + def test_calibrate_marginal(self): + """Test calibration with marginal coverage (single alpha).""" + cluster_model = ClusterLabel( + model=self.model, + alpha=0.3, + n_clusters=3, + random_state=42, + ) + + # Split into train and cal sets + train_indices = [0, 1, 2] + cal_indices = [3, 4, 5] + train_dataset = self.dataset.subset(train_indices) + cal_dataset = self.dataset.subset(cal_indices) + + # Extract embeddings + train_embeddings = self._get_embeddings(train_dataset) + cal_embeddings = self._get_embeddings(cal_dataset) + + cluster_model.calibrate( + cal_dataset=cal_dataset, + train_embeddings=train_embeddings, + cal_embeddings=cal_embeddings, + ) + + # Check that K-means model is fitted + self.assertIsNotNone(cluster_model.kmeans_model) + self.assertEqual(cluster_model.kmeans_model.n_clusters, 3) + + # Check that cluster thresholds are set + self.assertIsNotNone(cluster_model.cluster_thresholds) + self.assertIsInstance(cluster_model.cluster_thresholds, dict) + self.assertEqual(len(cluster_model.cluster_thresholds), 3) + + # Check that each cluster has a threshold + for cluster_id in range(3): + self.assertIn(cluster_id, cluster_model.cluster_thresholds) + threshold = cluster_model.cluster_thresholds[cluster_id] + self.assertIsInstance(threshold, (float, np.floating)) + + def test_calibrate_class_conditional(self): + """Test calibration with class-conditional coverage.""" + alpha_per_class = [0.2, 0.25, 0.3] + cluster_model = ClusterLabel( + model=self.model, + alpha=alpha_per_class, + n_clusters=2, + random_state=42, + ) + + # Split into train and cal sets + train_indices = [0, 1, 2] + cal_indices = [3, 4, 5] + train_dataset = self.dataset.subset(train_indices) + cal_dataset = self.dataset.subset(cal_indices) + + # Extract embeddings + train_embeddings = self._get_embeddings(train_dataset) + cal_embeddings = self._get_embeddings(cal_dataset) + + cluster_model.calibrate( + cal_dataset=cal_dataset, + train_embeddings=train_embeddings, + cal_embeddings=cal_embeddings, + ) + + # Check that cluster thresholds are set (one per class per cluster) + self.assertIsNotNone(cluster_model.cluster_thresholds) + for cluster_id in cluster_model.cluster_thresholds: + threshold = cluster_model.cluster_thresholds[cluster_id] + self.assertIsInstance(threshold, np.ndarray) + self.assertEqual(len(threshold), 3) # 3 classes + + def test_forward_returns_predset(self): + """Test that forward pass returns prediction sets.""" + cluster_model = ClusterLabel( + model=self.model, + alpha=0.2, + n_clusters=3, + random_state=42, + ) + + # Calibrate + train_indices = [0, 1, 2] + cal_indices = [3, 4, 5] + train_dataset = self.dataset.subset(train_indices) + cal_dataset = self.dataset.subset(cal_indices) + + train_embeddings = self._get_embeddings(train_dataset) + cal_embeddings = self._get_embeddings(cal_dataset) + + cluster_model.calibrate( + cal_dataset=cal_dataset, + train_embeddings=train_embeddings, + cal_embeddings=cal_embeddings, + ) + + # Test forward pass + test_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(test_loader)) + + with torch.no_grad(): + output = cluster_model(**data_batch) + + # Check output contains prediction set + self.assertIn("y_predset", output) + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + + # Check prediction set is boolean + self.assertEqual(output["y_predset"].dtype, torch.bool) + + # Check prediction set shape matches probability shape + self.assertEqual(output["y_predset"].shape, output["y_prob"].shape) + + def test_prediction_sets_nonempty(self): + """Test that prediction sets are non-empty for most examples.""" + cluster_model = ClusterLabel( + model=self.model, + alpha=0.3, + n_clusters=2, + random_state=42, + ) + + # Calibrate + train_indices = [0, 1, 2] + cal_indices = [3, 4, 5] + train_dataset = self.dataset.subset(train_indices) + cal_dataset = self.dataset.subset(cal_indices) + + train_embeddings = self._get_embeddings(train_dataset) + cal_embeddings = self._get_embeddings(cal_dataset) + + cluster_model.calibrate( + cal_dataset=cal_dataset, + train_embeddings=train_embeddings, + cal_embeddings=cal_embeddings, + ) + + # Test on all samples + test_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + + with torch.no_grad(): + for data_batch in test_loader: + output = cluster_model(**data_batch) + # Each example should have at least one class in prediction set + set_sizes = output["y_predset"].sum(dim=1) + self.assertTrue( + torch.all(set_sizes > 0), "Some prediction sets are empty" + ) + + def test_calibrate_requires_train_embeddings(self): + """Test that calibrate requires train_embeddings.""" + cluster_model = ClusterLabel( + model=self.model, + alpha=0.2, + n_clusters=3, + ) + + cal_indices = [3, 4, 5] + cal_dataset = self.dataset.subset(cal_indices) + cal_embeddings = self._get_embeddings(cal_dataset) + + with self.assertRaises(ValueError): + cluster_model.calibrate( + cal_dataset=cal_dataset, + train_embeddings=None, + cal_embeddings=cal_embeddings, + ) + + def test_forward_before_calibration_raises_error(self): + """Test that forward pass raises error before calibration.""" + cluster_model = ClusterLabel( + model=self.model, + alpha=0.2, + n_clusters=3, + ) + + test_loader = get_dataloader(self.dataset, batch_size=1, shuffle=False) + data_batch = next(iter(test_loader)) + + with self.assertRaises(RuntimeError): + with torch.no_grad(): + cluster_model(**data_batch) + + def test_different_cluster_counts(self): + """Test that different cluster counts work.""" + for n_clusters in [2, 3, 4]: + cluster_model = ClusterLabel( + model=self.model, + alpha=0.2, + n_clusters=n_clusters, + random_state=42, + ) + + train_indices = [0, 1, 2] + cal_indices = [3, 4, 5] + train_dataset = self.dataset.subset(train_indices) + cal_dataset = self.dataset.subset(cal_indices) + + train_embeddings = self._get_embeddings(train_dataset) + cal_embeddings = self._get_embeddings(cal_dataset) + + cluster_model.calibrate( + cal_dataset=cal_dataset, + train_embeddings=train_embeddings, + cal_embeddings=cal_embeddings, + ) + + self.assertEqual(cluster_model.kmeans_model.n_clusters, n_clusters) + self.assertEqual(len(cluster_model.cluster_thresholds), n_clusters) + + def test_model_device_handling(self): + """Test that the calibrator handles device correctly.""" + device = self.model.device + + cluster_model = ClusterLabel( + model=self.model, + alpha=0.2, + n_clusters=3, + random_state=42, + ) + + train_indices = [0, 1, 2] + cal_indices = [3, 4, 5] + train_dataset = self.dataset.subset(train_indices) + cal_dataset = self.dataset.subset(cal_indices) + + train_embeddings = self._get_embeddings(train_dataset) + cal_embeddings = self._get_embeddings(cal_dataset) + + cluster_model.calibrate( + cal_dataset=cal_dataset, + train_embeddings=train_embeddings, + cal_embeddings=cal_embeddings, + ) + + # Check that device is set correctly + self.assertEqual(cluster_model.device.type, device.type) + + # Test forward pass and check output device + test_loader = get_dataloader(self.dataset, batch_size=1, shuffle=False) + data_batch = next(iter(test_loader)) + + with torch.no_grad(): + output = cluster_model(**data_batch) + self.assertEqual(output["y_predset"].device.type, device.type) + + +if __name__ == "__main__": + unittest.main()