Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/api/tasks/pyhealth.tasks.readmission_prediction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_mimic4_fn
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn2
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_omop_fn

.. autoclass:: pyhealth.tasks.readmission_prediction.ReadmissionPredictionOMOP
:members:
:undoc-members:
:show-inheritance:
44 changes: 44 additions & 0 deletions examples/readmission/readmission_mimic3_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import tempfile

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import RNN
from pyhealth.tasks import ReadmissionPredictionMIMIC3
from pyhealth.trainer import Trainer

# Since PyHealth uses multiprocessing, it is best practice to use a main guard.
if __name__ == '__main__':
# Use tempfile to automate cleanup
cache_dir = tempfile.TemporaryDirectory()

base_dataset = MIMIC3Dataset(
root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
cache_dir=cache_dir.name
)
base_dataset.stats()

sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3(exclude_minors=False)) # Must include minors to get any readmission samples on the synthetic dataset

train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

model = RNN(
dataset=sample_dataset,
)

trainer = Trainer(model=model)
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=1,
monitor="roc_auc",
)

trainer.evaluate(test_dataloader)

sample_dataset.close()
43 changes: 43 additions & 0 deletions examples/readmission/readmission_omop_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import tempfile

from pyhealth.datasets import OMOPDataset, get_dataloader, split_by_patient
from pyhealth.models import RNN
from pyhealth.tasks import ReadmissionPredictionOMOP
from pyhealth.trainer import Trainer

# Since PyHealth uses multiprocessing, it is best practice to use a main guard.
if __name__ == '__main__':
# Use tempfile to automate cleanup
cache_dir = tempfile.TemporaryDirectory()

base_dataset = OMOPDataset(
root="https://physionet.org/files/mimic-iv-demo-omop/0.9/1_omop_data_csv",
tables=["person", "visit_occurrence", "condition_occurrence", "procedure_occurrence", "drug_exposure"],
cache_dir=cache_dir.name
)
base_dataset.stats()

sample_dataset = base_dataset.set_task(ReadmissionPredictionOMOP())

train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

model = RNN(
dataset=sample_dataset,
)

trainer = Trainer(model=model)
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=1,
monitor="roc_auc",
)

trainer.evaluate(test_dataloader)

sample_dataset.close()
39 changes: 0 additions & 39 deletions examples/readmission_mimic3_rnn.py

This file was deleted.

