Skip to content

Commit dc3bfdb

Browse files
committed
Create initial aj reporte similar to R
1 parent c765dfc commit dc3bfdb

File tree

3 files changed

+24397
-31
lines changed

3 files changed

+24397
-31
lines changed

docs/walkthrough_aj_estimate.qmd

Lines changed: 229 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
---
22
title: "Hello, Quarto"
33
format: html
4+
echo: false
5+
message: false
6+
warning: false
47
---
58

69
## Markdown
@@ -16,9 +19,11 @@ Markdown is an easy to read and write text format:
1619
Here is a Python code cell:
1720

1821
```{python}
19-
22+
from lifelines import AalenJohansenFitter
2023
import numpy as np
2124
from itertools import product
25+
import itertools
26+
2227
import pandas as pd
2328
from lifelines import CoxPHFitter
2429
@@ -34,6 +39,8 @@ def extract_aj_estimate(data_to_adjust, fixed_time_horizons):
3439
pd.DataFrame: DataFrame with Aalen-Johansen estimates
3540
"""
3641
import numpy as np
42+
43+
# print(f"data_to_adjust: {data_to_adjust}")
3744
3845
# Ensure 'strata' column exists
3946
if 'strata' not in data_to_adjust.columns:
@@ -144,48 +151,60 @@ def extract_aj_estimate(data_to_adjust, fixed_time_horizons):
144151
145152
def extract_crude_estimate(data_to_adjust):
146153
"""
147-
Calculate crude estimates for each group in the data.
148-
149-
Parameters:
150-
data_to_adjust (pd.DataFrame): DataFrame containing the data to adjust with columns 'strata', 'reals', and 'fixed_time_horizon'.
151-
154+
Computes the crude estimate by counting occurrences of 'reals' within
155+
each combination of 'strata' and 'fixed_time_horizon'.
156+
157+
Args:
158+
data_to_adjust (pd.DataFrame): Data containing 'strata', 'reals', and 'fixed_time_horizon'.
159+
152160
Returns:
153-
pd.DataFrame: DataFrame with crude estimates for each group.
161+
pd.DataFrame: Aggregated counts with missing combinations filled with zero.
154162
"""
155-
# Count occurrences of each group
156-
grouped = (
163+
# Group by strata, reals, and fixed_time_horizon, then count occurrences
164+
165+
# print('data_to_adjust')
166+
# print(data_to_adjust)
167+
168+
crude_estimate = (
157169
data_to_adjust
158-
.groupby(["strata", "reals", "fixed_time_horizon"])
170+
.groupby(["strata", "reals", "fixed_time_horizon"], dropna=False)
159171
.size()
160-
.reset_index(name="reals_estimate") # Equivalent to dplyr::summarise(n())
172+
.reset_index(name="reals_estimate")
161173
)
162174
163-
# Create a complete grid of all possible combinations
164-
strata_values = data_to_adjust["strata"].unique()
165-
reals_values = data_to_adjust["reals"].unique()
166-
fixed_time_horizon_values = data_to_adjust["fixed_time_horizon"].unique()
175+
unique_strata = data_to_adjust["strata"].unique()
176+
unique_time_horizons = data_to_adjust["fixed_time_horizon"].unique()
177+
unique_reals = data_to_adjust["reals"].unique()
178+
179+
# return crude_estimate
167180
168-
complete_grid = pd.DataFrame(
169-
list(product(strata_values, reals_values, fixed_time_horizon_values)),
181+
# Create all possible combinations to ensure completeness
182+
all_combinations = pd.DataFrame(
183+
list(itertools.product(unique_strata, unique_reals, unique_time_horizons)),
170184
columns=["strata", "reals", "fixed_time_horizon"]
171185
)
172186
187+
188+
# Ensure all possible combinations are present and fill missing values with 0
189+
crude_estimate = (
190+
all_combinations
191+
.merge(crude_estimate, on=["strata", "reals", "fixed_time_horizon"], how="left")
192+
.fillna({"reals_estimate": 0})
193+
)
194+
195+
return crude_estimate
196+
173197
def add_cutoff_strata(data, by):
174198
data["strata_probability_threshold"] = pd.cut(
175199
data["probs"],
176200
bins=create_breaks_values(data["probs"], "probability_threshold", by),
177201
include_lowest=True
178202
)
179-
data["strata_ppcr"] = pd.qcut(-data["probs"], q=int(1/by), labels=False) + 1
203+
data["strata_ppcr"] = (pd.qcut(-data["probs"], q=int(1/by), labels=False) + 1) / (1 / by)
180204
data["strata_ppcr"] = data["strata_ppcr"].astype(str)
181205
return data
182206
183-
def create_breaks_values(probs_vec, stratified_by, by):
184-
if stratified_by != "probability_threshold":
185-
breaks = np.quantile(probs_vec, np.linspace(1, 0, int(1/by) + 1))
186-
else:
187-
breaks = np.round(np.arange(0, 1 + by, by), decimals=len(str(by).split(".")[-1]))
188-
return breaks
207+
189208
190209
def create_strata_combinations(stratified_by, by):
191210
if stratified_by == "probability_threshold":
@@ -195,7 +214,7 @@ def create_strata_combinations(stratified_by, by):
195214
mid_point = upper_bound - by / 2
196215
include_lower_bound = lower_bound == 0
197216
include_upper_bound = upper_bound != 0
198-
strata = [f"{'[' if lb else '('}{l},{u}{']' if ub else ')'}" for lb, l, u, ub in zip(include_lower_bound, lower_bound, upper_bound, include_upper_bound)]
217+
strata = [f"{'[' if lb else '('}{l}, {u}{']' if ub else ')'}" for lb, l, u, ub in zip(include_lower_bound, lower_bound, upper_bound, include_upper_bound)]
199218
chosen_cutoff = upper_bound
200219
elif stratified_by == "ppcr":
201220
strata = create_breaks_values(None, "probability_threshold", by)[1:]
@@ -288,7 +307,6 @@ stratified_by = ["probability_threshold", "ppcr"]
288307
# Placeholder for create_aj_data_combinations
289308
aj_data_combinations = create_aj_data_combinations(list(probs_cox.keys()), fixed_time_horizons, stratified_by, 0.01)
290309
291-
aj_data_combinations
292310
293311
# Create reference groups
294312
data_to_adjust = pd.DataFrame({
@@ -301,6 +319,24 @@ data_to_adjust = pd.DataFrame({
301319
# # Placeholder for add_cutoff_strata function
302320
data_to_adjust = add_cutoff_strata(data_to_adjust, by=0.01)
303321
322+
def pivot_longer_strata(data):
323+
data = data.copy() # Ensure we are not modifying the original DataFrame
324+
325+
# Melt the DataFrame, converting multiple 'strata_*' columns into long format
326+
data_long = data.melt(
327+
id_vars=[col for col in data.columns if not col.startswith("strata_")], # Keep all non-strata columns
328+
value_vars=[col for col in data.columns if col.startswith("strata_")], # Melt only strata columns
329+
var_name="stratified_by",
330+
value_name="strata"
331+
)
332+
333+
# Remove "strata_" prefix from stratified_by column (equivalent to `names_prefix = "strata_"` in R)
334+
data_long["stratified_by"] = data_long["stratified_by"].str.replace("strata_", "")
335+
336+
return data_long
337+
338+
data_to_adjust = pivot_longer_strata(data_to_adjust)
339+
304340
data_to_adjust["reals"] = data_to_adjust["reals"].replace({
305341
0: "real_negatives",
306342
2: "real_competing",
@@ -313,11 +349,11 @@ list_data_to_adjust = {k: v for k, v in data_to_adjust.groupby("reference_group"
313349
314350
# # Define assumption sets
315351
assumption_sets = [
316-
{"competing": "excluded", "censored": "excluded"},
317-
{"competing": "adjusted_as_negative", "censored": "adjusted"},
318-
{"competing": "adjusted_as_censored", "censored": "adjusted"},
319-
{"competing": "excluded", "censored": "adjusted"},
320-
{"competing": "adjusted_as_negative", "censored": "excluded"}
352+
{"competing": "excluded", "censored": "excluded"}#,
353+
# {"competing": "adjusted_as_negative", "censored": "adjusted"},
354+
# {"competing": "adjusted_as_censored", "censored": "adjusted"},
355+
# {"competing": "excluded", "censored": "adjusted"},
356+
# {"competing": "adjusted_as_negative", "censored": "excluded"}
321357
]
322358
323359
def update_administrative_censoring(data_to_adjust):
@@ -340,14 +376,27 @@ def update_administrative_censoring(data_to_adjust):
340376
def extract_aj_estimate_by_assumptions(data_to_adjust, fixed_time_horizons,
341377
censoring_assumption="excluded",
342378
competing_assumption="excluded"):
379+
380+
# print('censoring_assumption')
381+
# print(censoring_assumption)
382+
383+
# print('competing assumption')
384+
# print(competing_assumption)
385+
386+
343387
if censoring_assumption == "excluded" and competing_assumption == "excluded":
388+
389+
344390
aj_estimate_data = (
345391
data_to_adjust
346392
.assign(fixed_time_horizon=lambda df: df.apply(lambda x: fixed_time_horizons, axis=1))
347393
.explode("fixed_time_horizon")
348394
.pipe(update_administrative_censoring)
349395
.pipe(extract_crude_estimate)
350396
)
397+
398+
# print('aj_estimate-data')
399+
# print(aj_estimate_data)
351400
352401
elif censoring_assumption == "excluded" and competing_assumption == "adjusted_as_negative":
353402
aj_estimate_data_excluded = (
@@ -435,3 +484,152 @@ list_data_to_adjust
435484
```
436485

