Skip to content

Commit 59d1f97

Browse files
authored
Merge pull request #92 from uriahf/65-create-a-clean-marimo-version-of-walkthrough_the_aj_estimateqmd
docs: close #91
2 parents b5acdc9 + 79b033d commit 59d1f97

File tree

1 file changed

+92
-248
lines changed

1 file changed

+92
-248
lines changed

docs/walkthrough_aj_estimate.qmd

Lines changed: 92 additions & 248 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,73 @@ warning: false
77
---
88

99
```{python}
10-
from lifelines import AalenJohansenFitter
10+
import polars as pl
11+
import pandas as pd
12+
import numpy as np
13+
from lifelines import AalenJohansenFitter, CoxPHFitter, WeibullAFTFitter
14+
15+
df_time_to_cancer_dx = pd.read_csv(
16+
"https://raw.githubusercontent.com/ddsjoberg/dca-tutorial/main/data/df_time_to_cancer_dx.csv"
17+
)
18+
```
19+
20+
21+
```{python}
22+
1123
import numpy as np
1224
from itertools import product
1325
import itertools
1426
from rtichoke.helpers.sandbox_observable_helpers import *
15-
from lifelines import CoxPHFitter
16-
from lifelines import WeibullAFTFitter
1727
import polars as pl
1828
print("Polars version:", pl.__version__)
1929
2030
import pandas as pd
2131
import pickle
2232
23-
with open(r'C:\Users\I\Documents\GitHub\rtichoke_python\probs_dict.pkl', 'rb') as file:
24-
probs_dict = pickle.load(file)
33+
cph = CoxPHFitter()
34+
thin_model = CoxPHFitter()
35+
aft_model = WeibullAFTFitter()
2536
26-
with open(r'C:\Users\I\Documents\GitHub\rtichoke_python\reals_dict.pkl', 'rb') as file:
27-
reals_dict = pickle.load(file)
37+
cox_formula = "age + famhistory + marker"
38+
thin_formula = "age + marker"
39+
aft_formula = "age + marker"
2840
29-
with open(r'C:\Users\I\Documents\GitHub\rtichoke_python\times_dict.pkl', 'rb') as file:
30-
times_dict = pickle.load(file)
41+
cph.fit(
42+
df_time_to_cancer_dx,
43+
duration_col="ttcancer",
44+
event_col="cancer",
45+
formula=cox_formula,
46+
)
3147
48+
thin_model.fit(
49+
df_time_to_cancer_dx,
50+
duration_col="ttcancer",
51+
event_col="cancer",
52+
formula=thin_formula,
53+
)
54+
55+
aft_model.fit(
56+
df_time_to_cancer_dx,
57+
duration_col="ttcancer",
58+
event_col="cancer",
59+
formula=aft_formula,
60+
)
61+
62+
63+
64+
cph_pred_vals = (1 - cph.predict_survival_function(df_time_to_cancer_dx[['age', 'famhistory', 'marker']], times=[1.5])).iloc[0, :].values
65+
66+
thin_pred_vals = (1 - thin_model.predict_survival_function(df_time_to_cancer_dx[['age', 'famhistory', 'marker']], times=[1.5])).iloc[0, :].values
67+
68+
aft_pred_vals = (1 - aft_model.predict_survival_function(df_time_to_cancer_dx[['age', 'famhistory', 'marker']], times=[1.5])).iloc[0, :].values
69+
70+
probs_dict = {"full": cph_pred_vals, "thin": thin_pred_vals, "aft": aft_pred_vals}
71+
72+
reals_mapping = {"censor": 0, "diagnosed with cancer": 1, "dead other causes": 2}
73+
74+
reals_dict = df_time_to_cancer_dx["cancer_cr"].map(reals_mapping)
75+
76+
times_dict = df_time_to_cancer_dx["ttcancer"]
3277
3378
```
3479

