From f2a35e5756f8de336d829626b2d7fa80ea5e8410 Mon Sep 17 00:00:00 2001 From: John Wu Date: Mon, 26 Jan 2026 16:24:35 -0600 Subject: [PATCH 1/2] Add multimodal RNN support --- docs/api/models/pyhealth.models.RNN.rst | 5 + .../mortality_mimic4_multimodal_rnn.py | 176 +++++++++++ pyhealth/models/__init__.py | 2 +- pyhealth/models/rnn.py | 232 ++++++++++++++- tests/core/test_multimodal_rnn.py | 278 ++++++++++++++++++ 5 files changed, 690 insertions(+), 3 deletions(-) create mode 100644 examples/mortality_prediction/mortality_mimic4_multimodal_rnn.py create mode 100644 tests/core/test_multimodal_rnn.py diff --git a/docs/api/models/pyhealth.models.RNN.rst b/docs/api/models/pyhealth.models.RNN.rst index d90f2d48a..3f264c176 100644 --- a/docs/api/models/pyhealth.models.RNN.rst +++ b/docs/api/models/pyhealth.models.RNN.rst @@ -10,6 +10,11 @@ The separate callable RNNLayer and the complete RNN model. :show-inheritance: .. autoclass:: pyhealth.models.RNN + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.MultimodalRNN :members: :undoc-members: :show-inheritance: \ No newline at end of file diff --git a/examples/mortality_prediction/mortality_mimic4_multimodal_rnn.py b/examples/mortality_prediction/mortality_mimic4_multimodal_rnn.py new file mode 100644 index 000000000..1360ddb46 --- /dev/null +++ b/examples/mortality_prediction/mortality_mimic4_multimodal_rnn.py @@ -0,0 +1,176 @@ +""" +Mortality Prediction on MIMIC-IV with MultimodalRNN + +This example demonstrates how to use the MultimodalRNN model with mixed +input modalities for in-hospital mortality prediction on MIMIC-IV. + +The MultimodalRNN model can handle: +- Sequential features (diagnoses, procedures, lab timeseries) → RNN processing +- Non-sequential features (demographics, static measurements) → Direct embedding + +This example shows: +1. Loading MIMIC-IV data with mixed feature types +2. Applying a mortality prediction task +3. Training a MultimodalRNN model with both sequential and non-sequential inputs +4. Evaluating the model performance +""" + +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.datasets import split_by_patient, get_dataloader +from pyhealth.models import MultimodalRNN +from pyhealth.tasks import InHospitalMortalityMIMIC4 +from pyhealth.trainer import Trainer + + +if __name__ == "__main__": + # STEP 1: Load MIMIC-IV base dataset + print("=" * 60) + print("STEP 1: Loading MIMIC-IV Dataset") + print("=" * 60) + + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=["diagnoses_icd", "procedures_icd", "labevents"], + dev=True, # Use development mode for faster testing + num_workers=4, + ) + base_dataset.stats() + + # STEP 2: Apply mortality prediction task with multimodal features + print("\n" + "=" * 60) + print("STEP 2: Setting Mortality Prediction Task") + print("=" * 60) + + # Use the InHospitalMortalityMIMIC4 task + # This task will create sequential features from diagnoses, procedures, and labs + task = InHospitalMortalityMIMIC4() + sample_dataset = base_dataset.set_task( + task, + num_workers=4, + ) + + print(f"\nTotal samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # Inspect a sample + if len(sample_dataset) > 0: + sample = sample_dataset[0] + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + for key in sample_dataset.input_schema.keys(): + if key in sample: + if isinstance(sample[key], (list, tuple)): + print(f" {key}: length {len(sample[key])}") + else: + print(f" {key}: {type(sample[key])}") + print(f" Mortality: {sample.get('mortality', 'N/A')}") + + # STEP 3: Split dataset + print("\n" + "=" * 60) + print("STEP 3: Splitting Dataset") + print("=" * 60) + + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + print(f"Train samples: {len(train_dataset)}") + print(f"Val samples: {len(val_dataset)}") + print(f"Test samples: {len(test_dataset)}") + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False) + + # STEP 4: Initialize MultimodalRNN model + print("\n" + "=" * 60) + print("STEP 4: Initializing MultimodalRNN Model") + print("=" * 60) + + model = MultimodalRNN( + dataset=sample_dataset, + embedding_dim=128, + hidden_dim=128, + rnn_type="GRU", + num_layers=2, + dropout=0.3, + bidirectional=False, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"Model initialized with {num_params:,} parameters") + + # Print feature classification + print(f"\nSequential features (RNN processing): {model.sequential_features}") + print(f"Non-sequential features (direct embedding): {model.non_sequential_features}") + + # Calculate expected embedding dimensions + seq_dim = len(model.sequential_features) * model.hidden_dim + non_seq_dim = len(model.non_sequential_features) * model.embedding_dim + total_dim = seq_dim + non_seq_dim + print(f"\nPatient representation dimension:") + print(f" Sequential contribution: {seq_dim}") + print(f" Non-sequential contribution: {non_seq_dim}") + print(f" Total: {total_dim}") + + # STEP 5: Train the model + print("\n" + "=" * 60) + print("STEP 5: Training Model") + print("=" * 60) + + trainer = Trainer( + model=model, + device="cuda:0", # Change to "cpu" if no GPU available + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=10, + monitor="roc_auc", + optimizer_params={"lr": 1e-3}, + ) + + # STEP 6: Evaluate on test set + print("\n" + "=" * 60) + print("STEP 6: Evaluating on Test Set") + print("=" * 60) + + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 7: Demonstrate model predictions + print("\n" + "=" * 60) + print("STEP 7: Sample Predictions") + print("=" * 60) + + import torch + + sample_batch = next(iter(test_loader)) + with torch.no_grad(): + output = model(**sample_batch) + + print(f"\nBatch size: {output['y_prob'].shape[0]}") + print(f"First 10 predicted probabilities:") + for i, (prob, true_label) in enumerate( + zip(output['y_prob'][:10], output['y_true'][:10]) + ): + print(f" Sample {i+1}: prob={prob.item():.4f}, true={int(true_label.item())}") + + # Summary + print("\n" + "=" * 60) + print("SUMMARY: MultimodalRNN Training Complete") + print("=" * 60) + print(f"Model: MultimodalRNN") + print(f"Dataset: MIMIC-IV") + print(f"Task: In-Hospital Mortality Prediction") + print(f"Sequential features: {len(model.sequential_features)}") + print(f"Non-sequential features: {len(model.non_sequential_features)}") + print(f"Best validation ROC-AUC: {max(results.get('roc_auc', 0), 0):.4f}") + print("=" * 60) + diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 3c0b5384d..821af48f8 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -17,7 +17,7 @@ from .mlp import MLP from .molerec import MoleRec, MoleRecLayer from .retain import RETAIN, RETAINLayer -from .rnn import RNN, RNNLayer +from .rnn import MultimodalRNN, RNN, RNNLayer from .safedrug import SafeDrug, SafeDrugLayer from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer diff --git a/pyhealth/models/rnn.py b/pyhealth/models/rnn.py index e89d58bd1..160abee8c 100644 --- a/pyhealth/models/rnn.py +++ b/pyhealth/models/rnn.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.nn as nn @@ -6,6 +6,18 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel +from pyhealth.processors import ( + DeepNestedFloatsProcessor, + DeepNestedSequenceProcessor, + MultiHotProcessor, + NestedFloatsProcessor, + NestedSequenceProcessor, + SequenceProcessor, + StageNetProcessor, + StageNetTensorProcessor, + TensorProcessor, + TimeseriesProcessor, +) from .embedding import EmbeddingModel @@ -235,7 +247,9 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: embedded = self.embedding_model(kwargs) for feature_key in self.feature_keys: x = embedded[feature_key] - mask = (x.sum(dim=-1) != 0).int() + # Use abs() before sum to catch edge cases where embeddings sum to 0 + # despite being valid values (e.g., [1.0, -1.0]) + mask = (torch.abs(x).sum(dim=-1) != 0).int() _, x = self.rnn[feature_key](x, mask) patient_emb.append(x) @@ -250,3 +264,217 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if kwargs.get("embed", False): results["embed"] = patient_emb return results + + +class MultimodalRNN(BaseModel): + """Multimodal RNN model that handles both sequential and non-sequential features. + + This model extends the vanilla RNN to support mixed input modalities: + - Sequential features (sequences, timeseries) go through RNN layers + - Non-sequential features (multi-hot, tensor) bypass RNN and use embeddings directly + + The model automatically classifies input features based on their processor types: + - Sequential processors (apply RNN): SequenceProcessor, NestedSequenceProcessor, + DeepNestedSequenceProcessor, NestedFloatsProcessor, DeepNestedFloatsProcessor, + TimeseriesProcessor + - Non-sequential processors (embeddings only): MultiHotProcessor, TensorProcessor, + StageNetProcessor, StageNetTensorProcessor + + For sequential features, the model: + 1. Embeds the input using EmbeddingModel + 2. Applies RNNLayer to get sequential representations + 3. Extracts the last hidden state + + For non-sequential features, the model: + 1. Embeds the input using EmbeddingModel + 2. Applies mean pooling if needed to reduce to 2D + 3. Uses the embedding directly + + All feature representations are concatenated and passed through a final + fully connected layer for predictions. + + Args: + dataset (SampleDataset): the dataset to train the model. It is used to query + certain information such as the set of all tokens and processor types. + embedding_dim (int): the embedding dimension. Default is 128. + hidden_dim (int): the hidden dimension for RNN layers. Default is 128. + **kwargs: other parameters for the RNN layer (e.g., rnn_type, num_layers, + dropout, bidirectional). + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "conditions": ["cond-33", "cond-86"], # sequential + ... "demographics": ["asian", "male"], # multi-hot + ... "vitals": [120.0, 80.0, 98.6], # tensor + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-1", + ... "conditions": ["cond-12", "cond-52"], # sequential + ... "demographics": ["white", "female"], # multi-hot + ... "vitals": [110.0, 75.0, 98.2], # tensor + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={ + ... "conditions": "sequence", + ... "demographics": "multi_hot", + ... "vitals": "tensor", + ... }, + ... output_schema={"label": "binary"}, + ... dataset_name="test" + ... ) + >>> + >>> from pyhealth.datasets import get_dataloader + >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + >>> + >>> model = MultimodalRNN(dataset=dataset, embedding_dim=128, hidden_dim=64) + >>> + >>> data_batch = next(iter(train_loader)) + >>> + >>> ret = model(**data_batch) + >>> print(ret) + { + 'loss': tensor(...), + 'y_prob': tensor(...), + 'y_true': tensor(...), + 'logit': tensor(...) + } + """ + + def __init__( + self, + dataset: SampleDataset, + embedding_dim: int = 128, + hidden_dim: int = 128, + **kwargs + ): + super(MultimodalRNN, self).__init__(dataset=dataset) + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + + # validate kwargs for RNN layer + if "input_size" in kwargs: + raise ValueError("input_size is determined by embedding_dim") + if "hidden_size" in kwargs: + raise ValueError("hidden_size is determined by hidden_dim") + + assert len(self.label_keys) == 1, "Only one label key is supported" + self.label_key = self.label_keys[0] + self.mode = self.dataset.output_schema[self.label_key] + + self.embedding_model = EmbeddingModel(dataset, embedding_dim) + + # Classify features as sequential or non-sequential + self.sequential_features = [] + self.non_sequential_features = [] + + self.rnn = nn.ModuleDict() + for feature_key in self.feature_keys: + processor = dataset.input_processors[feature_key] + if self._is_sequential_processor(processor): + self.sequential_features.append(feature_key) + # Create RNN for this feature + self.rnn[feature_key] = RNNLayer( + input_size=embedding_dim, + hidden_size=hidden_dim, + **kwargs + ) + else: + self.non_sequential_features.append(feature_key) + + # Calculate final concatenated dimension + final_dim = (len(self.sequential_features) * hidden_dim + + len(self.non_sequential_features) * embedding_dim) + output_size = self.get_output_size() + self.fc = nn.Linear(final_dim, output_size) + + def _is_sequential_processor(self, processor) -> bool: + """Check if processor represents sequential data. + + Sequential processors are those that benefit from RNN processing, + including sequences of codes and timeseries data. + + Note: + StageNetProcessor and StageNetTensorProcessor are excluded as they + are specialized for the StageNet model architecture and should be + treated as non-sequential for standard RNN processing. + + Args: + processor: The processor instance to check. + + Returns: + bool: True if processor is sequential, False otherwise. + """ + return isinstance(processor, ( + SequenceProcessor, + NestedSequenceProcessor, + DeepNestedSequenceProcessor, + NestedFloatsProcessor, + DeepNestedFloatsProcessor, + TimeseriesProcessor, + )) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation handling mixed modalities. + + The label `kwargs[self.label_key]` is a list of labels for each patient. + + Args: + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. + + Returns: + Dict[str, torch.Tensor]: A dictionary with the following keys: + - loss: a scalar tensor representing the loss. + - y_prob: a tensor representing the predicted probabilities. + - y_true: a tensor representing the true labels. + - logit: a tensor representing the logits. + - embed (optional): a tensor representing the patient embeddings if requested. + """ + patient_emb = [] + embedded = self.embedding_model(kwargs) + + # Process sequential features through RNN + for feature_key in self.sequential_features: + x = embedded[feature_key] + # Use abs() before sum to catch edge cases where embeddings sum to 0 + # despite being valid values (e.g., [1.0, -1.0]) + mask = (torch.abs(x).sum(dim=-1) != 0).int() + _, last_hidden = self.rnn[feature_key](x, mask) + patient_emb.append(last_hidden) + + # Process non-sequential features (use embeddings directly) + for feature_key in self.non_sequential_features: + x = embedded[feature_key] + # If multi-dimensional, aggregate (mean pooling) + while x.dim() > 2: + x = x.mean(dim=1) + patient_emb.append(x) + + # Concatenate all representations + patient_emb = torch.cat(patient_emb, dim=1) + # (patient, label_size) + logits = self.fc(patient_emb) + + # Calculate loss and predictions + y_true = kwargs[self.label_key].to(self.device) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + results = { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits + } + if kwargs.get("embed", False): + results["embed"] = patient_emb + return results diff --git a/tests/core/test_multimodal_rnn.py b/tests/core/test_multimodal_rnn.py new file mode 100644 index 000000000..d2fe5fa0b --- /dev/null +++ b/tests/core/test_multimodal_rnn.py @@ -0,0 +1,278 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MultimodalRNN + + +class TestMultimodalRNN(unittest.TestCase): + """Test cases for the MultimodalRNN model.""" + + def setUp(self): + """Set up test data and model with mixed feature types.""" + # Samples with mixed sequential and non-sequential features + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80"], # sequential + "procedures": ["proc-12", "proc-45"], # sequential + "demographics": ["asian", "male", "smoker"], # multi-hot + "vitals": [120.0, 80.0, 98.6, 16.0], # tensor + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-12", "cond-52"], # sequential + "procedures": ["proc-23"], # sequential + "demographics": ["white", "female"], # multi-hot + "vitals": [110.0, 75.0, 98.2, 18.0], # tensor + "label": 0, + }, + ] + + # Define input and output schemas with mixed types + self.input_schema = { + "conditions": "sequence", # sequential + "procedures": "sequence", # sequential + "demographics": "multi_hot", # non-sequential + "vitals": "tensor", # non-sequential + } + self.output_schema = {"label": "binary"} + + # Create dataset + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + # Create model + self.model = MultimodalRNN(dataset=self.dataset) + + def test_model_initialization(self): + """Test that the MultimodalRNN model initializes correctly.""" + self.assertIsInstance(self.model, MultimodalRNN) + self.assertEqual(self.model.embedding_dim, 128) + self.assertEqual(self.model.hidden_dim, 128) + self.assertEqual(len(self.model.feature_keys), 4) + + # Check that features are correctly classified + self.assertIn("conditions", self.model.sequential_features) + self.assertIn("procedures", self.model.sequential_features) + self.assertIn("demographics", self.model.non_sequential_features) + self.assertIn("vitals", self.model.non_sequential_features) + + # Check that RNN layers are only created for sequential features + self.assertIn("conditions", self.model.rnn) + self.assertIn("procedures", self.model.rnn) + self.assertNotIn("demographics", self.model.rnn) + self.assertNotIn("vitals", self.model.rnn) + + self.assertEqual(self.model.label_key, "label") + + def test_model_forward(self): + """Test that the MultimodalRNN model forward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + with torch.no_grad(): + ret = self.model(**data_batch) + + # Check output structure + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + # Check tensor shapes + self.assertEqual(ret["y_prob"].shape[0], 2) # batch size + self.assertEqual(ret["y_true"].shape[0], 2) # batch size + self.assertEqual(ret["logit"].shape[0], 2) # batch size + + # Check that loss is a scalar + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the MultimodalRNN model backward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + ret = self.model(**data_batch) + + # Backward pass + ret["loss"].backward() + + # Check that at least one parameter has gradients + has_gradient = False + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + has_gradient = True + break + self.assertTrue( + has_gradient, "No parameters have gradients after backward pass" + ) + + def test_model_with_embedding(self): + """Test that the MultimodalRNN model returns embeddings when requested.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + data_batch["embed"] = True + + # Forward pass + with torch.no_grad(): + ret = self.model(**data_batch) + + # Check that embeddings are returned + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) # batch size + + # Check embedding dimension + # 2 sequential features * hidden_dim + 2 non-sequential features * embedding_dim + expected_embed_dim = ( + len(self.model.sequential_features) * self.model.hidden_dim + + len(self.model.non_sequential_features) * self.model.embedding_dim + ) + self.assertEqual(ret["embed"].shape[1], expected_embed_dim) + + def test_custom_hyperparameters(self): + """Test MultimodalRNN model with custom hyperparameters.""" + model = MultimodalRNN( + dataset=self.dataset, + embedding_dim=64, + hidden_dim=32, + rnn_type="LSTM", + num_layers=2, + dropout=0.3, + bidirectional=True, + ) + + self.assertEqual(model.embedding_dim, 64) + self.assertEqual(model.hidden_dim, 32) + + # Test forward pass + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_only_sequential_features(self): + """Test MultimodalRNN with only sequential features (like vanilla RNN).""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86"], + "procedures": ["proc-12", "proc-45"], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-12"], + "procedures": ["proc-23"], + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"conditions": "sequence", "procedures": "sequence"}, + output_schema={"label": "binary"}, + dataset_name="test_seq_only", + ) + + model = MultimodalRNN(dataset=dataset, hidden_dim=64) + + # Check that all features are sequential + self.assertEqual(len(model.sequential_features), 2) + self.assertEqual(len(model.non_sequential_features), 0) + + # Test forward pass + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_only_non_sequential_features(self): + """Test MultimodalRNN with only non-sequential features (like MLP).""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "demographics": ["asian", "male", "smoker"], + "vitals": [120.0, 80.0, 98.6, 16.0], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "demographics": ["white", "female"], + "vitals": [110.0, 75.0, 98.2, 18.0], + "label": 0, + }, + ] + + dataset = create_sample_dataset( + samples=samples, + input_schema={"demographics": "multi_hot", "vitals": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="test_non_seq_only", + ) + + model = MultimodalRNN(dataset=dataset, hidden_dim=64) + + # Check that all features are non-sequential + self.assertEqual(len(model.sequential_features), 0) + self.assertEqual(len(model.non_sequential_features), 2) + + # Test forward pass + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_sequential_processor_classification(self): + """Test that _is_sequential_processor correctly identifies processor types.""" + from pyhealth.processors import ( + MultiHotProcessor, + SequenceProcessor, + TensorProcessor, + TimeseriesProcessor, + ) + + # Test with actual processor instances + seq_proc = SequenceProcessor() + self.assertTrue(self.model._is_sequential_processor(seq_proc)) + + # Create simple multi-hot processor + multihot_proc = MultiHotProcessor() + self.assertFalse(self.model._is_sequential_processor(multihot_proc)) + + # Tensor processor + tensor_proc = TensorProcessor() + self.assertFalse(self.model._is_sequential_processor(tensor_proc)) + + +if __name__ == "__main__": + unittest.main() + From b0b29da4a5ed4575d27a6b575b5dbb68c544f6de Mon Sep 17 00:00:00 2001 From: John Wu Date: Tue, 27 Jan 2026 10:54:06 -0600 Subject: [PATCH 2/2] add @todos for later --- pyhealth/models/rnn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyhealth/models/rnn.py b/pyhealth/models/rnn.py index 160abee8c..3b4306ed7 100644 --- a/pyhealth/models/rnn.py +++ b/pyhealth/models/rnn.py @@ -248,6 +248,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: for feature_key in self.feature_keys: x = embedded[feature_key] # Use abs() before sum to catch edge cases where embeddings sum to 0 + # @TODO bug with 0 embedding sum can still persist if the embedding is all 0s but the mask is not all 0s. # despite being valid values (e.g., [1.0, -1.0]) mask = (torch.abs(x).sum(dim=-1) != 0).int() _, x = self.rnn[feature_key](x, mask) @@ -447,6 +448,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: x = embedded[feature_key] # Use abs() before sum to catch edge cases where embeddings sum to 0 # despite being valid values (e.g., [1.0, -1.0]) + # @TODO bug with 0 embedding sum can still persist if the embedding is all 0s but the mask is not all 0s. mask = (torch.abs(x).sum(dim=-1) != 0).int() _, last_hidden = self.rnn[feature_key](x, mask) patient_emb.append(last_hidden)