437486

487+
```{python}
488+
489+
list_data_to_adjust
490+
491+
# Adjust data based on assumptions
492+
adjusted_data_list = []
493+
for reference_group, group_data in list_data_to_adjust.items():
494+
for assumptions in assumption_sets:
495+
# print(f"Processing assumptions: {assumptions}")
496+
# print(f"group_data: {group_data}")
497+
# adjusted_data = extract_aj_estimate_by_assumptions(
498+
# group_data,
499+
# fixed_time_horizons=fixed_time_horizons,
500+
# censoring_assumption="excluded",
501+
# competing_assumption="excluded"
502+
# )
503+
adjusted_data = extract_aj_estimate_by_assumptions(
504+
group_data,
505+
fixed_time_horizons=fixed_time_horizons,
506+
censoring_assumption=assumptions["censored"],
507+
competing_assumption=assumptions["competing"]
508+
)
509+
adjusted_data["reference_group"] = reference_group
510+
adjusted_data_list.append(adjusted_data)
511+
512+
# Combine all adjusted data
513+
final_adjusted_data = pd.concat(adjusted_data_list, ignore_index=True)
514+
515+
aj_data_combinations['strata'] = aj_data_combinations['strata'].astype(str)
516+
517+
final_adjusted_data['strata'] = final_adjusted_data['strata'].astype(str)
518+
519+
aj_data_combinations['reals'] = aj_data_combinations['reals'].astype(str)
520+
521+
final_adjusted_data['reals'] = final_adjusted_data['reals'].astype(str)
522+
523+
categories = ["real_negatives", "real_positives", "real_competing", "real_censored"]
524+
aj_data_combinations['reals'] = pd.Categorical(aj_data_combinations['reals'], categories=categories, ordered=True)
525+
final_adjusted_data['reals'] = pd.Categorical(final_adjusted_data['reals'], categories=categories, ordered=True)
526+
527+
combined_adjusted_data = aj_data_combinations.merge(final_adjusted_data, on=["reference_group", "fixed_time_horizon", "censoring_assumption", "competing_assumption", "reals", "strata"], how='left')
528+
529+
```
530+
531+
```{python}
532+
533+
ojs_define(reference_groups_data = ["thin", "full"])
534+
535+
ojs_define(data = combined_adjusted_data)
536+
537+
```
538+
539+
```{ojs}
540+
541+
//| panel: input
542+
543+
viewof time_horizon = Inputs.range(
544+
[1, 5],
545+
{value: 3, step: 2, label: "Time Horizon:"}
546+
)
547+
548+
viewof reference_group = Inputs.radio(
549+
reference_groups_data, {label: "Reference Group"}
550+
)
551+
552+
viewof stratified_by = Inputs.radio(
553+
["probability_threshold", "ppcr"], {value: "probability_threshold", label: "Stratified By"}
554+
)
555+
556+
viewof censored_assumption = Inputs.radio(
557+
["excluded", "adjusted"], {value: "excluded", label: "Censored Assumption"}
558+
)
559+
560+
viewof competing_assumption = Inputs.radio(
561+
["excluded", "adjusted_as_negative", "adjusted_as_censored", "reals"], {value: "excluded", label: "Competing Assumption"}
562+
)
563+
564+
```
565+
566+
```{ojs}
567+
568+
//cumulative_aj_data_filtered = transpose(cumulative_aj_data).filter(function(subset) {
569+
//
570+
// return time_horizon == subset.fixed_time_horizon &&
571+
// censored_assumption == subset.censored_assumption &&
572+
// competing_assumption == subset.competing_assumption &&
573+
// stratified_by == subset.stratified_by &&
574+
// reference_group === subset.reference_group;
575+
//})
576+
577+
filtered = transpose(data).filter(function(subset) {
578+
579+
return time_horizon == subset.fixed_time_horizon &&
580+
censored_assumption == subset.censoring_assumption &&
581+
competing_assumption == subset.competing_assumption &&
582+
stratified_by === subset.stratified_by &&
583+
reference_group === subset.reference_group;
584+
})
585+
586+
filtered
587+
588+
589+
```
590+
591+
```{ojs}
592+
593+
594+
Plot.plot({
595+
marks: [
596+
Plot.barY(filtered, {
597+
x: "strata",
598+
y: "reals_estimate",
599+
fill: "reals",
600+
tip: true
601+
})
602+
],
603+
color: {
604+
domain: ["real_positives", "real_competing", "real_negatives", "real_censored"],
605+
range: ["#009e73", "#9DB4C0", "#FAC8CD", "#E3F09B"],
606+
legend: true
607+
}
608+
})
609+
610+
```
611+
612+
```{python}
613+
614+
# combined_adjusted_data.dropna(subset=['reals_estimate'])
615+
# #
616+
617+
# Perform left join between aj_data_combinations and final_adjusted_data on 'strata' and 'reals_estimate'
618+
# only when stratified_by == 'probability_threshold' for aj_data_combinations
619+
620+
aj_data_combinations_prob_threshold = aj_data_combinations[aj_data_combinations['stratified_by'] == 'probability_threshold']
621+
622+
# Convert 'strata' columns to strings
623+
aj_data_combinations_prob_threshold['strata'] = aj_data_combinations_prob_threshold['strata'].astype(str)
624+
final_adjusted_data['strata'] = final_adjusted_data['strata'].astype(str)
625+
626+
combined_adjusted_data = aj_data_combinations_prob_threshold.merge(
627+
final_adjusted_data[['strata', 'reals', 'reals_estimate']],
628+
on=['strata', 'reals'],
629+
how='left'
630+
)
631+
632+
633+
aj_data_combinations_prob_threshold[['strata']]
634+
final_adjusted_data[['strata']]
635+
```

0 commit comments

Comments
 (0)