2 changes: 1 addition & 1 deletion pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
readmission_prediction_eicu_fn,
readmission_prediction_eicu_fn2,
readmission_prediction_mimic4_fn,
readmission_prediction_omop_fn,
ReadmissionPredictionOMOP,
)
from .sleep_staging import (
sleep_staging_isruc_fn,
Expand Down
160 changes: 94 additions & 66 deletions pyhealth/tasks/readmission_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pyhealth.data import Event, Patient
from pyhealth.tasks import BaseTask


class ReadmissionPredictionMIMIC3(BaseTask):
"""
Readmission prediction on the MIMIC3 dataset.
Expand Down Expand Up @@ -311,67 +312,106 @@ def readmission_prediction_eicu_fn2(patient: Patient, time_window=5):
return samples


def readmission_prediction_omop_fn(patient: Patient, time_window=15):
"""Processes a single patient for the readmission prediction task.
class ReadmissionPredictionOMOP(BaseTask):
"""
Readmission prediction on the OMOP dataset.

Readmission prediction aims at predicting whether the patient will be readmitted
into hospital within time_window days based on the clinical information from
current visit (e.g., conditions and procedures).
This task aims at predicting whether the patient will be readmitted into hospital within
a specified number of days based on clinical information from the current visit.

Args:
patient: a Patient object
time_window: the time window threshold (gap < time_window means label=1 for
the task)
Attributes:
task_name (str): The name of the task.
input_schema (Dict[str, str]): The schema for the task input.
output_schema (Dict[str, str]): The schema for the task output.
"""
task_name: str = "ReadmissionPredictionOMOP"
input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"}
output_schema: Dict[str, str] = {"readmission": "binary"}

Returns:
samples: a list of samples, each sample is a dict with patient_id, visit_id,
and other task-specific attributes as key
def __init__(self, window: timedelta=timedelta(days=15), exclude_minors: bool=True) -> None:
"""
Initializes the task object.

Note that we define the task as a binary classification task.
Args:
window (timedelta): If two admissions are closer than this window, it is considered a readmission. Defaults to 15 days.
exclude_minors (bool): Whether to exclude visits where the patient was under 18 years old. Defaults to True.
"""
self.window = window
self.exclude_minors = exclude_minors

Examples:
>>> from pyhealth.datasets import OMOPDataset
>>> omop_base = OMOPDataset(
... root="https://storage.googleapis.com/pyhealth/synpuf1k_omop_cdm_5.2.2",
... tables=["condition_occurrence", "procedure_occurrence"],
... code_mapping={},
... )
>>> from pyhealth.tasks import readmission_prediction_omop_fn
>>> omop_sample = omop_base.set_task(readmission_prediction_eicu_fn)
>>> omop_sample.samples[0]
[{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 1}]
"""
samples = []
# we will drop the last visit
for i in range(len(patient) - 1):
visit: Visit = patient[i]
next_visit: Visit = patient[i + 1]
time_diff = (next_visit.encounter_time - visit.encounter_time).days
readmission_label = 1 if time_diff < time_window else 0
def __call__(self, patient: Patient) -> List[Dict]:
"""
Generates binary classification data samples for a single patient.

conditions = visit.get_code_list(table="condition_occurrence")
procedures = visit.get_code_list(table="procedure_occurrence")
drugs = visit.get_code_list(table="drug_exposure")
# labs = get_code_from_list_of_event(
# visit.get_event_list(table="measurement")
# )
Visits with no conditions OR no procedures OR no drugs are excluded from the output but are still used to calculate readmission for prior visits.

# exclude: visits without condition, procedure, or drug code
if len(conditions) * len(procedures) * len(drugs) == 0:
continue
# TODO: should also exclude visit with age < 18
samples.append(
{
"visit_id": visit.visit_id,
"patient_id": patient.patient_id,
"conditions": [conditions],
"procedures": [procedures],
"drugs": [drugs],
"label": readmission_label,
}
)
# no cohort selection
return samples
Args:
patient (Patient): A patient object.

Returns:
List[Dict]: A list containing a dictionary for each patient visit with:
- 'visit_id': OMOP visit_occurrence_id.
- 'patient_id': OMOP person_id.
- 'conditions': OMOP condition_occurrence table condition_concept_id attribute.
- 'procedures': OMOP procedure_occurrence table procedure_concept_id attribute.
- 'drugs': OMOP drug_exposure table drug_concept_id attribute.
- 'readmission': binary label.
"""
patients: List[Event] = patient.get_events(event_type="person")
assert len(patients) == 1

if self.exclude_minors:
year = int(patients[0].year_of_birth)
month = int(patients[0].month_of_birth) if patients[0].month_of_birth else 1
day = int(patients[0].day_of_birth) if patients[0].day_of_birth else 1

dob = datetime.strptime(f"{year:04d}-{month:02d}-{day:02d}", "%Y-%m-%d")

admissions: List[Event] = patient.get_events(event_type="visit_occurrence")
if len(admissions) < 2:
return []

samples = []
for i in range(len(admissions) - 1): # Skip the last admission since we need a "next" admission
if self.exclude_minors:
age = admissions[i].timestamp.year - dob.year
age = age-1 if ((admissions[i].timestamp.month, admissions[i].timestamp.day) < (dob.month, dob.day)) else age
if age < 18:
continue

filter = ("visit_occurrence_id", "==", admissions[i].visit_occurrence_id)

conditions = patient.get_events(event_type="condition_occurrence", filters=[filter])
conditions = [event.condition_concept_id for event in conditions]
if len(conditions) == 0:
continue

procedures = patient.get_events(event_type="procedure_occurrence", filters=[filter])
procedures = [event.procedure_concept_id for event in procedures]
if len(procedures) == 0:
continue

drugs = patient.get_events(event_type="drug_exposure", filters=[filter])
drugs = [event.drug_concept_id for event in drugs]
if len(drugs) == 0:
continue

discharge_time = datetime.strptime(admissions[i].visit_end_datetime, "%Y-%m-%d %H:%M:%S")

readmission = int((admissions[i + 1].timestamp - discharge_time) < self.window)

samples.append(
{
"visit_id": admissions[i].visit_occurrence_id,
"patient_id": patient.patient_id,
"conditions": conditions,
"procedures": procedures,
"drugs": drugs,
"readmission": readmission,
}
)

return samples


if __name__ == "__main__":
Expand Down Expand Up @@ -409,15 +449,3 @@ def readmission_prediction_omop_fn(patient: Patient, time_window=15):
sample_dataset = base_dataset.set_task(task_fn=readmission_prediction_eicu_fn2)
sample_dataset.stat()
print(sample_dataset.available_keys)

from pyhealth.datasets import OMOPDataset

base_dataset = OMOPDataset(
root="/srv/local/data/zw12/pyhealth/raw_data/synpuf1k_omop_cdm_5.2.2",
tables=["condition_occurrence", "procedure_occurrence", "drug_exposure"],
dev=True,
refresh_cache=False,
)
sample_dataset = base_dataset.set_task(task_fn=readmission_prediction_omop_fn)
sample_dataset.stat()
print(sample_dataset.available_keys)
5 changes: 5 additions & 0 deletions test-resources/omop/condition_occurrence.csv
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@ condition_occurrence_id,person_id,condition_concept_id,condition_start_date,cond
7253832832107915505,-775517641933593374,4337543,2196-02-24,2196-02-24 12:15:00,2196-03-04,2196-03-04 14:02:00,32821,,,-6400688276878690493,,570 ,44825528,,
287698808176663180,-775517641933593374,439777,2196-02-24,2196-02-24 12:15:00,2196-03-04,2196-03-04 14:02:00,32821,,,-6400688276878690493,,2859 ,44825282,,
-2692851526029372798,-775517641933593374,444044,2196-02-24,2196-02-24 12:15:00,2196-03-04,2196-03-04 14:02:00,32821,,,-6400688276878690493,,5845 ,44833659,,
1,1,1,2017-12-31,2017-12-31 23:59:59,2018-01-01,2018-01-01 12:00:00,,,,1,,,,,
2,1,1,2018-01-15,2018-01-15 12:00:00,2018-01-16,2018-01-16 12:00:00,,,,2,,,,,
3,1,1,2018-01-21,2018-01-21 12:00:00,2018-01-22,2018-01-22 12:00:00,,,,3,,,,,
4,1,1,2018-01-23,2018-01-23 12:00:00,2018-01-24,2018-01-24 12:00:00,,,,4,,,,,
5,2,1,2018-01-23,2018-01-23 12:00:00,2018-01-24,2018-01-24 12:00:00,,,,5,,,,,
5 changes: 5 additions & 0 deletions test-resources/omop/drug_exposure.csv
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@ drug_exposure_id,person_id,drug_concept_id,drug_exposure_start_date,drug_exposur
-6172173649014275634,-4873075614181207858,19133574,2147-12-20,2147-12-20 08:00:00,2147-12-21,2147-12-21 20:00:00,,32838,,,1.0,,,4132161,,,1621955522922802717,,68682036790,45860229,PO,CAP
7591752385060720828,-4873075614181207858,35605482,2147-12-19,2147-12-19 16:00:00,2147-12-21,2147-12-21 20:00:00,,32838,,,1.0,,,4171047,,,1621955522922802717,,70860077602,44416283,IV,VIAL
2810626833483061033,-4873075614181207858,40223140,2147-12-19,2147-12-19 16:00:00,2147-12-21,2147-12-21 20:00:00,,32838,,,0.5,,,4171047,,,1621955522922802717,,76045000410,44854903,IV,SYR
1,1,1,2017-12-31,2017-12-31 23:59:59,2018-01-01,2018-01-01 12:00:00,,,,,,,,,,,1,,,,,
2,1,1,2017-12-31,2017-12-31 23:59:59,2018-01-01,2018-01-01 12:00:00,,,,,,,,,,,2,,,,,
3,1,1,2017-12-31,2017-12-31 23:59:59,2018-01-01,2018-01-01 12:00:00,,,,,,,,,,,3,,,,,
4,1,1,2017-12-31,2017-12-31 23:59:59,2018-01-01,2018-01-01 12:00:00,,,,,,,,,,,4,,,,,
5,2,1,2017-12-31,2017-12-31 23:59:59,2018-01-01,2018-01-01 12:00:00,,,,,,,,,,,5,,,,,
2 changes: 2 additions & 0 deletions test-resources/omop/person.csv
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ person_id,gender_concept_id,year_of_birth,month_of_birth,day_of_birth,birth_date
8692405834444096922,8507,2119,,,,8527,0,,,,10007058,M,0,WHITE,2000001404,,0
-4873075614181207858,8507,2075,,,,8527,0,,,,10004457,M,0,WHITE,2000001404,,0
-5829006308524050971,8507,2086,,,,8527,0,,,,10038999,M,0,WHITE,2000001404,,0
1,,2000,1,1,,,,,,,,,,,,,
2,,2000,1,1,,,,,,,,,,,,,
5 changes: 5 additions & 0 deletions test-resources/omop/procedure_occurrence.csv
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@ procedure_occurrence_id,person_id,procedure_concept_id,procedure_date,procedure_
-7536943892209702328,-5829006308524050971,4021323,2131-05-28,2131-05-28 08:00:00,32817,0,,,-7351550472011464089,,223795,2000030019,
-7938165663568576869,-8970844422700220177,4021323,2148-09-12,2148-09-12 12:00:00,32817,0,,,2213171355741333981,,223795,2000030019,
4925796009207523980,4668337230155062633,4021323,2116-12-04,2116-12-04 12:00:00,32817,0,,,2572864492537913938,,223795,2000030019,
1,1,1,2017-12-31,2017-12-31 23:59:59,,,,,1,,,,
2,1,1,2018-01-15,2018-01-15 12:00:00,,,,,2,,,,
3,1,1,2018-01-21,2018-01-21 12:00:00,,,,,3,,,,
4,1,1,2018-01-23,2018-01-23 12:00:00,,,,,4,,,,
5,2,1,2018-01-23,2018-01-23 12:00:00,,,,,5,,,,
5 changes: 5 additions & 0 deletions test-resources/omop/visit_occurrence.csv
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,8 @@ visit_occurrence_id,person_id,visit_concept_id,visit_start_date,visit_start_date
1621955522922802717,-4873075614181207858,8883,2147-12-19,2147-12-19 00:00:00,2147-12-21,2147-12-21 16:10:00,32817,,,10004457|28108313,2000001808,38004207,PHYSICIAN REFERRAL,8863,SKILLED NURSING FACILITY,
-6400688276878690493,-775517641933593374,262,2196-02-24,2196-02-24 12:15:00,2196-03-04,2196-03-04 14:02:00,32817,,,10004235|24181354,2000001809,8717,TRANSFER FROM HOSPITAL,8863,SKILLED NURSING FACILITY,
-3292958454000269853,2631971469928551627,262,2119-10-26,2119-10-26 06:00:00,2119-11-06,2119-11-06 12:30:00,32817,,,10026354|24547356,2000001809,8717,TRANSFER FROM HOSPITAL,8863,SKILLED NURSING FACILITY,
1,1,,2017-12-31,2017-12-31 23:59:59,2018-01-01,2018-01-01 12:00:00,,,,,,,,,,
2,1,,2018-01-15,2018-01-15 12:00:00,2018-01-16,2018-01-16 12:00:00,,,,,,,,,,
3,1,,2018-01-21,2018-01-21 12:00:00,2018-01-22,2018-01-22 12:00:00,,,,,,,,,,
4,1,,2018-01-23,2018-01-23 12:00:00,2018-01-24,2018-01-24 12:00:00,,,,,,,,,,
5,2,,2018-01-23,2018-01-23 12:00:00,2018-01-24,2018-01-24 12:00:00,,,,,,,,,,
4 changes: 0 additions & 4 deletions tests/core/test_omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ def setUp(self):
Path(__file__).parent.parent.parent / "test-resources" / "omop"
)

# Check if test data exists
if not self.test_data_path.exists():
self.skipTest("OMOP test data not found in test-resources/omop/")

# Load dataset with all available tables
self.tables = [
"person",
Expand Down
Loading