@@ -39,7 +84,7 @@ with open(r'C:\Users\I\Documents\GitHub\rtichoke_python\times_dict.pkl', 'rb') a
3984
4085
4186
42-
fixed_time_horizons = [1, 3, 5]
87+
fixed_time_horizons = [1.0, 3.0, 5.0]
4388
stratified_by = ["probability_threshold", "ppcr"]
4489
by=0.1
4590
@@ -71,222 +116,49 @@ list_data_to_adjust_polars = create_list_data_to_adjust_polars(
71116

72117
### New extract aj estimate by assumptions polars
73118

74-
#### One polars dataframe
75-
76-
```{python}
77-
78-
example_polars_df = list_data_to_adjust_polars.get('full').select(pl.col("strata"), pl.col("reals"), pl.col("times"))
79-
80-
fixed_time_horizons = [1, 3, 5]
81-
82-
83-
```
84-
85-
86119
## Create aj_estimates_data
87120

88-
## Create aj_data
89-
90-
```{python}
91-
92-
fixed_time_horizons = [2, 4]
93-
94-
aj_estimates_per_strata_adj_adjneg = create_aj_data(example_polars_df, "adjusted", "adjusted_as_negative", fixed_time_horizons)
95-
96-
aj_estimates_per_strata_excl_adjneg = create_aj_data(example_polars_df, "excluded", "adjusted_as_negative", fixed_time_horizons)
97-
98-
aj_estimates_per_strata_adj_adjcens = create_aj_data(example_polars_df, "adjusted", "adjusted_as_censored", fixed_time_horizons)
99-
100-
101-
```
102-
103-
## AJ estimates per assumptions
104-
105121
```{python}
106122
107-
# 1 adjusted - adjusted_as_negative
108-
109-
aj_estimates_per_strata_adj_adjneg = example_polars_df.group_by("strata").map_groups(
110-
lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizons)).join(pl.DataFrame({"real_censored_est": 0.0, "censoring_assumption": "adjusted", "competing_assumption": "adjusted_as_negative"}), how = 'cross')
111-
112-
113-
114-
# 2 excluded - adjusted as negative
115-
116-
exploded_data = example_polars_df.with_columns(fixed_time_horizon = pl.lit([1,3,5])).explode("fixed_time_horizon")
117-
118-
aj_estimates_per_strata_censored = exploded_data.filter((pl.col("times") < pl.col("fixed_time_horizon")) & pl.col("reals")==0).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_censored_est"}).with_columns(
119-
pl.col("real_censored_est").cast(pl.Float64)
120-
)
121-
122-
non_censored_data = exploded_data.filter((pl.col("times") >= pl.col("fixed_time_horizon")) | pl.col("reals")>0)
123-
124-
125-
aj_estimates_per_strata_noncensored = pl.concat(
126-
[
127-
non_censored_data
128-
.filter(pl.col("fixed_time_horizon") == fixed_time_horizon)
129-
.group_by("strata")
130-
.map_groups(lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizon))
131-
for fixed_time_horizon in fixed_time_horizons
132-
],
133-
how="vertical"
134-
)
135-
136-
aj_estimates_per_strata_excl_adjneg = aj_estimates_per_strata_noncensored.join(
137-
aj_estimates_per_strata_censored,
138-
on = ['strata', 'fixed_time_horizon']
139-
).join(pl.DataFrame({"censoring_assumption": "excluded", "competing_assumption": "adjusted_as_negative"}), how = 'cross')
140-
141-
142-
# 3 adjusted - adjusted as censored
143-
144-
145-
aj_estimates_per_strata_adj_adjcens = example_polars_df.with_columns([
146-
pl.when(
147-
(pl.col("reals") ==2)
148-
).then(pl.lit(0))
149-
.otherwise(pl.col("reals"))
150-
.alias("reals")
151-
]).group_by("strata").map_groups(
152-
lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizons)).join(pl.DataFrame({"real_censored_est": 0.0, "censoring_assumption": "adjusted", "competing_assumption": "adjusted_as_censored"}), how = 'cross')
153-
154-
# 4 excluded - adjusted as censored
155-
156-
exploded_data = example_polars_df.with_columns(fixed_time_horizon = pl.lit([1,3,5])).explode("fixed_time_horizon")
157-
158-
aj_estimates_per_strata_censored = exploded_data.filter((pl.col("times") < pl.col("fixed_time_horizon")) & pl.col("reals")==0).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_censored_est"}).with_columns(
159-
pl.col("real_censored_est").cast(pl.Float64)
160-
)
161-
162-
non_censored_data = exploded_data.filter((pl.col("times") >= pl.col("fixed_time_horizon")) | pl.col("reals")>0).with_columns([
163-
pl.when(
164-
(pl.col("reals") ==2)
165-
).then(pl.lit(0))
166-
.otherwise(pl.col("reals"))
167-
.alias("reals")
168-
])
169-
170-
171-
aj_estimates_per_strata_noncensored = pl.concat(
172-
[
173-
non_censored_data
174-
.filter(pl.col("fixed_time_horizon") == fixed_time_horizon)
175-
.group_by("strata")
176-
.map_groups(lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizon))
177-
for fixed_time_horizon in fixed_time_horizons
178-
],
179-
how="vertical"
180-
)
181-
182-
aj_estimates_per_strata_excl_adjcens = aj_estimates_per_strata_noncensored.join(
183-
aj_estimates_per_strata_censored,
184-
on = ['strata', 'fixed_time_horizon']
185-
).join(pl.DataFrame({"censoring_assumption": "excluded", "competing_assumption": "adjusted_as_negative"}), how = 'cross')
186-
187-
188-
189-
## 5 adjusted - excluded
190-
191-
exploded_data = example_polars_df.with_columns(fixed_time_horizon = pl.lit([1,3,5])).explode("fixed_time_horizon")
192-
193-
aj_estimates_per_strata_competing = exploded_data.filter((pl.col("reals")==2) & (pl.col("times") < pl.col("fixed_time_horizon"))).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_competing_est"}).with_columns(
194-
pl.col("real_competing_est").cast(pl.Float64)
195-
)
196-
197-
non_competing_data = exploded_data.filter((pl.col("times") >= pl.col("fixed_time_horizon")) | pl.col("reals")!=2).with_columns([
198-
pl.when(
199-
(pl.col("reals") ==2)
200-
).then(pl.lit(0))
201-
.otherwise(pl.col("reals"))
202-
.alias("reals")
203-
])
204-
205-
206-
aj_estimates_per_strata_noncompeting = pl.concat(
207-
[
208-
non_competing_data
209-
.filter(pl.col("fixed_time_horizon") == fixed_time_horizon)
210-
.group_by("strata")
211-
.map_groups(lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizon))
212-
for fixed_time_horizon in fixed_time_horizons
213-
],
214-
how="vertical"
215-
).select(pl.exclude("real_competing_est"))
216-
217-
aj_estimates_per_strata_adj_excl = aj_estimates_per_strata_competing.join(
218-
aj_estimates_per_strata_noncompeting,
219-
on = ['strata', 'fixed_time_horizon']
220-
).join(pl.DataFrame({"real_censored_est": 0.0, "censoring_assumption": "adjusted", "competing_assumption": "excluded"}), how = 'cross').select(
221-
['strata',
222-
'fixed_time_horizon',
223-
'real_negatives_est',
224-
'real_positives_est',
225-
'real_competing_est',
226-
'real_censored_est',
227-
'censoring_assumption',
228-
'competing_assumption']
229-
)
230-
231-
232-
## 6 excluded - excluded
233-
234-
235-
exploded_data = example_polars_df.with_columns(fixed_time_horizon = pl.lit([1,3,5])).explode("fixed_time_horizon")
236-
237-
aj_estimates_per_strata_censored = exploded_data.filter((pl.col("times") < pl.col("fixed_time_horizon")) & pl.col("reals")==0).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_censored_est"}).with_columns(
238-
pl.col("real_censored_est").cast(pl.Float64)
239-
)
240-
241-
aj_estimates_per_strata_competing = exploded_data.filter((pl.col("reals")==2) & (pl.col("times") < pl.col("fixed_time_horizon"))).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_competing_est"}).with_columns(
242-
pl.col("real_competing_est").cast(pl.Float64)
243-
)
244-
245-
246-
non_censored_non_competing_data = exploded_data.filter(((pl.col("times") >= pl.col("fixed_time_horizon")) | pl.col("reals")==1))
247-
123+
fixed_time_horizons = [1.0, 3.0, 5.0]
124+
125+
assumption_sets = [
126+
{
127+
"censoring_assumption": "adjusted",
128+
"competing_assumption": "adjusted_as_negative",
129+
},
130+
{
131+
"censoring_assumption": "excluded",
132+
"competing_assumption": "adjusted_as_negative",
133+
},
134+
{
135+
"censoring_assumption": "adjusted",
136+
"competing_assumption": "adjusted_as_censored",
137+
},
138+
{
139+
"censoring_assumption": "excluded",
140+
"competing_assumption": "adjusted_as_censored",
141+
},
142+
{"censoring_assumption": "adjusted", "competing_assumption": "excluded"},
143+
{"censoring_assumption": "excluded", "competing_assumption": "excluded"},
144+
]
145+
146+
# aj_estimates_data = extract_aj_estimate_by_assumptions(
147+
# example_polars_df,
148+
# assumption_sets=assumption_sets,
149+
# fixed_time_horizons=fixed_time_horizons,
150+
# )
248151
249-
aj_estimates_per_strata_noncompeting_noncompeting = pl.concat(
250-
[
251-
non_censored_non_competing_data
252-
.filter(pl.col("fixed_time_horizon") == fixed_time_horizon)
253-
.group_by("strata")
254-
.map_groups(lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizon))
255-
for fixed_time_horizon in fixed_time_horizons
256-
],
257-
how="vertical"
258-
)
259152
260-
aj_estimates_per_strata_excl_excl = aj_estimates_per_strata_competing.join(aj_estimates_per_strata_censored, on = ['strata', 'fixed_time_horizon']).join(
261-
aj_estimates_per_strata_noncompeting,
262-
on = ['strata', 'fixed_time_horizon']
263-
).join(pl.DataFrame({"censoring_assumption": "excluded", "competing_assumption": "excluded"}), how = 'cross').select(
264-
['strata',
265-
'fixed_time_horizon',
266-
'real_negatives_est',
267-
'real_positives_est',
268-
'real_competing_est',
269-
'real_censored_est',
270-
'censoring_assumption',
271-
'competing_assumption']
153+
aj_estimates_data = create_adjusted_data(
154+
list_data_to_adjust_polars,
155+
assumption_sets=assumption_sets,
156+
fixed_time_horizons=fixed_time_horizons
272157
)
273158
274-
## combine all
275-
276-
aj_estimates_data = pl.concat(
277-
[
278-
aj_estimates_per_strata_adj_adjneg,
279-
aj_estimates_per_strata_adj_adjcens,
280-
aj_estimates_per_strata_adj_excl,
281-
aj_estimates_per_strata_excl_adjneg,
282-
aj_estimates_per_strata_excl_adjcens,
283-
aj_estimates_per_strata_excl_excl
284-
]
285-
).unpivot( index = ["strata", "fixed_time_horizon", "censoring_assumption", "competing_assumption"] , variable_name = "reals_labels", value_name = "reals_estimate")
286-
287-
288159
```
289160

161+
290162
### Check strata values
291163

292164
```{python}
@@ -322,38 +194,7 @@ print(result.filter(pl.col("is_in_df2") == False))
322194

323195
```{python}
324196
325-
reals_enum_dtype = aj_data_combinations.schema["reals_labels"]
326-
censoring_assumptions_enum_dtype = aj_data_combinations.schema["censoring_assumption"]
327-
competing_assumptions_enum_dtype = aj_data_combinations.schema["competing_assumption"]
328-
329-
strata_enum_dtype = aj_data_combinations.schema["strata"]
330-
331-
332-
aj_estimates_data = aj_estimates_data.with_columns([
333-
pl.col("strata")
334-
]).with_columns(
335-
pl.col("reals_labels").str.replace(r"_est$", "").cast(reals_enum_dtype)
336-
).with_columns(
337-
pl.col("censoring_assumption").cast(censoring_assumptions_enum_dtype)
338-
).with_columns(
339-
pl.col("competing_assumption").cast(competing_assumptions_enum_dtype)
340-
).with_columns(
341-
pl.col("strata").cast(strata_enum_dtype)
342-
)
343-
344-
```
345-
346-
```{python}
347-
348-
349-
final_adjusted_data_polars = aj_data_combinations.with_columns([
350-
pl.col("strata")
351-
]).join(
352-
aj_estimates_data,
353-
on = ['strata', 'fixed_time_horizon', 'censoring_assumption', 'competing_assumption', 'reals_labels'],
354-
how = 'left'
355-
)
356-
197+
final_adjusted_data_polars = cast_and_join_adjusted_data(aj_data_combinations, aj_estimates_data)
357198
358199
```
359200

@@ -441,6 +282,9 @@ Plot.plot({
441282
domain: ["real_positives", "real_competing", "real_negatives", "real_censored"],
442283
range: ["#009e73", "#9DB4C0", "#FAC8CD", "#E3F09B"],
443284
legend: true
285+
},
286+
style: {
287+
background: "none"
444288
}
445289
})
446290

0 commit comments

Comments
 (0)