-
Notifications
You must be signed in to change notification settings - Fork 557
[Conformal EEG] K-means clustering #795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds K-means clustering-based conformal prediction (ClusterLabel) for multiclass classification in EEG analysis. The method groups patients into clusters using K-means on embeddings and computes cluster-specific calibration thresholds to improve prediction set efficiency compared to global thresholds.
Changes:
- Implemented
ClusterLabelclass that performs K-means clustering on patient embeddings and applies cluster-specific calibration thresholds - Added comprehensive test suite covering initialization, calibration, and prediction
- Included example script demonstrating usage on TUEV EEG dataset with ContraWR model
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| pyhealth/calib/predictionset/cluster/cluster_label.py | Core implementation of cluster-based conformal prediction with K-means clustering |
| pyhealth/calib/predictionset/cluster/init.py | Module initialization exporting ClusterLabel class |
| pyhealth/calib/predictionset/init.py | Updated to export ClusterLabel from cluster submodule |
| tests/core/test_cluster_label.py | Test suite covering initialization, calibration, and forward pass with various configurations |
| examples/conformal_eeg/tuev_kmeans_conformal.py | Example script demonstrating ClusterLabel usage on TUEV EEG events dataset |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| alpha = np.asarray(alpha) | ||
| self.alpha = alpha | ||
|
|
||
| # Store clustering parameters |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing input validation for n_clusters parameter. The parameter should be validated to ensure it's a positive integer (n_clusters > 0). Invalid values like 0 or negative numbers would cause sklearn's KMeans to raise an error during calibration, but it's better to fail fast with a clear error message during initialization.
| # Store clustering parameters | |
| # Store clustering parameters | |
| if not isinstance(n_clusters, int) or n_clusters <= 0: | |
| raise ValueError( | |
| f"n_clusters must be a positive integer, got {n_clusters!r}" | |
| ) |
| cluster_id = self.kmeans_model.predict(sample_embedding)[0] | ||
|
|
||
| # Get cluster-specific threshold | ||
| cluster_threshold = self.cluster_thresholds[cluster_id] | ||
|
|
||
| # Convert to tensor if needed | ||
| if isinstance(cluster_threshold, np.ndarray): | ||
| cluster_threshold = torch.tensor( | ||
| cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype | ||
| ) | ||
| else: | ||
| cluster_threshold = torch.tensor( | ||
| cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype | ||
| ) | ||
|
|
||
| # Construct prediction set using cluster-specific threshold | ||
| pred["y_predset"] = pred["y_prob"] >= cluster_threshold | ||
|
|
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical batch handling bug: The forward method only uses the first sample's cluster assignment for the entire batch. At line 293, cluster_id = self.kmeans_model.predict(sample_embedding)[0] extracts only the first element from the cluster predictions array. When processing a batch with multiple samples, this causes all samples in the batch to use the same threshold (from the first sample's cluster), leading to incorrect prediction sets for all but the first sample.
The method needs to handle batches properly by:
- Predicting cluster assignments for all samples in the batch
- Retrieving the appropriate threshold for each sample based on its cluster
- Applying per-sample thresholds to construct prediction sets
This will require either processing samples individually in a loop or implementing vectorized batch operations with per-sample threshold application.
| cluster_id = self.kmeans_model.predict(sample_embedding)[0] | |
| # Get cluster-specific threshold | |
| cluster_threshold = self.cluster_thresholds[cluster_id] | |
| # Convert to tensor if needed | |
| if isinstance(cluster_threshold, np.ndarray): | |
| cluster_threshold = torch.tensor( | |
| cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype | |
| ) | |
| else: | |
| cluster_threshold = torch.tensor( | |
| cluster_threshold, device=self.device, dtype=pred["y_prob"].dtype | |
| ) | |
| # Construct prediction set using cluster-specific threshold | |
| pred["y_predset"] = pred["y_prob"] >= cluster_threshold | |
| # Predict cluster assignments for all samples in the batch | |
| cluster_ids = self.kmeans_model.predict(sample_embedding) | |
| # Get cluster-specific thresholds for each sample | |
| cluster_thresholds = np.array( | |
| [self.cluster_thresholds[cid] for cid in cluster_ids] | |
| ) | |
| # Convert thresholds to tensor on the correct device/dtype | |
| cluster_thresholds = torch.as_tensor( | |
| cluster_thresholds, device=self.device, dtype=pred["y_prob"].dtype | |
| ) | |
| # Reshape thresholds to broadcast over any non-batch dimensions of y_prob | |
| if pred["y_prob"].ndim > 1: | |
| # (batch_size,) -> (batch_size, 1, ..., 1) | |
| view_shape = (cluster_thresholds.shape[0],) + (1,) * (pred["y_prob"].ndim - 1) | |
| cluster_thresholds = cluster_thresholds.view(view_shape) | |
| # Construct prediction set using per-sample, cluster-specific thresholds | |
| pred["y_predset"] = pred["y_prob"] >= cluster_thresholds |
| # Get base model prediction | ||
| pred = self.model(**kwargs) | ||
|
|
||
| # Extract embedding for this sample to assign to cluster | ||
| embed_kwargs = {**kwargs, "embed": True} | ||
| embed_output = self.model(**embed_kwargs) | ||
| if "embed" not in embed_output: | ||
| raise ValueError( | ||
| f"Model {type(self.model).__name__} does not return " | ||
| "embeddings. Make sure the model supports the " | ||
| "embed=True flag in its forward() method." | ||
| ) |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performance issue: The forward method calls the model twice for each prediction - once at line 276 to get probabilities and once at line 280 to get embeddings. This doubles the computational cost and memory usage during inference.
The model should be called only once with embed=True to get both predictions and embeddings in a single forward pass. The returned dictionary should contain both 'y_prob' (or similar prediction keys) and 'embed'. This would require verifying that models support returning both outputs simultaneously when embed=True is set.
| print("Extracting embeddings from calibration set...") | ||
| cal_embeddings = extract_embeddings( | ||
| self.model, cal_dataset, batch_size=32, device=self.device | ||
| ) | ||
| else: | ||
| cal_embeddings = np.asarray(cal_embeddings) | ||
|
|
||
| if train_embeddings is None: | ||
| raise ValueError( | ||
| "train_embeddings must be provided. " | ||
| "Extract embeddings from training set using extract_embeddings()." | ||
| ) | ||
| else: | ||
| train_embeddings = np.asarray(train_embeddings) | ||
|
|
||
| # Combine embeddings for clustering | ||
| print(f"Combining embeddings: train={train_embeddings.shape}, cal={cal_embeddings.shape}") | ||
| all_embeddings = np.concatenate([train_embeddings, cal_embeddings], axis=0) | ||
| print(f"Total embeddings for clustering: {all_embeddings.shape}") | ||
|
|
||
| # Fit K-means on combined embeddings | ||
| print(f"Fitting K-means with {self.n_clusters} clusters...") | ||
| self.kmeans_model = KMeans( | ||
| n_clusters=self.n_clusters, | ||
| random_state=self.random_state, | ||
| n_init=10, | ||
| ) | ||
| self.kmeans_model.fit(all_embeddings) | ||
|
|
||
| # Assign calibration samples to clusters | ||
| # Note: cal_embeddings start at index len(train_embeddings) in all_embeddings | ||
| cal_start_idx = len(train_embeddings) | ||
| cal_cluster_labels = self.kmeans_model.labels_[cal_start_idx:] | ||
|
|
||
| print(f"Cluster assignments: {np.bincount(cal_cluster_labels)}") | ||
|
|
||
| # Compute conformity scores (probabilities of true class) | ||
| conformity_scores = y_prob[np.arange(N), y_true] | ||
|
|
||
| # Compute cluster-specific thresholds | ||
| self.cluster_thresholds = {} | ||
| for cluster_id in range(self.n_clusters): | ||
| cluster_mask = cal_cluster_labels == cluster_id | ||
| cluster_scores = conformity_scores[cluster_mask] | ||
|
|
||
| if len(cluster_scores) == 0: | ||
| print( | ||
| f"Warning: No calibration samples in cluster {cluster_id}, " | ||
| "using -inf threshold (include all classes)" | ||
| ) | ||
| if isinstance(self.alpha, float): | ||
| self.cluster_thresholds[cluster_id] = -np.inf | ||
| else: | ||
| self.cluster_thresholds[cluster_id] = np.array( | ||
| [-np.inf] * K | ||
| ) | ||
| else: | ||
| if isinstance(self.alpha, float): | ||
| # Marginal coverage: single threshold per cluster | ||
| t = _query_quantile(cluster_scores, self.alpha) | ||
| self.cluster_thresholds[cluster_id] = t | ||
| else: | ||
| # Class-conditional coverage: one threshold per class per cluster | ||
| if len(self.alpha) != K: | ||
| raise ValueError( | ||
| f"alpha must have length {K} for class-conditional " | ||
| f"coverage, got {len(self.alpha)}" | ||
| ) | ||
| t = [] | ||
| for k in range(K): | ||
| class_mask = (y_true[cluster_mask] == k) | ||
| if np.sum(class_mask) > 0: | ||
| class_scores = cluster_scores[class_mask] | ||
| t_k = _query_quantile(class_scores, self.alpha[k]) | ||
| else: | ||
| # If no calibration examples for this class in this cluster | ||
| print( | ||
| f"Warning: No calibration examples for class {k} " | ||
| f"in cluster {cluster_id}, using -inf threshold" | ||
| ) | ||
| t_k = -np.inf | ||
| t.append(t_k) | ||
| self.cluster_thresholds[cluster_id] = np.array(t) | ||
|
|
||
| if self.debug: | ||
| print(f"Cluster thresholds: {self.cluster_thresholds}") |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The print statements in the calibrate method should use a logging framework or be controlled by the debug flag. Lines 174, 190, 192, 195, 208, 220-222, and 250-252 unconditionally print to stdout, which can clutter output in production usage. Consider using Python's logging module or only printing when self.debug is True, consistent with line 259 which gates debug output.
| if sample_embedding.ndim == 1: | ||
| sample_embedding = sample_embedding.reshape(1, -1) |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition at line 290 checking if embeddings are 1D appears to be attempting to handle single-sample inputs, but this conflicts with the batch-based usage throughout the codebase (as evidenced by the tests using batch_size=2 and the example script using batch_size=32). This suggests unclear design intent - the method should be designed to handle batches consistently. If single-sample support is needed, it should be handled as a batch of size 1, not as a special case with different dimensionality.
| if sample_embedding.ndim == 1: | |
| sample_embedding = sample_embedding.reshape(1, -1) | |
| # Ensure embeddings are always treated as a batch (even for single samples) | |
| sample_embedding = np.atleast_2d(sample_embedding) |
| cal_embeddings = extract_embeddings( | ||
| self.model, cal_dataset, batch_size=32, device=self.device |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batch_size parameter is hardcoded to 32 at line 176. This should use a configurable parameter or match the batch size used elsewhere in the method. Consider adding a batch_size parameter to the calibrate method or using a class attribute to control this, especially since users may want to adjust it based on available memory.
| cal_embeddings = extract_embeddings( | |
| self.model, cal_dataset, batch_size=32, device=self.device | |
| batch_size = getattr(self, "batch_size", 32) | |
| cal_embeddings = extract_embeddings( | |
| self.model, cal_dataset, batch_size=batch_size, device=self.device |
This PR adds the k-means clustering conformal prediction.