@@ -7,28 +7,73 @@ warning: false
77---
88
99``` {python}
10- from lifelines import AalenJohansenFitter
10+ import polars as pl
11+ import pandas as pd
12+ import numpy as np
13+ from lifelines import AalenJohansenFitter, CoxPHFitter, WeibullAFTFitter
14+
15+ df_time_to_cancer_dx = pd.read_csv(
16+ "https://raw.githubusercontent.com/ddsjoberg/dca-tutorial/main/data/df_time_to_cancer_dx.csv"
17+ )
18+ ```
19+
20+
21+ ``` {python}
22+
1123import numpy as np
1224from itertools import product
1325import itertools
1426from rtichoke.helpers.sandbox_observable_helpers import *
15- from lifelines import CoxPHFitter
16- from lifelines import WeibullAFTFitter
1727import polars as pl
1828print("Polars version:", pl.__version__)
1929
2030import pandas as pd
2131import pickle
2232
23- with open(r'C:\Users\I\Documents\GitHub\rtichoke_python\probs_dict.pkl', 'rb') as file:
24- probs_dict = pickle.load(file)
33+ cph = CoxPHFitter()
34+ thin_model = CoxPHFitter()
35+ aft_model = WeibullAFTFitter()
2536
26- with open(r'C:\Users\I\Documents\GitHub\rtichoke_python\reals_dict.pkl', 'rb') as file:
27- reals_dict = pickle.load(file)
37+ cox_formula = "age + famhistory + marker"
38+ thin_formula = "age + marker"
39+ aft_formula = "age + marker"
2840
29- with open(r'C:\Users\I\Documents\GitHub\rtichoke_python\times_dict.pkl', 'rb') as file:
30- times_dict = pickle.load(file)
41+ cph.fit(
42+ df_time_to_cancer_dx,
43+ duration_col="ttcancer",
44+ event_col="cancer",
45+ formula=cox_formula,
46+ )
3147
48+ thin_model.fit(
49+ df_time_to_cancer_dx,
50+ duration_col="ttcancer",
51+ event_col="cancer",
52+ formula=thin_formula,
53+ )
54+
55+ aft_model.fit(
56+ df_time_to_cancer_dx,
57+ duration_col="ttcancer",
58+ event_col="cancer",
59+ formula=aft_formula,
60+ )
61+
62+
63+
64+ cph_pred_vals = (1 - cph.predict_survival_function(df_time_to_cancer_dx[['age', 'famhistory', 'marker']], times=[1.5])).iloc[0, :].values
65+
66+ thin_pred_vals = (1 - thin_model.predict_survival_function(df_time_to_cancer_dx[['age', 'famhistory', 'marker']], times=[1.5])).iloc[0, :].values
67+
68+ aft_pred_vals = (1 - aft_model.predict_survival_function(df_time_to_cancer_dx[['age', 'famhistory', 'marker']], times=[1.5])).iloc[0, :].values
69+
70+ probs_dict = {"full": cph_pred_vals, "thin": thin_pred_vals, "aft": aft_pred_vals}
71+
72+ reals_mapping = {"censor": 0, "diagnosed with cancer": 1, "dead other causes": 2}
73+
74+ reals_dict = df_time_to_cancer_dx["cancer_cr"].map(reals_mapping)
75+
76+ times_dict = df_time_to_cancer_dx["ttcancer"]
3277
3378```
3479
@@ -39,7 +84,7 @@ with open(r'C:\Users\I\Documents\GitHub\rtichoke_python\times_dict.pkl', 'rb') a
3984
4085
4186
42- fixed_time_horizons = [1, 3, 5]
87+ fixed_time_horizons = [1.0 , 3.0 , 5.0 ]
4388stratified_by = ["probability_threshold", "ppcr"]
4489by=0.1
4590
@@ -71,222 +116,49 @@ list_data_to_adjust_polars = create_list_data_to_adjust_polars(
71116
72117### New extract aj estimate by assumptions polars
73118
74- #### One polars dataframe
75-
76- ``` {python}
77-
78- example_polars_df = list_data_to_adjust_polars.get('full').select(pl.col("strata"), pl.col("reals"), pl.col("times"))
79-
80- fixed_time_horizons = [1, 3, 5]
81-
82-
83- ```
84-
85-
86119## Create aj_estimates_data
87120
88- ## Create aj_data
89-
90- ``` {python}
91-
92- fixed_time_horizons = [2, 4]
93-
94- aj_estimates_per_strata_adj_adjneg = create_aj_data(example_polars_df, "adjusted", "adjusted_as_negative", fixed_time_horizons)
95-
96- aj_estimates_per_strata_excl_adjneg = create_aj_data(example_polars_df, "excluded", "adjusted_as_negative", fixed_time_horizons)
97-
98- aj_estimates_per_strata_adj_adjcens = create_aj_data(example_polars_df, "adjusted", "adjusted_as_censored", fixed_time_horizons)
99-
100-
101- ```
102-
103- ## AJ estimates per assumptions
104-
105121``` {python}
106122
107- # 1 adjusted - adjusted_as_negative
108-
109- aj_estimates_per_strata_adj_adjneg = example_polars_df.group_by("strata").map_groups(
110- lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizons)).join(pl.DataFrame({"real_censored_est": 0.0, "censoring_assumption": "adjusted", "competing_assumption": "adjusted_as_negative"}), how = 'cross')
111-
112-
113-
114- # 2 excluded - adjusted as negative
115-
116- exploded_data = example_polars_df.with_columns(fixed_time_horizon = pl.lit([1,3,5])).explode("fixed_time_horizon")
117-
118- aj_estimates_per_strata_censored = exploded_data.filter((pl.col("times") < pl.col("fixed_time_horizon")) & pl.col("reals")==0).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_censored_est"}).with_columns(
119- pl.col("real_censored_est").cast(pl.Float64)
120- )
121-
122- non_censored_data = exploded_data.filter((pl.col("times") >= pl.col("fixed_time_horizon")) | pl.col("reals")>0)
123-
124-
125- aj_estimates_per_strata_noncensored = pl.concat(
126- [
127- non_censored_data
128- .filter(pl.col("fixed_time_horizon") == fixed_time_horizon)
129- .group_by("strata")
130- .map_groups(lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizon))
131- for fixed_time_horizon in fixed_time_horizons
132- ],
133- how="vertical"
134- )
135-
136- aj_estimates_per_strata_excl_adjneg = aj_estimates_per_strata_noncensored.join(
137- aj_estimates_per_strata_censored,
138- on = ['strata', 'fixed_time_horizon']
139- ).join(pl.DataFrame({"censoring_assumption": "excluded", "competing_assumption": "adjusted_as_negative"}), how = 'cross')
140-
141-
142- # 3 adjusted - adjusted as censored
143-
144-
145- aj_estimates_per_strata_adj_adjcens = example_polars_df.with_columns([
146- pl.when(
147- (pl.col("reals") ==2)
148- ).then(pl.lit(0))
149- .otherwise(pl.col("reals"))
150- .alias("reals")
151- ]).group_by("strata").map_groups(
152- lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizons)).join(pl.DataFrame({"real_censored_est": 0.0, "censoring_assumption": "adjusted", "competing_assumption": "adjusted_as_censored"}), how = 'cross')
153-
154- # 4 excluded - adjusted as censored
155-
156- exploded_data = example_polars_df.with_columns(fixed_time_horizon = pl.lit([1,3,5])).explode("fixed_time_horizon")
157-
158- aj_estimates_per_strata_censored = exploded_data.filter((pl.col("times") < pl.col("fixed_time_horizon")) & pl.col("reals")==0).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_censored_est"}).with_columns(
159- pl.col("real_censored_est").cast(pl.Float64)
160- )
161-
162- non_censored_data = exploded_data.filter((pl.col("times") >= pl.col("fixed_time_horizon")) | pl.col("reals")>0).with_columns([
163- pl.when(
164- (pl.col("reals") ==2)
165- ).then(pl.lit(0))
166- .otherwise(pl.col("reals"))
167- .alias("reals")
168- ])
169-
170-
171- aj_estimates_per_strata_noncensored = pl.concat(
172- [
173- non_censored_data
174- .filter(pl.col("fixed_time_horizon") == fixed_time_horizon)
175- .group_by("strata")
176- .map_groups(lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizon))
177- for fixed_time_horizon in fixed_time_horizons
178- ],
179- how="vertical"
180- )
181-
182- aj_estimates_per_strata_excl_adjcens = aj_estimates_per_strata_noncensored.join(
183- aj_estimates_per_strata_censored,
184- on = ['strata', 'fixed_time_horizon']
185- ).join(pl.DataFrame({"censoring_assumption": "excluded", "competing_assumption": "adjusted_as_negative"}), how = 'cross')
186-
187-
188-
189- ## 5 adjusted - excluded
190-
191- exploded_data = example_polars_df.with_columns(fixed_time_horizon = pl.lit([1,3,5])).explode("fixed_time_horizon")
192-
193- aj_estimates_per_strata_competing = exploded_data.filter((pl.col("reals")==2) & (pl.col("times") < pl.col("fixed_time_horizon"))).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_competing_est"}).with_columns(
194- pl.col("real_competing_est").cast(pl.Float64)
195- )
196-
197- non_competing_data = exploded_data.filter((pl.col("times") >= pl.col("fixed_time_horizon")) | pl.col("reals")!=2).with_columns([
198- pl.when(
199- (pl.col("reals") ==2)
200- ).then(pl.lit(0))
201- .otherwise(pl.col("reals"))
202- .alias("reals")
203- ])
204-
205-
206- aj_estimates_per_strata_noncompeting = pl.concat(
207- [
208- non_competing_data
209- .filter(pl.col("fixed_time_horizon") == fixed_time_horizon)
210- .group_by("strata")
211- .map_groups(lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizon))
212- for fixed_time_horizon in fixed_time_horizons
213- ],
214- how="vertical"
215- ).select(pl.exclude("real_competing_est"))
216-
217- aj_estimates_per_strata_adj_excl = aj_estimates_per_strata_competing.join(
218- aj_estimates_per_strata_noncompeting,
219- on = ['strata', 'fixed_time_horizon']
220- ).join(pl.DataFrame({"real_censored_est": 0.0, "censoring_assumption": "adjusted", "competing_assumption": "excluded"}), how = 'cross').select(
221- ['strata',
222- 'fixed_time_horizon',
223- 'real_negatives_est',
224- 'real_positives_est',
225- 'real_competing_est',
226- 'real_censored_est',
227- 'censoring_assumption',
228- 'competing_assumption']
229- )
230-
231-
232- ## 6 excluded - excluded
233-
234-
235- exploded_data = example_polars_df.with_columns(fixed_time_horizon = pl.lit([1,3,5])).explode("fixed_time_horizon")
236-
237- aj_estimates_per_strata_censored = exploded_data.filter((pl.col("times") < pl.col("fixed_time_horizon")) & pl.col("reals")==0).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_censored_est"}).with_columns(
238- pl.col("real_censored_est").cast(pl.Float64)
239- )
240-
241- aj_estimates_per_strata_competing = exploded_data.filter((pl.col("reals")==2) & (pl.col("times") < pl.col("fixed_time_horizon"))).group_by(["strata", "fixed_time_horizon"]).count().rename({"count": "real_competing_est"}).with_columns(
242- pl.col("real_competing_est").cast(pl.Float64)
243- )
244-
245-
246- non_censored_non_competing_data = exploded_data.filter(((pl.col("times") >= pl.col("fixed_time_horizon")) | pl.col("reals")==1))
247-
123+ fixed_time_horizons = [1.0, 3.0, 5.0]
124+
125+ assumption_sets = [
126+ {
127+ "censoring_assumption": "adjusted",
128+ "competing_assumption": "adjusted_as_negative",
129+ },
130+ {
131+ "censoring_assumption": "excluded",
132+ "competing_assumption": "adjusted_as_negative",
133+ },
134+ {
135+ "censoring_assumption": "adjusted",
136+ "competing_assumption": "adjusted_as_censored",
137+ },
138+ {
139+ "censoring_assumption": "excluded",
140+ "competing_assumption": "adjusted_as_censored",
141+ },
142+ {"censoring_assumption": "adjusted", "competing_assumption": "excluded"},
143+ {"censoring_assumption": "excluded", "competing_assumption": "excluded"},
144+ ]
145+
146+ # aj_estimates_data = extract_aj_estimate_by_assumptions(
147+ # example_polars_df,
148+ # assumption_sets=assumption_sets,
149+ # fixed_time_horizons=fixed_time_horizons,
150+ # )
248151
249- aj_estimates_per_strata_noncompeting_noncompeting = pl.concat(
250- [
251- non_censored_non_competing_data
252- .filter(pl.col("fixed_time_horizon") == fixed_time_horizon)
253- .group_by("strata")
254- .map_groups(lambda group: extract_aj_estimate_for_strata(group, fixed_time_horizon))
255- for fixed_time_horizon in fixed_time_horizons
256- ],
257- how="vertical"
258- )
259152
260- aj_estimates_per_strata_excl_excl = aj_estimates_per_strata_competing.join(aj_estimates_per_strata_censored, on = ['strata', 'fixed_time_horizon']).join(
261- aj_estimates_per_strata_noncompeting,
262- on = ['strata', 'fixed_time_horizon']
263- ).join(pl.DataFrame({"censoring_assumption": "excluded", "competing_assumption": "excluded"}), how = 'cross').select(
264- ['strata',
265- 'fixed_time_horizon',
266- 'real_negatives_est',
267- 'real_positives_est',
268- 'real_competing_est',
269- 'real_censored_est',
270- 'censoring_assumption',
271- 'competing_assumption']
153+ aj_estimates_data = create_adjusted_data(
154+ list_data_to_adjust_polars,
155+ assumption_sets=assumption_sets,
156+ fixed_time_horizons=fixed_time_horizons
272157)
273158
274- ## combine all
275-
276- aj_estimates_data = pl.concat(
277- [
278- aj_estimates_per_strata_adj_adjneg,
279- aj_estimates_per_strata_adj_adjcens,
280- aj_estimates_per_strata_adj_excl,
281- aj_estimates_per_strata_excl_adjneg,
282- aj_estimates_per_strata_excl_adjcens,
283- aj_estimates_per_strata_excl_excl
284- ]
285- ).unpivot( index = ["strata", "fixed_time_horizon", "censoring_assumption", "competing_assumption"] , variable_name = "reals_labels", value_name = "reals_estimate")
286-
287-
288159```
289160
161+
290162### Check strata values
291163
292164``` {python}
@@ -322,38 +194,7 @@ print(result.filter(pl.col("is_in_df2") == False))
322194
323195``` {python}
324196
325- reals_enum_dtype = aj_data_combinations.schema["reals_labels"]
326- censoring_assumptions_enum_dtype = aj_data_combinations.schema["censoring_assumption"]
327- competing_assumptions_enum_dtype = aj_data_combinations.schema["competing_assumption"]
328-
329- strata_enum_dtype = aj_data_combinations.schema["strata"]
330-
331-
332- aj_estimates_data = aj_estimates_data.with_columns([
333- pl.col("strata")
334- ]).with_columns(
335- pl.col("reals_labels").str.replace(r"_est$", "").cast(reals_enum_dtype)
336- ).with_columns(
337- pl.col("censoring_assumption").cast(censoring_assumptions_enum_dtype)
338- ).with_columns(
339- pl.col("competing_assumption").cast(competing_assumptions_enum_dtype)
340- ).with_columns(
341- pl.col("strata").cast(strata_enum_dtype)
342- )
343-
344- ```
345-
346- ``` {python}
347-
348-
349- final_adjusted_data_polars = aj_data_combinations.with_columns([
350- pl.col("strata")
351- ]).join(
352- aj_estimates_data,
353- on = ['strata', 'fixed_time_horizon', 'censoring_assumption', 'competing_assumption', 'reals_labels'],
354- how = 'left'
355- )
356-
197+ final_adjusted_data_polars = cast_and_join_adjusted_data(aj_data_combinations, aj_estimates_data)
357198
358199```
359200
@@ -441,6 +282,9 @@ Plot.plot({
441282 domain: ["real_positives", "real_competing", "real_negatives", "real_censored"],
442283 range: ["#009e73", "#9DB4C0", "#FAC8CD", "#E3F09B"],
443284 legend: true
285+ },
286+ style: {
287+ background: "none"
444288 }
445289})
446290
0 commit comments