From d8793d93ad10e34e3ab9cda7a45d88dea925c50c Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 22 Dec 2025 07:43:38 +0200 Subject: [PATCH 01/17] feat: close #246 --- src/rtichoke/calibration/calibration.py | 98 ++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index f9d32f3..52af895 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -2,12 +2,15 @@ A module for Calibration Curves """ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union # import pandas as pd import plotly.graph_objects as go from plotly.subplots import make_subplots from plotly.graph_objs._figure import Figure +import polars as pl +import numpy as np + # from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r @@ -284,3 +287,96 @@ def create_plotly_curve_from_calibration_curve_list( ) return calibration_curve + + +def _make_deciles_dat_binary( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + n_bins: int = 10, + reference_group_name_if_array: str = "overall", +) -> pl.DataFrame: + if isinstance(reals, dict): + reference_groups_keys = list(reals.keys()) + y_list = [ + np.asarray(reals[reference_group]).ravel() + for reference_group in reference_groups_keys + ] + lengths = np.array([len(y) for y in y_list], dtype=np.int64) + offsets = np.concatenate([np.array([0], dtype=np.int64), np.cumsum(lengths)]) + n_total = int(offsets[-1]) + + frames: list[pl.DataFrame] = [] + for model, p_all in probs.items(): + p_all = np.asarray(p_all).ravel() + if p_all.shape[0] != n_total: + raise ValueError( + f"probs['{model}'] length={p_all.shape[0]} does not match " + f"sum of population sizes={n_total}." + ) + + for i, pop in enumerate(reference_groups_keys): + start = int(offsets[i]) + end = int(offsets[i + 1]) + + frames.append( + pl.DataFrame( + { + "reference_group": pop, + "model": model, + "prob": p_all[start:end].astype(float, copy=False), + "real": y_list[i].astype(float, copy=False), + } + ) + ) + + df = pl.concat(frames, how="vertical") + + else: + y = np.asarray(reals).ravel() + n = y.shape[0] + frames = [] + for model, p in probs.items(): + p = np.asarray(p).ravel() + if p.shape[0] != n: + raise ValueError( + f"probs['{model}'] length={p.shape[0]} does not match reals length={n}." + ) + frames.append( + pl.DataFrame( + { + "reference_group": reference_group_name_if_array, + "model": model, + "prob": p.astype(float, copy=False), + "real": y.astype(float, copy=False), + } + ) + ) + + df = pl.concat(frames, how="vertical") + + labels = [str(i) for i in range(1, n_bins + 1)] + + df = df.with_columns( + [ + pl.col("prob").cast(pl.Float64), + pl.col("real").cast(pl.Float64), + pl.col("prob") + .qcut(n_bins, labels=labels) + .over(["reference_group", "model"]) + .alias("decile"), + ] + ).with_columns(pl.col("decile").cast(pl.Int32)) + + deciles_data = ( + df.group_by(["reference_group", "model", "decile"]) + .agg( + [ + pl.len().alias("n"), + pl.mean("prob").alias("pred_mean"), + pl.mean("real").alias("real_mean"), + ] + ) + .sort(["reference_group", "model", "decile"]) + ) + + return deciles_data From 87702d345fe789a0e8545e4b301573bec2154e25 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 22 Dec 2025 10:24:29 +0200 Subject: [PATCH 02/17] feat: close #248 --- src/rtichoke/calibration/calibration.py | 38 +++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 52af895..c54bbda 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -380,3 +380,41 @@ def _make_deciles_dat_binary( ) return deciles_data + + +def _check_performance_type_by_probs_and_reals( + probs: Dict[str, np.ndarray], reals: Union[np.ndarray, Dict[str, np.ndarray]] +) -> str: + if isinstance(reals, dict) and len(reals) > 1: + return "multiple populations" + if len(probs) > 1: + return "multiple models" + return "one model" + + +def _create_calibration_curve_list( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + color_values: List[str], + size: Optional[int], +) -> Dict[str, Any]: + deciles_data = _make_deciles_dat_binary(probs, reals) + + performance_type = _check_performance_type_by_probs_and_reals + + calibration_curve_list = { + "deciles_dat": deciles_data, + # "smooth_dat": smooth_dat, + # "reference_data": reference_data, + # "histogram_for_calibration": histogram_for_calibration, + # "histogram_opacity": [0.4], + # "axes_ranges": axes_ranges, + # "group_colors_vec": { + # "reference_line": ["#737373"], + # **group_colors_vec, + # }, + "performance_type": [performance_type], + # "size": [(size_value, size_value)], + } + + return calibration_curve_list From d4d3100fc6cc42a311690f70022ebfc518237786 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 22 Dec 2025 11:08:13 +0200 Subject: [PATCH 03/17] feat: close #249 --- src/rtichoke/calibration/calibration.py | 28 ++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index c54bbda..24b3a73 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -127,9 +127,9 @@ def create_plotly_curve_from_calibration_curve_list( calibration_curve.add_trace( go.Scatter( - x=calibration_curve_list["reference_data"]["x"].values.tolist(), - y=calibration_curve_list["reference_data"]["y"].values.tolist(), - hovertext=calibration_curve_list["reference_data"]["text"].values.tolist(), + x=calibration_curve_list["reference_data"]["x"], + y=calibration_curve_list["reference_data"]["y"], + hovertext=calibration_curve_list["reference_data"]["text"], name="Perfectly Calibrated", legendgroup="Perfectly Calibrated", hoverinfo="text", @@ -400,12 +400,14 @@ def _create_calibration_curve_list( ) -> Dict[str, Any]: deciles_data = _make_deciles_dat_binary(probs, reals) - performance_type = _check_performance_type_by_probs_and_reals + performance_type = _check_performance_type_by_probs_and_reals(probs, reals) + + reference_data = _create_reference_data_for_calibration_curve() calibration_curve_list = { "deciles_dat": deciles_data, # "smooth_dat": smooth_dat, - # "reference_data": reference_data, + "reference_data": reference_data, # "histogram_for_calibration": histogram_for_calibration, # "histogram_opacity": [0.4], # "axes_ranges": axes_ranges, @@ -418,3 +420,19 @@ def _create_calibration_curve_list( } return calibration_curve_list + + +def _create_reference_data_for_calibration_curve() -> pl.DataFrame: + x_ref = np.linspace(0, 1, 101) + reference_data = pl.DataFrame({"x": x_ref, "y": x_ref}) + reference_data = reference_data.with_columns( + pl.concat_str( + [ + pl.lit("Perfectly Calibrated
Predicted: "), + pl.col("x").round(3).cast(pl.Utf8), + pl.lit("
Observed: "), + pl.col("y").round(3).cast(pl.Utf8), + ] + ).alias("text") + ) + return reference_data From 3bb78dab852b3c6ba3a4c9b39ebe38ecd9f783dd Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 22 Dec 2025 11:56:14 +0200 Subject: [PATCH 04/17] feat: close #250 --- src/rtichoke/calibration/calibration.py | 171 ++++++++++++++---------- 1 file changed, 104 insertions(+), 67 deletions(-) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 24b3a73..0f6638e 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -91,7 +91,7 @@ def create_calibration_curve( # return calibration_curve -def create_plotly_curve_from_calibration_curve_list( +def _create_plotly_curve_from_calibration_curve_list( calibration_curve_list: Dict[str, Any], calibration_type: str = "discrete" ) -> Figure: """Create plotly curve from calibration curve list @@ -129,14 +129,14 @@ def create_plotly_curve_from_calibration_curve_list( go.Scatter( x=calibration_curve_list["reference_data"]["x"], y=calibration_curve_list["reference_data"]["y"], - hovertext=calibration_curve_list["reference_data"]["text"], + # hovertext=calibration_curve_list["reference_data"]["text"], name="Perfectly Calibrated", legendgroup="Perfectly Calibrated", - hoverinfo="text", + # hoverinfo="text", line={ "width": 2, "dash": "dot", - "color": calibration_curve_list["group_colors_vec"]["reference_line"][ + "color": calibration_curve_list["colors_dictionary"]["reference_line"][ 0 ], }, @@ -147,42 +147,37 @@ def create_plotly_curve_from_calibration_curve_list( ) if calibration_type == "discrete": - for reference_group in list(calibration_curve_list["group_colors_vec"].keys()): - if any( - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ): - calibration_curve.add_trace( - go.Scatter( - x=calibration_curve_list["deciles_dat"]["x"][ - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ].values.tolist(), - y=calibration_curve_list["deciles_dat"]["y"][ - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ].values.tolist(), - hovertext=calibration_curve_list["deciles_dat"]["text"][ - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ].values.tolist(), - name=reference_group, - legendgroup=reference_group, - hoverinfo="text", - mode="lines+markers", - marker={ - "size": 10, - "color": calibration_curve_list["group_colors_vec"][ - reference_group - ][0], - }, - ), - row=1, - col=1, - ) + print(calibration_curve_list["deciles_dat"]) + + for reference_group in calibration_curve_list["colors_dictionary"].keys(): + dec_sub = calibration_curve_list["deciles_dat"].filter( + pl.col("reference_group") == reference_group + ) + + print(dec_sub) + + calibration_curve.add_trace( + go.Scatter( + x=dec_sub.get_column("x").to_list(), + y=dec_sub.get_column("y").to_list(), + # hovertext=dec_sub.get_column("text").to_list(), + name=reference_group, + legendgroup=reference_group, + # hoverinfo="text", + mode="lines+markers", + marker={ + "size": 10, + "color": calibration_curve_list["colors_dictionary"][ + reference_group + ][0], + }, + ), + row=1, + col=1, + ) if calibration_type == "smooth": - for reference_group in list(calibration_curve_list["group_colors_vec"].keys()): + for reference_group in list(calibration_curve_list["colors_dictionary"].keys()): if any( calibration_curve_list["smooth_dat"]["reference_group"] == reference_group @@ -192,22 +187,22 @@ def create_plotly_curve_from_calibration_curve_list( x=calibration_curve_list["smooth_dat"]["x"][ calibration_curve_list["smooth_dat"]["reference_group"] == reference_group - ].values.tolist(), + ], y=calibration_curve_list["smooth_dat"]["y"][ calibration_curve_list["smooth_dat"]["reference_group"] == reference_group - ].values.tolist(), - hovertext=calibration_curve_list["smooth_dat"]["text"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ].values.tolist(), + ], + # hovertext=calibration_curve_list["smooth_dat"]["text"][ + # calibration_curve_list["smooth_dat"]["reference_group"] + # == reference_group + # ], name=reference_group, legendgroup=reference_group, - hoverinfo="text", + # hoverinfo="text", mode="lines", marker={ "size": 10, - "color": calibration_curve_list["group_colors_vec"][ + "color": calibration_curve_list["colors_dictionary"][ reference_group ][0], }, @@ -216,7 +211,7 @@ def create_plotly_curve_from_calibration_curve_list( col=1, ) - for reference_group in list(calibration_curve_list["group_colors_vec"].keys()): + for reference_group in list(calibration_curve_list["colors_dictionary"].keys()): if any( calibration_curve_list["histogram_for_calibration"]["reference_group"] == reference_group @@ -228,26 +223,26 @@ def create_plotly_curve_from_calibration_curve_list( "reference_group" ] == reference_group - ].values.tolist(), + ], y=calibration_curve_list["histogram_for_calibration"]["counts"][ calibration_curve_list["histogram_for_calibration"][ "reference_group" ] == reference_group - ].values.tolist(), - hovertext=calibration_curve_list["histogram_for_calibration"][ - "text" - ][ - calibration_curve_list["histogram_for_calibration"][ - "reference_group" - ] - == reference_group - ].values.tolist(), + ], + # hovertext=calibration_curve_list["histogram_for_calibration"][ + # "text" + # ][ + # calibration_curve_list["histogram_for_calibration"][ + # "reference_group" + # ] + # == reference_group + # ], name=reference_group, width=0.01, legendgroup=reference_group, - hoverinfo="text", - marker_color=calibration_curve_list["group_colors_vec"][ + # hoverinfo="text", + marker_color=calibration_curve_list["colors_dictionary"][ reference_group ][0], showlegend=False, @@ -372,8 +367,8 @@ def _make_deciles_dat_binary( .agg( [ pl.len().alias("n"), - pl.mean("prob").alias("pred_mean"), - pl.mean("real").alias("real_mean"), + pl.mean("prob").alias("x"), + pl.mean("real").alias("y"), ] ) .sort(["reference_group", "model", "decile"]) @@ -395,8 +390,29 @@ def _check_performance_type_by_probs_and_reals( def _create_calibration_curve_list( probs: Dict[str, np.ndarray], reals: Union[np.ndarray, Dict[str, np.ndarray]], - color_values: List[str], - size: Optional[int], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], ) -> Dict[str, Any]: deciles_data = _make_deciles_dat_binary(probs, reals) @@ -404,6 +420,12 @@ def _create_calibration_curve_list( reference_data = _create_reference_data_for_calibration_curve() + reference_groups = deciles_data["reference_group"].unique().to_list() + + colors_dictionary = _create_colors_dictionary_for_calibration( + reference_groups, color_values, performance_type + ) + calibration_curve_list = { "deciles_dat": deciles_data, # "smooth_dat": smooth_dat, @@ -411,10 +433,7 @@ def _create_calibration_curve_list( # "histogram_for_calibration": histogram_for_calibration, # "histogram_opacity": [0.4], # "axes_ranges": axes_ranges, - # "group_colors_vec": { - # "reference_line": ["#737373"], - # **group_colors_vec, - # }, + "colors_dictionary": colors_dictionary, "performance_type": [performance_type], # "size": [(size_value, size_value)], } @@ -436,3 +455,21 @@ def _create_reference_data_for_calibration_curve() -> pl.DataFrame: ).alias("text") ) return reference_data + + +def _create_colors_dictionary_for_calibration( + reference_groups: List[str], + color_values: List[str], + performance_type: str = "one model", +) -> Dict[str, List[str]]: + if performance_type == "one model": + colors = ["black"] + else: + colors = color_values[: len(reference_groups)] + + return { + "reference_line": ["#BEBEBE"], + **{ + group: [colors[i % len(colors)]] for i, group in enumerate(reference_groups) + }, + } From fa7e7ac5a20a4a0c119c518a8a8810f316fd1cbd Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 22 Dec 2025 12:47:55 +0200 Subject: [PATCH 05/17] feat: close #251 --- src/rtichoke/calibration/calibration.py | 48 +++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 0f6638e..a3bc0fb 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -426,16 +426,21 @@ def _create_calibration_curve_list( reference_groups, color_values, performance_type ) + histogram_for_calibration = _create_histogram_for_calibration(probs) + + limits = _define_limits_for_calibration_plot(deciles_data) + axes_ranges = {"xaxis": limits, "yaxis": limits} + calibration_curve_list = { "deciles_dat": deciles_data, # "smooth_dat": smooth_dat, "reference_data": reference_data, - # "histogram_for_calibration": histogram_for_calibration, + "histogram_for_calibration": histogram_for_calibration, # "histogram_opacity": [0.4], - # "axes_ranges": axes_ranges, + "axes_ranges": axes_ranges, "colors_dictionary": colors_dictionary, "performance_type": [performance_type], - # "size": [(size_value, size_value)], + "size": [(size, size)], } return calibration_curve_list @@ -473,3 +478,40 @@ def _create_colors_dictionary_for_calibration( group: [colors[i % len(colors)]] for i, group in enumerate(reference_groups) }, } + + +def _create_histogram_for_calibration(probs: Dict[str, np.ndarray]) -> pl.DataFrame: + hist_dfs = [] + for group, prob_values in probs.items(): + counts, mids = np.histogram(prob_values, bins=np.arange(0, 1.01, 0.01)) + hist_df = pl.DataFrame( + {"mids": mids[:-1] + 0.005, "counts": counts, "reference_group": group} + ) + hist_df = hist_df.with_columns( + ( + pl.col("counts").cast(str) + + " observations in [" + + (pl.col("mids") - 0.005).round(3).cast(str) + + ", " + + (pl.col("mids") + 0.005).round(3).cast(str) + + "]" + ).alias("text") + ) + hist_dfs.append(hist_df) + + histogram_for_calibration = pl.concat(hist_dfs) + + return histogram_for_calibration + + +def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float]: + if deciles_dat.height == 1: + lower_bound, upper_bound = 0.0, 1.0 + else: + lower_bound = float(max(0, min(deciles_dat["x"].min(), deciles_dat["y"].min()))) + upper_bound = float(max(deciles_dat["x"].max(), deciles_dat["y"].max())) + + return [ + lower_bound - (upper_bound - lower_bound) * 0.05, + upper_bound + (upper_bound - lower_bound) * 0.05, + ] From efe238ff65e861a20f1cf361725a2c3eba6fad7d Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 22 Dec 2025 13:57:07 +0200 Subject: [PATCH 06/17] fix: migrate code from pandas to polars and remove `reference_group_name_if_array` argument --- src/rtichoke/calibration/calibration.py | 50 ++++++++++--------------- 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index a3bc0fb..c369f93 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -211,38 +211,23 @@ def _create_plotly_curve_from_calibration_curve_list( col=1, ) - for reference_group in list(calibration_curve_list["colors_dictionary"].keys()): - if any( - calibration_curve_list["histogram_for_calibration"]["reference_group"] - == reference_group - ): + hist = calibration_curve_list["histogram_for_calibration"] + + for reference_group in calibration_curve_list["group_colors_vec"].keys(): + hist_sub = hist.filter(pl.col("reference_group") == reference_group) + if hist_sub.height == 0: + continue + calibration_curve.add_trace( go.Bar( - x=calibration_curve_list["histogram_for_calibration"]["mids"][ - calibration_curve_list["histogram_for_calibration"][ - "reference_group" - ] - == reference_group - ], - y=calibration_curve_list["histogram_for_calibration"]["counts"][ - calibration_curve_list["histogram_for_calibration"][ - "reference_group" - ] - == reference_group - ], - # hovertext=calibration_curve_list["histogram_for_calibration"][ - # "text" - # ][ - # calibration_curve_list["histogram_for_calibration"][ - # "reference_group" - # ] - # == reference_group - # ], + x=hist_sub.get_column("mids").to_list(), + y=hist_sub.get_column("counts").to_list(), + hovertext=hist_sub.get_column("text").to_list(), name=reference_group, width=0.01, legendgroup=reference_group, - # hoverinfo="text", - marker_color=calibration_curve_list["colors_dictionary"][ + hoverinfo="text", + marker_color=calibration_curve_list["group_colors_vec"][ reference_group ][0], showlegend=False, @@ -288,7 +273,6 @@ def _make_deciles_dat_binary( probs: Dict[str, np.ndarray], reals: Union[np.ndarray, Dict[str, np.ndarray]], n_bins: int = 10, - reference_group_name_if_array: str = "overall", ) -> pl.DataFrame: if isinstance(reals, dict): reference_groups_keys = list(reals.keys()) @@ -339,7 +323,7 @@ def _make_deciles_dat_binary( frames.append( pl.DataFrame( { - "reference_group": reference_group_name_if_array, + "reference_group": model, "model": model, "prob": p.astype(float, copy=False), "real": y.astype(float, copy=False), @@ -426,8 +410,12 @@ def _create_calibration_curve_list( reference_groups, color_values, performance_type ) + print("histogram for calibration") + histogram_for_calibration = _create_histogram_for_calibration(probs) + print(histogram_for_calibration) + limits = _define_limits_for_calibration_plot(deciles_data) axes_ranges = {"xaxis": limits, "yaxis": limits} @@ -499,9 +487,9 @@ def _create_histogram_for_calibration(probs: Dict[str, np.ndarray]) -> pl.DataFr ) hist_dfs.append(hist_df) - histogram_for_calibration = pl.concat(hist_dfs) + histogram_for_calibration = pl.concat(hist_dfs) - return histogram_for_calibration + return histogram_for_calibration def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float]: From 1c6f216c1597737196b83134face45b7790803a0 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 22 Dec 2025 16:11:27 +0200 Subject: [PATCH 07/17] fix: close #253 --- src/rtichoke/calibration/calibration.py | 35 ++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index c369f93..6b625f4 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -147,9 +147,12 @@ def _create_plotly_curve_from_calibration_curve_list( ) if calibration_type == "discrete": - print(calibration_curve_list["deciles_dat"]) - - for reference_group in calibration_curve_list["colors_dictionary"].keys(): + reference_groups = [ + k + for k in calibration_curve_list["colors_dictionary"].keys() + if k != "reference_line" + ] + for reference_group in reference_groups: dec_sub = calibration_curve_list["deciles_dat"].filter( pl.col("reference_group") == reference_group ) @@ -176,6 +179,32 @@ def _create_plotly_curve_from_calibration_curve_list( col=1, ) + hist = calibration_curve_list["histogram_for_calibration"] + + for reference_group in reference_groups: + hist_sub = hist.filter(pl.col("reference_group") == reference_group) + if hist_sub.height == 0: + continue + + calibration_curve.add_trace( + go.Bar( + x=hist_sub.get_column("mids").to_list(), + y=hist_sub.get_column("counts").to_list(), + hovertext=hist_sub.get_column("text").to_list(), + name=reference_group, + width=0.01, + legendgroup=reference_group, + hoverinfo="text", + marker_color=calibration_curve_list["colors_dictionary"][ + reference_group + ][0], + showlegend=False, + opacity=0.4, + ), + row=2, + col=1, + ) + if calibration_type == "smooth": for reference_group in list(calibration_curve_list["colors_dictionary"].keys()): if any( From e1d9c32e73f51b19100a65c0c4a620026527c152 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 22 Dec 2025 19:10:49 +0200 Subject: [PATCH 08/17] feat: close #252 --- src/rtichoke/calibration/calibration.py | 120 +++++++++++++++++++++--- 1 file changed, 109 insertions(+), 11 deletions(-) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 6b625f4..7056faa 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -129,10 +129,10 @@ def _create_plotly_curve_from_calibration_curve_list( go.Scatter( x=calibration_curve_list["reference_data"]["x"], y=calibration_curve_list["reference_data"]["y"], - # hovertext=calibration_curve_list["reference_data"]["text"], + hovertext=calibration_curve_list["reference_data"]["text"], name="Perfectly Calibrated", legendgroup="Perfectly Calibrated", - # hoverinfo="text", + hoverinfo="text", line={ "width": 2, "dash": "dot", @@ -163,10 +163,10 @@ def _create_plotly_curve_from_calibration_curve_list( go.Scatter( x=dec_sub.get_column("x").to_list(), y=dec_sub.get_column("y").to_list(), - # hovertext=dec_sub.get_column("text").to_list(), + hovertext=dec_sub.get_column("text").to_list(), name=reference_group, legendgroup=reference_group, - # hoverinfo="text", + hoverinfo="text", mode="lines+markers", marker={ "size": 10, @@ -221,13 +221,13 @@ def _create_plotly_curve_from_calibration_curve_list( calibration_curve_list["smooth_dat"]["reference_group"] == reference_group ], - # hovertext=calibration_curve_list["smooth_dat"]["text"][ - # calibration_curve_list["smooth_dat"]["reference_group"] - # == reference_group - # ], + hovertext=calibration_curve_list["smooth_dat"]["text"][ + calibration_curve_list["smooth_dat"]["reference_group"] + == reference_group + ], name=reference_group, legendgroup=reference_group, - # hoverinfo="text", + hoverinfo="text", mode="lines", marker={ "size": 10, @@ -382,6 +382,7 @@ def _make_deciles_dat_binary( pl.len().alias("n"), pl.mean("prob").alias("x"), pl.mean("real").alias("y"), + pl.sum("real").alias("n_reals"), ] ) .sort(["reference_group", "model", "decile"]) @@ -431,6 +432,49 @@ def _create_calibration_curve_list( performance_type = _check_performance_type_by_probs_and_reals(probs, reals) + if performance_type != "one model": + deciles_data = deciles_data.with_columns( + pl.concat_str( + [ + pl.lit(""), + pl.col("reference_group"), + pl.lit("
Predicted: "), + pl.col("x").map_elements( + lambda x: f"{x:.3f}", return_dtype=pl.Utf8 + ), + pl.lit("
Observed: "), + pl.col("y").map_elements( + lambda y: f"{y:.3f}", return_dtype=pl.Utf8 + ), + pl.lit(" ( "), + pl.col("n_reals").cast(pl.Int64).cast(pl.Utf8), + pl.lit(" / "), + pl.col("n").cast(pl.Utf8), + pl.lit(" )"), + ] + ).alias("text") + ) + else: + deciles_data = deciles_data.with_columns( + pl.concat_str( + [ + pl.lit("Predicted: "), + pl.col("x").map_elements( + lambda x: f"{x:.3f}", return_dtype=pl.Utf8 + ), + pl.lit("
Observed: "), + pl.col("y").map_elements( + lambda y: f"{y:.3f}", return_dtype=pl.Utf8 + ), + pl.lit(" ( "), + pl.col("n_reals").cast(pl.Int64).cast(pl.Utf8), + pl.lit(" / "), + pl.col("n").cast(pl.Utf8), + pl.lit(" )"), + ] + ).alias("text") + ) + reference_data = _create_reference_data_for_calibration_curve() reference_groups = deciles_data["reference_group"].unique().to_list() @@ -470,15 +514,69 @@ def _create_reference_data_for_calibration_curve() -> pl.DataFrame: pl.concat_str( [ pl.lit("Perfectly Calibrated
Predicted: "), - pl.col("x").round(3).cast(pl.Utf8), + pl.col("x").map_elements(lambda x: f"{x:.3f}", return_dtype=pl.Utf8), pl.lit("
Observed: "), - pl.col("y").round(3).cast(pl.Utf8), + pl.col("y").map_elements(lambda y: f"{y:.3f}", return_dtype=pl.Utf8), ] ).alias("text") ) return reference_data +def _calculate_smooth_curve( + deciles_dat: pl.DataFrame, performance_type: str +) -> pl.DataFrame: + """ + Calculate the smoothed calibration curve using lowess. + """ + smooth_frames = [] + for group in deciles_dat["reference_group"].unique(): + group_data = deciles_dat.filter(pl.col("reference_group") == group) + # Assuming lowess is available from statsmodels + from statsmodels.nonparametric.smoothers_lowess import lowess + + smoothed = lowess(group_data["y"], group_data["x"], frac=0.5) + smooth_df = pl.DataFrame({"x": smoothed[:, 0], "y": smoothed[:, 1]}) + smooth_df = smooth_df.with_columns(pl.lit(group).alias("reference_group")) + smooth_frames.append(smooth_df) + + smooth_dat = pl.concat(smooth_frames) + + if performance_type != "one model": + smooth_dat = smooth_dat.with_columns( + pl.concat_str( + [ + pl.lit(""), + pl.col("reference_group"), + pl.lit("
Predicted: "), + pl.col("x").map_elements( + lambda x: f"{x:.3f}", return_dtype=pl.Utf8 + ), + pl.lit("
Observed: "), + pl.col("y").map_elements( + lambda y: f"{y:.3f}", return_dtype=pl.Utf8 + ), + ] + ).alias("text") + ) + else: + smooth_dat = smooth_dat.with_columns( + pl.concat_str( + [ + pl.lit("Predicted: "), + pl.col("x").map_elements( + lambda x: f"{x:.3f}", return_dtype=pl.Utf8 + ), + pl.lit("
Observed: "), + pl.col("y").map_elements( + lambda y: f"{y:.3f}", return_dtype=pl.Utf8 + ), + ] + ).alias("text") + ) + return smooth_dat + + def _create_colors_dictionary_for_calibration( reference_groups: List[str], color_values: List[str], From be49a503605766c5e0530f0627738f8b9231df5f Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 22 Dec 2025 19:41:31 +0200 Subject: [PATCH 09/17] feat: close #255 --- src/rtichoke/__init__.py | 6 +-- src/rtichoke/calibration/calibration.py | 51 +++++++------------------ 2 files changed, 17 insertions(+), 40 deletions(-) diff --git a/src/rtichoke/__init__.py b/src/rtichoke/__init__.py index cc5b47a..97fbc14 100644 --- a/src/rtichoke/__init__.py +++ b/src/rtichoke/__init__.py @@ -30,9 +30,9 @@ ) from rtichoke.discrimination.gains import plot_gains_curve as plot_gains_curve -# from rtichoke.calibration.calibration import ( -# create_calibration_curve as create_calibration_curve, -# ) +from rtichoke.calibration.calibration import ( + create_calibration_curve as create_calibration_curve, +) from rtichoke.utility.decision import ( create_decision_curve as create_decision_curve, diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 7056faa..fd1977e 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -2,7 +2,7 @@ A module for Calibration Curves """ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union # import pandas as pd import plotly.graph_objects as go @@ -15,11 +15,11 @@ def create_calibration_curve( - probs: Dict[str, List[float]], - reals: Dict[str, List[int]], + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], calibration_type: str = "discrete", - size: Optional[int] = None, - color_values: Optional[List[str]] = [ + size: int = 600, + color_values: List[str] = [ "#1b9e77", "#d95f02", "#7570b3", @@ -41,7 +41,6 @@ def create_calibration_curve( "#D1603D", "#585123", ], - url_api: str = "http://localhost:4242/", ) -> Figure: """Creates Calibration Curve @@ -58,37 +57,15 @@ def create_calibration_curve( """ pass - # rtichoke_response = send_requests_to_rtichoke_r( - # dictionary_to_send={ - # "probs": probs, - # "reals": reals, - # "size": size, - # "color_values ": color_values, - # }, - # url_api=url_api, - # endpoint="create_calibration_curve_list", - # ) - - # calibration_curve_list = rtichoke_response.json() - - # calibration_curve_list["deciles_dat"] = pd.DataFrame.from_dict( - # calibration_curve_list["deciles_dat"] - # ) - # calibration_curve_list["smooth_dat"] = pd.DataFrame.from_dict( - # calibration_curve_list["smooth_dat"] - # ) - # calibration_curve_list["reference_data"] = pd.DataFrame.from_dict( - # calibration_curve_list["reference_data"] - # ) - # calibration_curve_list["histogram_for_calibration"] = pd.DataFrame.from_dict( - # calibration_curve_list["histogram_for_calibration"] - # ) - - # calibration_curve = create_plotly_curve_from_calibration_curve_list( - # calibration_curve_list=calibration_curve_list, calibration_type=calibration_type - # ) - - # return calibration_curve + calibration_curve_list = _create_calibration_curve_list( + probs, reals, size=size, color_values=color_values + ) + + calibration_curve = _create_plotly_curve_from_calibration_curve_list( + calibration_curve_list, calibration_type=calibration_type + ) + + return calibration_curve def _create_plotly_curve_from_calibration_curve_list( From e0032db7b955990ae875379ddf2cb990ccbce6e5 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Tue, 23 Dec 2025 13:38:37 +0200 Subject: [PATCH 10/17] build: Add AGENTS.md file --- AGENTS.md | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..aa1968c --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,69 @@ +# rtichoke Agent Information + +This document provides guidance for AI agents working on the `rtichoke` repository. + +## Development Environment + +To set up the development environment, follow these steps: + +1. **Install `uv`**: If you don't have `uv` installed, please follow the official installation instructions. +2. **Create a virtual environment**: Use `uv venv` to create a virtual environment. +3. **Install dependencies**: Install the project dependencies, including the `dev` dependencies, with the following command: + + ```bash + uv pip install -e .[dev] + ``` + +## Running Tests + +The test suite is run using `pytest`. To run the tests, use the following command: + +```bash +uv run pytest +``` + +## Coding Conventions + +### Functional Programming + +Strive to use a functional programming style as much as possible. Avoid side effects and mutable state where practical. + +### Docstrings + +All exported functions must have NumPy-style docstrings. This is to ensure that the documentation is clear, consistent, and can be easily parsed by tools like `quartodoc`. + +Example of a NumPy-style docstring: + +```python +def my_function(param1, param2): + """Summary of the function's purpose. + + Parameters + ---------- + param1 : int + Description of the first parameter. + param2 : str + Description of the second parameter. + + Returns + ------- + bool + Description of the return value. + """ + # function body + return True +``` + +## Pre-commit Hooks + +This repository uses pre-commit hooks to ensure code quality and consistency. The following hooks are configured: + +* **`ruff-check`**: A linter to check for common errors and style issues. +* **`ruff-format`**: A code formatter to ensure a consistent code style. +* **`uv-lock`**: A hook to keep the `uv.lock` file up to date. + +Before committing, please ensure that the pre-commit hooks pass. You can run them manually on all files with `pre-commit run --all-files`. + +## Documentation + +The documentation for this project is built using `quartodoc`. The documentation is automatically built and deployed via GitHub Actions. There is no need to build the documentation manually. From 2672757441f885f394d541c9beed1adeb983e6b0 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 12:48:15 +0000 Subject: [PATCH 11/17] feat: Add smoothed calibration curve Adds a new `smooth` option to the `calibration_type` parameter in the `create_calibration_curve` function. This option generates a smoothed calibration curve using a LOESS smoother, replicating the functionality of the R `rtichoke` package. The implementation includes: - Adding `statsmodels` as a new dependency for the LOESS implementation. - A new helper function `_calculate_smooth_curve` to perform the smoothing and interpolation. - Updates to the plotting logic to render the smoothed curve. - New tests to verify the functionality of the smoothed curve, including an edge case for a single unique probability value. - A fix for a `DuplicateError` in the existing decile calculation when handling uniform probability distributions. --- pyproject.toml | 1 + src/rtichoke/calibration/calibration.py | 129 +++++++++++++++--------- tests/test_calibration.py | 29 ++++++ uv.lock | 4 +- 4 files changed, 115 insertions(+), 48 deletions(-) create mode 100644 tests/test_calibration.py diff --git a/pyproject.toml b/pyproject.toml index 7ff976a..d9c3f98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "polarstate==0.1.8", "marimo>=0.17.0", "pyarrow>=21.0.0", + "statsmodels>=0.14.0", ] name = "rtichoke" version = "0.1.25" diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index fd1977e..0389931 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -183,43 +183,43 @@ def _create_plotly_curve_from_calibration_curve_list( ) if calibration_type == "smooth": - for reference_group in list(calibration_curve_list["colors_dictionary"].keys()): - if any( - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ): - calibration_curve.add_trace( - go.Scatter( - x=calibration_curve_list["smooth_dat"]["x"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ], - y=calibration_curve_list["smooth_dat"]["y"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ], - hovertext=calibration_curve_list["smooth_dat"]["text"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ], - name=reference_group, - legendgroup=reference_group, - hoverinfo="text", - mode="lines", - marker={ - "size": 10, - "color": calibration_curve_list["colors_dictionary"][ - reference_group - ][0], - }, - ), - row=1, - col=1, - ) + smooth_dat = calibration_curve_list["smooth_dat"] + reference_groups = [ + k + for k in calibration_curve_list["colors_dictionary"].keys() + if k != "reference_line" + ] + + for reference_group in reference_groups: + smooth_sub = smooth_dat.filter(pl.col("reference_group") == reference_group) + if smooth_sub.height == 0: + continue + + mode = "lines+markers" if smooth_sub.height == 1 else "lines" + + calibration_curve.add_trace( + go.Scatter( + x=smooth_sub.get_column("x").to_list(), + y=smooth_sub.get_column("y").to_list(), + hovertext=smooth_sub.get_column("text").to_list(), + name=reference_group, + legendgroup=reference_group, + hoverinfo="text", + mode=mode, + marker={ + "size": 10, + "color": calibration_curve_list["colors_dictionary"][ + reference_group + ][0], + }, + ), + row=1, + col=1, + ) hist = calibration_curve_list["histogram_for_calibration"] - for reference_group in calibration_curve_list["group_colors_vec"].keys(): + for reference_group in reference_groups: hist_sub = hist.filter(pl.col("reference_group") == reference_group) if hist_sub.height == 0: continue @@ -233,11 +233,11 @@ def _create_plotly_curve_from_calibration_curve_list( width=0.01, legendgroup=reference_group, hoverinfo="text", - marker_color=calibration_curve_list["group_colors_vec"][ + marker_color=calibration_curve_list["colors_dictionary"][ reference_group ][0], showlegend=False, - opacity=calibration_curve_list["histogram_opacity"][0], + opacity=0.4, ), row=2, col=1, @@ -346,7 +346,7 @@ def _make_deciles_dat_binary( pl.col("prob").cast(pl.Float64), pl.col("real").cast(pl.Float64), pl.col("prob") - .qcut(n_bins, labels=labels) + .qcut(n_bins, labels=labels, allow_duplicates=True) .over(["reference_group", "model"]) .alias("decile"), ] @@ -469,9 +469,11 @@ def _create_calibration_curve_list( limits = _define_limits_for_calibration_plot(deciles_data) axes_ranges = {"xaxis": limits, "yaxis": limits} + smooth_dat = _calculate_smooth_curve(probs, reals, performance_type) + calibration_curve_list = { "deciles_dat": deciles_data, - # "smooth_dat": smooth_dat, + "smooth_dat": smooth_dat, "reference_data": reference_data, "histogram_for_calibration": histogram_for_calibration, # "histogram_opacity": [0.4], @@ -501,21 +503,53 @@ def _create_reference_data_for_calibration_curve() -> pl.DataFrame: def _calculate_smooth_curve( - deciles_dat: pl.DataFrame, performance_type: str + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + performance_type: str, ) -> pl.DataFrame: """ Calculate the smoothed calibration curve using lowess. """ + from statsmodels.nonparametric.smoothers_lowess import lowess + smooth_frames = [] - for group in deciles_dat["reference_group"].unique(): - group_data = deciles_dat.filter(pl.col("reference_group") == group) - # Assuming lowess is available from statsmodels - from statsmodels.nonparametric.smoothers_lowess import lowess - smoothed = lowess(group_data["y"], group_data["x"], frac=0.5) - smooth_df = pl.DataFrame({"x": smoothed[:, 0], "y": smoothed[:, 1]}) - smooth_df = smooth_df.with_columns(pl.lit(group).alias("reference_group")) - smooth_frames.append(smooth_df) + # Helper function to process a single probability and real array + def process_single_array(p, r, group_name): + if len(np.unique(p)) == 1: + return pl.DataFrame( + {"x": [np.unique(p)[0]], "y": [np.mean(r)], "reference_group": [group_name]} + ) + else: + # lowess returns a 2D array where the first column is x and the second is y + smoothed = lowess(r, p, it=0) + xout = np.linspace(0, 1, 101) + yout = np.interp(xout, smoothed[:, 0], smoothed[:, 1]) + return pl.DataFrame( + {"x": xout, "y": yout, "reference_group": [group_name] * len(xout)} + ) + + if isinstance(reals, dict): + for model_name, prob_array in probs.items(): + # This logic assumes that for multiple populations, one model's probs are evaluated against multiple real outcomes. + # This might need adjustment based on the exact structure for multiple models and populations. + if len(probs) == 1 and len(reals) > 1: # One model, multiple populations + for pop_name, real_array in reals.items(): + frame = process_single_array(prob_array, real_array, pop_name) + smooth_frames.append(frame) + else: # Multiple models, potentially multiple populations + for group_name in reals.keys(): + if group_name in probs: + frame = process_single_array(probs[group_name], reals[group_name], group_name) + smooth_frames.append(frame) + + else: # reals is a single numpy array + for group_name, prob_array in probs.items(): + frame = process_single_array(prob_array, reals, group_name) + smooth_frames.append(frame) + + if not smooth_frames: + return pl.DataFrame(schema={"x": pl.Float64, "y": pl.Float64, "reference_group": pl.Utf8, "text": pl.Utf8}) smooth_dat = pl.concat(smooth_frames) @@ -554,6 +588,7 @@ def _calculate_smooth_curve( return smooth_dat + def _create_colors_dictionary_for_calibration( reference_groups: List[str], color_values: List[str], diff --git a/tests/test_calibration.py b/tests/test_calibration.py new file mode 100644 index 0000000..3729c4d --- /dev/null +++ b/tests/test_calibration.py @@ -0,0 +1,29 @@ + +import numpy as np +import polars as pl +from rtichoke.calibration.calibration import create_calibration_curve + +def test_create_calibration_curve_smooth(): + probs = {"model_1": np.linspace(0, 1, 100)} + reals = np.random.randint(0, 2, 100) + fig = create_calibration_curve(probs, reals, calibration_type="smooth") + + # Check if the figure has the correct number of traces (smooth curve, histogram, and reference line) + assert len(fig.data) == 3 + + # Check reference line data + reference_line = fig.data[0] + assert reference_line.name == "Perfectly Calibrated" + + +def test_create_calibration_curve_smooth_single_point(): + probs = {"model_1": np.array([0.5] * 100)} + reals = np.random.randint(0, 2, 100) + fig = create_calibration_curve(probs, reals, calibration_type="smooth") + + # Check that the plot mode is "lines+markers" + assert fig.data[1].mode == "lines+markers" + + # Check histogram data + histogram = fig.data[2] + assert histogram.type == "bar" diff --git a/uv.lock b/uv.lock index 7cf5135..1a9f3e2 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.13'", @@ -3899,6 +3899,7 @@ dependencies = [ { name = "polarstate" }, { name = "pyarrow", version = "21.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pyarrow", version = "22.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "statsmodels" }, { name = "typing" }, ] @@ -3933,6 +3934,7 @@ requires-dist = [ { name = "polars", specifier = ">=1.28.0" }, { name = "polarstate", specifier = "==0.1.8" }, { name = "pyarrow", specifier = ">=21.0.0" }, + { name = "statsmodels", specifier = ">=0.14.0" }, { name = "typing", specifier = ">=3.7.4.3" }, ] From 5cf0a968d2da20c69086d610000ae06842387186 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 15:04:20 +0000 Subject: [PATCH 12/17] feat: Add create_calibration_curve_times function This commit introduces a new function, `create_calibration_curve_times`, for generating time-dependent calibration curves. This function provides an interactive calibration curve plot with a slider to adjust the time horizon. It also handles administrative censoring and various data filtering heuristics for competing risks. Key changes: - Added `create_calibration_curve_times` function to `rtichoke/calibration/calibration.py`. - Implemented helper functions for data preparation (`_create_calibration_curve_list_times`) and plotting (`_create_plotly_curve_from_calibration_curve_list_times`). - Added a new helper function `_apply_heuristics_and_censoring` to handle censoring and competing risk logic. - Added unit tests for the new heuristic logic in `tests/test_heuristics.py`. - Added a smoke test for the new `create_calibration_curve_times` function in `tests/test_calibration_times.py`. - Refactored hover text logic into a shared helper function to reduce code duplication. --- src/rtichoke/calibration/__init__.py | 3 + src/rtichoke/calibration/calibration.py | 469 +++++++++++++++++++++--- tests/test_calibration_times.py | 23 ++ tests/test_heuristics.py | 66 ++++ 4 files changed, 518 insertions(+), 43 deletions(-) create mode 100644 tests/test_calibration_times.py create mode 100644 tests/test_heuristics.py diff --git a/src/rtichoke/calibration/__init__.py b/src/rtichoke/calibration/__init__.py index 4267999..c4b9dcb 100644 --- a/src/rtichoke/calibration/__init__.py +++ b/src/rtichoke/calibration/__init__.py @@ -1,3 +1,6 @@ """ Subpackage for Calibration """ +from .calibration import create_calibration_curve, create_calibration_curve_times + +__all__ = ["create_calibration_curve", "create_calibration_curve_times"] diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 0389931..da33490 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -68,6 +68,179 @@ def create_calibration_curve( return calibration_curve +def create_calibration_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: List[float], + heuristics_sets: List[Dict[str, str]], + calibration_type: str = "discrete", + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Creates a time-dependent Calibration Curve with a slider for different time horizons.""" + + calibration_curve_list_times = _create_calibration_curve_list_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + size=size, + color_values=color_values, + ) + + fig = _create_plotly_curve_from_calibration_curve_list_times( + calibration_curve_list_times, calibration_type=calibration_type + ) + + return fig + + +def _create_plotly_curve_from_calibration_curve_list_times( + calibration_curve_list: Dict[str, Any], calibration_type: str = "discrete" +) -> Figure: + """ + Creates a plotly figure for time-dependent calibration curves. + """ + fig = make_subplots( + rows=2, cols=1, shared_xaxes=True, x_title="Predicted", row_heights=[0.8, 0.2] + ) + + initial_horizon = calibration_curve_list["fixed_time_horizons"][0] + + # Add traces for each horizon, initially visible only for the first horizon + for horizon in calibration_curve_list["fixed_time_horizons"]: + visible = horizon == initial_horizon + + # Reference Line + fig.add_trace( + go.Scatter( + x=calibration_curve_list["reference_data"]["x"], + y=calibration_curve_list["reference_data"]["y"], + hovertext=calibration_curve_list["reference_data"]["text"], + name="Perfectly Calibrated", + legendgroup="Perfectly Calibrated", + hoverinfo="text", + line={"width": 2, "dash": "dot", "color": "#BEBEBE"}, + showlegend=False, + visible=visible, + ), + row=1, col=1, + ) + + for group in calibration_curve_list["reference_group_keys"]: + color = calibration_curve_list["colors_dictionary"][group][0] + + # Calibration curve (discrete or smooth) + if calibration_type == "discrete": + data_subset = calibration_curve_list["deciles_dat"].filter( + (pl.col("reference_group") == group) & (pl.col("fixed_time_horizon") == horizon) + ) + mode = "lines+markers" + else: # smooth + data_subset = calibration_curve_list["smooth_dat"].filter( + (pl.col("reference_group") == group) & (pl.col("fixed_time_horizon") == horizon) + ) + mode = "lines+markers" if data_subset.height == 1 else "lines" + + fig.add_trace( + go.Scatter( + x=data_subset["x"], + y=data_subset["y"], + hovertext=data_subset["text"], + name=group, + legendgroup=group, + hoverinfo="text", + mode=mode, + marker={"size": 10, "color": color}, + visible=visible, + ), + row=1, col=1, + ) + + # Histogram + hist_subset = calibration_curve_list["histogram_for_calibration"].filter( + (pl.col("reference_group") == group) & (pl.col("fixed_time_horizon") == horizon) + ) + fig.add_trace( + go.Bar( + x=hist_subset["mids"], + y=hist_subset["counts"], + hovertext=hist_subset["text"], + name=group, + width=0.01, + legendgroup=group, + hoverinfo="text", + marker_color=color, + showlegend=False, + opacity=0.4, + visible=visible, + ), + row=2, col=1, + ) + + # Create slider + steps = [] + num_traces_per_horizon = 1 + 2 * len(calibration_curve_list["reference_group_keys"]) + + for i, horizon in enumerate(calibration_curve_list["fixed_time_horizons"]): + step = dict( + method="restyle", + args=[{"visible": [False] * (num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"]))}], + label=str(horizon), + ) + for j in range(num_traces_per_horizon): + step["args"][0]["visible"][i * num_traces_per_horizon + j] = True + steps.append(step) + + sliders = [dict( + active=0, + currentvalue={"prefix": "Time Horizon: "}, + pad={"t": 50}, + steps=steps, + )] + + # Layout + fig.update_layout( + sliders=sliders, + xaxis={"showgrid": False, "range": calibration_curve_list["axes_ranges"]["xaxis"]}, + yaxis={"showgrid": False, "range": calibration_curve_list["axes_ranges"]["yaxis"], "title": "Observed"}, + barmode="overlay", + plot_bgcolor="rgba(0, 0, 0, 0)", + legend={ + "orientation": "h", "xanchor": "center", "yanchor": "top", + "x": 0.5, "y": 1.3, "bgcolor": "rgba(0, 0, 0, 0)", + }, + showlegend=calibration_curve_list["performance_type"][0] != "one model", + width=calibration_curve_list["size"][0][0], + height=calibration_curve_list["size"][0][0], + ) + + return fig + + def _create_plotly_curve_from_calibration_curve_list( calibration_curve_list: Dict[str, Any], calibration_type: str = "discrete" ) -> Figure: @@ -406,51 +579,12 @@ def _create_calibration_curve_list( ], ) -> Dict[str, Any]: deciles_data = _make_deciles_dat_binary(probs, reals) - performance_type = _check_performance_type_by_probs_and_reals(probs, reals) + smooth_dat = _calculate_smooth_curve(probs, reals, performance_type) - if performance_type != "one model": - deciles_data = deciles_data.with_columns( - pl.concat_str( - [ - pl.lit(""), - pl.col("reference_group"), - pl.lit("
Predicted: "), - pl.col("x").map_elements( - lambda x: f"{x:.3f}", return_dtype=pl.Utf8 - ), - pl.lit("
Observed: "), - pl.col("y").map_elements( - lambda y: f"{y:.3f}", return_dtype=pl.Utf8 - ), - pl.lit(" ( "), - pl.col("n_reals").cast(pl.Int64).cast(pl.Utf8), - pl.lit(" / "), - pl.col("n").cast(pl.Utf8), - pl.lit(" )"), - ] - ).alias("text") - ) - else: - deciles_data = deciles_data.with_columns( - pl.concat_str( - [ - pl.lit("Predicted: "), - pl.col("x").map_elements( - lambda x: f"{x:.3f}", return_dtype=pl.Utf8 - ), - pl.lit("
Observed: "), - pl.col("y").map_elements( - lambda y: f"{y:.3f}", return_dtype=pl.Utf8 - ), - pl.lit(" ( "), - pl.col("n_reals").cast(pl.Int64).cast(pl.Utf8), - pl.lit(" / "), - pl.col("n").cast(pl.Utf8), - pl.lit(" )"), - ] - ).alias("text") - ) + deciles_data, smooth_dat = _add_hover_text_to_calibration_data( + deciles_data, smooth_dat, performance_type + ) reference_data = _create_reference_data_for_calibration_curve() @@ -588,6 +722,42 @@ def process_single_array(p, r, group_name): return smooth_dat +def _add_hover_text_to_calibration_data( + deciles_dat: pl.DataFrame, + smooth_dat: pl.DataFrame, + performance_type: str, +) -> (pl.DataFrame, pl.DataFrame): + """Adds hover text to the deciles and smooth dataframes.""" + if performance_type != "one model": + deciles_dat = deciles_dat.with_columns( + pl.concat_str([ + pl.lit(""), pl.col("reference_group"), pl.lit("
Predicted: "), + pl.col("x").round(3), pl.lit("
Observed: "), pl.col("y").round(3), + pl.lit(" ( "), pl.col("n_reals"), pl.lit(" / "), pl.col("n"), pl.lit(" )"), + ]).alias("text") + ) + smooth_dat = smooth_dat.with_columns( + pl.concat_str([ + pl.lit(""), pl.col("reference_group"), pl.lit("
Predicted: "), + pl.col("x").round(3), pl.lit("
Observed: "), pl.col("y").round(3), + ]).alias("text") + ) + else: + deciles_dat = deciles_dat.with_columns( + pl.concat_str([ + pl.lit("Predicted: "), pl.col("x").round(3), pl.lit("
Observed: "), + pl.col("y").round(3), pl.lit(" ( "), pl.col("n_reals"), + pl.lit(" / "), pl.col("n"), pl.lit(" )"), + ]).alias("text") + ) + smooth_dat = smooth_dat.with_columns( + pl.concat_str([ + pl.lit("Predicted: "), pl.col("x").round(3), + pl.lit("
Observed: "), pl.col("y").round(3), + ]).alias("text") + ) + return deciles_dat, smooth_dat + def _create_colors_dictionary_for_calibration( reference_groups: List[str], @@ -642,3 +812,216 @@ def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float lower_bound - (upper_bound - lower_bound) * 0.05, upper_bound + (upper_bound - lower_bound) * 0.05, ] + + +def _build_initial_df_for_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], +) -> pl.DataFrame: + """Builds the initial DataFrame for time-dependent calibration curves.""" + + # Case 1: Multiple populations (reals is a dict) + if isinstance(reals, dict): + if not isinstance(times, dict): + raise TypeError("If reals is a dict, times must also be a dict.") + + # Unnest reals and times dictionaries into a long DataFrame + reals_df = pl.DataFrame( + [ + pl.Series("reference_group", list(reals.keys())), + pl.Series("real", list(reals.values())), + pl.Series("time", list(times.values())), + ] + ).explode(["real", "time"]) + + # Unnest probs and join + probs_df = pl.DataFrame( + [ + pl.Series("model", list(probs.keys())), + pl.Series("prob", list(probs.values())), + ] + ).explode("prob") + + # If one model for many populations, cross join + if len(probs) == 1 and len(reals) > 1: + return reals_df.join(probs_df, how="cross") + else: # otherwise, assume a 1-to-1 mapping on reference_group/model + return reals_df.join(probs_df, left_on="reference_group", right_on="model") + + # Case 2: Single population (reals is an array) + else: + if not isinstance(times, np.ndarray): + raise TypeError("If reals is an array, times must also be an array.") + + base_df = pl.DataFrame({"real": reals, "time": times}) + + prob_frames = [] + for model_name, prob_array in probs.items(): + prob_frames.append( + base_df.with_columns( + pl.Series("prob", prob_array), + pl.lit(model_name).alias("reference_group") + ) + ) + return pl.concat(prob_frames) + + +def _apply_heuristics_and_censoring( + df: pl.DataFrame, + horizon: float, + censoring_heuristic: str, + competing_heuristic: str, +) -> pl.DataFrame: + """ + Applies censoring and competing risk heuristics to the data for a given time horizon. + """ + # Administrative censoring: outcomes after horizon are negative + df_adj = df.with_columns( + pl.when(pl.col("time") > horizon).then(0).otherwise(pl.col("real")).alias("real") + ) + + # Heuristics for events before or at horizon + if censoring_heuristic == "excluded": + df_adj = df_adj.filter(~((pl.col("real") == 0) & (pl.col("time") <= horizon))) + + if competing_heuristic == "excluded": + df_adj = df_adj.filter(~((pl.col("real") == 2) & (pl.col("time") <= horizon))) + elif competing_heuristic == "adjusted_as_negative": + df_adj = df_adj.with_columns( + pl.when((pl.col("real") == 2) & (pl.col("time") <= horizon)) + .then(0) + .otherwise(pl.col("real")) + .alias("real") + ) + elif competing_heuristic == "adjusted_as_composite": + df_adj = df_adj.with_columns( + pl.when((pl.col("real") == 2) & (pl.col("time") <= horizon)) + .then(1) + .otherwise(pl.col("real")) + .alias("real") + ) + + return df_adj + + +def _create_calibration_curve_list_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: List[float], + heuristics_sets: List[Dict[str, str]], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Dict[str, Any]: + """ + Creates the data structures needed for a time-dependent calibration curve plot. + """ + # Part 1: Prepare initial dataframe from inputs + initial_df = _build_initial_df_for_times(probs, reals, times) + + # Part 2: Iterate and generate calibration data for each horizon/heuristic + all_deciles = [] + all_smooth = [] + all_histograms = [] + + performance_type = _check_performance_type_by_probs_and_reals(probs, reals) + + for horizon in fixed_time_horizons: + for heuristics in heuristics_sets: + censoring_heuristic = heuristics["censoring_heuristic"] + competing_heuristic = heuristics["competing_heuristic"] + + if censoring_heuristic == "adjusted" or competing_heuristic == "adjusted_as_censored": + continue + + df_adj = _apply_heuristics_and_censoring( + initial_df, horizon, censoring_heuristic, competing_heuristic + ) + + if df_adj.height == 0: + continue + + # Re-create probs and reals dicts for helpers + probs_adj = { + group[0]: group_df["prob"].to_numpy() + for group, group_df in df_adj.group_by("reference_group") + } + reals_adj = { + group[0]: group_df["real"].to_numpy() + for group, group_df in df_adj.group_by("reference_group") + } + # If single population initially, reals_adj should be an array + if not isinstance(reals, dict) and len(probs) == 1: + reals_adj = next(iter(reals_adj.values())) + + + # Deciles + deciles_data = _make_deciles_dat_binary(probs_adj, reals_adj) + all_deciles.append(deciles_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))) + + # Smooth curve + smooth_data = _calculate_smooth_curve(probs_adj, reals_adj, performance_type) + all_smooth.append(smooth_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))) + + # Histogram + hist_data = _create_histogram_for_calibration(probs_adj) + all_histograms.append(hist_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))) + + + # Part 3: Combine results and create final dictionary + if not all_deciles: + raise ValueError("No data remaining after applying heuristics and time horizons.") + deciles_dat_final = pl.concat(all_deciles) + smooth_dat_final = pl.concat(all_smooth) + histogram_final = pl.concat(all_histograms) + + # Add hover text + deciles_dat_final, smooth_dat_final = _add_hover_text_to_calibration_data( + deciles_dat_final, smooth_dat_final, performance_type + ) + + + reference_data = _create_reference_data_for_calibration_curve() + reference_groups = deciles_dat_final["reference_group"].unique().to_list() + colors_dictionary = _create_colors_dictionary_for_calibration( + reference_groups, color_values, performance_type + ) + limits = _define_limits_for_calibration_plot(deciles_dat_final) + axes_ranges = {"xaxis": limits, "yaxis": limits} + + calibration_curve_list = { + "deciles_dat": deciles_dat_final, + "smooth_dat": smooth_dat_final, + "reference_data": reference_data, + "histogram_for_calibration": histogram_final, + "axes_ranges": axes_ranges, + "colors_dictionary": colors_dictionary, + "performance_type": [performance_type], + "size": [(size, size)], + "fixed_time_horizons": fixed_time_horizons, + "reference_group_keys": reference_groups + } + + return calibration_curve_list diff --git a/tests/test_calibration_times.py b/tests/test_calibration_times.py new file mode 100644 index 0000000..6820382 --- /dev/null +++ b/tests/test_calibration_times.py @@ -0,0 +1,23 @@ +import pytest +import numpy as np +import polars as pl +from rtichoke.calibration import create_calibration_curve_times + +def test_create_calibration_curve_times(): + probs = {"model_1": np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])} + reals = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + times = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + fixed_time_horizons = [5, 10] + heuristics_sets = [{"censoring_heuristic": "excluded", "competing_heuristic": "excluded"}] + + fig = create_calibration_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + ) + + assert fig is not None + assert len(fig.data) > 0 + assert len(fig.layout.sliders) > 0 diff --git a/tests/test_heuristics.py b/tests/test_heuristics.py new file mode 100644 index 0000000..ad78d24 --- /dev/null +++ b/tests/test_heuristics.py @@ -0,0 +1,66 @@ +import pytest +import polars as pl +from polars.testing import assert_frame_equal +from rtichoke.calibration.calibration import _apply_heuristics_and_censoring + +@pytest.fixture +def sample_data(): + return pl.DataFrame({ + "real": [1, 0, 2, 1, 2, 0, 1], + "time": [1, 2, 3, 8, 9, 10, 12], + }) + +def test_competing_as_negative_logic(sample_data): + # Heuristics that shouldn't change data before horizon + result = _apply_heuristics_and_censoring(sample_data, 15, "adjusted", "adjusted_as_negative") + # Competing events at times 3 and 9 should become 0. + expected = pl.DataFrame({ + "real": [1, 0, 0, 1, 0, 0, 1], + "time": [1, 2, 3, 8, 9, 10, 12], + }) + assert_frame_equal(result, expected) + +def test_admin_censoring(sample_data): + result = _apply_heuristics_and_censoring(sample_data, 7, "adjusted", "adjusted_as_negative") + # Admin censoring for times > 7. Competing event at time=3 becomes 0. + expected = pl.DataFrame({ + "real": [1, 0, 0, 0, 0, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + }) + assert_frame_equal(result, expected) + +def test_censoring_excluded(sample_data): + result = _apply_heuristics_and_censoring(sample_data, 10, "excluded", "adjusted_as_negative") + # Excludes censored at times 2, 10. Admin censors time > 10. Competing at 3,9 -> 0. + expected = pl.DataFrame({ + "real": [1, 0, 1, 0, 0], + "time": [1, 3, 8, 9, 12], + }) + assert_frame_equal(result.sort("time"), expected.sort("time")) + +def test_competing_excluded(sample_data): + result = _apply_heuristics_and_censoring(sample_data, 10, "adjusted", "excluded") + # Excludes competing at 3, 9. Admin censors time > 10. + expected = pl.DataFrame({ + "real": [1, 0, 1, 0, 0], + "time": [1, 2, 8, 10, 12], + }) + assert_frame_equal(result.sort("time"), expected.sort("time")) + +def test_competing_as_negative(sample_data): + result = _apply_heuristics_and_censoring(sample_data, 10, "adjusted", "adjusted_as_negative") + # Competing at 3,9 -> 0. Admin censors time > 10. + expected = pl.DataFrame({ + "real": [1, 0, 0, 1, 0, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + }) + assert_frame_equal(result, expected) + +def test_competing_as_composite(sample_data): + result = _apply_heuristics_and_censoring(sample_data, 10, "adjusted", "adjusted_as_composite") + # Competing at 3,9 -> 1. Admin censors time > 10. + expected = pl.DataFrame({ + "real": [1, 0, 1, 1, 1, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + }) + assert_frame_equal(result, expected) From b56de94faaf48e8f68cb79dc0a0ab9e836f3f5f9 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 15:26:57 +0000 Subject: [PATCH 13/17] fix: Improve data loading logic for calibration curves This commit improves the data loading logic for calibration curves to be more robust and flexible. The previous implementation was too strict about the types of input it accepted and was not robust enough in handling different data shapes. This commit corrects the data loading logic to handle different input types (lists and NumPy arrays) and data shapes for both single and multiple population scenarios. --- src/rtichoke/calibration/calibration.py | 105 +++++++++++++++--------- 1 file changed, 64 insertions(+), 41 deletions(-) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index da33490..2dd2212 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -821,50 +821,73 @@ def _build_initial_df_for_times( ) -> pl.DataFrame: """Builds the initial DataFrame for time-dependent calibration curves.""" - # Case 1: Multiple populations (reals is a dict) - if isinstance(reals, dict): - if not isinstance(times, dict): - raise TypeError("If reals is a dict, times must also be a dict.") - - # Unnest reals and times dictionaries into a long DataFrame - reals_df = pl.DataFrame( - [ - pl.Series("reference_group", list(reals.keys())), - pl.Series("real", list(reals.values())), - pl.Series("time", list(times.values())), - ] - ).explode(["real", "time"]) - - # Unnest probs and join - probs_df = pl.DataFrame( - [ - pl.Series("model", list(probs.keys())), - pl.Series("prob", list(probs.values())), - ] - ).explode("prob") - - # If one model for many populations, cross join - if len(probs) == 1 and len(reals) > 1: - return reals_df.join(probs_df, how="cross") - else: # otherwise, assume a 1-to-1 mapping on reference_group/model - return reals_df.join(probs_df, left_on="reference_group", right_on="model") + # Convert all inputs to dictionaries of arrays to unify processing + if not isinstance(reals, dict): + reals = {"single_population": np.asarray(reals)} + if not isinstance(times, dict): + times = {"single_population": np.asarray(times)} + + # Verify matching keys and lengths + if reals.keys() != times.keys(): + raise ValueError("Keys in reals and times dictionaries do not match.") + for key in reals: + if len(reals[key]) != len(times[key]): + raise ValueError(f"Length mismatch for population '{key}' in reals and times.") + + # Create a base DataFrame with population data + population_frames = [] + for key in reals: + population_frames.append( + pl.DataFrame({ + "reference_group": key, + "real": reals[key], + "time": times[key], + }) + ) + base_df = pl.concat(population_frames) + + # Prepare model predictions + # Single model case + if len(probs) == 1: + model_name, prob_array = next(iter(probs.items())) + if len(prob_array) != base_df.height: + raise ValueError(f"Length of probabilities for model '{model_name}' does not match total number of observations.") + return base_df.with_columns( + pl.Series("prob", prob_array), + pl.lit(model_name).alias("model") + ) - # Case 2: Single population (reals is an array) + # Multiple models else: - if not isinstance(times, np.ndarray): - raise TypeError("If reals is an array, times must also be an array.") - - base_df = pl.DataFrame({"real": reals, "time": times}) - - prob_frames = [] - for model_name, prob_array in probs.items(): - prob_frames.append( - base_df.with_columns( - pl.Series("prob", prob_array), - pl.lit(model_name).alias("reference_group") + # One model per population (keys must match) + if probs.keys() == reals.keys(): + prob_frames = [] + for model_name, prob_array in probs.items(): + pop_df = base_df.filter(pl.col("reference_group") == model_name) + if len(prob_array) != pop_df.height: + raise ValueError(f"Length of probabilities for model '{model_name}' does not match population size.") + prob_frames.append( + pop_df.with_columns( + pl.Series("prob", prob_array), + pl.lit(model_name).alias("model") + ) ) - ) - return pl.concat(prob_frames) + return pl.concat(prob_frames) + # Multiple models on a single population + elif len(reals) == 1: + final_frames = [] + for model_name, prob_array in probs.items(): + if len(prob_array) != base_df.height: + raise ValueError(f"Length of probabilities for model '{model_name}' does not match population size.") + final_frames.append( + base_df.with_columns( + pl.Series("prob", prob_array), + pl.lit(model_name).alias("reference_group") # Overwrite reference_group with model name + ) + ) + return pl.concat(final_frames) + + raise ValueError("Unsupported combination of probs, reals, and times structures.") def _apply_heuristics_and_censoring( From 41ff8c4b6108a2b2b07b832bb50a300eb47d1007 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 15:39:52 +0000 Subject: [PATCH 14/17] fix: Improve data loading logic for calibration curves This commit improves the data loading logic for calibration curves to be more robust and flexible. The previous implementation was too strict about the types of input it accepted and was not robust enough in handling different data shapes. This commit corrects the data loading logic to handle different input types (lists and NumPy arrays) and data shapes for both single and multiple population scenarios. From c5ad9257663e572e2392dc87f7d55bf284d1e5f1 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 16:58:48 +0000 Subject: [PATCH 15/17] feat: Pass ty check and update AGENTS.md This commit resolves all type errors reported by the `ty` type checker, ensuring that the project's type hints are correct and consistent. Additionally, it updates the `AGENTS.md` file to include instructions on how to run `ty` for future development. --- AGENTS.md | 8 ++++++++ src/rtichoke/calibration/calibration.py | 9 +++++---- src/rtichoke/helpers/plotly_helper_functions.py | 11 ++++++++++- src/rtichoke/helpers/sandbox_observable_helpers.py | 2 +- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index aa1968c..04f3079 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -67,3 +67,11 @@ Before committing, please ensure that the pre-commit hooks pass. You can run the ## Documentation The documentation for this project is built using `quartodoc`. The documentation is automatically built and deployed via GitHub Actions. There is no need to build the documentation manually. + +## Type Checking + +This project uses `ty` for type checking. To check for type errors, run the following command: + +```bash +uv run ty check src tests +``` diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 2dd2212..76c319b 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -206,13 +206,14 @@ def _create_plotly_curve_from_calibration_curve_list_times( num_traces_per_horizon = 1 + 2 * len(calibration_curve_list["reference_group_keys"]) for i, horizon in enumerate(calibration_curve_list["fixed_time_horizons"]): + visibility = [False] * (num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"])) + for j in range(num_traces_per_horizon): + visibility[i * num_traces_per_horizon + j] = True step = dict( method="restyle", - args=[{"visible": [False] * (num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"]))}], + args=[{"visible": visibility}], label=str(horizon), ) - for j in range(num_traces_per_horizon): - step["args"][0]["visible"][i * num_traces_per_horizon + j] = True steps.append(step) sliders = [dict( @@ -726,7 +727,7 @@ def _add_hover_text_to_calibration_data( deciles_dat: pl.DataFrame, smooth_dat: pl.DataFrame, performance_type: str, -) -> (pl.DataFrame, pl.DataFrame): +) -> tuple[pl.DataFrame, pl.DataFrame]: """Adds hover text to the deciles and smooth dataframes.""" if performance_type != "one model": deciles_dat = deciles_dat.with_columns( diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index ab475be..074fc52 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -681,6 +681,15 @@ def _htext(title: pl.Expr) -> pl.Expr: (pl.col("x") >= min_p_threshold) & (pl.col("x") <= max_p_threshold) ) + return pl.DataFrame( + schema={ + "reference_group": pl.Utf8, + "x": pl.Float64, + "y": pl.Float64, + "text": pl.Utf8, + } + ) + def create_non_interactive_curve_polars( performance_data_ready_for_curve, reference_group_color, reference_group @@ -1157,7 +1166,7 @@ def _add_hover_text_to_performance_data( ) return performance_data.with_columns( - [pl.col(pl.FLOAT_DTYPES).round(3), hover_text_expr.alias("text")] + [pl.col(pl.Float64).round(3), hover_text_expr.alias("text")] ) diff --git a/src/rtichoke/helpers/sandbox_observable_helpers.py b/src/rtichoke/helpers/sandbox_observable_helpers.py index d3cc352..cc11c1f 100644 --- a/src/rtichoke/helpers/sandbox_observable_helpers.py +++ b/src/rtichoke/helpers/sandbox_observable_helpers.py @@ -152,7 +152,7 @@ def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame: labels = [f"{x:.{decimals}f}" for x in np.linspace(by, 1.0, q)] - strata_labels = np.array([labels[i] for i in bin_idx], dtype=object) + strata_labels = np.array(labels)[bin_idx] columns_to_add.append( pl.Series("strata_ppcr", strata_labels).cast(pl.Enum(labels)) From c570bff18ac4c51a8d18d46627653aeefa3b9dda Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:02:31 +0000 Subject: [PATCH 16/17] Refactor helpers module into processing This commit refactors the `helpers` module into a `processing` module with a more logical structure. The `sandbox_observable_helpers.py` file has been split into three new modules: `adjustments.py`, `combinations.py`, and `transforms.py`. This change improves the maintainability and readability of the codebase. --- src/rtichoke/discrimination/gains.py | 2 +- src/rtichoke/discrimination/lift.py | 2 +- .../discrimination/precision_recall.py | 2 +- src/rtichoke/discrimination/roc.py | 2 +- .../helpers/sandbox_observable_helpers.py | 1770 ----------------- .../performance_data/performance_data.py | 10 +- .../performance_data_times.py | 12 +- .../{helpers => processing}/__init__.py | 0 src/rtichoke/processing/adjustments.py | 743 +++++++ src/rtichoke/processing/combinations.py | 218 ++ .../exported_functions.py | 2 +- .../plotly_helper_functions.py | 0 .../send_post_request_to_r_rtichoke.py | 2 +- src/rtichoke/processing/transforms.py | 700 +++++++ src/rtichoke/summary_report/summary_report.py | 6 +- src/rtichoke/utility/decision.py | 2 +- tests/test_rtichoke.py | 2 +- 17 files changed, 1686 insertions(+), 1789 deletions(-) delete mode 100644 src/rtichoke/helpers/sandbox_observable_helpers.py rename src/rtichoke/{helpers => processing}/__init__.py (100%) create mode 100644 src/rtichoke/processing/adjustments.py create mode 100644 src/rtichoke/processing/combinations.py rename src/rtichoke/{helpers => processing}/exported_functions.py (99%) rename src/rtichoke/{helpers => processing}/plotly_helper_functions.py (100%) rename src/rtichoke/{helpers => processing}/send_post_request_to_r_rtichoke.py (98%) create mode 100644 src/rtichoke/processing/transforms.py diff --git a/src/rtichoke/discrimination/gains.py b/src/rtichoke/discrimination/gains.py index 59366f4..8cf7c8d 100644 --- a/src/rtichoke/discrimination/gains.py +++ b/src/rtichoke/discrimination/gains.py @@ -4,7 +4,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, diff --git a/src/rtichoke/discrimination/lift.py b/src/rtichoke/discrimination/lift.py index 5f358af..e3e394c 100644 --- a/src/rtichoke/discrimination/lift.py +++ b/src/rtichoke/discrimination/lift.py @@ -4,7 +4,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, diff --git a/src/rtichoke/discrimination/precision_recall.py b/src/rtichoke/discrimination/precision_recall.py index 565cf5c..1a3d7a0 100644 --- a/src/rtichoke/discrimination/precision_recall.py +++ b/src/rtichoke/discrimination/precision_recall.py @@ -4,7 +4,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, diff --git a/src/rtichoke/discrimination/roc.py b/src/rtichoke/discrimination/roc.py index d8a8ed0..9bcc653 100644 --- a/src/rtichoke/discrimination/roc.py +++ b/src/rtichoke/discrimination/roc.py @@ -4,7 +4,7 @@ from typing import Dict, List, Union, Sequence from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, diff --git a/src/rtichoke/helpers/sandbox_observable_helpers.py b/src/rtichoke/helpers/sandbox_observable_helpers.py deleted file mode 100644 index cc11c1f..0000000 --- a/src/rtichoke/helpers/sandbox_observable_helpers.py +++ /dev/null @@ -1,1770 +0,0 @@ -# from lifelines import AalenJohansenFitter -import pandas as pd -import numpy as np -import polars as pl -from polarstate import predict_aj_estimates -from polarstate import prepare_event_table -from typing import Dict, Union -from collections.abc import Sequence - - -def _enum_dataframe(column_name: str, values: Sequence[str]) -> pl.DataFrame: - """Create a single-column DataFrame with an enum dtype.""" - enum_values = list(dict.fromkeys(values)) - enum_dtype = pl.Enum(enum_values) - return pl.DataFrame({column_name: pl.Series(values, dtype=enum_dtype)}) - - -# def extract_aj_estimate(data_to_adjust, fixed_time_horizons): -# """ -# Python implementation of the R extract_aj_estimate function for Aalen-Johansen estimation. - -# Parameters: -# data_to_adjust (pd.DataFrame): DataFrame containing survival data -# fixed_time_horizons (list or float): Time points at which to evaluate the survival - -# Returns: -# pd.DataFrame: DataFrame with Aalen-Johansen estimates -# """ - -# # Ensure fixed_time_horizons is a list -# if not isinstance(fixed_time_horizons, list): -# fixed_time_horizons = [fixed_time_horizons] - -# # Create a categorical version of reals for stratification -# data = data_to_adjust.copy() -# data["reals_cat"] = pd.Categorical( -# data["reals_labels"], -# categories=[ -# "real_negatives", -# "real_positives", -# "real_competing", -# "real_censored", -# ], -# ordered=True, -# ) - -# # Get unique strata values -# strata_values = data["strata"].unique() - -# event_map = { -# "real_negatives": 0, # Treat as censored -# "real_positives": 1, # Event of interest -# "real_competing": 2, # Competing risk -# "real_censored": 0, # Censored -# } - -# data["event_code"] = data["reals_labels"].map(event_map) - -# # Initialize result dataframes -# results = [] - -# # For each stratum, fit Aalen-Johansen model -# for stratum in strata_values: -# # Filter data for current stratum -# stratum_data = data.loc[data["strata"] == stratum] - -# # Initialize Aalen-Johansen fitter -# ajf = AalenJohansenFitter() -# ajf_competing = AalenJohansenFitter() - -# # Fit the model -# ajf.fit(stratum_data["times"], stratum_data["event_code"], event_of_interest=1) - -# ajf_competing.fit( -# stratum_data["times"], stratum_data["event_code"], event_of_interest=2 -# ) - -# # Calculate cumulative incidence at fixed time horizons -# for t in fixed_time_horizons: -# n = len(stratum_data) -# real_positives_est = ajf.predict(t) -# real_competing_est = ajf_competing.predict(t) -# real_negatives_est = 1 - real_positives_est - real_competing_est - -# states = ["real_negatives", "real_positives", "real_competing"] -# estimates = [real_negatives_est, real_positives_est, real_competing_est] - -# for state, estimate in zip(states, estimates): -# results.append( -# { -# "strata": stratum, -# "reals": state, -# "fixed_time_horizon": t, -# "reals_estimate": estimate * n, -# } -# ) - -# # Convert to DataFrame -# result_df = pd.DataFrame(results) - -# # Convert strata to categorical if needed -# result_df["strata"] = pd.Categorical(result_df["strata"]) - -# return result_df - - -def add_cutoff_strata(data: pl.DataFrame, by: float, stratified_by) -> pl.DataFrame: - def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame: - probs = group["probs"].to_numpy() - columns_to_add = [] - - breaks = create_breaks_values(probs, "probability_threshold", by) - if "probability_threshold" in stratified_by: - last_bin_index = len(breaks) - 2 - - bin_indices = np.digitize(probs, bins=breaks, right=False) - 1 - bin_indices = np.where(probs == 1.0, last_bin_index, bin_indices) - - lower_bounds = breaks[bin_indices] - upper_bounds = breaks[bin_indices + 1] - - include_upper_bounds = bin_indices == last_bin_index - - strata_prob_labels = np.where( - include_upper_bounds, - [f"[{lo:.2f}, {hi:.2f}]" for lo, hi in zip(lower_bounds, upper_bounds)], - [f"[{lo:.2f}, {hi:.2f})" for lo, hi in zip(lower_bounds, upper_bounds)], - ).astype(str) - - columns_to_add.append( - pl.Series("strata_probability_threshold", strata_prob_labels) - ) - - if "ppcr" in stratified_by: - # --- Compute strata_ppcr as equal-frequency quantile bins by rank --- - by = float(by) - q = int(round(1 / by)) # e.g. 0.2 -> 5 bins - - probs = np.asarray(probs, float) - - edges = np.quantile(probs, np.linspace(0.0, 1.0, q + 1), method="linear") - - edges = np.maximum.accumulate(edges) - - edges[0] = 0.0 - edges[-1] = 1.0 - - bin_idx = np.digitize(probs, bins=edges[1:-1], right=True) - - s = str(by) - decimals = len(s.split(".")[-1]) if "." in s else 0 - - labels = [f"{x:.{decimals}f}" for x in np.linspace(by, 1.0, q)] - - strata_labels = np.array(labels)[bin_idx] - - columns_to_add.append( - pl.Series("strata_ppcr", strata_labels).cast(pl.Enum(labels)) - ) - return group.with_columns(columns_to_add) - - # Apply per-group transformation - grouped = data.partition_by("reference_group", as_dict=True) - transformed_groups = [transform_group(group, by) for group in grouped.values()] - return pl.concat(transformed_groups) - - -def create_strata_combinations(stratified_by: str, by: float, breaks) -> pl.DataFrame: - s_by = str(by) - decimals = len(s_by.split(".")[-1]) if "." in s_by else 0 - fmt = f"{{:.{decimals}f}}" - - if stratified_by == "probability_threshold": - upper_bound = breaks[1:] # breaks - lower_bound = breaks[:-1] # np.roll(upper_bound, 1) - # lower_bound[0] = 0.0 - mid_point = upper_bound - by / 2 - include_lower_bound = lower_bound > -0.1 - include_upper_bound = upper_bound == 1.0 # upper_bound != 0.0 - # chosen_cutoff = upper_bound - strata = format_strata_column( - lower_bound=lower_bound, - upper_bound=upper_bound, - include_lower_bound=include_lower_bound, - include_upper_bound=include_upper_bound, - decimals=2, - ) - - elif stratified_by == "ppcr": - strata_mid = breaks[1:] - lower_bound = strata_mid - by / 2 - upper_bound = strata_mid + by / 2 - mid_point = breaks[1:] - include_lower_bound = np.ones_like(strata_mid, dtype=bool) - include_upper_bound = np.zeros_like(strata_mid, dtype=bool) - # chosen_cutoff = strata_mid - strata = np.array([fmt.format(x) for x in strata_mid], dtype=object) - else: - raise ValueError(f"Unsupported stratified_by: {stratified_by}") - - bins_df = pl.DataFrame( - { - "strata": pl.Series(strata), - "lower_bound": lower_bound, - "upper_bound": upper_bound, - "mid_point": mid_point, - "include_lower_bound": include_lower_bound, - "include_upper_bound": include_upper_bound, - # "chosen_cutoff": chosen_cutoff, - "stratified_by": [stratified_by] * len(strata), - } - ) - - cutoffs_df = pl.DataFrame({"chosen_cutoff": breaks}) - - return bins_df.join(cutoffs_df, how="cross") - - -def format_strata_column( - lower_bound: list[float], - upper_bound: list[float], - include_lower_bound: list[bool], - include_upper_bound: list[bool], - decimals: int = 3, -) -> list[str]: - return [ - f"{'[' if ilb else '('}" - f"{round(lb, decimals):.{decimals}f}, " - f"{round(ub, decimals):.{decimals}f}" - f"{']' if iub else ')'}" - for lb, ub, ilb, iub in zip( - lower_bound, upper_bound, include_lower_bound, include_upper_bound - ) - ] - - -def format_strata_interval( - lower: float, upper: float, include_lower: bool, include_upper: bool -) -> str: - left = "[" if include_lower else "(" - right = "]" if include_upper else ")" - return f"{left}{lower:.3f}, {upper:.3f}{right}" - - -def create_breaks_values(probs_vec, stratified_by, by): - if stratified_by != "probability_threshold": - breaks = np.quantile(probs_vec, np.linspace(1, 0, int(1 / by) + 1)) - else: - breaks = np.round( - np.arange(0, 1 + by, by), decimals=len(str(by).split(".")[-1]) - ) - return breaks - - -def _create_aj_data_combinations_binary( - reference_groups: Sequence[str], - stratified_by: Sequence[str], - by: float, - breaks: Sequence[float], -) -> pl.DataFrame: - dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by] - - strata_combinations = pl.concat(dfs, how="vertical") - - strata_cats = ( - strata_combinations.select(pl.col("strata").unique(maintain_order=True)) - .to_series() - .to_list() - ) - - strata_enum = pl.Enum(strata_cats) - stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"]) - - strata_combinations = strata_combinations.with_columns( - [ - pl.col("strata").cast(strata_enum), - pl.col("stratified_by").cast(stratified_by_enum), - ] - ) - - # Define values for Cartesian product - reals_labels = ["real_negatives", "real_positives"] - - combinations_frames: list[pl.DataFrame] = [ - _enum_dataframe("reference_group", reference_groups), - strata_combinations, - _enum_dataframe("reals_labels", reals_labels), - ] - - result = combinations_frames[0] - for frame in combinations_frames[1:]: - result = result.join(frame, how="cross") - - return result - - -def create_aj_data_combinations( - reference_groups: Sequence[str], - heuristics_sets: list[Dict], - fixed_time_horizons: Sequence[float], - stratified_by: Sequence[str], - by: float, - breaks: Sequence[float], - risk_set_scope: Sequence[str] = ["within_stratum", "pooled_by_cutoff"], -) -> pl.DataFrame: - dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by] - strata_combinations = pl.concat(dfs, how="vertical") - - # strata_enum = pl.Enum(strata_combinations["strata"]) - - strata_cats = ( - strata_combinations.select(pl.col("strata").unique(maintain_order=True)) - .to_series() - .to_list() - ) - - strata_enum = pl.Enum(strata_cats) - stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"]) - - strata_combinations = strata_combinations.with_columns( - [ - pl.col("strata").cast(strata_enum), - pl.col("stratified_by").cast(stratified_by_enum), - ] - ) - - risk_set_scope_combinations = pl.DataFrame( - { - "risk_set_scope": pl.Series(risk_set_scope).cast( - pl.Enum(["within_stratum", "pooled_by_cutoff"]) - ) - } - ) - - # Define values for Cartesian product - reals_labels = [ - "real_negatives", - "real_positives", - "real_competing", - "real_censored", - ] - - heuristics_combinations = pl.DataFrame(heuristics_sets) - - censoring_heuristics_enum = pl.Enum( - heuristics_combinations["censoring_heuristic"].unique(maintain_order=True) - ) - competing_heuristics_enum = pl.Enum( - heuristics_combinations["competing_heuristic"].unique(maintain_order=True) - ) - - combinations_frames: list[pl.DataFrame] = [ - _enum_dataframe("reference_group", reference_groups), - pl.DataFrame( - {"fixed_time_horizon": pl.Series(fixed_time_horizons, dtype=pl.Float64)} - ), - heuristics_combinations.with_columns( - [ - pl.col("censoring_heuristic").cast(censoring_heuristics_enum), - pl.col("competing_heuristic").cast(competing_heuristics_enum), - ] - ), - strata_combinations, - risk_set_scope_combinations, - _enum_dataframe("reals_labels", reals_labels), - ] - - result = combinations_frames[0] - for frame in combinations_frames[1:]: - result = result.join(frame, how="cross") - - return result - - -def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame: - # Identify id_vars and value_vars - id_vars = [col for col in data.columns if not col.startswith("strata_")] - value_vars = [col for col in data.columns if col.startswith("strata_")] - - # Perform the melt (equivalent to pandas.melt) - data_long = data.melt( - id_vars=id_vars, - value_vars=value_vars, - variable_name="stratified_by", - value_name="strata", - ) - - stratified_by_labels = ["probability_threshold", "ppcr"] - stratified_by_enum = pl.Enum(stratified_by_labels) - - # Remove "strata_" prefix from the 'stratified_by' column - data_long = data_long.with_columns( - pl.col("stratified_by").str.replace("^strata_", "").cast(stratified_by_enum) - ) - - return data_long - - -def map_reals_to_labels_polars(data: pl.DataFrame) -> pl.DataFrame: - return data.with_columns( - [ - pl.when(pl.col("reals") == 0) - .then("real_negatives") - .when(pl.col("reals") == 1) - .then("real_positives") - .when(pl.col("reals") == 2) - .then("real_competing") - .otherwise("real_censored") - .alias("reals") - ] - ) - - -def update_administrative_censoring_polars(data: pl.DataFrame) -> pl.DataFrame: - data = data.with_columns( - [ - pl.when( - (pl.col("times") > pl.col("fixed_time_horizon")) - & (pl.col("reals_labels") == "real_positives") - ) - .then(pl.lit("real_negatives")) - .when( - (pl.col("times") < pl.col("fixed_time_horizon")) - & (pl.col("reals_labels") == "real_negatives") - ) - .then(pl.lit("real_censored")) - .otherwise(pl.col("reals_labels")) - .alias("reals_labels") - ] - ) - - return data - - -def create_aj_data( - reference_group_data, - breaks, - censoring_heuristic, - competing_heuristic, - fixed_time_horizons, - stratified_by: Sequence[str], - full_event_table: bool = False, - risk_set_scope: Sequence[str] = ["within_stratum"], -): - """ - Create AJ estimates per strata based on censoring and competing heuristicss. - """ - - def aj_estimates_with_cross(df, extra_cols): - return df.join(pl.DataFrame(extra_cols), how="cross") - - exploded = assign_and_explode_polars(reference_group_data, fixed_time_horizons) - - event_table = prepare_event_table(reference_group_data) - - # TODO: solve strata in the pipeline - - excluded_events = _extract_excluded_events( - event_table, fixed_time_horizons, censoring_heuristic, competing_heuristic - ) - - aj_dfs = [] - for rscope in risk_set_scope: - aj_res = _aj_adjusted_events( - reference_group_data, - breaks, - exploded, - censoring_heuristic, - competing_heuristic, - fixed_time_horizons, - stratified_by, - full_event_table, - rscope, - ) - - aj_res = aj_res.select( - [ - "strata", - "times", - "chosen_cutoff", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - "estimate_origin", - "fixed_time_horizon", - "risk_set_scope", - ] - ) - - aj_dfs.append(aj_res) - - aj_df = pl.concat(aj_dfs, how="vertical") - - result = aj_df.join(excluded_events, on=["fixed_time_horizon"], how="left") - - return aj_estimates_with_cross( - result, - { - "censoring_heuristic": censoring_heuristic, - "competing_heuristic": competing_heuristic, - }, - ).select( - [ - "strata", - "chosen_cutoff", - "fixed_time_horizon", - "times", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - "real_censored_est", - "censoring_heuristic", - "competing_heuristic", - "estimate_origin", - "risk_set_scope", - ] - ) - - -def _extract_excluded_events( - event_table: pl.DataFrame, - fixed_time_horizons: list[float], - censoring_heuristic: str, - competing_heuristic: str, -) -> pl.DataFrame: - horizons_df = pl.DataFrame({"times": fixed_time_horizons}).sort("times") - - excluded_events = horizons_df.join_asof( - event_table.with_columns( - pl.col("count_0").cum_sum().cast(pl.Float64).alias("real_censored_est"), - pl.col("count_2").cum_sum().cast(pl.Float64).alias("real_competing_est"), - ).select( - pl.col("times"), - pl.col("real_censored_est"), - pl.col("real_competing_est"), - ), - left_on="times", - right_on="times", - ).with_columns([pl.col("times").alias("fixed_time_horizon")]) - - if censoring_heuristic != "excluded": - excluded_events = excluded_events.with_columns( - pl.lit(0.0).alias("real_censored_est") - ) - - if competing_heuristic != "excluded": - excluded_events = excluded_events.with_columns( - pl.lit(0.0).alias("real_competing_est") - ) - - return excluded_events - - -def extract_crude_estimate_polars(data: pl.DataFrame) -> pl.DataFrame: - all_combinations = data.select(["strata", "reals", "fixed_time_horizon"]).unique() - - counts = data.group_by(["strata", "reals", "fixed_time_horizon"]).agg( - pl.count().alias("reals_estimate") - ) - - return all_combinations.join( - counts, on=["strata", "reals", "fixed_time_horizon"], how="left" - ).with_columns([pl.col("reals_estimate").fill_null(0).cast(pl.Int64)]) - - -# def update_administrative_censoring(data_to_adjust: pd.DataFrame) -> pd.DataFrame: -# pl_df = pl.from_pandas(data_to_adjust) - -# # Perform the transformation using polars -# pl_result = pl_df.with_columns( -# pl.when( -# (pl.col("times") > pl.col("fixed_time_horizon")) & -# (pl.col("reals") == "real_positives") -# ).then( -# "real_negatives" -# ).when( -# (pl.col("times") < pl.col("fixed_time_horizon")) & -# (pl.col("reals") == "real_negatives") -# ).then( -# "real_censored" -# ).otherwise( -# pl.col("reals") -# ).alias("reals") -# ) - -# # Convert back to pandas DataFrame and return -# result_pandas = pl_result.to_pandas() - -# return result_pandas - - -def extract_aj_estimate_by_cutoffs( - data_to_adjust, horizons, breaks, stratified_by, full_event_table: bool -): - # n = data_to_adjust.height - - counts_per_strata = ( - data_to_adjust.group_by( - ["strata", "stratified_by", "upper_bound", "lower_bound"] - ) - .len(name="strata_count") - .with_columns(pl.col("strata_count").cast(pl.Float64)) - ) - - aj_estimates_predicted_positives = pl.DataFrame() - aj_estimates_predicted_negatives = pl.DataFrame() - - for stratification_criteria in stratified_by: - for chosen_cutoff in breaks: - if stratification_criteria == "probability_threshold": - mask_predicted_positives = (pl.col("upper_bound") > chosen_cutoff) & ( - pl.col("stratified_by") == "probability_threshold" - ) - mask_predicted_negatives = (pl.col("upper_bound") <= chosen_cutoff) & ( - pl.col("stratified_by") == "probability_threshold" - ) - - elif stratification_criteria == "ppcr": - mask_predicted_positives = ( - pl.col("lower_bound") > 1 - chosen_cutoff - ) & (pl.col("stratified_by") == "ppcr") - mask_predicted_negatives = ( - pl.col("lower_bound") <= 1 - chosen_cutoff - ) & (pl.col("stratified_by") == "ppcr") - - predicted_positives = data_to_adjust.filter(mask_predicted_positives) - predicted_negatives = data_to_adjust.filter(mask_predicted_negatives) - - counts_per_strata_predicted_positives = counts_per_strata.filter( - mask_predicted_positives - ) - counts_per_strata_predicted_negatives = counts_per_strata.filter( - mask_predicted_negatives - ) - - event_table_predicted_positives = prepare_event_table(predicted_positives) - event_table_predicted_negatives = prepare_event_table(predicted_negatives) - - aj_estimate_predicted_positives = ( - ( - predict_aj_estimates( - event_table_predicted_positives, - pl.Series(horizons), - full_event_table, - ) - .with_columns( - pl.lit(chosen_cutoff).alias("chosen_cutoff"), - pl.lit(stratification_criteria) - .alias("stratified_by") - .cast(pl.Enum(["probability_threshold", "ppcr"])), - ) - .join( - counts_per_strata_predicted_positives, - on=["stratified_by"], - how="left", - ) - .with_columns( - [ - ( - pl.col("state_occupancy_probability_0") - * pl.col("strata_count") - ).alias("real_negatives_est"), - ( - pl.col("state_occupancy_probability_1") - * pl.col("strata_count") - ).alias("real_positives_est"), - ( - pl.col("state_occupancy_probability_2") - * pl.col("strata_count") - ).alias("real_competing_est"), - ] - ) - ) - .select( - [ - "strata", - # "stratified_by", - "times", - "chosen_cutoff", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - "estimate_origin", - ] - ) - .with_columns([pl.col("times").alias("fixed_time_horizon")]) - ) - - aj_estimate_predicted_negatives = ( - ( - predict_aj_estimates( - event_table_predicted_negatives, - pl.Series(horizons), - full_event_table, - ) - .with_columns( - pl.lit(chosen_cutoff).alias("chosen_cutoff"), - pl.lit(stratification_criteria) - .alias("stratified_by") - .cast(pl.Enum(["probability_threshold", "ppcr"])), - ) - .join( - counts_per_strata_predicted_negatives, - on=["stratified_by"], - how="left", - ) - .with_columns( - [ - ( - pl.col("state_occupancy_probability_0") - * pl.col("strata_count") - ).alias("real_negatives_est"), - ( - pl.col("state_occupancy_probability_1") - * pl.col("strata_count") - ).alias("real_positives_est"), - ( - pl.col("state_occupancy_probability_2") - * pl.col("strata_count") - ).alias("real_competing_est"), - ] - ) - ) - .select( - [ - "strata", - # "stratified_by", - "times", - "chosen_cutoff", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - "estimate_origin", - ] - ) - .with_columns([pl.col("times").alias("fixed_time_horizon")]) - ) - - aj_estimates_predicted_negatives = pl.concat( - [aj_estimates_predicted_negatives, aj_estimate_predicted_negatives], - how="vertical", - ) - - aj_estimates_predicted_positives = pl.concat( - [aj_estimates_predicted_positives, aj_estimate_predicted_positives], - how="vertical", - ) - - aj_estimate_by_cutoffs = pl.concat( - [aj_estimates_predicted_negatives, aj_estimates_predicted_positives], - how="vertical", - ) - - return aj_estimate_by_cutoffs - - -def extract_aj_estimate_for_strata(data_to_adjust, horizons, full_event_table: bool): - n = data_to_adjust.height - - event_table = prepare_event_table(data_to_adjust) - - aj_estimate_for_strata_polars = predict_aj_estimates( - event_table, pl.Series(horizons), full_event_table - ) - - if len(horizons) == 1: - aj_estimate_for_strata_polars = aj_estimate_for_strata_polars.with_columns( - pl.lit(horizons[0]).alias("fixed_time_horizon") - ) - - else: - fixed_df = aj_estimate_for_strata_polars.filter( - pl.col("estimate_origin") == "fixed_time_horizons" - ).with_columns([pl.col("times").alias("fixed_time_horizon")]) - - event_df = ( - aj_estimate_for_strata_polars.filter( - pl.col("estimate_origin") == "event_table" - ) - .with_columns([pl.lit(horizons).alias("fixed_time_horizon")]) - .explode("fixed_time_horizon") - ) - - aj_estimate_for_strata_polars = pl.concat( - [fixed_df, event_df], how="vertical" - ).sort("estimate_origin", "fixed_time_horizon", "times") - - return aj_estimate_for_strata_polars.with_columns( - [ - (pl.col("state_occupancy_probability_0") * n).alias("real_negatives_est"), - (pl.col("state_occupancy_probability_1") * n).alias("real_positives_est"), - (pl.col("state_occupancy_probability_2") * n).alias("real_competing_est"), - pl.col("fixed_time_horizon").cast(pl.Float64), - pl.lit(data_to_adjust["strata"][0]).alias("strata"), - ] - ).select( - [ - "strata", - "times", - "fixed_time_horizon", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - pl.col("estimate_origin"), - ] - ) - - -def assign_and_explode_polars( - data: pl.DataFrame, fixed_time_horizons: list[float] -) -> pl.DataFrame: - return ( - data.with_columns(pl.lit(fixed_time_horizons).alias("fixed_time_horizon")) - .explode("fixed_time_horizon") - .with_columns(pl.col("fixed_time_horizon").cast(pl.Float64)) - ) - - -def _create_list_data_to_adjust_binary( - aj_data_combinations: pl.DataFrame, - probs_dict: Dict[str, np.ndarray], - reals_dict: Union[np.ndarray, Dict[str, np.ndarray]], - stratified_by, - by, -) -> Dict[str, pl.DataFrame]: - reference_group_labels = list(probs_dict.keys()) - - if isinstance(reals_dict, dict): - num_keys_reals = len(reals_dict) - else: - num_keys_reals = 1 - - reference_group_enum = pl.Enum(reference_group_labels) - - strata_enum_dtype = aj_data_combinations.schema["strata"] - - if len(probs_dict) == 1: - probs_array = np.asarray(probs_dict[reference_group_labels[0]]) - - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, len(probs_array)), - "probs": probs_array, - "reals": reals_dict, - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) - - elif num_keys_reals == 1: - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, len(reals_dict)), - "probs": np.concatenate( - [probs_dict[group] for group in reference_group_labels] - ), - "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)), - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) - - elif isinstance(reals_dict, dict): - data_to_adjust = ( - pl.DataFrame( - { - "reference_group": list(probs_dict.keys()), - "probs": list(probs_dict.values()), - "reals": list(reals_dict.values()), - } - ) - .explode(["probs", "reals"]) - .with_columns(pl.col("reference_group").cast(reference_group_enum)) - ) - - data_to_adjust = add_cutoff_strata( - data_to_adjust, by=by, stratified_by=stratified_by - ) - - data_to_adjust = pivot_longer_strata(data_to_adjust) - - data_to_adjust = ( - data_to_adjust.with_columns([pl.col("strata")]) - .with_columns(pl.col("strata").cast(strata_enum_dtype)) - .join( - aj_data_combinations.select( - pl.col("strata"), - pl.col("stratified_by"), - pl.col("upper_bound"), - pl.col("lower_bound"), - ).unique(), - how="left", - on=["strata", "stratified_by"], - ) - ) - - reals_labels = ["real_negatives", "real_positives"] - - reals_enum = pl.Enum(reals_labels) - - reals_map = {0: "real_negatives", 1: "real_positives"} - - data_to_adjust = data_to_adjust.with_columns( - pl.col("reals") - .replace_strict(reals_map, return_dtype=reals_enum) - .alias("reals_labels") - ) - - list_data_to_adjust = { - group[0]: df - for group, df in data_to_adjust.partition_by( - "reference_group", as_dict=True - ).items() - } - - return list_data_to_adjust - - -def _create_list_data_to_adjust( - aj_data_combinations: pl.DataFrame, - probs_dict: Dict[str, np.ndarray], - reals_dict: Union[np.ndarray, Dict[str, np.ndarray]], - times_dict: Union[np.ndarray, Dict[str, np.ndarray]], - stratified_by, - by, -) -> Dict[str, pl.DataFrame]: - # reference_groups = list(probs_dict.keys()) - reference_group_labels = list(probs_dict.keys()) - - if isinstance(reals_dict, dict): - num_keys_reals = len(reals_dict) - else: - num_keys_reals = 1 - - # num_reals = len(reals_dict) - - reference_group_enum = pl.Enum(reference_group_labels) - - strata_enum_dtype = aj_data_combinations.schema["strata"] - - if len(probs_dict) == 1: - probs_array = np.asarray(probs_dict[reference_group_labels[0]]) - - if isinstance(reals_dict, dict): - reals_array = np.asarray(reals_dict[0]) - else: - reals_array = np.asarray(reals_dict) - - if isinstance(times_dict, dict): - times_array = np.asarray(times_dict[0]) - else: - times_array = np.asarray(times_dict) - - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, len(probs_array)), - "probs": probs_array, - "reals": reals_array, - "times": times_array, - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) - - elif num_keys_reals == 1: - reals_array = np.asarray(reals_dict) - times_array = np.asarray(times_dict) - n = len(reals_array) - - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, n), - "probs": np.concatenate( - [np.asarray(probs_dict[g]) for g in reference_group_labels] - ), - "reals": np.tile(reals_array, len(reference_group_labels)), - "times": np.tile(times_array, len(reference_group_labels)), - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) - - elif isinstance(reals_dict, dict) and isinstance(times_dict, dict): - data_to_adjust = ( - pl.DataFrame( - { - "reference_group": reference_group_labels, - "probs": list(probs_dict.values()), - "reals": list(reals_dict.values()), - "times": list(times_dict.values()), - } - ) - .explode(["probs", "reals", "times"]) - .with_columns(pl.col("reference_group").cast(reference_group_enum)) - ) - - data_to_adjust = add_cutoff_strata( - data_to_adjust, by=by, stratified_by=stratified_by - ) - - data_to_adjust = pivot_longer_strata(data_to_adjust) - - data_to_adjust = ( - data_to_adjust.with_columns([pl.col("strata")]) - .with_columns(pl.col("strata").cast(strata_enum_dtype)) - .join( - aj_data_combinations.select( - pl.col("strata"), - pl.col("stratified_by"), - pl.col("upper_bound"), - pl.col("lower_bound"), - ).unique(), - how="left", - on=["strata", "stratified_by"], - ) - ) - - reals_labels = [ - "real_negatives", - "real_positives", - "real_competing", - "real_censored", - ] - - reals_enum = pl.Enum(reals_labels) - - # Map reals values to strings - reals_map = {0: "real_negatives", 2: "real_competing", 1: "real_positives"} - - data_to_adjust = data_to_adjust.with_columns( - pl.col("reals") - .replace_strict(reals_map, return_dtype=reals_enum) - .alias("reals_labels") - ) - - # Partition by reference_group - list_data_to_adjust = { - group[0]: df - for group, df in data_to_adjust.partition_by( - "reference_group", as_dict=True - ).items() - } - - return list_data_to_adjust - - -def ensure_no_categorical(df: pd.DataFrame) -> pd.DataFrame: - df = df.copy() - for col in df.select_dtypes(include="category").columns: - df[col] = df[col].astype(str) - return df - - -def extract_aj_estimate_by_heuristics( - df: pl.DataFrame, - breaks: Sequence[float], - heuristics_sets: list[dict], - fixed_time_horizons: list[float], - stratified_by: Sequence[str], - risk_set_scope: Sequence[str] = ["within_stratum"], -) -> pl.DataFrame: - aj_dfs = [] - - for heuristic in heuristics_sets: - censoring = heuristic["censoring_heuristic"] - competing = heuristic["competing_heuristic"] - - aj_df = create_aj_data( - df, - breaks, - censoring, - competing, - fixed_time_horizons, - stratified_by=stratified_by, - full_event_table=False, - risk_set_scope=risk_set_scope, - ).with_columns( - [ - pl.lit(censoring).alias("censoring_heuristic"), - pl.lit(competing).alias("competing_heuristic"), - ] - ) - - aj_dfs.append(aj_df) - - aj_estimates_data = pl.concat(aj_dfs).drop(["estimate_origin", "times"]) - - aj_estimates_unpivoted = aj_estimates_data.unpivot( - index=[ - "strata", - "chosen_cutoff", - "fixed_time_horizon", - "censoring_heuristic", - "competing_heuristic", - "risk_set_scope", - ], - variable_name="reals_labels", - value_name="reals_estimate", - ) - - return aj_estimates_unpivoted - - -def _create_adjusted_data_binary( - list_data_to_adjust: dict[str, pl.DataFrame], - breaks: Sequence[float], - stratified_by: Sequence[str], -) -> pl.DataFrame: - long_df = pl.concat(list(list_data_to_adjust.values()), how="vertical") - - adjusted_data_binary = ( - long_df.group_by(["strata", "stratified_by", "reference_group", "reals_labels"]) - .agg(pl.count().alias("reals_estimate")) - .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - ) - - return adjusted_data_binary - - -def create_adjusted_data( - list_data_to_adjust: dict[str, pl.DataFrame], - heuristics_sets: list[dict[str, str]], - fixed_time_horizons: list[float], - breaks: Sequence[float], - stratified_by: Sequence[str], - risk_set_scope: Sequence[str] = ["within_stratum"], -) -> pl.DataFrame: - all_results = [] - - reference_groups = list(list_data_to_adjust.keys()) - reference_group_enum = pl.Enum(reference_groups) - - heuristics_df = pl.DataFrame(heuristics_sets) - censoring_heuristic_enum = pl.Enum( - heuristics_df["censoring_heuristic"].unique(maintain_order=True) - ) - competing_heuristic_enum = pl.Enum( - heuristics_df["competing_heuristic"].unique(maintain_order=True) - ) - - for reference_group, df in list_data_to_adjust.items(): - input_df = df.select( - ["strata", "reals", "times", "upper_bound", "lower_bound", "stratified_by"] - ) - - aj_result = extract_aj_estimate_by_heuristics( - input_df, - breaks, - heuristics_sets=heuristics_sets, - fixed_time_horizons=fixed_time_horizons, - stratified_by=stratified_by, - risk_set_scope=risk_set_scope, - ) - - aj_result_with_group = aj_result.with_columns( - [ - pl.lit(reference_group) - .cast(reference_group_enum) - .alias("reference_group") - ] - ) - - all_results.append(aj_result_with_group) - - reals_enum_dtype = pl.Enum( - [ - "real_negatives", - "real_positives", - "real_competing", - "real_censored", - ] - ) - - return ( - pl.concat(all_results) - .with_columns([pl.col("reference_group").cast(reference_group_enum)]) - .with_columns( - [ - pl.col("reals_labels").str.replace(r"_est$", "").cast(reals_enum_dtype), - pl.col("censoring_heuristic").cast(censoring_heuristic_enum), - pl.col("competing_heuristic").cast(competing_heuristic_enum), - ] - ) - ) - - -def _cast_and_join_adjusted_data_binary( - aj_data_combinations: pl.DataFrame, aj_estimates_data: pl.DataFrame -) -> pl.DataFrame: - strata_enum_dtype = aj_data_combinations.schema["strata"] - - aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns( - pl.col("strata").cast(strata_enum_dtype) - ) - - final_adjusted_data_polars = ( - ( - aj_data_combinations.with_columns([pl.col("strata")]).join( - aj_estimates_data, - on=[ - "strata", - "stratified_by", - "reals_labels", - "reference_group", - "chosen_cutoff", - ], - how="left", - ) - ) - .with_columns( - pl.when( - ( - (pl.col("chosen_cutoff") >= pl.col("upper_bound")) - & (pl.col("stratified_by") == "probability_threshold") - ) - | ( - ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point")) - & (pl.col("stratified_by") == "ppcr") - ) - ) - .then(pl.lit("predicted_negatives")) - .otherwise(pl.lit("predicted_positives")) - .cast(pl.Enum(["predicted_negatives", "predicted_positives"])) - .alias("prediction_label") - ) - .with_columns( - ( - pl.when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_positives")) - ) - .then(pl.lit("true_positives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_negatives")) - ) - .then(pl.lit("false_positives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_negatives")) - ) - .then(pl.lit("true_negatives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_positives")) - ) - .then(pl.lit("false_negatives")) - .cast( - pl.Enum( - [ - "true_positives", - "false_positives", - "true_negatives", - "false_negatives", - ] - ) - ) - ).alias("classification_outcome") - ) - ).with_columns(pl.col("reals_estimate").fill_null(0)) - - return final_adjusted_data_polars - - -def cast_and_join_adjusted_data( - aj_data_combinations, aj_estimates_data -) -> pl.DataFrame: - strata_enum_dtype = aj_data_combinations.schema["strata"] - - aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns( - pl.col("strata").cast(strata_enum_dtype) - ) - - final_adjusted_data_polars = ( - aj_data_combinations.with_columns([pl.col("strata")]) - .join( - aj_estimates_data, - on=[ - "strata", - "fixed_time_horizon", - "censoring_heuristic", - "competing_heuristic", - "reals_labels", - "reference_group", - "chosen_cutoff", - "risk_set_scope", - ], - how="left", - ) - .with_columns( - pl.when( - ( - (pl.col("chosen_cutoff") >= pl.col("upper_bound")) - & (pl.col("stratified_by") == "probability_threshold") - ) - | ( - ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point")) - & (pl.col("stratified_by") == "ppcr") - ) - ) - .then(pl.lit("predicted_negatives")) - .otherwise(pl.lit("predicted_positives")) - .cast(pl.Enum(["predicted_negatives", "predicted_positives"])) - .alias("prediction_label") - ) - .with_columns( - ( - pl.when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_positives")) - ) - .then(pl.lit("true_positives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_negatives")) - ) - .then(pl.lit("false_positives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_negatives")) - ) - .then(pl.lit("true_negatives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_positives")) - ) - .then(pl.lit("false_negatives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_competing")) - & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative")) - ) - .then(pl.lit("true_negatives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_competing")) - & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative")) - ) - .then(pl.lit("false_positives")) - .otherwise(pl.lit("excluded")) # or pl.lit(None) if you prefer nulls - .cast( - pl.Enum( - [ - "true_positives", - "false_positives", - "true_negatives", - "false_negatives", - "excluded", - ] - ) - ) - ).alias("classification_outcome") - ) - ) - return final_adjusted_data_polars - - -def _censored_count(df: pl.DataFrame) -> pl.DataFrame: - return ( - df.with_columns( - ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 0)) - .cast(pl.Float64) - .alias("is_censored") - ) - .group_by(["strata", "fixed_time_horizon"]) - .agg(pl.col("is_censored").sum().alias("real_censored_est")) - ) - - -def _competing_count(df: pl.DataFrame) -> pl.DataFrame: - return ( - df.with_columns( - ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 2)) - .cast(pl.Float64) - .alias("is_competing") - ) - .group_by(["strata", "fixed_time_horizon"]) - .agg(pl.col("is_competing").sum().alias("real_competing_est")) - ) - - -def _aj_estimates_by_cutoff_per_horizon( - df: pl.DataFrame, - horizons: list[float], - breaks: Sequence[float], - stratified_by: Sequence[str], -) -> pl.DataFrame: - return pl.concat( - [ - df.filter(pl.col("fixed_time_horizon") == h) - .group_by("strata") - .map_groups( - lambda group: extract_aj_estimate_by_cutoffs( - group, [h], breaks, stratified_by, full_event_table=False - ) - ) - for h in horizons - ], - how="vertical", - ) - - -def _aj_estimates_per_horizon( - df: pl.DataFrame, horizons: list[float], full_event_table: bool -) -> pl.DataFrame: - return pl.concat( - [ - df.filter(pl.col("fixed_time_horizon") == h) - .group_by("strata") - .map_groups( - lambda group: extract_aj_estimate_for_strata( - group, [h], full_event_table - ) - ) - for h in horizons - ], - how="vertical", - ) - - -def _aj_adjusted_events( - reference_group_data: pl.DataFrame, - breaks: Sequence[float], - exploded: pl.DataFrame, - censoring: str, - competing: str, - horizons: list[float], - stratified_by: Sequence[str], - full_event_table: bool = False, - risk_set_scope: Sequence[str] = ["within_stratum"], -) -> pl.DataFrame: - strata_enum_dtype = reference_group_data.schema["strata"] - - # Special-case: adjusted censoring + competing adjusted_as_negative supports pooled_by_cutoff - if censoring == "adjusted" and competing == "adjusted_as_negative": - if risk_set_scope == "within_stratum": - adjusted = ( - reference_group_data.group_by("strata") - .map_groups( - lambda group: extract_aj_estimate_for_strata( - group, horizons, full_event_table - ) - ) - .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - ) - # preserve the original enum dtype for 'strata' coming from reference_group_data - - adjusted = adjusted.with_columns( - [ - pl.col("strata").cast(strata_enum_dtype), - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope"), - ] - ) - - return adjusted - - elif risk_set_scope == "pooled_by_cutoff": - adjusted = extract_aj_estimate_by_cutoffs( - reference_group_data, horizons, breaks, stratified_by, full_event_table - ) - adjusted = adjusted.with_columns( - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope") - ) - return adjusted - - # Special-case: both excluded (faster branch in original) - if censoring == "excluded" and competing == "excluded": - non_censored_non_competing = exploded.filter( - (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") == 1) - ) - - adjusted = _aj_estimates_per_horizon( - non_censored_non_competing, horizons, full_event_table - ) - - adjusted = adjusted.with_columns( - [ - pl.col("strata").cast(strata_enum_dtype), - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope"), - ] - ).join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - - return adjusted - - # Special-case: competing excluded (handled by filtering out competing events) - if competing == "excluded": - # Use exploded to apply filters that depend on fixed_time_horizon consistently - non_competing = exploded.filter( - (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") != 2) - ).with_columns( - pl.when(pl.col("reals") == 2) - .then(pl.lit(0)) - .otherwise(pl.col("reals")) - .alias("reals") - ) - - if risk_set_scope == "within_stratum": - adjusted = ( - _aj_estimates_per_horizon(non_competing, horizons, full_event_table) - # .select(pl.exclude("real_competing_est")) - .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - ) - - elif risk_set_scope == "pooled_by_cutoff": - adjusted = extract_aj_estimate_by_cutoffs( - non_competing, horizons, breaks, stratified_by, full_event_table - ) - - adjusted = adjusted.with_columns( - [ - pl.col("strata").cast(strata_enum_dtype), - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope"), - ] - ) - return adjusted - - # For remaining cases, determine base dataframe depending on censoring rule: - # - "adjusted": use the full reference_group_data (events censored at horizon are kept/adjusted) - # - "excluded": remove administratively censored observations (use exploded with filter) - base_df = ( - exploded.filter( - (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") > 0) - ) - if censoring == "excluded" - else reference_group_data - ) - - # Apply competing-event transformation if required - if competing == "adjusted_as_censored": - base_df = base_df.with_columns( - pl.when(pl.col("reals") == 2) - .then(pl.lit(0)) - .otherwise(pl.col("reals")) - .alias("reals") - ) - elif competing == "adjusted_as_composite": - base_df = base_df.with_columns( - pl.when(pl.col("reals") == 2) - .then(pl.lit(1)) - .otherwise(pl.col("reals")) - .alias("reals") - ) - # competing == "adjusted_as_negative": keep reals as-is (no transform) - - # Finally choose aggregation strategy: per-stratum or horizon-wise - if censoring == "excluded": - # For excluded censoring we always evaluate per-horizon on the filtered (exploded) dataset - - if risk_set_scope == "within_stratum": - adjusted = _aj_estimates_per_horizon(base_df, horizons, full_event_table) - - adjusted = adjusted.join( - pl.DataFrame({"chosen_cutoff": breaks}), how="cross" - ) - - elif risk_set_scope == "pooled_by_cutoff": - adjusted = _aj_estimates_by_cutoff_per_horizon( - base_df, horizons, breaks, stratified_by - ) - - adjusted = adjusted.with_columns( - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope") - ) - - return adjusted.with_columns(pl.col("strata").cast(strata_enum_dtype)) - else: - # For adjusted censoring we aggregate within strata - - if risk_set_scope == "within_stratum": - adjusted = ( - base_df.group_by("strata") - .map_groups( - lambda group: extract_aj_estimate_for_strata( - group, horizons, full_event_table - ) - ) - .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - ) - - elif risk_set_scope == "pooled_by_cutoff": - adjusted = extract_aj_estimate_by_cutoffs( - base_df, horizons, breaks, stratified_by, full_event_table - ) - - adjusted = adjusted.with_columns( - [ - pl.col("strata").cast(strata_enum_dtype), - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope"), - ] - ) - - return adjusted - - -def _calculate_cumulative_aj_data_binary(aj_data: pl.DataFrame) -> pl.DataFrame: - cumulative_aj_data = ( - aj_data.group_by( - [ - "reference_group", - "stratified_by", - "chosen_cutoff", - "classification_outcome", - ] - ) - .agg([pl.col("reals_estimate").sum()]) - .pivot(on="classification_outcome", values="reals_estimate") - .with_columns( - [ - pl.col(col).fill_null(0) - for col in [ - "true_positives", - "true_negatives", - "false_positives", - "false_negatives", - ] - ] - ) - .with_columns( - (pl.col("true_positives") + pl.col("false_positives")).alias( - "predicted_positives" - ), - (pl.col("true_negatives") + pl.col("false_negatives")).alias( - "predicted_negatives" - ), - (pl.col("true_positives") + pl.col("false_negatives")).alias( - "real_positives" - ), - (pl.col("false_positives") + pl.col("true_negatives")).alias( - "real_negatives" - ), - ( - pl.col("true_positives") - + pl.col("true_negatives") - + pl.col("false_positives") - + pl.col("false_negatives") - ) - .alias("n") - .sum(), - ) - .with_columns( - (pl.col("true_positives") + pl.col("false_positives")).alias( - "predicted_positives" - ), - (pl.col("true_negatives") + pl.col("false_negatives")).alias( - "predicted_negatives" - ), - (pl.col("true_positives") + pl.col("false_negatives")).alias( - "real_positives" - ), - (pl.col("false_positives") + pl.col("true_negatives")).alias( - "real_negatives" - ), - ( - pl.col("true_positives") - + pl.col("true_negatives") - + pl.col("false_positives") - + pl.col("false_negatives") - ).alias("n"), - ) - ) - - return cumulative_aj_data - - -def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame: - cumulative_aj_data = ( - aj_data.filter(pl.col("risk_set_scope") == "pooled_by_cutoff") - .group_by( - [ - "reference_group", - "fixed_time_horizon", - "censoring_heuristic", - "competing_heuristic", - "stratified_by", - "chosen_cutoff", - "classification_outcome", - ] - ) - .agg([pl.col("reals_estimate").sum()]) - .pivot(on="classification_outcome", values="reals_estimate") - .fill_null(0) - .with_columns( - (pl.col("true_positives") + pl.col("false_positives")).alias( - "predicted_positives" - ), - (pl.col("true_negatives") + pl.col("false_negatives")).alias( - "predicted_negatives" - ), - (pl.col("true_positives") + pl.col("false_negatives")).alias( - "real_positives" - ), - (pl.col("false_positives") + pl.col("true_negatives")).alias( - "real_negatives" - ), - ( - pl.col("true_positives") - + pl.col("true_negatives") - + pl.col("false_positives") - + pl.col("false_negatives") - ).alias("n"), - ) - .with_columns( - (pl.col("true_positives") + pl.col("false_positives")).alias( - "predicted_positives" - ), - (pl.col("true_negatives") + pl.col("false_negatives")).alias( - "predicted_negatives" - ), - (pl.col("true_positives") + pl.col("false_negatives")).alias( - "real_positives" - ), - (pl.col("false_positives") + pl.col("true_negatives")).alias( - "real_negatives" - ), - ( - pl.col("true_positives") - + pl.col("true_negatives") - + pl.col("false_positives") - + pl.col("false_negatives") - ).alias("n"), - ) - ) - - return cumulative_aj_data - - -def _turn_cumulative_aj_to_performance_data( - cumulative_aj_data: pl.DataFrame, -) -> pl.DataFrame: - performance_data = cumulative_aj_data.with_columns( - (pl.col("true_positives") / pl.col("real_positives")).alias("sensitivity"), - (pl.col("true_negatives") / pl.col("real_negatives")).alias("specificity"), - (pl.col("true_positives") / pl.col("predicted_positives")).alias("ppv"), - (pl.col("true_negatives") / pl.col("predicted_negatives")).alias("npv"), - (pl.col("false_positives") / pl.col("real_negatives")).alias( - "false_positive_rate" - ), - ( - (pl.col("true_positives") / pl.col("predicted_positives")) - / (pl.col("real_positives") / pl.col("n")) - ).alias("lift"), - pl.when(pl.col("stratified_by") == "probability_threshold") - .then( - (pl.col("true_positives") / pl.col("n")) - - (pl.col("false_positives") / pl.col("n")) - * pl.col("chosen_cutoff") - / (1 - pl.col("chosen_cutoff")) - ) - .otherwise(None) - .alias("net_benefit"), - pl.when(pl.col("stratified_by") == "probability_threshold") - .then( - 100 * (pl.col("true_negatives") / pl.col("n")) - - (pl.col("false_negatives") / pl.col("n")) - * (1 - pl.col("chosen_cutoff")) - / pl.col("chosen_cutoff") - ) - .otherwise(None) - .alias("net_benefit_interventions_avoided"), - pl.when(pl.col("stratified_by") == "probability_threshold") - .then(pl.col("predicted_positives") / pl.col("n")) - .otherwise(pl.col("chosen_cutoff")) - .alias("ppcr"), - ) - - return performance_data diff --git a/src/rtichoke/performance_data/performance_data.py b/src/rtichoke/performance_data/performance_data.py index 3922723..8fa2d30 100644 --- a/src/rtichoke/performance_data/performance_data.py +++ b/src/rtichoke/performance_data/performance_data.py @@ -5,13 +5,15 @@ from typing import Dict, Union import polars as pl from collections.abc import Sequence -from rtichoke.helpers.sandbox_observable_helpers import ( +from rtichoke.processing.adjustments import _create_adjusted_data_binary +from rtichoke.processing.combinations import ( _create_aj_data_combinations_binary, create_breaks_values, - _create_list_data_to_adjust_binary, - _create_adjusted_data_binary, - _cast_and_join_adjusted_data_binary, +) +from rtichoke.processing.transforms import ( _calculate_cumulative_aj_data_binary, + _cast_and_join_adjusted_data_binary, + _create_list_data_to_adjust_binary, _turn_cumulative_aj_to_performance_data, ) import numpy as np diff --git a/src/rtichoke/performance_data/performance_data_times.py b/src/rtichoke/performance_data/performance_data_times.py index 901bd59..d1629b4 100644 --- a/src/rtichoke/performance_data/performance_data_times.py +++ b/src/rtichoke/performance_data/performance_data_times.py @@ -5,14 +5,16 @@ from typing import Dict, Union import polars as pl from collections.abc import Sequence -from rtichoke.helpers.sandbox_observable_helpers import ( - create_breaks_values, +from rtichoke.processing.adjustments import create_adjusted_data +from rtichoke.processing.combinations import ( create_aj_data_combinations, - _create_list_data_to_adjust, - create_adjusted_data, - cast_and_join_adjusted_data, + create_breaks_values, +) +from rtichoke.processing.transforms import ( _calculate_cumulative_aj_data, + _create_list_data_to_adjust, _turn_cumulative_aj_to_performance_data, + cast_and_join_adjusted_data, ) import numpy as np diff --git a/src/rtichoke/helpers/__init__.py b/src/rtichoke/processing/__init__.py similarity index 100% rename from src/rtichoke/helpers/__init__.py rename to src/rtichoke/processing/__init__.py diff --git a/src/rtichoke/processing/adjustments.py b/src/rtichoke/processing/adjustments.py new file mode 100644 index 0000000..05c52ca --- /dev/null +++ b/src/rtichoke/processing/adjustments.py @@ -0,0 +1,743 @@ +import pandas as pd +import polars as pl +from polarstate import predict_aj_estimates +from polarstate import prepare_event_table +from typing import Dict +from collections.abc import Sequence + + +def create_aj_data( + reference_group_data, + breaks, + censoring_heuristic, + competing_heuristic, + fixed_time_horizons, + stratified_by: Sequence[str], + full_event_table: bool = False, + risk_set_scope: Sequence[str] = ["within_stratum"], +): + """ + Create AJ estimates per strata based on censoring and competing heuristicss. + """ + + def aj_estimates_with_cross(df, extra_cols): + return df.join(pl.DataFrame(extra_cols), how="cross") + + exploded = assign_and_explode_polars(reference_group_data, fixed_time_horizons) + + event_table = prepare_event_table(reference_group_data) + + # TODO: solve strata in the pipeline + + excluded_events = _extract_excluded_events( + event_table, fixed_time_horizons, censoring_heuristic, competing_heuristic + ) + + aj_dfs = [] + for rscope in risk_set_scope: + aj_res = _aj_adjusted_events( + reference_group_data, + breaks, + exploded, + censoring_heuristic, + competing_heuristic, + fixed_time_horizons, + stratified_by, + full_event_table, + rscope, + ) + + aj_res = aj_res.select( + [ + "strata", + "times", + "chosen_cutoff", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + "estimate_origin", + "fixed_time_horizon", + "risk_set_scope", + ] + ) + + aj_dfs.append(aj_res) + + aj_df = pl.concat(aj_dfs, how="vertical") + + result = aj_df.join(excluded_events, on=["fixed_time_horizon"], how="left") + + return aj_estimates_with_cross( + result, + { + "censoring_heuristic": censoring_heuristic, + "competing_heuristic": competing_heuristic, + }, + ).select( + [ + "strata", + "chosen_cutoff", + "fixed_time_horizon", + "times", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + "real_censored_est", + "censoring_heuristic", + "competing_heuristic", + "estimate_origin", + "risk_set_scope", + ] + ) + + +def _extract_excluded_events( + event_table: pl.DataFrame, + fixed_time_horizons: list[float], + censoring_heuristic: str, + competing_heuristic: str, +) -> pl.DataFrame: + horizons_df = pl.DataFrame({"times": fixed_time_horizons}).sort("times") + + excluded_events = horizons_df.join_asof( + event_table.with_columns( + pl.col("count_0").cum_sum().cast(pl.Float64).alias("real_censored_est"), + pl.col("count_2").cum_sum().cast(pl.Float64).alias("real_competing_est"), + ).select( + pl.col("times"), + pl.col("real_censored_est"), + pl.col("real_competing_est"), + ), + left_on="times", + right_on="times", + ).with_columns([pl.col("times").alias("fixed_time_horizon")]) + + if censoring_heuristic != "excluded": + excluded_events = excluded_events.with_columns( + pl.lit(0.0).alias("real_censored_est") + ) + + if competing_heuristic != "excluded": + excluded_events = excluded_events.with_columns( + pl.lit(0.0).alias("real_competing_est") + ) + + return excluded_events + + +def extract_crude_estimate_polars(data: pl.DataFrame) -> pl.DataFrame: + all_combinations = data.select(["strata", "reals", "fixed_time_horizon"]).unique() + + counts = data.group_by(["strata", "reals", "fixed_time_horizon"]).agg( + pl.count().alias("reals_estimate") + ) + + return all_combinations.join( + counts, on=["strata", "reals", "fixed_time_horizon"], how="left" + ).with_columns([pl.col("reals_estimate").fill_null(0).cast(pl.Int64)]) + + +def extract_aj_estimate_by_cutoffs( + data_to_adjust, horizons, breaks, stratified_by, full_event_table: bool +): + # n = data_to_adjust.height + + counts_per_strata = ( + data_to_adjust.group_by( + ["strata", "stratified_by", "upper_bound", "lower_bound"] + ) + .len(name="strata_count") + .with_columns(pl.col("strata_count").cast(pl.Float64)) + ) + + aj_estimates_predicted_positives = pl.DataFrame() + aj_estimates_predicted_negatives = pl.DataFrame() + + for stratification_criteria in stratified_by: + for chosen_cutoff in breaks: + if stratification_criteria == "probability_threshold": + mask_predicted_positives = (pl.col("upper_bound") > chosen_cutoff) & ( + pl.col("stratified_by") == "probability_threshold" + ) + mask_predicted_negatives = (pl.col("upper_bound") <= chosen_cutoff) & ( + pl.col("stratified_by") == "probability_threshold" + ) + + elif stratification_criteria == "ppcr": + mask_predicted_positives = ( + pl.col("lower_bound") > 1 - chosen_cutoff + ) & (pl.col("stratified_by") == "ppcr") + mask_predicted_negatives = ( + pl.col("lower_bound") <= 1 - chosen_cutoff + ) & (pl.col("stratified_by") == "ppcr") + + predicted_positives = data_to_adjust.filter(mask_predicted_positives) + predicted_negatives = data_to_adjust.filter(mask_predicted_negatives) + + counts_per_strata_predicted_positives = counts_per_strata.filter( + mask_predicted_positives + ) + counts_per_strata_predicted_negatives = counts_per_strata.filter( + mask_predicted_negatives + ) + + event_table_predicted_positives = prepare_event_table(predicted_positives) + event_table_predicted_negatives = prepare_event_table(predicted_negatives) + + aj_estimate_predicted_positives = ( + ( + predict_aj_estimates( + event_table_predicted_positives, + pl.Series(horizons), + full_event_table, + ) + .with_columns( + pl.lit(chosen_cutoff).alias("chosen_cutoff"), + pl.lit(stratification_criteria) + .alias("stratified_by") + .cast(pl.Enum(["probability_threshold", "ppcr"])), + ) + .join( + counts_per_strata_predicted_positives, + on=["stratified_by"], + how="left", + ) + .with_columns( + [ + ( + pl.col("state_occupancy_probability_0") + * pl.col("strata_count") + ).alias("real_negatives_est"), + ( + pl.col("state_occupancy_probability_1") + * pl.col("strata_count") + ).alias("real_positives_est"), + ( + pl.col("state_occupancy_probability_2") + * pl.col("strata_count") + ).alias("real_competing_est"), + ] + ) + ) + .select( + [ + "strata", + # "stratified_by", + "times", + "chosen_cutoff", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + "estimate_origin", + ] + ) + .with_columns([pl.col("times").alias("fixed_time_horizon")]) + ) + + aj_estimate_predicted_negatives = ( + ( + predict_aj_estimates( + event_table_predicted_negatives, + pl.Series(horizons), + full_event_table, + ) + .with_columns( + pl.lit(chosen_cutoff).alias("chosen_cutoff"), + pl.lit(stratification_criteria) + .alias("stratified_by") + .cast(pl.Enum(["probability_threshold", "ppcr"])), + ) + .join( + counts_per_strata_predicted_negatives, + on=["stratified_by"], + how="left", + ) + .with_columns( + [ + ( + pl.col("state_occupancy_probability_0") + * pl.col("strata_count") + ).alias("real_negatives_est"), + ( + pl.col("state_occupancy_probability_1") + * pl.col("strata_count") + ).alias("real_positives_est"), + ( + pl.col("state_occupancy_probability_2") + * pl.col("strata_count") + ).alias("real_competing_est"), + ] + ) + ) + .select( + [ + "strata", + # "stratified_by", + "times", + "chosen_cutoff", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + "estimate_origin", + ] + ) + .with_columns([pl.col("times").alias("fixed_time_horizon")]) + ) + + aj_estimates_predicted_negatives = pl.concat( + [aj_estimates_predicted_negatives, aj_estimate_predicted_negatives], + how="vertical", + ) + + aj_estimates_predicted_positives = pl.concat( + [aj_estimates_predicted_positives, aj_estimate_predicted_positives], + how="vertical", + ) + + aj_estimate_by_cutoffs = pl.concat( + [aj_estimates_predicted_negatives, aj_estimates_predicted_positives], + how="vertical", + ) + + return aj_estimate_by_cutoffs + + +def extract_aj_estimate_for_strata(data_to_adjust, horizons, full_event_table: bool): + n = data_to_adjust.height + + event_table = prepare_event_table(data_to_adjust) + + aj_estimate_for_strata_polars = predict_aj_estimates( + event_table, pl.Series(horizons), full_event_table + ) + + if len(horizons) == 1: + aj_estimate_for_strata_polars = aj_estimate_for_strata_polars.with_columns( + pl.lit(horizons[0]).alias("fixed_time_horizon") + ) + + else: + fixed_df = aj_estimate_for_strata_polars.filter( + pl.col("estimate_origin") == "fixed_time_horizons" + ).with_columns([pl.col("times").alias("fixed_time_horizon")]) + + event_df = ( + aj_estimate_for_strata_polars.filter( + pl.col("estimate_origin") == "event_table" + ) + .with_columns([pl.lit(horizons).alias("fixed_time_horizon")]) + .explode("fixed_time_horizon") + ) + + aj_estimate_for_strata_polars = pl.concat( + [fixed_df, event_df], how="vertical" + ).sort("estimate_origin", "fixed_time_horizon", "times") + + return aj_estimate_for_strata_polars.with_columns( + [ + (pl.col("state_occupancy_probability_0") * n).alias("real_negatives_est"), + (pl.col("state_occupancy_probability_1") * n).alias("real_positives_est"), + (pl.col("state_occupancy_probability_2") * n).alias("real_competing_est"), + pl.col("fixed_time_horizon").cast(pl.Float64), + pl.lit(data_to_adjust["strata"][0]).alias("strata"), + ] + ).select( + [ + "strata", + "times", + "fixed_time_horizon", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + pl.col("estimate_origin"), + ] + ) + + +def ensure_no_categorical(df: pd.DataFrame) -> pd.DataFrame: + df = df.copy() + for col in df.select_dtypes(include="category").columns: + df[col] = df[col].astype(str) + return df + + +def extract_aj_estimate_by_heuristics( + df: pl.DataFrame, + breaks: Sequence[float], + heuristics_sets: list[dict], + fixed_time_horizons: list[float], + stratified_by: Sequence[str], + risk_set_scope: Sequence[str] = ["within_stratum"], +) -> pl.DataFrame: + aj_dfs = [] + + for heuristic in heuristics_sets: + censoring = heuristic["censoring_heuristic"] + competing = heuristic["competing_heuristic"] + + aj_df = create_aj_data( + df, + breaks, + censoring, + competing, + fixed_time_horizons, + stratified_by=stratified_by, + full_event_table=False, + risk_set_scope=risk_set_scope, + ).with_columns( + [ + pl.lit(censoring).alias("censoring_heuristic"), + pl.lit(competing).alias("competing_heuristic"), + ] + ) + + aj_dfs.append(aj_df) + + aj_estimates_data = pl.concat(aj_dfs).drop(["estimate_origin", "times"]) + + aj_estimates_unpivoted = aj_estimates_data.unpivot( + index=[ + "strata", + "chosen_cutoff", + "fixed_time_horizon", + "censoring_heuristic", + "competing_heuristic", + "risk_set_scope", + ], + variable_name="reals_labels", + value_name="reals_estimate", + ) + + return aj_estimates_unpivoted + + +def _create_adjusted_data_binary( + list_data_to_adjust: dict[str, pl.DataFrame], + breaks: Sequence[float], + stratified_by: Sequence[str], +) -> pl.DataFrame: + long_df = pl.concat(list(list_data_to_adjust.values()), how="vertical") + + adjusted_data_binary = ( + long_df.group_by(["strata", "stratified_by", "reference_group", "reals_labels"]) + .agg(pl.count().alias("reals_estimate")) + .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + ) + + return adjusted_data_binary + + +def create_adjusted_data( + list_data_to_adjust: dict[str, pl.DataFrame], + heuristics_sets: list[dict[str, str]], + fixed_time_horizons: list[float], + breaks: Sequence[float], + stratified_by: Sequence[str], + risk_set_scope: Sequence[str] = ["within_stratum"], +) -> pl.DataFrame: + all_results = [] + + reference_groups = list(list_data_to_adjust.keys()) + reference_group_enum = pl.Enum(reference_groups) + + heuristics_df = pl.DataFrame(heuristics_sets) + censoring_heuristic_enum = pl.Enum( + heuristics_df["censoring_heuristic"].unique(maintain_order=True) + ) + competing_heuristic_enum = pl.Enum( + heuristics_df["competing_heuristic"].unique(maintain_order=True) + ) + + for reference_group, df in list_data_to_adjust.items(): + input_df = df.select( + ["strata", "reals", "times", "upper_bound", "lower_bound", "stratified_by"] + ) + + aj_result = extract_aj_estimate_by_heuristics( + input_df, + breaks, + heuristics_sets=heuristics_sets, + fixed_time_horizons=fixed_time_horizons, + stratified_by=stratified_by, + risk_set_scope=risk_set_scope, + ) + + aj_result_with_group = aj_result.with_columns( + [ + pl.lit(reference_group) + .cast(reference_group_enum) + .alias("reference_group") + ] + ) + + all_results.append(aj_result_with_group) + + reals_enum_dtype = pl.Enum( + [ + "real_negatives", + "real_positives", + "real_competing", + "real_censored", + ] + ) + + return ( + pl.concat(all_results) + .with_columns([pl.col("reference_group").cast(reference_group_enum)]) + .with_columns( + [ + pl.col("reals_labels").str.replace(r"_est$", "").cast(reals_enum_dtype), + pl.col("censoring_heuristic").cast(censoring_heuristic_enum), + pl.col("competing_heuristic").cast(competing_heuristic_enum), + ] + ) + ) + + +def _censored_count(df: pl.DataFrame) -> pl.DataFrame: + return ( + df.with_columns( + ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 0)) + .cast(pl.Float64) + .alias("is_censored") + ) + .group_by(["strata", "fixed_time_horizon"]) + .agg(pl.col("is_censored").sum().alias("real_censored_est")) + ) + + +def _competing_count(df: pl.DataFrame) -> pl.DataFrame: + return ( + df.with_columns( + ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 2)) + .cast(pl.Float64) + .alias("is_competing") + ) + .group_by(["strata", "fixed_time_horizon"]) + .agg(pl.col("is_competing").sum().alias("real_competing_est")) + ) + + +def _aj_estimates_by_cutoff_per_horizon( + df: pl.DataFrame, + horizons: list[float], + breaks: Sequence[float], + stratified_by: Sequence[str], +) -> pl.DataFrame: + return pl.concat( + [ + df.filter(pl.col("fixed_time_horizon") == h) + .group_by("strata") + .map_groups( + lambda group: extract_aj_estimate_by_cutoffs( + group, [h], breaks, stratified_by, full_event_table=False + ) + ) + for h in horizons + ], + how="vertical", + ) + + +def _aj_estimates_per_horizon( + df: pl.DataFrame, horizons: list[float], full_event_table: bool +) -> pl.DataFrame: + return pl.concat( + [ + df.filter(pl.col("fixed_time_horizon") == h) + .group_by("strata") + .map_groups( + lambda group: extract_aj_estimate_for_strata( + group, [h], full_event_table + ) + ) + for h in horizons + ], + how="vertical", + ) + + +def _aj_adjusted_events( + reference_group_data: pl.DataFrame, + breaks: Sequence[float], + exploded: pl.DataFrame, + censoring: str, + competing: str, + horizons: list[float], + stratified_by: Sequence[str], + full_event_table: bool = False, + risk_set_scope: Sequence[str] = ["within_stratum"], +) -> pl.DataFrame: + strata_enum_dtype = reference_group_data.schema["strata"] + + # Special-case: adjusted censoring + competing adjusted_as_negative supports pooled_by_cutoff + if censoring == "adjusted" and competing == "adjusted_as_negative": + if risk_set_scope == "within_stratum": + adjusted = ( + reference_group_data.group_by("strata") + .map_groups( + lambda group: extract_aj_estimate_for_strata( + group, horizons, full_event_table + ) + ) + .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + ) + # preserve the original enum dtype for 'strata' coming from reference_group_data + + adjusted = adjusted.with_columns( + [ + pl.col("strata").cast(strata_enum_dtype), + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope"), + ] + ) + + return adjusted + + elif risk_set_scope == "pooled_by_cutoff": + adjusted = extract_aj_estimate_by_cutoffs( + reference_group_data, horizons, breaks, stratified_by, full_event_table + ) + adjusted = adjusted.with_columns( + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope") + ) + return adjusted + + # Special-case: both excluded (faster branch in original) + if censoring == "excluded" and competing == "excluded": + non_censored_non_competing = exploded.filter( + (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") == 1) + ) + + adjusted = _aj_estimates_per_horizon( + non_censored_non_competing, horizons, full_event_table + ) + + adjusted = adjusted.with_columns( + [ + pl.col("strata").cast(strata_enum_dtype), + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope"), + ] + ).join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + + return adjusted + + # Special-case: competing excluded (handled by filtering out competing events) + if competing == "excluded": + # Use exploded to apply filters that depend on fixed_time_horizon consistently + non_competing = exploded.filter( + (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") != 2) + ).with_columns( + pl.when(pl.col("reals") == 2) + .then(pl.lit(0)) + .otherwise(pl.col("reals")) + .alias("reals") + ) + + if risk_set_scope == "within_stratum": + adjusted = ( + _aj_estimates_per_horizon(non_competing, horizons, full_event_table) + # .select(pl.exclude("real_competing_est")) + .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + ) + + elif risk_set_scope == "pooled_by_cutoff": + adjusted = extract_aj_estimate_by_cutoffs( + non_competing, horizons, breaks, stratified_by, full_event_table + ) + + adjusted = adjusted.with_columns( + [ + pl.col("strata").cast(strata_enum_dtype), + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope"), + ] + ) + return adjusted + + # For remaining cases, determine base dataframe depending on censoring rule: + # - "adjusted": use the full reference_group_data (events censored at horizon are kept/adjusted) + # - "excluded": remove administratively censored observations (use exploded with filter) + base_df = ( + exploded.filter( + (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") > 0) + ) + if censoring == "excluded" + else reference_group_data + ) + + # Apply competing-event transformation if required + if competing == "adjusted_as_censored": + base_df = base_df.with_columns( + pl.when(pl.col("reals") == 2) + .then(pl.lit(0)) + .otherwise(pl.col("reals")) + .alias("reals") + ) + elif competing == "adjusted_as_composite": + base_df = base_df.with_columns( + pl.when(pl.col("reals") == 2) + .then(pl.lit(1)) + .otherwise(pl.col("reals")) + .alias("reals") + ) + # competing == "adjusted_as_negative": keep reals as-is (no transform) + + # Finally choose aggregation strategy: per-stratum or horizon-wise + if censoring == "excluded": + # For excluded censoring we always evaluate per-horizon on the filtered (exploded) dataset + + if risk_set_scope == "within_stratum": + adjusted = _aj_estimates_per_horizon(base_df, horizons, full_event_table) + + adjusted = adjusted.join( + pl.DataFrame({"chosen_cutoff": breaks}), how="cross" + ) + + elif risk_set_scope == "pooled_by_cutoff": + adjusted = _aj_estimates_by_cutoff_per_horizon( + base_df, horizons, breaks, stratified_by + ) + + adjusted = adjusted.with_columns( + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope") + ) + + return adjusted.with_columns(pl.col("strata").cast(strata_enum_dtype)) + else: + # For adjusted censoring we aggregate within strata + + if risk_set_scope == "within_stratum": + adjusted = ( + base_df.group_by("strata") + .map_groups( + lambda group: extract_aj_estimate_for_strata( + group, horizons, full_event_table + ) + ) + .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + ) + + elif risk_set_scope == "pooled_by_cutoff": + adjusted = extract_aj_estimate_by_cutoffs( + base_df, horizons, breaks, stratified_by, full_event_table + ) + + adjusted = adjusted.with_columns( + [ + pl.col("strata").cast(strata_enum_dtype), + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope"), + ] + ) + + return adjusted diff --git a/src/rtichoke/processing/combinations.py b/src/rtichoke/processing/combinations.py new file mode 100644 index 0000000..b790929 --- /dev/null +++ b/src/rtichoke/processing/combinations.py @@ -0,0 +1,218 @@ +import numpy as np +import polars as pl +from typing import Dict +from collections.abc import Sequence + + +def _enum_dataframe(column_name: str, values: Sequence[str]) -> pl.DataFrame: + """Create a single-column DataFrame with an enum dtype.""" + enum_values = list(dict.fromkeys(values)) + enum_dtype = pl.Enum(enum_values) + return pl.DataFrame({column_name: pl.Series(values, dtype=enum_dtype)}) + + +def create_strata_combinations(stratified_by: str, by: float, breaks) -> pl.DataFrame: + s_by = str(by) + decimals = len(s_by.split(".")[-1]) if "." in s_by else 0 + fmt = f"{{:.{decimals}f}}" + + if stratified_by == "probability_threshold": + upper_bound = breaks[1:] # breaks + lower_bound = breaks[:-1] # np.roll(upper_bound, 1) + # lower_bound[0] = 0.0 + mid_point = upper_bound - by / 2 + include_lower_bound = lower_bound > -0.1 + include_upper_bound = upper_bound == 1.0 # upper_bound != 0.0 + # chosen_cutoff = upper_bound + strata = format_strata_column( + lower_bound=lower_bound, + upper_bound=upper_bound, + include_lower_bound=include_lower_bound, + include_upper_bound=include_upper_bound, + decimals=2, + ) + + elif stratified_by == "ppcr": + strata_mid = breaks[1:] + lower_bound = strata_mid - by / 2 + upper_bound = strata_mid + by / 2 + mid_point = breaks[1:] + include_lower_bound = np.ones_like(strata_mid, dtype=bool) + include_upper_bound = np.zeros_like(strata_mid, dtype=bool) + # chosen_cutoff = strata_mid + strata = np.array([fmt.format(x) for x in strata_mid], dtype=object) + else: + raise ValueError(f"Unsupported stratified_by: {stratified_by}") + + bins_df = pl.DataFrame( + { + "strata": pl.Series(strata), + "lower_bound": lower_bound, + "upper_bound": upper_bound, + "mid_point": mid_point, + "include_lower_bound": include_lower_bound, + "include_upper_bound": include_upper_bound, + # "chosen_cutoff": chosen_cutoff, + "stratified_by": [stratified_by] * len(strata), + } + ) + + cutoffs_df = pl.DataFrame({"chosen_cutoff": breaks}) + + return bins_df.join(cutoffs_df, how="cross") + + +def format_strata_column( + lower_bound: list[float], + upper_bound: list[float], + include_lower_bound: list[bool], + include_upper_bound: list[bool], + decimals: int = 3, +) -> list[str]: + return [ + f"{'[' if ilb else '('}" + f"{round(lb, decimals):.{decimals}f}, " + f"{round(ub, decimals):.{decimals}f}" + f"{']' if iub else ')'}" + for lb, ub, ilb, iub in zip( + lower_bound, upper_bound, include_lower_bound, include_upper_bound + ) + ] + + +def format_strata_interval( + lower: float, upper: float, include_lower: bool, include_upper: bool +) -> str: + left = "[" if include_lower else "(" + right = "]" if include_upper else ")" + return f"{left}{lower:.3f}, {upper:.3f}{right}" + + +def create_breaks_values(probs_vec, stratified_by, by): + if stratified_by != "probability_threshold": + breaks = np.quantile(probs_vec, np.linspace(1, 0, int(1 / by) + 1)) + else: + breaks = np.round( + np.arange(0, 1 + by, by), decimals=len(str(by).split(".")[-1]) + ) + return breaks + + +def _create_aj_data_combinations_binary( + reference_groups: Sequence[str], + stratified_by: Sequence[str], + by: float, + breaks: Sequence[float], +) -> pl.DataFrame: + dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by] + + strata_combinations = pl.concat(dfs, how="vertical") + + strata_cats = ( + strata_combinations.select(pl.col("strata").unique(maintain_order=True)) + .to_series() + .to_list() + ) + + strata_enum = pl.Enum(strata_cats) + stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"]) + + strata_combinations = strata_combinations.with_columns( + [ + pl.col("strata").cast(strata_enum), + pl.col("stratified_by").cast(stratified_by_enum), + ] + ) + + # Define values for Cartesian product + reals_labels = ["real_negatives", "real_positives"] + + combinations_frames: list[pl.DataFrame] = [ + _enum_dataframe("reference_group", reference_groups), + strata_combinations, + _enum_dataframe("reals_labels", reals_labels), + ] + + result = combinations_frames[0] + for frame in combinations_frames[1:]: + result = result.join(frame, how="cross") + + return result + + +def create_aj_data_combinations( + reference_groups: Sequence[str], + heuristics_sets: list[Dict], + fixed_time_horizons: Sequence[float], + stratified_by: Sequence[str], + by: float, + breaks: Sequence[float], + risk_set_scope: Sequence[str] = ["within_stratum", "pooled_by_cutoff"], +) -> pl.DataFrame: + dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by] + strata_combinations = pl.concat(dfs, how="vertical") + + # strata_enum = pl.Enum(strata_combinations["strata"]) + + strata_cats = ( + strata_combinations.select(pl.col("strata").unique(maintain_order=True)) + .to_series() + .to_list() + ) + + strata_enum = pl.Enum(strata_cats) + stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"]) + + strata_combinations = strata_combinations.with_columns( + [ + pl.col("strata").cast(strata_enum), + pl.col("stratified_by").cast(stratified_by_enum), + ] + ) + + risk_set_scope_combinations = pl.DataFrame( + { + "risk_set_scope": pl.Series(risk_set_scope).cast( + pl.Enum(["within_stratum", "pooled_by_cutoff"]) + ) + } + ) + + # Define values for Cartesian product + reals_labels = [ + "real_negatives", + "real_positives", + "real_competing", + "real_censored", + ] + + heuristics_combinations = pl.DataFrame(heuristics_sets) + + censoring_heuristics_enum = pl.Enum( + heuristics_combinations["censoring_heuristic"].unique(maintain_order=True) + ) + competing_heuristics_enum = pl.Enum( + heuristics_combinations["competing_heuristic"].unique(maintain_order=True) + ) + + combinations_frames: list[pl.DataFrame] = [ + _enum_dataframe("reference_group", reference_groups), + pl.DataFrame( + {"fixed_time_horizon": pl.Series(fixed_time_horizons, dtype=pl.Float64)} + ), + heuristics_combinations.with_columns( + [ + pl.col("censoring_heuristic").cast(censoring_heuristics_enum), + pl.col("competing_heuristic").cast(competing_heuristics_enum), + ] + ), + strata_combinations, + risk_set_scope_combinations, + _enum_dataframe("reals_labels", reals_labels), + ] + + result = combinations_frames[0] + for frame in combinations_frames[1:]: + result = result.join(frame, how="cross") + + return result diff --git a/src/rtichoke/helpers/exported_functions.py b/src/rtichoke/processing/exported_functions.py similarity index 99% rename from src/rtichoke/helpers/exported_functions.py rename to src/rtichoke/processing/exported_functions.py index a273346..778ad91 100644 --- a/src/rtichoke/helpers/exported_functions.py +++ b/src/rtichoke/processing/exported_functions.py @@ -4,7 +4,7 @@ import plotly.graph_objects as go -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( create_non_interactive_curve, create_interactive_marker, create_reference_lines_for_plotly, diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/processing/plotly_helper_functions.py similarity index 100% rename from src/rtichoke/helpers/plotly_helper_functions.py rename to src/rtichoke/processing/plotly_helper_functions.py diff --git a/src/rtichoke/helpers/send_post_request_to_r_rtichoke.py b/src/rtichoke/processing/send_post_request_to_r_rtichoke.py similarity index 98% rename from src/rtichoke/helpers/send_post_request_to_r_rtichoke.py rename to src/rtichoke/processing/send_post_request_to_r_rtichoke.py index aaca6f3..f8254e6 100644 --- a/src/rtichoke/helpers/send_post_request_to_r_rtichoke.py +++ b/src/rtichoke/processing/send_post_request_to_r_rtichoke.py @@ -4,7 +4,7 @@ # import requests import pandas as pd -from rtichoke.helpers.exported_functions import create_plotly_curve +from rtichoke.processing.exported_functions import create_plotly_curve def send_requests_to_rtichoke_r(dictionary_to_send, url_api, endpoint): diff --git a/src/rtichoke/processing/transforms.py b/src/rtichoke/processing/transforms.py new file mode 100644 index 0000000..f2acb92 --- /dev/null +++ b/src/rtichoke/processing/transforms.py @@ -0,0 +1,700 @@ +import numpy as np +import polars as pl +from typing import Dict, Union +from collections.abc import Sequence + + +def add_cutoff_strata(data: pl.DataFrame, by: float, stratified_by) -> pl.DataFrame: + def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame: + probs = group["probs"].to_numpy() + columns_to_add = [] + + breaks = create_breaks_values(probs, "probability_threshold", by) + if "probability_threshold" in stratified_by: + last_bin_index = len(breaks) - 2 + + bin_indices = np.digitize(probs, bins=breaks, right=False) - 1 + bin_indices = np.where(probs == 1.0, last_bin_index, bin_indices) + + lower_bounds = breaks[bin_indices] + upper_bounds = breaks[bin_indices + 1] + + include_upper_bounds = bin_indices == last_bin_index + + strata_prob_labels = np.where( + include_upper_bounds, + [f"[{lo:.2f}, {hi:.2f}]" for lo, hi in zip(lower_bounds, upper_bounds)], + [f"[{lo:.2f}, {hi:.2f})" for lo, hi in zip(lower_bounds, upper_bounds)], + ).astype(str) + + columns_to_add.append( + pl.Series("strata_probability_threshold", strata_prob_labels) + ) + + if "ppcr" in stratified_by: + # --- Compute strata_ppcr as equal-frequency quantile bins by rank --- + by = float(by) + q = int(round(1 / by)) # e.g. 0.2 -> 5 bins + + probs = np.asarray(probs, float) + + edges = np.quantile(probs, np.linspace(0.0, 1.0, q + 1), method="linear") + + edges = np.maximum.accumulate(edges) + + edges[0] = 0.0 + edges[-1] = 1.0 + + bin_idx = np.digitize(probs, bins=edges[1:-1], right=True) + + s = str(by) + decimals = len(s.split(".")[-1]) if "." in s else 0 + + labels = [f"{x:.{decimals}f}" for x in np.linspace(by, 1.0, q)] + + strata_labels = np.array(labels)[bin_idx] + + columns_to_add.append( + pl.Series("strata_ppcr", strata_labels).cast(pl.Enum(labels)) + ) + return group.with_columns(columns_to_add) + + # Apply per-group transformation + grouped = data.partition_by("reference_group", as_dict=True) + transformed_groups = [transform_group(group, by) for group in grouped.values()] + return pl.concat(transformed_groups) + + +def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame: + # Identify id_vars and value_vars + id_vars = [col for col in data.columns if not col.startswith("strata_")] + value_vars = [col for col in data.columns if col.startswith("strata_")] + + # Perform the melt (equivalent to pandas.melt) + data_long = data.melt( + id_vars=id_vars, + value_vars=value_vars, + variable_name="stratified_by", + value_name="strata", + ) + + stratified_by_labels = ["probability_threshold", "ppcr"] + stratified_by_enum = pl.Enum(stratified_by_labels) + + # Remove "strata_" prefix from the 'stratified_by' column + data_long = data_long.with_columns( + pl.col("stratified_by").str.replace("^strata_", "").cast(stratified_by_enum) + ) + + return data_long + + +def map_reals_to_labels_polars(data: pl.DataFrame) -> pl.DataFrame: + return data.with_columns( + [ + pl.when(pl.col("reals") == 0) + .then("real_negatives") + .when(pl.col("reals") == 1) + .then("real_positives") + .when(pl.col("reals") == 2) + .then("real_competing") + .otherwise("real_censored") + .alias("reals") + ] + ) + + +def update_administrative_censoring_polars(data: pl.DataFrame) -> pl.DataFrame: + data = data.with_columns( + [ + pl.when( + (pl.col("times") > pl.col("fixed_time_horizon")) + & (pl.col("reals_labels") == "real_positives") + ) + .then(pl.lit("real_negatives")) + .when( + (pl.col("times") < pl.col("fixed_time_horizon")) + & (pl.col("reals_labels") == "real_negatives") + ) + .then(pl.lit("real_censored")) + .otherwise(pl.col("reals_labels")) + .alias("reals_labels") + ] + ) + + return data + + +def assign_and_explode_polars( + data: pl.DataFrame, fixed_time_horizons: list[float] +) -> pl.DataFrame: + return ( + data.with_columns(pl.lit(fixed_time_horizons).alias("fixed_time_horizon")) + .explode("fixed_time_horizon") + .with_columns(pl.col("fixed_time_horizon").cast(pl.Float64)) + ) + + +def _create_list_data_to_adjust_binary( + aj_data_combinations: pl.DataFrame, + probs_dict: Dict[str, np.ndarray], + reals_dict: Union[np.ndarray, Dict[str, np.ndarray]], + stratified_by, + by, +) -> Dict[str, pl.DataFrame]: + reference_group_labels = list(probs_dict.keys()) + + if isinstance(reals_dict, dict): + num_keys_reals = len(reals_dict) + else: + num_keys_reals = 1 + + reference_group_enum = pl.Enum(reference_group_labels) + + strata_enum_dtype = aj_data_combinations.schema["strata"] + + if len(probs_dict) == 1: + probs_array = np.asarray(probs_dict[reference_group_labels[0]]) + + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, len(probs_array)), + "probs": probs_array, + "reals": reals_dict, + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif num_keys_reals == 1: + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, len(reals_dict)), + "probs": np.concatenate( + [probs_dict[group] for group in reference_group_labels] + ), + "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)), + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif isinstance(reals_dict, dict): + data_to_adjust = ( + pl.DataFrame( + { + "reference_group": list(probs_dict.keys()), + "probs": list(probs_dict.values()), + "reals": list(reals_dict.values()), + } + ) + .explode(["probs", "reals"]) + .with_columns(pl.col("reference_group").cast(reference_group_enum)) + ) + + data_to_adjust = add_cutoff_strata( + data_to_adjust, by=by, stratified_by=stratified_by + ) + + data_to_adjust = pivot_longer_strata(data_to_adjust) + + data_to_adjust = ( + data_to_adjust.with_columns([pl.col("strata")]) + .with_columns(pl.col("strata").cast(strata_enum_dtype)) + .join( + aj_data_combinations.select( + pl.col("strata"), + pl.col("stratified_by"), + pl.col("upper_bound"), + pl.col("lower_bound"), + ).unique(), + how="left", + on=["strata", "stratified_by"], + ) + ) + + reals_labels = ["real_negatives", "real_positives"] + + reals_enum = pl.Enum(reals_labels) + + reals_map = {0: "real_negatives", 1: "real_positives"} + + data_to_adjust = data_to_adjust.with_columns( + pl.col("reals") + .replace_strict(reals_map, return_dtype=reals_enum) + .alias("reals_labels") + ) + + list_data_to_adjust = { + group[0]: df + for group, df in data_to_adjust.partition_by( + "reference_group", as_dict=True + ).items() + } + + return list_data_to_adjust + + +def _create_list_data_to_adjust( + aj_data_combinations: pl.DataFrame, + probs_dict: Dict[str, np.ndarray], + reals_dict: Union[np.ndarray, Dict[str, np.ndarray]], + times_dict: Union[np.ndarray, Dict[str, np.ndarray]], + stratified_by, + by, +) -> Dict[str, pl.DataFrame]: + # reference_groups = list(probs_dict.keys()) + reference_group_labels = list(probs_dict.keys()) + + if isinstance(reals_dict, dict): + num_keys_reals = len(reals_dict) + else: + num_keys_reals = 1 + + # num_reals = len(reals_dict) + + reference_group_enum = pl.Enum(reference_group_labels) + + strata_enum_dtype = aj_data_combinations.schema["strata"] + + if len(probs_dict) == 1: + probs_array = np.asarray(probs_dict[reference_group_labels[0]]) + + if isinstance(reals_dict, dict): + reals_array = np.asarray(reals_dict[0]) + else: + reals_array = np.asarray(reals_dict) + + if isinstance(times_dict, dict): + times_array = np.asarray(times_dict[0]) + else: + times_array = np.asarray(times_dict) + + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, len(probs_array)), + "probs": probs_array, + "reals": reals_array, + "times": times_array, + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif num_keys_reals == 1: + reals_array = np.asarray(reals_dict) + times_array = np.asarray(times_dict) + n = len(reals_array) + + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, n), + "probs": np.concatenate( + [np.asarray(probs_dict[g]) for g in reference_group_labels] + ), + "reals": np.tile(reals_array, len(reference_group_labels)), + "times": np.tile(times_array, len(reference_group_labels)), + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif isinstance(reals_dict, dict) and isinstance(times_dict, dict): + data_to_adjust = ( + pl.DataFrame( + { + "reference_group": reference_group_labels, + "probs": list(probs_dict.values()), + "reals": list(reals_dict.values()), + "times": list(times_dict.values()), + } + ) + .explode(["probs", "reals", "times"]) + .with_columns(pl.col("reference_group").cast(reference_group_enum)) + ) + + data_to_adjust = add_cutoff_strata( + data_to_adjust, by=by, stratified_by=stratified_by + ) + + data_to_adjust = pivot_longer_strata(data_to_adjust) + + data_to_adjust = ( + data_to_adjust.with_columns([pl.col("strata")]) + .with_columns(pl.col("strata").cast(strata_enum_dtype)) + .join( + aj_data_combinations.select( + pl.col("strata"), + pl.col("stratified_by"), + pl.col("upper_bound"), + pl.col("lower_bound"), + ).unique(), + how="left", + on=["strata", "stratified_by"], + ) + ) + + reals_labels = [ + "real_negatives", + "real_positives", + "real_competing", + "real_censored", + ] + + reals_enum = pl.Enum(reals_labels) + + # Map reals values to strings + reals_map = {0: "real_negatives", 2: "real_competing", 1: "real_positives"} + + data_to_adjust = data_to_adjust.with_columns( + pl.col("reals") + .replace_strict(reals_map, return_dtype=reals_enum) + .alias("reals_labels") + ) + + # Partition by reference_group + list_data_to_adjust = { + group[0]: df + for group, df in data_to_adjust.partition_by( + "reference_group", as_dict=True + ).items() + } + + return list_data_to_adjust + + +def _cast_and_join_adjusted_data_binary( + aj_data_combinations: pl.DataFrame, aj_estimates_data: pl.DataFrame +) -> pl.DataFrame: + strata_enum_dtype = aj_data_combinations.schema["strata"] + + aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns( + pl.col("strata").cast(strata_enum_dtype) + ) + + final_adjusted_data_polars = ( + ( + aj_data_combinations.with_columns([pl.col("strata")]).join( + aj_estimates_data, + on=[ + "strata", + "stratified_by", + "reals_labels", + "reference_group", + "chosen_cutoff", + ], + how="left", + ) + ) + .with_columns( + pl.when( + ( + (pl.col("chosen_cutoff") >= pl.col("upper_bound")) + & (pl.col("stratified_by") == "probability_threshold") + ) + | ( + ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point")) + & (pl.col("stratified_by") == "ppcr") + ) + ) + .then(pl.lit("predicted_negatives")) + .otherwise(pl.lit("predicted_positives")) + .cast(pl.Enum(["predicted_negatives", "predicted_positives"])) + .alias("prediction_label") + ) + .with_columns( + ( + pl.when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_positives")) + ) + .then(pl.lit("true_positives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_negatives")) + ) + .then(pl.lit("false_positives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_negatives")) + ) + .then(pl.lit("true_negatives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_positives")) + ) + .then(pl.lit("false_negatives")) + .cast( + pl.Enum( + [ + "true_positives", + "false_positives", + "true_negatives", + "false_negatives", + ] + ) + ) + ).alias("classification_outcome") + ) + ).with_columns(pl.col("reals_estimate").fill_null(0)) + + return final_adjusted_data_polars + + +def cast_and_join_adjusted_data( + aj_data_combinations, aj_estimates_data +) -> pl.DataFrame: + strata_enum_dtype = aj_data_combinations.schema["strata"] + + aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns( + pl.col("strata").cast(strata_enum_dtype) + ) + + final_adjusted_data_polars = ( + aj_data_combinations.with_columns([pl.col("strata")]) + .join( + aj_estimates_data, + on=[ + "strata", + "fixed_time_horizon", + "censoring_heuristic", + "competing_heuristic", + "reals_labels", + "reference_group", + "chosen_cutoff", + "risk_set_scope", + ], + how="left", + ) + .with_columns( + pl.when( + ( + (pl.col("chosen_cutoff") >= pl.col("upper_bound")) + & (pl.col("stratified_by") == "probability_threshold") + ) + | ( + ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point")) + & (pl.col("stratified_by") == "ppcr") + ) + ) + .then(pl.lit("predicted_negatives")) + .otherwise(pl.lit("predicted_positives")) + .cast(pl.Enum(["predicted_negatives", "predicted_positives"])) + .alias("prediction_label") + ) + .with_columns( + ( + pl.when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_positives")) + ) + .then(pl.lit("true_positives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_negatives")) + ) + .then(pl.lit("false_positives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_negatives")) + ) + .then(pl.lit("true_negatives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_positives")) + ) + .then(pl.lit("false_negatives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_competing")) + & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative")) + ) + .then(pl.lit("true_negatives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_competing")) + & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative")) + ) + .then(pl.lit("false_positives")) + .otherwise(pl.lit("excluded")) # or pl.lit(None) if you prefer nulls + .cast( + pl.Enum( + [ + "true_positives", + "false_positives", + "true_negatives", + "false_negatives", + "excluded", + ] + ) + ) + ).alias("classification_outcome") + ) + ) + return final_adjusted_data_polars + + +def _calculate_cumulative_aj_data_binary(aj_data: pl.DataFrame) -> pl.DataFrame: + cumulative_aj_data = ( + aj_data.group_by( + [ + "reference_group", + "stratified_by", + "chosen_cutoff", + "classification_outcome", + ] + ) + .agg([pl.col("reals_estimate").sum()]) + .pivot(on="classification_outcome", values="reals_estimate") + .with_columns( + [ + pl.col(col).fill_null(0) + for col in [ + "true_positives", + "true_negatives", + "false_positives", + "false_negatives", + ] + ] + ) + .with_columns( + (pl.col("true_positives") + pl.col("false_positives")).alias( + "predicted_positives" + ), + (pl.col("true_negatives") + pl.col("false_negatives")).alias( + "predicted_negatives" + ), + (pl.col("true_positives") + pl.col("false_negatives")).alias( + "real_positives" + ), + (pl.col("false_positives") + pl.col("true_negatives")).alias( + "real_negatives" + ), + ( + pl.col("true_positives") + + pl.col("true_negatives") + + pl.col("false_positives") + + pl.col("false_negatives") + ) + .alias("n") + .sum(), + ) + .with_columns( + (pl.col("true_positives") + pl.col("false_positives")).alias( + "predicted_positives" + ), + (pl.col("true_negatives") + pl.col("false_negatives")).alias( + "predicted_negatives" + ), + (pl.col("true_positives") + pl.col("false_negatives")).alias( + "real_positives" + ), + (pl.col("false_positives") + pl.col("true_negatives")).alias( + "real_negatives" + ), + ( + pl.col("true_positives") + + pl.col("true_negatives") + + pl.col("false_positives") + + pl.col("false_negatives") + ).alias("n"), + ) + ) + + return cumulative_aj_data + + +def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame: + cumulative_aj_data = ( + aj_data.filter(pl.col("risk_set_scope") == "pooled_by_cutoff") + .group_by( + [ + "reference_group", + "fixed_time_horizon", + "censoring_heuristic", + "competing_heuristic", + "stratified_by", + "chosen_cutoff", + "classification_outcome", + ] + ) + .agg([pl.col("reals_estimate").sum()]) + .pivot(on="classification_outcome", values="reals_estimate") + .fill_null(0) + .with_columns( + (pl.col("true_positives") + pl.col("false_positives")).alias( + "predicted_positives" + ), + (pl.col("true_negatives") + pl.col("false_negatives")).alias( + "predicted_negatives" + ), + (pl.col("true_positives") + pl.col("false_negatives")).alias( + "real_positives" + ), + (pl.col("false_positives") + pl.col("true_negatives")).alias( + "real_negatives" + ), + ( + pl.col("true_positives") + + pl.col("true_negatives") + + pl.col("false_positives") + + pl.col("false_negatives") + ).alias("n"), + ) + .with_columns( + (pl.col("true_positives") + pl.col("false_positives")).alias( + "predicted_positives" + ), + (pl.col("true_negatives") + pl.col("false_negatives")).alias( + "predicted_negatives" + ), + (pl.col("true_positives") + pl.col("false_negatives")).alias( + "real_positives" + ), + (pl.col("false_positives") + pl.col("true_negatives")).alias( + "real_negatives" + ), + ( + pl.col("true_positives") + + pl.col("true_negatives") + + pl.col("false_positives") + + pl.col("false_negatives") + ).alias("n"), + ) + ) + + return cumulative_aj_data + + +def _turn_cumulative_aj_to_performance_data( + cumulative_aj_data: pl.DataFrame, +) -> pl.DataFrame: + performance_data = cumulative_aj_data.with_columns( + (pl.col("true_positives") / pl.col("real_positives")).alias("sensitivity"), + (pl.col("true_negatives") / pl.col("real_negatives")).alias("specificity"), + (pl.col("true_positives") / pl.col("predicted_positives")).alias("ppv"), + (pl.col("true_negatives") / pl.col("predicted_negatives")).alias("npv"), + (pl.col("false_positives") / pl.col("real_negatives")).alias( + "false_positive_rate" + ), + ( + (pl.col("true_positives") / pl.col("predicted_positives")) + / (pl.col("real_positives") / pl.col("n")) + ).alias("lift"), + pl.when(pl.col("stratified_by") == "probability_threshold") + .then( + (pl.col("true_positives") / pl.col("n")) + - (pl.col("false_positives") / pl.col("n")) + * pl.col("chosen_cutoff") + / (1 - pl.col("chosen_cutoff")) + ) + .otherwise(None) + .alias("net_benefit"), + pl.when(pl.col("stratified_by") == "probability_threshold") + .then( + 100 * (pl.col("true_negatives") / pl.col("n")) + - (pl.col("false_negatives") / pl.col("n")) + * (1 - pl.col("chosen_cutoff")) + / pl.col("chosen_cutoff") + ) + .otherwise(None) + .alias("net_benefit_interventions_avoided"), + pl.when(pl.col("stratified_by") == "probability_threshold") + .then(pl.col("predicted_positives") / pl.col("n")) + .otherwise(pl.col("chosen_cutoff")) + .alias("ppcr"), + ) + + return performance_data \ No newline at end of file diff --git a/src/rtichoke/summary_report/summary_report.py b/src/rtichoke/summary_report/summary_report.py index 9549260..8506794 100644 --- a/src/rtichoke/summary_report/summary_report.py +++ b/src/rtichoke/summary_report/summary_report.py @@ -2,8 +2,10 @@ A module for Summary Report """ -from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r -from rtichoke.helpers.sandbox_observable_helpers import ( +from rtichoke.processing.send_post_request_to_r_rtichoke import ( + send_requests_to_rtichoke_r, +) +from rtichoke.processing.transforms import ( _create_list_data_to_adjust, ) import subprocess diff --git a/src/rtichoke/utility/decision.py b/src/rtichoke/utility/decision.py index 50f0e6d..57fdf53 100644 --- a/src/rtichoke/utility/decision.py +++ b/src/rtichoke/utility/decision.py @@ -4,7 +4,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_binary, _create_rtichoke_plotly_curve_times, _plot_rtichoke_curve_binary, diff --git a/tests/test_rtichoke.py b/tests/test_rtichoke.py index 0ff1b91..1dd7916 100644 --- a/tests/test_rtichoke.py +++ b/tests/test_rtichoke.py @@ -2,7 +2,7 @@ A module for tests """ -from rtichoke.helpers.sandbox_observable_helpers import ( +from rtichoke.processing.adjustments import ( extract_aj_estimate_for_strata, ) From f8dd249fd68a0613e9e667970e10ffa976e31449 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Wed, 24 Dec 2025 08:17:04 +0200 Subject: [PATCH 17/17] build: bump version and update __init__.py --- pyproject.toml | 2 +- src/rtichoke/__init__.py | 1 + src/rtichoke/calibration/__init__.py | 1 + src/rtichoke/calibration/calibration.py | 229 ++++++++++++++++-------- src/rtichoke/processing/adjustments.py | 2 +- src/rtichoke/processing/transforms.py | 4 +- tests/test_calibration.py | 3 +- tests/test_calibration_times.py | 7 +- tests/test_heuristics.py | 97 ++++++---- uv.lock | 2 +- 10 files changed, 233 insertions(+), 115 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d9c3f98..2388c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "statsmodels>=0.14.0", ] name = "rtichoke" -version = "0.1.25" +version = "0.1.26" description = "interactive visualizations for performance of predictive models" readme = "README.md" diff --git a/src/rtichoke/__init__.py b/src/rtichoke/__init__.py index 97fbc14..bc84b74 100644 --- a/src/rtichoke/__init__.py +++ b/src/rtichoke/__init__.py @@ -32,6 +32,7 @@ from rtichoke.calibration.calibration import ( create_calibration_curve as create_calibration_curve, + create_calibration_curve_times as create_calibration_curve_times, ) from rtichoke.utility.decision import ( diff --git a/src/rtichoke/calibration/__init__.py b/src/rtichoke/calibration/__init__.py index c4b9dcb..190e74e 100644 --- a/src/rtichoke/calibration/__init__.py +++ b/src/rtichoke/calibration/__init__.py @@ -1,6 +1,7 @@ """ Subpackage for Calibration """ + from .calibration import create_calibration_curve, create_calibration_curve_times __all__ = ["create_calibration_curve", "create_calibration_curve_times"] diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 76c319b..4e245a9 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -147,7 +147,8 @@ def _create_plotly_curve_from_calibration_curve_list_times( showlegend=False, visible=visible, ), - row=1, col=1, + row=1, + col=1, ) for group in calibration_curve_list["reference_group_keys"]: @@ -156,12 +157,14 @@ def _create_plotly_curve_from_calibration_curve_list_times( # Calibration curve (discrete or smooth) if calibration_type == "discrete": data_subset = calibration_curve_list["deciles_dat"].filter( - (pl.col("reference_group") == group) & (pl.col("fixed_time_horizon") == horizon) + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == horizon) ) mode = "lines+markers" - else: # smooth + else: # smooth data_subset = calibration_curve_list["smooth_dat"].filter( - (pl.col("reference_group") == group) & (pl.col("fixed_time_horizon") == horizon) + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == horizon) ) mode = "lines+markers" if data_subset.height == 1 else "lines" @@ -177,12 +180,14 @@ def _create_plotly_curve_from_calibration_curve_list_times( marker={"size": 10, "color": color}, visible=visible, ), - row=1, col=1, + row=1, + col=1, ) # Histogram hist_subset = calibration_curve_list["histogram_for_calibration"].filter( - (pl.col("reference_group") == group) & (pl.col("fixed_time_horizon") == horizon) + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == horizon) ) fig.add_trace( go.Bar( @@ -198,7 +203,8 @@ def _create_plotly_curve_from_calibration_curve_list_times( opacity=0.4, visible=visible, ), - row=2, col=1, + row=2, + col=1, ) # Create slider @@ -206,7 +212,9 @@ def _create_plotly_curve_from_calibration_curve_list_times( num_traces_per_horizon = 1 + 2 * len(calibration_curve_list["reference_group_keys"]) for i, horizon in enumerate(calibration_curve_list["fixed_time_horizons"]): - visibility = [False] * (num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"])) + visibility = [False] * ( + num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"]) + ) for j in range(num_traces_per_horizon): visibility[i * num_traces_per_horizon + j] = True step = dict( @@ -216,23 +224,36 @@ def _create_plotly_curve_from_calibration_curve_list_times( ) steps.append(step) - sliders = [dict( - active=0, - currentvalue={"prefix": "Time Horizon: "}, - pad={"t": 50}, - steps=steps, - )] + sliders = [ + dict( + active=0, + currentvalue={"prefix": "Time Horizon: "}, + pad={"t": 50}, + steps=steps, + ) + ] # Layout fig.update_layout( sliders=sliders, - xaxis={"showgrid": False, "range": calibration_curve_list["axes_ranges"]["xaxis"]}, - yaxis={"showgrid": False, "range": calibration_curve_list["axes_ranges"]["yaxis"], "title": "Observed"}, + xaxis={ + "showgrid": False, + "range": calibration_curve_list["axes_ranges"]["xaxis"], + }, + yaxis={ + "showgrid": False, + "range": calibration_curve_list["axes_ranges"]["yaxis"], + "title": "Observed", + }, barmode="overlay", plot_bgcolor="rgba(0, 0, 0, 0)", legend={ - "orientation": "h", "xanchor": "center", "yanchor": "top", - "x": 0.5, "y": 1.3, "bgcolor": "rgba(0, 0, 0, 0)", + "orientation": "h", + "xanchor": "center", + "yanchor": "top", + "x": 0.5, + "y": 1.3, + "bgcolor": "rgba(0, 0, 0, 0)", }, showlegend=calibration_curve_list["performance_type"][0] != "one model", width=calibration_curve_list["size"][0][0], @@ -653,7 +674,11 @@ def _calculate_smooth_curve( def process_single_array(p, r, group_name): if len(np.unique(p)) == 1: return pl.DataFrame( - {"x": [np.unique(p)[0]], "y": [np.mean(r)], "reference_group": [group_name]} + { + "x": [np.unique(p)[0]], + "y": [np.mean(r)], + "reference_group": [group_name], + } ) else: # lowess returns a 2D array where the first column is x and the second is y @@ -668,23 +693,32 @@ def process_single_array(p, r, group_name): for model_name, prob_array in probs.items(): # This logic assumes that for multiple populations, one model's probs are evaluated against multiple real outcomes. # This might need adjustment based on the exact structure for multiple models and populations. - if len(probs) == 1 and len(reals) > 1: # One model, multiple populations - for pop_name, real_array in reals.items(): - frame = process_single_array(prob_array, real_array, pop_name) - smooth_frames.append(frame) - else: # Multiple models, potentially multiple populations + if len(probs) == 1 and len(reals) > 1: # One model, multiple populations + for pop_name, real_array in reals.items(): + frame = process_single_array(prob_array, real_array, pop_name) + smooth_frames.append(frame) + else: # Multiple models, potentially multiple populations for group_name in reals.keys(): if group_name in probs: - frame = process_single_array(probs[group_name], reals[group_name], group_name) + frame = process_single_array( + probs[group_name], reals[group_name], group_name + ) smooth_frames.append(frame) - else: # reals is a single numpy array + else: # reals is a single numpy array for group_name, prob_array in probs.items(): frame = process_single_array(prob_array, reals, group_name) smooth_frames.append(frame) if not smooth_frames: - return pl.DataFrame(schema={"x": pl.Float64, "y": pl.Float64, "reference_group": pl.Utf8, "text": pl.Utf8}) + return pl.DataFrame( + schema={ + "x": pl.Float64, + "y": pl.Float64, + "reference_group": pl.Utf8, + "text": pl.Utf8, + } + ) smooth_dat = pl.concat(smooth_frames) @@ -731,31 +765,59 @@ def _add_hover_text_to_calibration_data( """Adds hover text to the deciles and smooth dataframes.""" if performance_type != "one model": deciles_dat = deciles_dat.with_columns( - pl.concat_str([ - pl.lit(""), pl.col("reference_group"), pl.lit("
Predicted: "), - pl.col("x").round(3), pl.lit("
Observed: "), pl.col("y").round(3), - pl.lit(" ( "), pl.col("n_reals"), pl.lit(" / "), pl.col("n"), pl.lit(" )"), - ]).alias("text") + pl.concat_str( + [ + pl.lit(""), + pl.col("reference_group"), + pl.lit("
Predicted: "), + pl.col("x").round(3), + pl.lit("
Observed: "), + pl.col("y").round(3), + pl.lit(" ( "), + pl.col("n_reals"), + pl.lit(" / "), + pl.col("n"), + pl.lit(" )"), + ] + ).alias("text") ) smooth_dat = smooth_dat.with_columns( - pl.concat_str([ - pl.lit(""), pl.col("reference_group"), pl.lit("
Predicted: "), - pl.col("x").round(3), pl.lit("
Observed: "), pl.col("y").round(3), - ]).alias("text") + pl.concat_str( + [ + pl.lit(""), + pl.col("reference_group"), + pl.lit("
Predicted: "), + pl.col("x").round(3), + pl.lit("
Observed: "), + pl.col("y").round(3), + ] + ).alias("text") ) else: deciles_dat = deciles_dat.with_columns( - pl.concat_str([ - pl.lit("Predicted: "), pl.col("x").round(3), pl.lit("
Observed: "), - pl.col("y").round(3), pl.lit(" ( "), pl.col("n_reals"), - pl.lit(" / "), pl.col("n"), pl.lit(" )"), - ]).alias("text") + pl.concat_str( + [ + pl.lit("Predicted: "), + pl.col("x").round(3), + pl.lit("
Observed: "), + pl.col("y").round(3), + pl.lit(" ( "), + pl.col("n_reals"), + pl.lit(" / "), + pl.col("n"), + pl.lit(" )"), + ] + ).alias("text") ) smooth_dat = smooth_dat.with_columns( - pl.concat_str([ - pl.lit("Predicted: "), pl.col("x").round(3), - pl.lit("
Observed: "), pl.col("y").round(3), - ]).alias("text") + pl.concat_str( + [ + pl.lit("Predicted: "), + pl.col("x").round(3), + pl.lit("
Observed: "), + pl.col("y").round(3), + ] + ).alias("text") ) return deciles_dat, smooth_dat @@ -833,17 +895,21 @@ def _build_initial_df_for_times( raise ValueError("Keys in reals and times dictionaries do not match.") for key in reals: if len(reals[key]) != len(times[key]): - raise ValueError(f"Length mismatch for population '{key}' in reals and times.") + raise ValueError( + f"Length mismatch for population '{key}' in reals and times." + ) # Create a base DataFrame with population data population_frames = [] for key in reals: population_frames.append( - pl.DataFrame({ - "reference_group": key, - "real": reals[key], - "time": times[key], - }) + pl.DataFrame( + { + "reference_group": key, + "real": reals[key], + "time": times[key], + } + ) ) base_df = pl.concat(population_frames) @@ -852,10 +918,11 @@ def _build_initial_df_for_times( if len(probs) == 1: model_name, prob_array = next(iter(probs.items())) if len(prob_array) != base_df.height: - raise ValueError(f"Length of probabilities for model '{model_name}' does not match total number of observations.") + raise ValueError( + f"Length of probabilities for model '{model_name}' does not match total number of observations." + ) return base_df.with_columns( - pl.Series("prob", prob_array), - pl.lit(model_name).alias("model") + pl.Series("prob", prob_array), pl.lit(model_name).alias("model") ) # Multiple models @@ -866,11 +933,12 @@ def _build_initial_df_for_times( for model_name, prob_array in probs.items(): pop_df = base_df.filter(pl.col("reference_group") == model_name) if len(prob_array) != pop_df.height: - raise ValueError(f"Length of probabilities for model '{model_name}' does not match population size.") + raise ValueError( + f"Length of probabilities for model '{model_name}' does not match population size." + ) prob_frames.append( pop_df.with_columns( - pl.Series("prob", prob_array), - pl.lit(model_name).alias("model") + pl.Series("prob", prob_array), pl.lit(model_name).alias("model") ) ) return pl.concat(prob_frames) @@ -879,11 +947,15 @@ def _build_initial_df_for_times( final_frames = [] for model_name, prob_array in probs.items(): if len(prob_array) != base_df.height: - raise ValueError(f"Length of probabilities for model '{model_name}' does not match population size.") + raise ValueError( + f"Length of probabilities for model '{model_name}' does not match population size." + ) final_frames.append( base_df.with_columns( - pl.Series("prob", prob_array), - pl.lit(model_name).alias("reference_group") # Overwrite reference_group with model name + pl.Series("prob", prob_array), + pl.lit(model_name).alias( + "reference_group" + ), # Overwrite reference_group with model name ) ) return pl.concat(final_frames) @@ -902,7 +974,10 @@ def _apply_heuristics_and_censoring( """ # Administrative censoring: outcomes after horizon are negative df_adj = df.with_columns( - pl.when(pl.col("time") > horizon).then(0).otherwise(pl.col("real")).alias("real") + pl.when(pl.col("time") > horizon) + .then(0) + .otherwise(pl.col("real")) + .alias("real") ) # Heuristics for events before or at horizon @@ -977,7 +1052,10 @@ def _create_calibration_curve_list_times( censoring_heuristic = heuristics["censoring_heuristic"] competing_heuristic = heuristics["competing_heuristic"] - if censoring_heuristic == "adjusted" or competing_heuristic == "adjusted_as_censored": + if ( + censoring_heuristic == "adjusted" + or competing_heuristic == "adjusted_as_censored" + ): continue df_adj = _apply_heuristics_and_censoring( @@ -998,25 +1076,33 @@ def _create_calibration_curve_list_times( } # If single population initially, reals_adj should be an array if not isinstance(reals, dict) and len(probs) == 1: - reals_adj = next(iter(reals_adj.values())) - + reals_adj = next(iter(reals_adj.values())) # Deciles deciles_data = _make_deciles_dat_binary(probs_adj, reals_adj) - all_deciles.append(deciles_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))) + all_deciles.append( + deciles_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon")) + ) # Smooth curve - smooth_data = _calculate_smooth_curve(probs_adj, reals_adj, performance_type) - all_smooth.append(smooth_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))) + smooth_data = _calculate_smooth_curve( + probs_adj, reals_adj, performance_type + ) + all_smooth.append( + smooth_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon")) + ) # Histogram hist_data = _create_histogram_for_calibration(probs_adj) - all_histograms.append(hist_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))) - + all_histograms.append( + hist_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon")) + ) # Part 3: Combine results and create final dictionary if not all_deciles: - raise ValueError("No data remaining after applying heuristics and time horizons.") + raise ValueError( + "No data remaining after applying heuristics and time horizons." + ) deciles_dat_final = pl.concat(all_deciles) smooth_dat_final = pl.concat(all_smooth) histogram_final = pl.concat(all_histograms) @@ -1026,7 +1112,6 @@ def _create_calibration_curve_list_times( deciles_dat_final, smooth_dat_final, performance_type ) - reference_data = _create_reference_data_for_calibration_curve() reference_groups = deciles_dat_final["reference_group"].unique().to_list() colors_dictionary = _create_colors_dictionary_for_calibration( @@ -1045,7 +1130,7 @@ def _create_calibration_curve_list_times( "performance_type": [performance_type], "size": [(size, size)], "fixed_time_horizons": fixed_time_horizons, - "reference_group_keys": reference_groups + "reference_group_keys": reference_groups, } return calibration_curve_list diff --git a/src/rtichoke/processing/adjustments.py b/src/rtichoke/processing/adjustments.py index 05c52ca..4bb982b 100644 --- a/src/rtichoke/processing/adjustments.py +++ b/src/rtichoke/processing/adjustments.py @@ -2,8 +2,8 @@ import polars as pl from polarstate import predict_aj_estimates from polarstate import prepare_event_table -from typing import Dict from collections.abc import Sequence +from rtichoke.processing.transforms import assign_and_explode_polars def create_aj_data( diff --git a/src/rtichoke/processing/transforms.py b/src/rtichoke/processing/transforms.py index f2acb92..4e4339a 100644 --- a/src/rtichoke/processing/transforms.py +++ b/src/rtichoke/processing/transforms.py @@ -1,7 +1,7 @@ import numpy as np import polars as pl from typing import Dict, Union -from collections.abc import Sequence +from rtichoke.processing.combinations import create_breaks_values def add_cutoff_strata(data: pl.DataFrame, by: float, stratified_by) -> pl.DataFrame: @@ -697,4 +697,4 @@ def _turn_cumulative_aj_to_performance_data( .alias("ppcr"), ) - return performance_data \ No newline at end of file + return performance_data diff --git a/tests/test_calibration.py b/tests/test_calibration.py index 3729c4d..4e79687 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -1,8 +1,7 @@ - import numpy as np -import polars as pl from rtichoke.calibration.calibration import create_calibration_curve + def test_create_calibration_curve_smooth(): probs = {"model_1": np.linspace(0, 1, 100)} reals = np.random.randint(0, 2, 100) diff --git a/tests/test_calibration_times.py b/tests/test_calibration_times.py index 6820382..b6570c9 100644 --- a/tests/test_calibration_times.py +++ b/tests/test_calibration_times.py @@ -1,14 +1,15 @@ -import pytest import numpy as np -import polars as pl from rtichoke.calibration import create_calibration_curve_times + def test_create_calibration_curve_times(): probs = {"model_1": np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])} reals = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) times = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) fixed_time_horizons = [5, 10] - heuristics_sets = [{"censoring_heuristic": "excluded", "competing_heuristic": "excluded"}] + heuristics_sets = [ + {"censoring_heuristic": "excluded", "competing_heuristic": "excluded"} + ] fig = create_calibration_curve_times( probs, diff --git a/tests/test_heuristics.py b/tests/test_heuristics.py index ad78d24..6730c1c 100644 --- a/tests/test_heuristics.py +++ b/tests/test_heuristics.py @@ -3,64 +3,95 @@ from polars.testing import assert_frame_equal from rtichoke.calibration.calibration import _apply_heuristics_and_censoring + @pytest.fixture def sample_data(): - return pl.DataFrame({ - "real": [1, 0, 2, 1, 2, 0, 1], - "time": [1, 2, 3, 8, 9, 10, 12], - }) + return pl.DataFrame( + { + "real": [1, 0, 2, 1, 2, 0, 1], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + def test_competing_as_negative_logic(sample_data): # Heuristics that shouldn't change data before horizon - result = _apply_heuristics_and_censoring(sample_data, 15, "adjusted", "adjusted_as_negative") + result = _apply_heuristics_and_censoring( + sample_data, 15, "adjusted", "adjusted_as_negative" + ) # Competing events at times 3 and 9 should become 0. - expected = pl.DataFrame({ - "real": [1, 0, 0, 1, 0, 0, 1], - "time": [1, 2, 3, 8, 9, 10, 12], - }) + expected = pl.DataFrame( + { + "real": [1, 0, 0, 1, 0, 0, 1], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) assert_frame_equal(result, expected) + def test_admin_censoring(sample_data): - result = _apply_heuristics_and_censoring(sample_data, 7, "adjusted", "adjusted_as_negative") + result = _apply_heuristics_and_censoring( + sample_data, 7, "adjusted", "adjusted_as_negative" + ) # Admin censoring for times > 7. Competing event at time=3 becomes 0. - expected = pl.DataFrame({ - "real": [1, 0, 0, 0, 0, 0, 0], - "time": [1, 2, 3, 8, 9, 10, 12], - }) + expected = pl.DataFrame( + { + "real": [1, 0, 0, 0, 0, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) assert_frame_equal(result, expected) + def test_censoring_excluded(sample_data): - result = _apply_heuristics_and_censoring(sample_data, 10, "excluded", "adjusted_as_negative") + result = _apply_heuristics_and_censoring( + sample_data, 10, "excluded", "adjusted_as_negative" + ) # Excludes censored at times 2, 10. Admin censors time > 10. Competing at 3,9 -> 0. - expected = pl.DataFrame({ - "real": [1, 0, 1, 0, 0], - "time": [1, 3, 8, 9, 12], - }) + expected = pl.DataFrame( + { + "real": [1, 0, 1, 0, 0], + "time": [1, 3, 8, 9, 12], + } + ) assert_frame_equal(result.sort("time"), expected.sort("time")) + def test_competing_excluded(sample_data): result = _apply_heuristics_and_censoring(sample_data, 10, "adjusted", "excluded") # Excludes competing at 3, 9. Admin censors time > 10. - expected = pl.DataFrame({ - "real": [1, 0, 1, 0, 0], - "time": [1, 2, 8, 10, 12], - }) + expected = pl.DataFrame( + { + "real": [1, 0, 1, 0, 0], + "time": [1, 2, 8, 10, 12], + } + ) assert_frame_equal(result.sort("time"), expected.sort("time")) + def test_competing_as_negative(sample_data): - result = _apply_heuristics_and_censoring(sample_data, 10, "adjusted", "adjusted_as_negative") + result = _apply_heuristics_and_censoring( + sample_data, 10, "adjusted", "adjusted_as_negative" + ) # Competing at 3,9 -> 0. Admin censors time > 10. - expected = pl.DataFrame({ - "real": [1, 0, 0, 1, 0, 0, 0], - "time": [1, 2, 3, 8, 9, 10, 12], - }) + expected = pl.DataFrame( + { + "real": [1, 0, 0, 1, 0, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) assert_frame_equal(result, expected) + def test_competing_as_composite(sample_data): - result = _apply_heuristics_and_censoring(sample_data, 10, "adjusted", "adjusted_as_composite") + result = _apply_heuristics_and_censoring( + sample_data, 10, "adjusted", "adjusted_as_composite" + ) # Competing at 3,9 -> 1. Admin censors time > 10. - expected = pl.DataFrame({ - "real": [1, 0, 1, 1, 1, 0, 0], - "time": [1, 2, 3, 8, 9, 10, 12], - }) + expected = pl.DataFrame( + { + "real": [1, 0, 1, 1, 1, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) assert_frame_equal(result, expected) diff --git a/uv.lock b/uv.lock index 1a9f3e2..a23d1da 100644 --- a/uv.lock +++ b/uv.lock @@ -3888,7 +3888,7 @@ wheels = [ [[package]] name = "rtichoke" -version = "0.1.25" +version = "0.1.26" source = { editable = "." } dependencies = [ { name = "marimo", version = "0.17.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },