diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..04f3079 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,77 @@ +# rtichoke Agent Information + +This document provides guidance for AI agents working on the `rtichoke` repository. + +## Development Environment + +To set up the development environment, follow these steps: + +1. **Install `uv`**: If you don't have `uv` installed, please follow the official installation instructions. +2. **Create a virtual environment**: Use `uv venv` to create a virtual environment. +3. **Install dependencies**: Install the project dependencies, including the `dev` dependencies, with the following command: + + ```bash + uv pip install -e .[dev] + ``` + +## Running Tests + +The test suite is run using `pytest`. To run the tests, use the following command: + +```bash +uv run pytest +``` + +## Coding Conventions + +### Functional Programming + +Strive to use a functional programming style as much as possible. Avoid side effects and mutable state where practical. + +### Docstrings + +All exported functions must have NumPy-style docstrings. This is to ensure that the documentation is clear, consistent, and can be easily parsed by tools like `quartodoc`. + +Example of a NumPy-style docstring: + +```python +def my_function(param1, param2): + """Summary of the function's purpose. + + Parameters + ---------- + param1 : int + Description of the first parameter. + param2 : str + Description of the second parameter. + + Returns + ------- + bool + Description of the return value. + """ + # function body + return True +``` + +## Pre-commit Hooks + +This repository uses pre-commit hooks to ensure code quality and consistency. The following hooks are configured: + +* **`ruff-check`**: A linter to check for common errors and style issues. +* **`ruff-format`**: A code formatter to ensure a consistent code style. +* **`uv-lock`**: A hook to keep the `uv.lock` file up to date. + +Before committing, please ensure that the pre-commit hooks pass. You can run them manually on all files with `pre-commit run --all-files`. + +## Documentation + +The documentation for this project is built using `quartodoc`. The documentation is automatically built and deployed via GitHub Actions. There is no need to build the documentation manually. + +## Type Checking + +This project uses `ty` for type checking. To check for type errors, run the following command: + +```bash +uv run ty check src tests +``` diff --git a/pyproject.toml b/pyproject.toml index 7ff976a..2388c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,10 @@ dependencies = [ "polarstate==0.1.8", "marimo>=0.17.0", "pyarrow>=21.0.0", + "statsmodels>=0.14.0", ] name = "rtichoke" -version = "0.1.25" +version = "0.1.26" description = "interactive visualizations for performance of predictive models" readme = "README.md" diff --git a/src/rtichoke/__init__.py b/src/rtichoke/__init__.py index cc5b47a..bc84b74 100644 --- a/src/rtichoke/__init__.py +++ b/src/rtichoke/__init__.py @@ -30,9 +30,10 @@ ) from rtichoke.discrimination.gains import plot_gains_curve as plot_gains_curve -# from rtichoke.calibration.calibration import ( -# create_calibration_curve as create_calibration_curve, -# ) +from rtichoke.calibration.calibration import ( + create_calibration_curve as create_calibration_curve, + create_calibration_curve_times as create_calibration_curve_times, +) from rtichoke.utility.decision import ( create_decision_curve as create_decision_curve, diff --git a/src/rtichoke/calibration/__init__.py b/src/rtichoke/calibration/__init__.py index 4267999..190e74e 100644 --- a/src/rtichoke/calibration/__init__.py +++ b/src/rtichoke/calibration/__init__.py @@ -1,3 +1,7 @@ """ Subpackage for Calibration """ + +from .calibration import create_calibration_curve, create_calibration_curve_times + +__all__ = ["create_calibration_curve", "create_calibration_curve_times"] diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index f9d32f3..4e245a9 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -2,21 +2,24 @@ A module for Calibration Curves """ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Union # import pandas as pd import plotly.graph_objects as go from plotly.subplots import make_subplots from plotly.graph_objs._figure import Figure +import polars as pl +import numpy as np + # from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r def create_calibration_curve( - probs: Dict[str, List[float]], - reals: Dict[str, List[int]], + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], calibration_type: str = "discrete", - size: Optional[int] = None, - color_values: Optional[List[str]] = [ + size: int = 600, + color_values: List[str] = [ "#1b9e77", "#d95f02", "#7570b3", @@ -38,7 +41,6 @@ def create_calibration_curve( "#D1603D", "#585123", ], - url_api: str = "http://localhost:4242/", ) -> Figure: """Creates Calibration Curve @@ -55,40 +57,213 @@ def create_calibration_curve( """ pass - # rtichoke_response = send_requests_to_rtichoke_r( - # dictionary_to_send={ - # "probs": probs, - # "reals": reals, - # "size": size, - # "color_values ": color_values, - # }, - # url_api=url_api, - # endpoint="create_calibration_curve_list", - # ) - - # calibration_curve_list = rtichoke_response.json() - - # calibration_curve_list["deciles_dat"] = pd.DataFrame.from_dict( - # calibration_curve_list["deciles_dat"] - # ) - # calibration_curve_list["smooth_dat"] = pd.DataFrame.from_dict( - # calibration_curve_list["smooth_dat"] - # ) - # calibration_curve_list["reference_data"] = pd.DataFrame.from_dict( - # calibration_curve_list["reference_data"] - # ) - # calibration_curve_list["histogram_for_calibration"] = pd.DataFrame.from_dict( - # calibration_curve_list["histogram_for_calibration"] - # ) - - # calibration_curve = create_plotly_curve_from_calibration_curve_list( - # calibration_curve_list=calibration_curve_list, calibration_type=calibration_type - # ) - - # return calibration_curve - - -def create_plotly_curve_from_calibration_curve_list( + calibration_curve_list = _create_calibration_curve_list( + probs, reals, size=size, color_values=color_values + ) + + calibration_curve = _create_plotly_curve_from_calibration_curve_list( + calibration_curve_list, calibration_type=calibration_type + ) + + return calibration_curve + + +def create_calibration_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: List[float], + heuristics_sets: List[Dict[str, str]], + calibration_type: str = "discrete", + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Creates a time-dependent Calibration Curve with a slider for different time horizons.""" + + calibration_curve_list_times = _create_calibration_curve_list_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + size=size, + color_values=color_values, + ) + + fig = _create_plotly_curve_from_calibration_curve_list_times( + calibration_curve_list_times, calibration_type=calibration_type + ) + + return fig + + +def _create_plotly_curve_from_calibration_curve_list_times( + calibration_curve_list: Dict[str, Any], calibration_type: str = "discrete" +) -> Figure: + """ + Creates a plotly figure for time-dependent calibration curves. + """ + fig = make_subplots( + rows=2, cols=1, shared_xaxes=True, x_title="Predicted", row_heights=[0.8, 0.2] + ) + + initial_horizon = calibration_curve_list["fixed_time_horizons"][0] + + # Add traces for each horizon, initially visible only for the first horizon + for horizon in calibration_curve_list["fixed_time_horizons"]: + visible = horizon == initial_horizon + + # Reference Line + fig.add_trace( + go.Scatter( + x=calibration_curve_list["reference_data"]["x"], + y=calibration_curve_list["reference_data"]["y"], + hovertext=calibration_curve_list["reference_data"]["text"], + name="Perfectly Calibrated", + legendgroup="Perfectly Calibrated", + hoverinfo="text", + line={"width": 2, "dash": "dot", "color": "#BEBEBE"}, + showlegend=False, + visible=visible, + ), + row=1, + col=1, + ) + + for group in calibration_curve_list["reference_group_keys"]: + color = calibration_curve_list["colors_dictionary"][group][0] + + # Calibration curve (discrete or smooth) + if calibration_type == "discrete": + data_subset = calibration_curve_list["deciles_dat"].filter( + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == horizon) + ) + mode = "lines+markers" + else: # smooth + data_subset = calibration_curve_list["smooth_dat"].filter( + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == horizon) + ) + mode = "lines+markers" if data_subset.height == 1 else "lines" + + fig.add_trace( + go.Scatter( + x=data_subset["x"], + y=data_subset["y"], + hovertext=data_subset["text"], + name=group, + legendgroup=group, + hoverinfo="text", + mode=mode, + marker={"size": 10, "color": color}, + visible=visible, + ), + row=1, + col=1, + ) + + # Histogram + hist_subset = calibration_curve_list["histogram_for_calibration"].filter( + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == horizon) + ) + fig.add_trace( + go.Bar( + x=hist_subset["mids"], + y=hist_subset["counts"], + hovertext=hist_subset["text"], + name=group, + width=0.01, + legendgroup=group, + hoverinfo="text", + marker_color=color, + showlegend=False, + opacity=0.4, + visible=visible, + ), + row=2, + col=1, + ) + + # Create slider + steps = [] + num_traces_per_horizon = 1 + 2 * len(calibration_curve_list["reference_group_keys"]) + + for i, horizon in enumerate(calibration_curve_list["fixed_time_horizons"]): + visibility = [False] * ( + num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"]) + ) + for j in range(num_traces_per_horizon): + visibility[i * num_traces_per_horizon + j] = True + step = dict( + method="restyle", + args=[{"visible": visibility}], + label=str(horizon), + ) + steps.append(step) + + sliders = [ + dict( + active=0, + currentvalue={"prefix": "Time Horizon: "}, + pad={"t": 50}, + steps=steps, + ) + ] + + # Layout + fig.update_layout( + sliders=sliders, + xaxis={ + "showgrid": False, + "range": calibration_curve_list["axes_ranges"]["xaxis"], + }, + yaxis={ + "showgrid": False, + "range": calibration_curve_list["axes_ranges"]["yaxis"], + "title": "Observed", + }, + barmode="overlay", + plot_bgcolor="rgba(0, 0, 0, 0)", + legend={ + "orientation": "h", + "xanchor": "center", + "yanchor": "top", + "x": 0.5, + "y": 1.3, + "bgcolor": "rgba(0, 0, 0, 0)", + }, + showlegend=calibration_curve_list["performance_type"][0] != "one model", + width=calibration_curve_list["size"][0][0], + height=calibration_curve_list["size"][0][0], + ) + + return fig + + +def _create_plotly_curve_from_calibration_curve_list( calibration_curve_list: Dict[str, Any], calibration_type: str = "discrete" ) -> Figure: """Create plotly curve from calibration curve list @@ -124,16 +299,16 @@ def create_plotly_curve_from_calibration_curve_list( calibration_curve.add_trace( go.Scatter( - x=calibration_curve_list["reference_data"]["x"].values.tolist(), - y=calibration_curve_list["reference_data"]["y"].values.tolist(), - hovertext=calibration_curve_list["reference_data"]["text"].values.tolist(), + x=calibration_curve_list["reference_data"]["x"], + y=calibration_curve_list["reference_data"]["y"], + hovertext=calibration_curve_list["reference_data"]["text"], name="Perfectly Calibrated", legendgroup="Perfectly Calibrated", hoverinfo="text", line={ "width": 2, "dash": "dot", - "color": calibration_curve_list["group_colors_vec"]["reference_line"][ + "color": calibration_curve_list["colors_dictionary"]["reference_line"][ 0 ], }, @@ -144,111 +319,120 @@ def create_plotly_curve_from_calibration_curve_list( ) if calibration_type == "discrete": - for reference_group in list(calibration_curve_list["group_colors_vec"].keys()): - if any( - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ): - calibration_curve.add_trace( - go.Scatter( - x=calibration_curve_list["deciles_dat"]["x"][ - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ].values.tolist(), - y=calibration_curve_list["deciles_dat"]["y"][ - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ].values.tolist(), - hovertext=calibration_curve_list["deciles_dat"]["text"][ - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ].values.tolist(), - name=reference_group, - legendgroup=reference_group, - hoverinfo="text", - mode="lines+markers", - marker={ - "size": 10, - "color": calibration_curve_list["group_colors_vec"][ - reference_group - ][0], - }, - ), - row=1, - col=1, - ) + reference_groups = [ + k + for k in calibration_curve_list["colors_dictionary"].keys() + if k != "reference_line" + ] + for reference_group in reference_groups: + dec_sub = calibration_curve_list["deciles_dat"].filter( + pl.col("reference_group") == reference_group + ) + + print(dec_sub) + + calibration_curve.add_trace( + go.Scatter( + x=dec_sub.get_column("x").to_list(), + y=dec_sub.get_column("y").to_list(), + hovertext=dec_sub.get_column("text").to_list(), + name=reference_group, + legendgroup=reference_group, + hoverinfo="text", + mode="lines+markers", + marker={ + "size": 10, + "color": calibration_curve_list["colors_dictionary"][ + reference_group + ][0], + }, + ), + row=1, + col=1, + ) + + hist = calibration_curve_list["histogram_for_calibration"] + + for reference_group in reference_groups: + hist_sub = hist.filter(pl.col("reference_group") == reference_group) + if hist_sub.height == 0: + continue + + calibration_curve.add_trace( + go.Bar( + x=hist_sub.get_column("mids").to_list(), + y=hist_sub.get_column("counts").to_list(), + hovertext=hist_sub.get_column("text").to_list(), + name=reference_group, + width=0.01, + legendgroup=reference_group, + hoverinfo="text", + marker_color=calibration_curve_list["colors_dictionary"][ + reference_group + ][0], + showlegend=False, + opacity=0.4, + ), + row=2, + col=1, + ) if calibration_type == "smooth": - for reference_group in list(calibration_curve_list["group_colors_vec"].keys()): - if any( - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ): - calibration_curve.add_trace( - go.Scatter( - x=calibration_curve_list["smooth_dat"]["x"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ].values.tolist(), - y=calibration_curve_list["smooth_dat"]["y"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ].values.tolist(), - hovertext=calibration_curve_list["smooth_dat"]["text"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ].values.tolist(), - name=reference_group, - legendgroup=reference_group, - hoverinfo="text", - mode="lines", - marker={ - "size": 10, - "color": calibration_curve_list["group_colors_vec"][ - reference_group - ][0], - }, - ), - row=1, - col=1, - ) + smooth_dat = calibration_curve_list["smooth_dat"] + reference_groups = [ + k + for k in calibration_curve_list["colors_dictionary"].keys() + if k != "reference_line" + ] + + for reference_group in reference_groups: + smooth_sub = smooth_dat.filter(pl.col("reference_group") == reference_group) + if smooth_sub.height == 0: + continue + + mode = "lines+markers" if smooth_sub.height == 1 else "lines" + + calibration_curve.add_trace( + go.Scatter( + x=smooth_sub.get_column("x").to_list(), + y=smooth_sub.get_column("y").to_list(), + hovertext=smooth_sub.get_column("text").to_list(), + name=reference_group, + legendgroup=reference_group, + hoverinfo="text", + mode=mode, + marker={ + "size": 10, + "color": calibration_curve_list["colors_dictionary"][ + reference_group + ][0], + }, + ), + row=1, + col=1, + ) + + hist = calibration_curve_list["histogram_for_calibration"] + + for reference_group in reference_groups: + hist_sub = hist.filter(pl.col("reference_group") == reference_group) + if hist_sub.height == 0: + continue - for reference_group in list(calibration_curve_list["group_colors_vec"].keys()): - if any( - calibration_curve_list["histogram_for_calibration"]["reference_group"] - == reference_group - ): calibration_curve.add_trace( go.Bar( - x=calibration_curve_list["histogram_for_calibration"]["mids"][ - calibration_curve_list["histogram_for_calibration"][ - "reference_group" - ] - == reference_group - ].values.tolist(), - y=calibration_curve_list["histogram_for_calibration"]["counts"][ - calibration_curve_list["histogram_for_calibration"][ - "reference_group" - ] - == reference_group - ].values.tolist(), - hovertext=calibration_curve_list["histogram_for_calibration"][ - "text" - ][ - calibration_curve_list["histogram_for_calibration"][ - "reference_group" - ] - == reference_group - ].values.tolist(), + x=hist_sub.get_column("mids").to_list(), + y=hist_sub.get_column("counts").to_list(), + hovertext=hist_sub.get_column("text").to_list(), name=reference_group, width=0.01, legendgroup=reference_group, hoverinfo="text", - marker_color=calibration_curve_list["group_colors_vec"][ + marker_color=calibration_curve_list["colors_dictionary"][ reference_group ][0], showlegend=False, - opacity=calibration_curve_list["histogram_opacity"][0], + opacity=0.4, ), row=2, col=1, @@ -284,3 +468,669 @@ def create_plotly_curve_from_calibration_curve_list( ) return calibration_curve + + +def _make_deciles_dat_binary( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + n_bins: int = 10, +) -> pl.DataFrame: + if isinstance(reals, dict): + reference_groups_keys = list(reals.keys()) + y_list = [ + np.asarray(reals[reference_group]).ravel() + for reference_group in reference_groups_keys + ] + lengths = np.array([len(y) for y in y_list], dtype=np.int64) + offsets = np.concatenate([np.array([0], dtype=np.int64), np.cumsum(lengths)]) + n_total = int(offsets[-1]) + + frames: list[pl.DataFrame] = [] + for model, p_all in probs.items(): + p_all = np.asarray(p_all).ravel() + if p_all.shape[0] != n_total: + raise ValueError( + f"probs['{model}'] length={p_all.shape[0]} does not match " + f"sum of population sizes={n_total}." + ) + + for i, pop in enumerate(reference_groups_keys): + start = int(offsets[i]) + end = int(offsets[i + 1]) + + frames.append( + pl.DataFrame( + { + "reference_group": pop, + "model": model, + "prob": p_all[start:end].astype(float, copy=False), + "real": y_list[i].astype(float, copy=False), + } + ) + ) + + df = pl.concat(frames, how="vertical") + + else: + y = np.asarray(reals).ravel() + n = y.shape[0] + frames = [] + for model, p in probs.items(): + p = np.asarray(p).ravel() + if p.shape[0] != n: + raise ValueError( + f"probs['{model}'] length={p.shape[0]} does not match reals length={n}." + ) + frames.append( + pl.DataFrame( + { + "reference_group": model, + "model": model, + "prob": p.astype(float, copy=False), + "real": y.astype(float, copy=False), + } + ) + ) + + df = pl.concat(frames, how="vertical") + + labels = [str(i) for i in range(1, n_bins + 1)] + + df = df.with_columns( + [ + pl.col("prob").cast(pl.Float64), + pl.col("real").cast(pl.Float64), + pl.col("prob") + .qcut(n_bins, labels=labels, allow_duplicates=True) + .over(["reference_group", "model"]) + .alias("decile"), + ] + ).with_columns(pl.col("decile").cast(pl.Int32)) + + deciles_data = ( + df.group_by(["reference_group", "model", "decile"]) + .agg( + [ + pl.len().alias("n"), + pl.mean("prob").alias("x"), + pl.mean("real").alias("y"), + pl.sum("real").alias("n_reals"), + ] + ) + .sort(["reference_group", "model", "decile"]) + ) + + return deciles_data + + +def _check_performance_type_by_probs_and_reals( + probs: Dict[str, np.ndarray], reals: Union[np.ndarray, Dict[str, np.ndarray]] +) -> str: + if isinstance(reals, dict) and len(reals) > 1: + return "multiple populations" + if len(probs) > 1: + return "multiple models" + return "one model" + + +def _create_calibration_curve_list( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Dict[str, Any]: + deciles_data = _make_deciles_dat_binary(probs, reals) + performance_type = _check_performance_type_by_probs_and_reals(probs, reals) + smooth_dat = _calculate_smooth_curve(probs, reals, performance_type) + + deciles_data, smooth_dat = _add_hover_text_to_calibration_data( + deciles_data, smooth_dat, performance_type + ) + + reference_data = _create_reference_data_for_calibration_curve() + + reference_groups = deciles_data["reference_group"].unique().to_list() + + colors_dictionary = _create_colors_dictionary_for_calibration( + reference_groups, color_values, performance_type + ) + + print("histogram for calibration") + + histogram_for_calibration = _create_histogram_for_calibration(probs) + + print(histogram_for_calibration) + + limits = _define_limits_for_calibration_plot(deciles_data) + axes_ranges = {"xaxis": limits, "yaxis": limits} + + smooth_dat = _calculate_smooth_curve(probs, reals, performance_type) + + calibration_curve_list = { + "deciles_dat": deciles_data, + "smooth_dat": smooth_dat, + "reference_data": reference_data, + "histogram_for_calibration": histogram_for_calibration, + # "histogram_opacity": [0.4], + "axes_ranges": axes_ranges, + "colors_dictionary": colors_dictionary, + "performance_type": [performance_type], + "size": [(size, size)], + } + + return calibration_curve_list + + +def _create_reference_data_for_calibration_curve() -> pl.DataFrame: + x_ref = np.linspace(0, 1, 101) + reference_data = pl.DataFrame({"x": x_ref, "y": x_ref}) + reference_data = reference_data.with_columns( + pl.concat_str( + [ + pl.lit("Perfectly Calibrated
Predicted: "), + pl.col("x").map_elements(lambda x: f"{x:.3f}", return_dtype=pl.Utf8), + pl.lit("
Observed: "), + pl.col("y").map_elements(lambda y: f"{y:.3f}", return_dtype=pl.Utf8), + ] + ).alias("text") + ) + return reference_data + + +def _calculate_smooth_curve( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + performance_type: str, +) -> pl.DataFrame: + """ + Calculate the smoothed calibration curve using lowess. + """ + from statsmodels.nonparametric.smoothers_lowess import lowess + + smooth_frames = [] + + # Helper function to process a single probability and real array + def process_single_array(p, r, group_name): + if len(np.unique(p)) == 1: + return pl.DataFrame( + { + "x": [np.unique(p)[0]], + "y": [np.mean(r)], + "reference_group": [group_name], + } + ) + else: + # lowess returns a 2D array where the first column is x and the second is y + smoothed = lowess(r, p, it=0) + xout = np.linspace(0, 1, 101) + yout = np.interp(xout, smoothed[:, 0], smoothed[:, 1]) + return pl.DataFrame( + {"x": xout, "y": yout, "reference_group": [group_name] * len(xout)} + ) + + if isinstance(reals, dict): + for model_name, prob_array in probs.items(): + # This logic assumes that for multiple populations, one model's probs are evaluated against multiple real outcomes. + # This might need adjustment based on the exact structure for multiple models and populations. + if len(probs) == 1 and len(reals) > 1: # One model, multiple populations + for pop_name, real_array in reals.items(): + frame = process_single_array(prob_array, real_array, pop_name) + smooth_frames.append(frame) + else: # Multiple models, potentially multiple populations + for group_name in reals.keys(): + if group_name in probs: + frame = process_single_array( + probs[group_name], reals[group_name], group_name + ) + smooth_frames.append(frame) + + else: # reals is a single numpy array + for group_name, prob_array in probs.items(): + frame = process_single_array(prob_array, reals, group_name) + smooth_frames.append(frame) + + if not smooth_frames: + return pl.DataFrame( + schema={ + "x": pl.Float64, + "y": pl.Float64, + "reference_group": pl.Utf8, + "text": pl.Utf8, + } + ) + + smooth_dat = pl.concat(smooth_frames) + + if performance_type != "one model": + smooth_dat = smooth_dat.with_columns( + pl.concat_str( + [ + pl.lit(""), + pl.col("reference_group"), + pl.lit("
Predicted: "), + pl.col("x").map_elements( + lambda x: f"{x:.3f}", return_dtype=pl.Utf8 + ), + pl.lit("
Observed: "), + pl.col("y").map_elements( + lambda y: f"{y:.3f}", return_dtype=pl.Utf8 + ), + ] + ).alias("text") + ) + else: + smooth_dat = smooth_dat.with_columns( + pl.concat_str( + [ + pl.lit("Predicted: "), + pl.col("x").map_elements( + lambda x: f"{x:.3f}", return_dtype=pl.Utf8 + ), + pl.lit("
Observed: "), + pl.col("y").map_elements( + lambda y: f"{y:.3f}", return_dtype=pl.Utf8 + ), + ] + ).alias("text") + ) + return smooth_dat + + +def _add_hover_text_to_calibration_data( + deciles_dat: pl.DataFrame, + smooth_dat: pl.DataFrame, + performance_type: str, +) -> tuple[pl.DataFrame, pl.DataFrame]: + """Adds hover text to the deciles and smooth dataframes.""" + if performance_type != "one model": + deciles_dat = deciles_dat.with_columns( + pl.concat_str( + [ + pl.lit(""), + pl.col("reference_group"), + pl.lit("
Predicted: "), + pl.col("x").round(3), + pl.lit("
Observed: "), + pl.col("y").round(3), + pl.lit(" ( "), + pl.col("n_reals"), + pl.lit(" / "), + pl.col("n"), + pl.lit(" )"), + ] + ).alias("text") + ) + smooth_dat = smooth_dat.with_columns( + pl.concat_str( + [ + pl.lit(""), + pl.col("reference_group"), + pl.lit("
Predicted: "), + pl.col("x").round(3), + pl.lit("
Observed: "), + pl.col("y").round(3), + ] + ).alias("text") + ) + else: + deciles_dat = deciles_dat.with_columns( + pl.concat_str( + [ + pl.lit("Predicted: "), + pl.col("x").round(3), + pl.lit("
Observed: "), + pl.col("y").round(3), + pl.lit(" ( "), + pl.col("n_reals"), + pl.lit(" / "), + pl.col("n"), + pl.lit(" )"), + ] + ).alias("text") + ) + smooth_dat = smooth_dat.with_columns( + pl.concat_str( + [ + pl.lit("Predicted: "), + pl.col("x").round(3), + pl.lit("
Observed: "), + pl.col("y").round(3), + ] + ).alias("text") + ) + return deciles_dat, smooth_dat + + +def _create_colors_dictionary_for_calibration( + reference_groups: List[str], + color_values: List[str], + performance_type: str = "one model", +) -> Dict[str, List[str]]: + if performance_type == "one model": + colors = ["black"] + else: + colors = color_values[: len(reference_groups)] + + return { + "reference_line": ["#BEBEBE"], + **{ + group: [colors[i % len(colors)]] for i, group in enumerate(reference_groups) + }, + } + + +def _create_histogram_for_calibration(probs: Dict[str, np.ndarray]) -> pl.DataFrame: + hist_dfs = [] + for group, prob_values in probs.items(): + counts, mids = np.histogram(prob_values, bins=np.arange(0, 1.01, 0.01)) + hist_df = pl.DataFrame( + {"mids": mids[:-1] + 0.005, "counts": counts, "reference_group": group} + ) + hist_df = hist_df.with_columns( + ( + pl.col("counts").cast(str) + + " observations in [" + + (pl.col("mids") - 0.005).round(3).cast(str) + + ", " + + (pl.col("mids") + 0.005).round(3).cast(str) + + "]" + ).alias("text") + ) + hist_dfs.append(hist_df) + + histogram_for_calibration = pl.concat(hist_dfs) + + return histogram_for_calibration + + +def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float]: + if deciles_dat.height == 1: + lower_bound, upper_bound = 0.0, 1.0 + else: + lower_bound = float(max(0, min(deciles_dat["x"].min(), deciles_dat["y"].min()))) + upper_bound = float(max(deciles_dat["x"].max(), deciles_dat["y"].max())) + + return [ + lower_bound - (upper_bound - lower_bound) * 0.05, + upper_bound + (upper_bound - lower_bound) * 0.05, + ] + + +def _build_initial_df_for_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], +) -> pl.DataFrame: + """Builds the initial DataFrame for time-dependent calibration curves.""" + + # Convert all inputs to dictionaries of arrays to unify processing + if not isinstance(reals, dict): + reals = {"single_population": np.asarray(reals)} + if not isinstance(times, dict): + times = {"single_population": np.asarray(times)} + + # Verify matching keys and lengths + if reals.keys() != times.keys(): + raise ValueError("Keys in reals and times dictionaries do not match.") + for key in reals: + if len(reals[key]) != len(times[key]): + raise ValueError( + f"Length mismatch for population '{key}' in reals and times." + ) + + # Create a base DataFrame with population data + population_frames = [] + for key in reals: + population_frames.append( + pl.DataFrame( + { + "reference_group": key, + "real": reals[key], + "time": times[key], + } + ) + ) + base_df = pl.concat(population_frames) + + # Prepare model predictions + # Single model case + if len(probs) == 1: + model_name, prob_array = next(iter(probs.items())) + if len(prob_array) != base_df.height: + raise ValueError( + f"Length of probabilities for model '{model_name}' does not match total number of observations." + ) + return base_df.with_columns( + pl.Series("prob", prob_array), pl.lit(model_name).alias("model") + ) + + # Multiple models + else: + # One model per population (keys must match) + if probs.keys() == reals.keys(): + prob_frames = [] + for model_name, prob_array in probs.items(): + pop_df = base_df.filter(pl.col("reference_group") == model_name) + if len(prob_array) != pop_df.height: + raise ValueError( + f"Length of probabilities for model '{model_name}' does not match population size." + ) + prob_frames.append( + pop_df.with_columns( + pl.Series("prob", prob_array), pl.lit(model_name).alias("model") + ) + ) + return pl.concat(prob_frames) + # Multiple models on a single population + elif len(reals) == 1: + final_frames = [] + for model_name, prob_array in probs.items(): + if len(prob_array) != base_df.height: + raise ValueError( + f"Length of probabilities for model '{model_name}' does not match population size." + ) + final_frames.append( + base_df.with_columns( + pl.Series("prob", prob_array), + pl.lit(model_name).alias( + "reference_group" + ), # Overwrite reference_group with model name + ) + ) + return pl.concat(final_frames) + + raise ValueError("Unsupported combination of probs, reals, and times structures.") + + +def _apply_heuristics_and_censoring( + df: pl.DataFrame, + horizon: float, + censoring_heuristic: str, + competing_heuristic: str, +) -> pl.DataFrame: + """ + Applies censoring and competing risk heuristics to the data for a given time horizon. + """ + # Administrative censoring: outcomes after horizon are negative + df_adj = df.with_columns( + pl.when(pl.col("time") > horizon) + .then(0) + .otherwise(pl.col("real")) + .alias("real") + ) + + # Heuristics for events before or at horizon + if censoring_heuristic == "excluded": + df_adj = df_adj.filter(~((pl.col("real") == 0) & (pl.col("time") <= horizon))) + + if competing_heuristic == "excluded": + df_adj = df_adj.filter(~((pl.col("real") == 2) & (pl.col("time") <= horizon))) + elif competing_heuristic == "adjusted_as_negative": + df_adj = df_adj.with_columns( + pl.when((pl.col("real") == 2) & (pl.col("time") <= horizon)) + .then(0) + .otherwise(pl.col("real")) + .alias("real") + ) + elif competing_heuristic == "adjusted_as_composite": + df_adj = df_adj.with_columns( + pl.when((pl.col("real") == 2) & (pl.col("time") <= horizon)) + .then(1) + .otherwise(pl.col("real")) + .alias("real") + ) + + return df_adj + + +def _create_calibration_curve_list_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: List[float], + heuristics_sets: List[Dict[str, str]], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Dict[str, Any]: + """ + Creates the data structures needed for a time-dependent calibration curve plot. + """ + # Part 1: Prepare initial dataframe from inputs + initial_df = _build_initial_df_for_times(probs, reals, times) + + # Part 2: Iterate and generate calibration data for each horizon/heuristic + all_deciles = [] + all_smooth = [] + all_histograms = [] + + performance_type = _check_performance_type_by_probs_and_reals(probs, reals) + + for horizon in fixed_time_horizons: + for heuristics in heuristics_sets: + censoring_heuristic = heuristics["censoring_heuristic"] + competing_heuristic = heuristics["competing_heuristic"] + + if ( + censoring_heuristic == "adjusted" + or competing_heuristic == "adjusted_as_censored" + ): + continue + + df_adj = _apply_heuristics_and_censoring( + initial_df, horizon, censoring_heuristic, competing_heuristic + ) + + if df_adj.height == 0: + continue + + # Re-create probs and reals dicts for helpers + probs_adj = { + group[0]: group_df["prob"].to_numpy() + for group, group_df in df_adj.group_by("reference_group") + } + reals_adj = { + group[0]: group_df["real"].to_numpy() + for group, group_df in df_adj.group_by("reference_group") + } + # If single population initially, reals_adj should be an array + if not isinstance(reals, dict) and len(probs) == 1: + reals_adj = next(iter(reals_adj.values())) + + # Deciles + deciles_data = _make_deciles_dat_binary(probs_adj, reals_adj) + all_deciles.append( + deciles_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon")) + ) + + # Smooth curve + smooth_data = _calculate_smooth_curve( + probs_adj, reals_adj, performance_type + ) + all_smooth.append( + smooth_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon")) + ) + + # Histogram + hist_data = _create_histogram_for_calibration(probs_adj) + all_histograms.append( + hist_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon")) + ) + + # Part 3: Combine results and create final dictionary + if not all_deciles: + raise ValueError( + "No data remaining after applying heuristics and time horizons." + ) + deciles_dat_final = pl.concat(all_deciles) + smooth_dat_final = pl.concat(all_smooth) + histogram_final = pl.concat(all_histograms) + + # Add hover text + deciles_dat_final, smooth_dat_final = _add_hover_text_to_calibration_data( + deciles_dat_final, smooth_dat_final, performance_type + ) + + reference_data = _create_reference_data_for_calibration_curve() + reference_groups = deciles_dat_final["reference_group"].unique().to_list() + colors_dictionary = _create_colors_dictionary_for_calibration( + reference_groups, color_values, performance_type + ) + limits = _define_limits_for_calibration_plot(deciles_dat_final) + axes_ranges = {"xaxis": limits, "yaxis": limits} + + calibration_curve_list = { + "deciles_dat": deciles_dat_final, + "smooth_dat": smooth_dat_final, + "reference_data": reference_data, + "histogram_for_calibration": histogram_final, + "axes_ranges": axes_ranges, + "colors_dictionary": colors_dictionary, + "performance_type": [performance_type], + "size": [(size, size)], + "fixed_time_horizons": fixed_time_horizons, + "reference_group_keys": reference_groups, + } + + return calibration_curve_list diff --git a/src/rtichoke/discrimination/gains.py b/src/rtichoke/discrimination/gains.py index 59366f4..8cf7c8d 100644 --- a/src/rtichoke/discrimination/gains.py +++ b/src/rtichoke/discrimination/gains.py @@ -4,7 +4,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, diff --git a/src/rtichoke/discrimination/lift.py b/src/rtichoke/discrimination/lift.py index 5f358af..e3e394c 100644 --- a/src/rtichoke/discrimination/lift.py +++ b/src/rtichoke/discrimination/lift.py @@ -4,7 +4,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, diff --git a/src/rtichoke/discrimination/precision_recall.py b/src/rtichoke/discrimination/precision_recall.py index 565cf5c..1a3d7a0 100644 --- a/src/rtichoke/discrimination/precision_recall.py +++ b/src/rtichoke/discrimination/precision_recall.py @@ -4,7 +4,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, diff --git a/src/rtichoke/discrimination/roc.py b/src/rtichoke/discrimination/roc.py index d8a8ed0..9bcc653 100644 --- a/src/rtichoke/discrimination/roc.py +++ b/src/rtichoke/discrimination/roc.py @@ -4,7 +4,7 @@ from typing import Dict, List, Union, Sequence from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, diff --git a/src/rtichoke/helpers/sandbox_observable_helpers.py b/src/rtichoke/helpers/sandbox_observable_helpers.py deleted file mode 100644 index d3cc352..0000000 --- a/src/rtichoke/helpers/sandbox_observable_helpers.py +++ /dev/null @@ -1,1770 +0,0 @@ -# from lifelines import AalenJohansenFitter -import pandas as pd -import numpy as np -import polars as pl -from polarstate import predict_aj_estimates -from polarstate import prepare_event_table -from typing import Dict, Union -from collections.abc import Sequence - - -def _enum_dataframe(column_name: str, values: Sequence[str]) -> pl.DataFrame: - """Create a single-column DataFrame with an enum dtype.""" - enum_values = list(dict.fromkeys(values)) - enum_dtype = pl.Enum(enum_values) - return pl.DataFrame({column_name: pl.Series(values, dtype=enum_dtype)}) - - -# def extract_aj_estimate(data_to_adjust, fixed_time_horizons): -# """ -# Python implementation of the R extract_aj_estimate function for Aalen-Johansen estimation. - -# Parameters: -# data_to_adjust (pd.DataFrame): DataFrame containing survival data -# fixed_time_horizons (list or float): Time points at which to evaluate the survival - -# Returns: -# pd.DataFrame: DataFrame with Aalen-Johansen estimates -# """ - -# # Ensure fixed_time_horizons is a list -# if not isinstance(fixed_time_horizons, list): -# fixed_time_horizons = [fixed_time_horizons] - -# # Create a categorical version of reals for stratification -# data = data_to_adjust.copy() -# data["reals_cat"] = pd.Categorical( -# data["reals_labels"], -# categories=[ -# "real_negatives", -# "real_positives", -# "real_competing", -# "real_censored", -# ], -# ordered=True, -# ) - -# # Get unique strata values -# strata_values = data["strata"].unique() - -# event_map = { -# "real_negatives": 0, # Treat as censored -# "real_positives": 1, # Event of interest -# "real_competing": 2, # Competing risk -# "real_censored": 0, # Censored -# } - -# data["event_code"] = data["reals_labels"].map(event_map) - -# # Initialize result dataframes -# results = [] - -# # For each stratum, fit Aalen-Johansen model -# for stratum in strata_values: -# # Filter data for current stratum -# stratum_data = data.loc[data["strata"] == stratum] - -# # Initialize Aalen-Johansen fitter -# ajf = AalenJohansenFitter() -# ajf_competing = AalenJohansenFitter() - -# # Fit the model -# ajf.fit(stratum_data["times"], stratum_data["event_code"], event_of_interest=1) - -# ajf_competing.fit( -# stratum_data["times"], stratum_data["event_code"], event_of_interest=2 -# ) - -# # Calculate cumulative incidence at fixed time horizons -# for t in fixed_time_horizons: -# n = len(stratum_data) -# real_positives_est = ajf.predict(t) -# real_competing_est = ajf_competing.predict(t) -# real_negatives_est = 1 - real_positives_est - real_competing_est - -# states = ["real_negatives", "real_positives", "real_competing"] -# estimates = [real_negatives_est, real_positives_est, real_competing_est] - -# for state, estimate in zip(states, estimates): -# results.append( -# { -# "strata": stratum, -# "reals": state, -# "fixed_time_horizon": t, -# "reals_estimate": estimate * n, -# } -# ) - -# # Convert to DataFrame -# result_df = pd.DataFrame(results) - -# # Convert strata to categorical if needed -# result_df["strata"] = pd.Categorical(result_df["strata"]) - -# return result_df - - -def add_cutoff_strata(data: pl.DataFrame, by: float, stratified_by) -> pl.DataFrame: - def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame: - probs = group["probs"].to_numpy() - columns_to_add = [] - - breaks = create_breaks_values(probs, "probability_threshold", by) - if "probability_threshold" in stratified_by: - last_bin_index = len(breaks) - 2 - - bin_indices = np.digitize(probs, bins=breaks, right=False) - 1 - bin_indices = np.where(probs == 1.0, last_bin_index, bin_indices) - - lower_bounds = breaks[bin_indices] - upper_bounds = breaks[bin_indices + 1] - - include_upper_bounds = bin_indices == last_bin_index - - strata_prob_labels = np.where( - include_upper_bounds, - [f"[{lo:.2f}, {hi:.2f}]" for lo, hi in zip(lower_bounds, upper_bounds)], - [f"[{lo:.2f}, {hi:.2f})" for lo, hi in zip(lower_bounds, upper_bounds)], - ).astype(str) - - columns_to_add.append( - pl.Series("strata_probability_threshold", strata_prob_labels) - ) - - if "ppcr" in stratified_by: - # --- Compute strata_ppcr as equal-frequency quantile bins by rank --- - by = float(by) - q = int(round(1 / by)) # e.g. 0.2 -> 5 bins - - probs = np.asarray(probs, float) - - edges = np.quantile(probs, np.linspace(0.0, 1.0, q + 1), method="linear") - - edges = np.maximum.accumulate(edges) - - edges[0] = 0.0 - edges[-1] = 1.0 - - bin_idx = np.digitize(probs, bins=edges[1:-1], right=True) - - s = str(by) - decimals = len(s.split(".")[-1]) if "." in s else 0 - - labels = [f"{x:.{decimals}f}" for x in np.linspace(by, 1.0, q)] - - strata_labels = np.array([labels[i] for i in bin_idx], dtype=object) - - columns_to_add.append( - pl.Series("strata_ppcr", strata_labels).cast(pl.Enum(labels)) - ) - return group.with_columns(columns_to_add) - - # Apply per-group transformation - grouped = data.partition_by("reference_group", as_dict=True) - transformed_groups = [transform_group(group, by) for group in grouped.values()] - return pl.concat(transformed_groups) - - -def create_strata_combinations(stratified_by: str, by: float, breaks) -> pl.DataFrame: - s_by = str(by) - decimals = len(s_by.split(".")[-1]) if "." in s_by else 0 - fmt = f"{{:.{decimals}f}}" - - if stratified_by == "probability_threshold": - upper_bound = breaks[1:] # breaks - lower_bound = breaks[:-1] # np.roll(upper_bound, 1) - # lower_bound[0] = 0.0 - mid_point = upper_bound - by / 2 - include_lower_bound = lower_bound > -0.1 - include_upper_bound = upper_bound == 1.0 # upper_bound != 0.0 - # chosen_cutoff = upper_bound - strata = format_strata_column( - lower_bound=lower_bound, - upper_bound=upper_bound, - include_lower_bound=include_lower_bound, - include_upper_bound=include_upper_bound, - decimals=2, - ) - - elif stratified_by == "ppcr": - strata_mid = breaks[1:] - lower_bound = strata_mid - by / 2 - upper_bound = strata_mid + by / 2 - mid_point = breaks[1:] - include_lower_bound = np.ones_like(strata_mid, dtype=bool) - include_upper_bound = np.zeros_like(strata_mid, dtype=bool) - # chosen_cutoff = strata_mid - strata = np.array([fmt.format(x) for x in strata_mid], dtype=object) - else: - raise ValueError(f"Unsupported stratified_by: {stratified_by}") - - bins_df = pl.DataFrame( - { - "strata": pl.Series(strata), - "lower_bound": lower_bound, - "upper_bound": upper_bound, - "mid_point": mid_point, - "include_lower_bound": include_lower_bound, - "include_upper_bound": include_upper_bound, - # "chosen_cutoff": chosen_cutoff, - "stratified_by": [stratified_by] * len(strata), - } - ) - - cutoffs_df = pl.DataFrame({"chosen_cutoff": breaks}) - - return bins_df.join(cutoffs_df, how="cross") - - -def format_strata_column( - lower_bound: list[float], - upper_bound: list[float], - include_lower_bound: list[bool], - include_upper_bound: list[bool], - decimals: int = 3, -) -> list[str]: - return [ - f"{'[' if ilb else '('}" - f"{round(lb, decimals):.{decimals}f}, " - f"{round(ub, decimals):.{decimals}f}" - f"{']' if iub else ')'}" - for lb, ub, ilb, iub in zip( - lower_bound, upper_bound, include_lower_bound, include_upper_bound - ) - ] - - -def format_strata_interval( - lower: float, upper: float, include_lower: bool, include_upper: bool -) -> str: - left = "[" if include_lower else "(" - right = "]" if include_upper else ")" - return f"{left}{lower:.3f}, {upper:.3f}{right}" - - -def create_breaks_values(probs_vec, stratified_by, by): - if stratified_by != "probability_threshold": - breaks = np.quantile(probs_vec, np.linspace(1, 0, int(1 / by) + 1)) - else: - breaks = np.round( - np.arange(0, 1 + by, by), decimals=len(str(by).split(".")[-1]) - ) - return breaks - - -def _create_aj_data_combinations_binary( - reference_groups: Sequence[str], - stratified_by: Sequence[str], - by: float, - breaks: Sequence[float], -) -> pl.DataFrame: - dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by] - - strata_combinations = pl.concat(dfs, how="vertical") - - strata_cats = ( - strata_combinations.select(pl.col("strata").unique(maintain_order=True)) - .to_series() - .to_list() - ) - - strata_enum = pl.Enum(strata_cats) - stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"]) - - strata_combinations = strata_combinations.with_columns( - [ - pl.col("strata").cast(strata_enum), - pl.col("stratified_by").cast(stratified_by_enum), - ] - ) - - # Define values for Cartesian product - reals_labels = ["real_negatives", "real_positives"] - - combinations_frames: list[pl.DataFrame] = [ - _enum_dataframe("reference_group", reference_groups), - strata_combinations, - _enum_dataframe("reals_labels", reals_labels), - ] - - result = combinations_frames[0] - for frame in combinations_frames[1:]: - result = result.join(frame, how="cross") - - return result - - -def create_aj_data_combinations( - reference_groups: Sequence[str], - heuristics_sets: list[Dict], - fixed_time_horizons: Sequence[float], - stratified_by: Sequence[str], - by: float, - breaks: Sequence[float], - risk_set_scope: Sequence[str] = ["within_stratum", "pooled_by_cutoff"], -) -> pl.DataFrame: - dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by] - strata_combinations = pl.concat(dfs, how="vertical") - - # strata_enum = pl.Enum(strata_combinations["strata"]) - - strata_cats = ( - strata_combinations.select(pl.col("strata").unique(maintain_order=True)) - .to_series() - .to_list() - ) - - strata_enum = pl.Enum(strata_cats) - stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"]) - - strata_combinations = strata_combinations.with_columns( - [ - pl.col("strata").cast(strata_enum), - pl.col("stratified_by").cast(stratified_by_enum), - ] - ) - - risk_set_scope_combinations = pl.DataFrame( - { - "risk_set_scope": pl.Series(risk_set_scope).cast( - pl.Enum(["within_stratum", "pooled_by_cutoff"]) - ) - } - ) - - # Define values for Cartesian product - reals_labels = [ - "real_negatives", - "real_positives", - "real_competing", - "real_censored", - ] - - heuristics_combinations = pl.DataFrame(heuristics_sets) - - censoring_heuristics_enum = pl.Enum( - heuristics_combinations["censoring_heuristic"].unique(maintain_order=True) - ) - competing_heuristics_enum = pl.Enum( - heuristics_combinations["competing_heuristic"].unique(maintain_order=True) - ) - - combinations_frames: list[pl.DataFrame] = [ - _enum_dataframe("reference_group", reference_groups), - pl.DataFrame( - {"fixed_time_horizon": pl.Series(fixed_time_horizons, dtype=pl.Float64)} - ), - heuristics_combinations.with_columns( - [ - pl.col("censoring_heuristic").cast(censoring_heuristics_enum), - pl.col("competing_heuristic").cast(competing_heuristics_enum), - ] - ), - strata_combinations, - risk_set_scope_combinations, - _enum_dataframe("reals_labels", reals_labels), - ] - - result = combinations_frames[0] - for frame in combinations_frames[1:]: - result = result.join(frame, how="cross") - - return result - - -def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame: - # Identify id_vars and value_vars - id_vars = [col for col in data.columns if not col.startswith("strata_")] - value_vars = [col for col in data.columns if col.startswith("strata_")] - - # Perform the melt (equivalent to pandas.melt) - data_long = data.melt( - id_vars=id_vars, - value_vars=value_vars, - variable_name="stratified_by", - value_name="strata", - ) - - stratified_by_labels = ["probability_threshold", "ppcr"] - stratified_by_enum = pl.Enum(stratified_by_labels) - - # Remove "strata_" prefix from the 'stratified_by' column - data_long = data_long.with_columns( - pl.col("stratified_by").str.replace("^strata_", "").cast(stratified_by_enum) - ) - - return data_long - - -def map_reals_to_labels_polars(data: pl.DataFrame) -> pl.DataFrame: - return data.with_columns( - [ - pl.when(pl.col("reals") == 0) - .then("real_negatives") - .when(pl.col("reals") == 1) - .then("real_positives") - .when(pl.col("reals") == 2) - .then("real_competing") - .otherwise("real_censored") - .alias("reals") - ] - ) - - -def update_administrative_censoring_polars(data: pl.DataFrame) -> pl.DataFrame: - data = data.with_columns( - [ - pl.when( - (pl.col("times") > pl.col("fixed_time_horizon")) - & (pl.col("reals_labels") == "real_positives") - ) - .then(pl.lit("real_negatives")) - .when( - (pl.col("times") < pl.col("fixed_time_horizon")) - & (pl.col("reals_labels") == "real_negatives") - ) - .then(pl.lit("real_censored")) - .otherwise(pl.col("reals_labels")) - .alias("reals_labels") - ] - ) - - return data - - -def create_aj_data( - reference_group_data, - breaks, - censoring_heuristic, - competing_heuristic, - fixed_time_horizons, - stratified_by: Sequence[str], - full_event_table: bool = False, - risk_set_scope: Sequence[str] = ["within_stratum"], -): - """ - Create AJ estimates per strata based on censoring and competing heuristicss. - """ - - def aj_estimates_with_cross(df, extra_cols): - return df.join(pl.DataFrame(extra_cols), how="cross") - - exploded = assign_and_explode_polars(reference_group_data, fixed_time_horizons) - - event_table = prepare_event_table(reference_group_data) - - # TODO: solve strata in the pipeline - - excluded_events = _extract_excluded_events( - event_table, fixed_time_horizons, censoring_heuristic, competing_heuristic - ) - - aj_dfs = [] - for rscope in risk_set_scope: - aj_res = _aj_adjusted_events( - reference_group_data, - breaks, - exploded, - censoring_heuristic, - competing_heuristic, - fixed_time_horizons, - stratified_by, - full_event_table, - rscope, - ) - - aj_res = aj_res.select( - [ - "strata", - "times", - "chosen_cutoff", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - "estimate_origin", - "fixed_time_horizon", - "risk_set_scope", - ] - ) - - aj_dfs.append(aj_res) - - aj_df = pl.concat(aj_dfs, how="vertical") - - result = aj_df.join(excluded_events, on=["fixed_time_horizon"], how="left") - - return aj_estimates_with_cross( - result, - { - "censoring_heuristic": censoring_heuristic, - "competing_heuristic": competing_heuristic, - }, - ).select( - [ - "strata", - "chosen_cutoff", - "fixed_time_horizon", - "times", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - "real_censored_est", - "censoring_heuristic", - "competing_heuristic", - "estimate_origin", - "risk_set_scope", - ] - ) - - -def _extract_excluded_events( - event_table: pl.DataFrame, - fixed_time_horizons: list[float], - censoring_heuristic: str, - competing_heuristic: str, -) -> pl.DataFrame: - horizons_df = pl.DataFrame({"times": fixed_time_horizons}).sort("times") - - excluded_events = horizons_df.join_asof( - event_table.with_columns( - pl.col("count_0").cum_sum().cast(pl.Float64).alias("real_censored_est"), - pl.col("count_2").cum_sum().cast(pl.Float64).alias("real_competing_est"), - ).select( - pl.col("times"), - pl.col("real_censored_est"), - pl.col("real_competing_est"), - ), - left_on="times", - right_on="times", - ).with_columns([pl.col("times").alias("fixed_time_horizon")]) - - if censoring_heuristic != "excluded": - excluded_events = excluded_events.with_columns( - pl.lit(0.0).alias("real_censored_est") - ) - - if competing_heuristic != "excluded": - excluded_events = excluded_events.with_columns( - pl.lit(0.0).alias("real_competing_est") - ) - - return excluded_events - - -def extract_crude_estimate_polars(data: pl.DataFrame) -> pl.DataFrame: - all_combinations = data.select(["strata", "reals", "fixed_time_horizon"]).unique() - - counts = data.group_by(["strata", "reals", "fixed_time_horizon"]).agg( - pl.count().alias("reals_estimate") - ) - - return all_combinations.join( - counts, on=["strata", "reals", "fixed_time_horizon"], how="left" - ).with_columns([pl.col("reals_estimate").fill_null(0).cast(pl.Int64)]) - - -# def update_administrative_censoring(data_to_adjust: pd.DataFrame) -> pd.DataFrame: -# pl_df = pl.from_pandas(data_to_adjust) - -# # Perform the transformation using polars -# pl_result = pl_df.with_columns( -# pl.when( -# (pl.col("times") > pl.col("fixed_time_horizon")) & -# (pl.col("reals") == "real_positives") -# ).then( -# "real_negatives" -# ).when( -# (pl.col("times") < pl.col("fixed_time_horizon")) & -# (pl.col("reals") == "real_negatives") -# ).then( -# "real_censored" -# ).otherwise( -# pl.col("reals") -# ).alias("reals") -# ) - -# # Convert back to pandas DataFrame and return -# result_pandas = pl_result.to_pandas() - -# return result_pandas - - -def extract_aj_estimate_by_cutoffs( - data_to_adjust, horizons, breaks, stratified_by, full_event_table: bool -): - # n = data_to_adjust.height - - counts_per_strata = ( - data_to_adjust.group_by( - ["strata", "stratified_by", "upper_bound", "lower_bound"] - ) - .len(name="strata_count") - .with_columns(pl.col("strata_count").cast(pl.Float64)) - ) - - aj_estimates_predicted_positives = pl.DataFrame() - aj_estimates_predicted_negatives = pl.DataFrame() - - for stratification_criteria in stratified_by: - for chosen_cutoff in breaks: - if stratification_criteria == "probability_threshold": - mask_predicted_positives = (pl.col("upper_bound") > chosen_cutoff) & ( - pl.col("stratified_by") == "probability_threshold" - ) - mask_predicted_negatives = (pl.col("upper_bound") <= chosen_cutoff) & ( - pl.col("stratified_by") == "probability_threshold" - ) - - elif stratification_criteria == "ppcr": - mask_predicted_positives = ( - pl.col("lower_bound") > 1 - chosen_cutoff - ) & (pl.col("stratified_by") == "ppcr") - mask_predicted_negatives = ( - pl.col("lower_bound") <= 1 - chosen_cutoff - ) & (pl.col("stratified_by") == "ppcr") - - predicted_positives = data_to_adjust.filter(mask_predicted_positives) - predicted_negatives = data_to_adjust.filter(mask_predicted_negatives) - - counts_per_strata_predicted_positives = counts_per_strata.filter( - mask_predicted_positives - ) - counts_per_strata_predicted_negatives = counts_per_strata.filter( - mask_predicted_negatives - ) - - event_table_predicted_positives = prepare_event_table(predicted_positives) - event_table_predicted_negatives = prepare_event_table(predicted_negatives) - - aj_estimate_predicted_positives = ( - ( - predict_aj_estimates( - event_table_predicted_positives, - pl.Series(horizons), - full_event_table, - ) - .with_columns( - pl.lit(chosen_cutoff).alias("chosen_cutoff"), - pl.lit(stratification_criteria) - .alias("stratified_by") - .cast(pl.Enum(["probability_threshold", "ppcr"])), - ) - .join( - counts_per_strata_predicted_positives, - on=["stratified_by"], - how="left", - ) - .with_columns( - [ - ( - pl.col("state_occupancy_probability_0") - * pl.col("strata_count") - ).alias("real_negatives_est"), - ( - pl.col("state_occupancy_probability_1") - * pl.col("strata_count") - ).alias("real_positives_est"), - ( - pl.col("state_occupancy_probability_2") - * pl.col("strata_count") - ).alias("real_competing_est"), - ] - ) - ) - .select( - [ - "strata", - # "stratified_by", - "times", - "chosen_cutoff", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - "estimate_origin", - ] - ) - .with_columns([pl.col("times").alias("fixed_time_horizon")]) - ) - - aj_estimate_predicted_negatives = ( - ( - predict_aj_estimates( - event_table_predicted_negatives, - pl.Series(horizons), - full_event_table, - ) - .with_columns( - pl.lit(chosen_cutoff).alias("chosen_cutoff"), - pl.lit(stratification_criteria) - .alias("stratified_by") - .cast(pl.Enum(["probability_threshold", "ppcr"])), - ) - .join( - counts_per_strata_predicted_negatives, - on=["stratified_by"], - how="left", - ) - .with_columns( - [ - ( - pl.col("state_occupancy_probability_0") - * pl.col("strata_count") - ).alias("real_negatives_est"), - ( - pl.col("state_occupancy_probability_1") - * pl.col("strata_count") - ).alias("real_positives_est"), - ( - pl.col("state_occupancy_probability_2") - * pl.col("strata_count") - ).alias("real_competing_est"), - ] - ) - ) - .select( - [ - "strata", - # "stratified_by", - "times", - "chosen_cutoff", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - "estimate_origin", - ] - ) - .with_columns([pl.col("times").alias("fixed_time_horizon")]) - ) - - aj_estimates_predicted_negatives = pl.concat( - [aj_estimates_predicted_negatives, aj_estimate_predicted_negatives], - how="vertical", - ) - - aj_estimates_predicted_positives = pl.concat( - [aj_estimates_predicted_positives, aj_estimate_predicted_positives], - how="vertical", - ) - - aj_estimate_by_cutoffs = pl.concat( - [aj_estimates_predicted_negatives, aj_estimates_predicted_positives], - how="vertical", - ) - - return aj_estimate_by_cutoffs - - -def extract_aj_estimate_for_strata(data_to_adjust, horizons, full_event_table: bool): - n = data_to_adjust.height - - event_table = prepare_event_table(data_to_adjust) - - aj_estimate_for_strata_polars = predict_aj_estimates( - event_table, pl.Series(horizons), full_event_table - ) - - if len(horizons) == 1: - aj_estimate_for_strata_polars = aj_estimate_for_strata_polars.with_columns( - pl.lit(horizons[0]).alias("fixed_time_horizon") - ) - - else: - fixed_df = aj_estimate_for_strata_polars.filter( - pl.col("estimate_origin") == "fixed_time_horizons" - ).with_columns([pl.col("times").alias("fixed_time_horizon")]) - - event_df = ( - aj_estimate_for_strata_polars.filter( - pl.col("estimate_origin") == "event_table" - ) - .with_columns([pl.lit(horizons).alias("fixed_time_horizon")]) - .explode("fixed_time_horizon") - ) - - aj_estimate_for_strata_polars = pl.concat( - [fixed_df, event_df], how="vertical" - ).sort("estimate_origin", "fixed_time_horizon", "times") - - return aj_estimate_for_strata_polars.with_columns( - [ - (pl.col("state_occupancy_probability_0") * n).alias("real_negatives_est"), - (pl.col("state_occupancy_probability_1") * n).alias("real_positives_est"), - (pl.col("state_occupancy_probability_2") * n).alias("real_competing_est"), - pl.col("fixed_time_horizon").cast(pl.Float64), - pl.lit(data_to_adjust["strata"][0]).alias("strata"), - ] - ).select( - [ - "strata", - "times", - "fixed_time_horizon", - "real_negatives_est", - "real_positives_est", - "real_competing_est", - pl.col("estimate_origin"), - ] - ) - - -def assign_and_explode_polars( - data: pl.DataFrame, fixed_time_horizons: list[float] -) -> pl.DataFrame: - return ( - data.with_columns(pl.lit(fixed_time_horizons).alias("fixed_time_horizon")) - .explode("fixed_time_horizon") - .with_columns(pl.col("fixed_time_horizon").cast(pl.Float64)) - ) - - -def _create_list_data_to_adjust_binary( - aj_data_combinations: pl.DataFrame, - probs_dict: Dict[str, np.ndarray], - reals_dict: Union[np.ndarray, Dict[str, np.ndarray]], - stratified_by, - by, -) -> Dict[str, pl.DataFrame]: - reference_group_labels = list(probs_dict.keys()) - - if isinstance(reals_dict, dict): - num_keys_reals = len(reals_dict) - else: - num_keys_reals = 1 - - reference_group_enum = pl.Enum(reference_group_labels) - - strata_enum_dtype = aj_data_combinations.schema["strata"] - - if len(probs_dict) == 1: - probs_array = np.asarray(probs_dict[reference_group_labels[0]]) - - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, len(probs_array)), - "probs": probs_array, - "reals": reals_dict, - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) - - elif num_keys_reals == 1: - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, len(reals_dict)), - "probs": np.concatenate( - [probs_dict[group] for group in reference_group_labels] - ), - "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)), - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) - - elif isinstance(reals_dict, dict): - data_to_adjust = ( - pl.DataFrame( - { - "reference_group": list(probs_dict.keys()), - "probs": list(probs_dict.values()), - "reals": list(reals_dict.values()), - } - ) - .explode(["probs", "reals"]) - .with_columns(pl.col("reference_group").cast(reference_group_enum)) - ) - - data_to_adjust = add_cutoff_strata( - data_to_adjust, by=by, stratified_by=stratified_by - ) - - data_to_adjust = pivot_longer_strata(data_to_adjust) - - data_to_adjust = ( - data_to_adjust.with_columns([pl.col("strata")]) - .with_columns(pl.col("strata").cast(strata_enum_dtype)) - .join( - aj_data_combinations.select( - pl.col("strata"), - pl.col("stratified_by"), - pl.col("upper_bound"), - pl.col("lower_bound"), - ).unique(), - how="left", - on=["strata", "stratified_by"], - ) - ) - - reals_labels = ["real_negatives", "real_positives"] - - reals_enum = pl.Enum(reals_labels) - - reals_map = {0: "real_negatives", 1: "real_positives"} - - data_to_adjust = data_to_adjust.with_columns( - pl.col("reals") - .replace_strict(reals_map, return_dtype=reals_enum) - .alias("reals_labels") - ) - - list_data_to_adjust = { - group[0]: df - for group, df in data_to_adjust.partition_by( - "reference_group", as_dict=True - ).items() - } - - return list_data_to_adjust - - -def _create_list_data_to_adjust( - aj_data_combinations: pl.DataFrame, - probs_dict: Dict[str, np.ndarray], - reals_dict: Union[np.ndarray, Dict[str, np.ndarray]], - times_dict: Union[np.ndarray, Dict[str, np.ndarray]], - stratified_by, - by, -) -> Dict[str, pl.DataFrame]: - # reference_groups = list(probs_dict.keys()) - reference_group_labels = list(probs_dict.keys()) - - if isinstance(reals_dict, dict): - num_keys_reals = len(reals_dict) - else: - num_keys_reals = 1 - - # num_reals = len(reals_dict) - - reference_group_enum = pl.Enum(reference_group_labels) - - strata_enum_dtype = aj_data_combinations.schema["strata"] - - if len(probs_dict) == 1: - probs_array = np.asarray(probs_dict[reference_group_labels[0]]) - - if isinstance(reals_dict, dict): - reals_array = np.asarray(reals_dict[0]) - else: - reals_array = np.asarray(reals_dict) - - if isinstance(times_dict, dict): - times_array = np.asarray(times_dict[0]) - else: - times_array = np.asarray(times_dict) - - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, len(probs_array)), - "probs": probs_array, - "reals": reals_array, - "times": times_array, - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) - - elif num_keys_reals == 1: - reals_array = np.asarray(reals_dict) - times_array = np.asarray(times_dict) - n = len(reals_array) - - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, n), - "probs": np.concatenate( - [np.asarray(probs_dict[g]) for g in reference_group_labels] - ), - "reals": np.tile(reals_array, len(reference_group_labels)), - "times": np.tile(times_array, len(reference_group_labels)), - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) - - elif isinstance(reals_dict, dict) and isinstance(times_dict, dict): - data_to_adjust = ( - pl.DataFrame( - { - "reference_group": reference_group_labels, - "probs": list(probs_dict.values()), - "reals": list(reals_dict.values()), - "times": list(times_dict.values()), - } - ) - .explode(["probs", "reals", "times"]) - .with_columns(pl.col("reference_group").cast(reference_group_enum)) - ) - - data_to_adjust = add_cutoff_strata( - data_to_adjust, by=by, stratified_by=stratified_by - ) - - data_to_adjust = pivot_longer_strata(data_to_adjust) - - data_to_adjust = ( - data_to_adjust.with_columns([pl.col("strata")]) - .with_columns(pl.col("strata").cast(strata_enum_dtype)) - .join( - aj_data_combinations.select( - pl.col("strata"), - pl.col("stratified_by"), - pl.col("upper_bound"), - pl.col("lower_bound"), - ).unique(), - how="left", - on=["strata", "stratified_by"], - ) - ) - - reals_labels = [ - "real_negatives", - "real_positives", - "real_competing", - "real_censored", - ] - - reals_enum = pl.Enum(reals_labels) - - # Map reals values to strings - reals_map = {0: "real_negatives", 2: "real_competing", 1: "real_positives"} - - data_to_adjust = data_to_adjust.with_columns( - pl.col("reals") - .replace_strict(reals_map, return_dtype=reals_enum) - .alias("reals_labels") - ) - - # Partition by reference_group - list_data_to_adjust = { - group[0]: df - for group, df in data_to_adjust.partition_by( - "reference_group", as_dict=True - ).items() - } - - return list_data_to_adjust - - -def ensure_no_categorical(df: pd.DataFrame) -> pd.DataFrame: - df = df.copy() - for col in df.select_dtypes(include="category").columns: - df[col] = df[col].astype(str) - return df - - -def extract_aj_estimate_by_heuristics( - df: pl.DataFrame, - breaks: Sequence[float], - heuristics_sets: list[dict], - fixed_time_horizons: list[float], - stratified_by: Sequence[str], - risk_set_scope: Sequence[str] = ["within_stratum"], -) -> pl.DataFrame: - aj_dfs = [] - - for heuristic in heuristics_sets: - censoring = heuristic["censoring_heuristic"] - competing = heuristic["competing_heuristic"] - - aj_df = create_aj_data( - df, - breaks, - censoring, - competing, - fixed_time_horizons, - stratified_by=stratified_by, - full_event_table=False, - risk_set_scope=risk_set_scope, - ).with_columns( - [ - pl.lit(censoring).alias("censoring_heuristic"), - pl.lit(competing).alias("competing_heuristic"), - ] - ) - - aj_dfs.append(aj_df) - - aj_estimates_data = pl.concat(aj_dfs).drop(["estimate_origin", "times"]) - - aj_estimates_unpivoted = aj_estimates_data.unpivot( - index=[ - "strata", - "chosen_cutoff", - "fixed_time_horizon", - "censoring_heuristic", - "competing_heuristic", - "risk_set_scope", - ], - variable_name="reals_labels", - value_name="reals_estimate", - ) - - return aj_estimates_unpivoted - - -def _create_adjusted_data_binary( - list_data_to_adjust: dict[str, pl.DataFrame], - breaks: Sequence[float], - stratified_by: Sequence[str], -) -> pl.DataFrame: - long_df = pl.concat(list(list_data_to_adjust.values()), how="vertical") - - adjusted_data_binary = ( - long_df.group_by(["strata", "stratified_by", "reference_group", "reals_labels"]) - .agg(pl.count().alias("reals_estimate")) - .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - ) - - return adjusted_data_binary - - -def create_adjusted_data( - list_data_to_adjust: dict[str, pl.DataFrame], - heuristics_sets: list[dict[str, str]], - fixed_time_horizons: list[float], - breaks: Sequence[float], - stratified_by: Sequence[str], - risk_set_scope: Sequence[str] = ["within_stratum"], -) -> pl.DataFrame: - all_results = [] - - reference_groups = list(list_data_to_adjust.keys()) - reference_group_enum = pl.Enum(reference_groups) - - heuristics_df = pl.DataFrame(heuristics_sets) - censoring_heuristic_enum = pl.Enum( - heuristics_df["censoring_heuristic"].unique(maintain_order=True) - ) - competing_heuristic_enum = pl.Enum( - heuristics_df["competing_heuristic"].unique(maintain_order=True) - ) - - for reference_group, df in list_data_to_adjust.items(): - input_df = df.select( - ["strata", "reals", "times", "upper_bound", "lower_bound", "stratified_by"] - ) - - aj_result = extract_aj_estimate_by_heuristics( - input_df, - breaks, - heuristics_sets=heuristics_sets, - fixed_time_horizons=fixed_time_horizons, - stratified_by=stratified_by, - risk_set_scope=risk_set_scope, - ) - - aj_result_with_group = aj_result.with_columns( - [ - pl.lit(reference_group) - .cast(reference_group_enum) - .alias("reference_group") - ] - ) - - all_results.append(aj_result_with_group) - - reals_enum_dtype = pl.Enum( - [ - "real_negatives", - "real_positives", - "real_competing", - "real_censored", - ] - ) - - return ( - pl.concat(all_results) - .with_columns([pl.col("reference_group").cast(reference_group_enum)]) - .with_columns( - [ - pl.col("reals_labels").str.replace(r"_est$", "").cast(reals_enum_dtype), - pl.col("censoring_heuristic").cast(censoring_heuristic_enum), - pl.col("competing_heuristic").cast(competing_heuristic_enum), - ] - ) - ) - - -def _cast_and_join_adjusted_data_binary( - aj_data_combinations: pl.DataFrame, aj_estimates_data: pl.DataFrame -) -> pl.DataFrame: - strata_enum_dtype = aj_data_combinations.schema["strata"] - - aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns( - pl.col("strata").cast(strata_enum_dtype) - ) - - final_adjusted_data_polars = ( - ( - aj_data_combinations.with_columns([pl.col("strata")]).join( - aj_estimates_data, - on=[ - "strata", - "stratified_by", - "reals_labels", - "reference_group", - "chosen_cutoff", - ], - how="left", - ) - ) - .with_columns( - pl.when( - ( - (pl.col("chosen_cutoff") >= pl.col("upper_bound")) - & (pl.col("stratified_by") == "probability_threshold") - ) - | ( - ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point")) - & (pl.col("stratified_by") == "ppcr") - ) - ) - .then(pl.lit("predicted_negatives")) - .otherwise(pl.lit("predicted_positives")) - .cast(pl.Enum(["predicted_negatives", "predicted_positives"])) - .alias("prediction_label") - ) - .with_columns( - ( - pl.when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_positives")) - ) - .then(pl.lit("true_positives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_negatives")) - ) - .then(pl.lit("false_positives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_negatives")) - ) - .then(pl.lit("true_negatives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_positives")) - ) - .then(pl.lit("false_negatives")) - .cast( - pl.Enum( - [ - "true_positives", - "false_positives", - "true_negatives", - "false_negatives", - ] - ) - ) - ).alias("classification_outcome") - ) - ).with_columns(pl.col("reals_estimate").fill_null(0)) - - return final_adjusted_data_polars - - -def cast_and_join_adjusted_data( - aj_data_combinations, aj_estimates_data -) -> pl.DataFrame: - strata_enum_dtype = aj_data_combinations.schema["strata"] - - aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns( - pl.col("strata").cast(strata_enum_dtype) - ) - - final_adjusted_data_polars = ( - aj_data_combinations.with_columns([pl.col("strata")]) - .join( - aj_estimates_data, - on=[ - "strata", - "fixed_time_horizon", - "censoring_heuristic", - "competing_heuristic", - "reals_labels", - "reference_group", - "chosen_cutoff", - "risk_set_scope", - ], - how="left", - ) - .with_columns( - pl.when( - ( - (pl.col("chosen_cutoff") >= pl.col("upper_bound")) - & (pl.col("stratified_by") == "probability_threshold") - ) - | ( - ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point")) - & (pl.col("stratified_by") == "ppcr") - ) - ) - .then(pl.lit("predicted_negatives")) - .otherwise(pl.lit("predicted_positives")) - .cast(pl.Enum(["predicted_negatives", "predicted_positives"])) - .alias("prediction_label") - ) - .with_columns( - ( - pl.when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_positives")) - ) - .then(pl.lit("true_positives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_negatives")) - ) - .then(pl.lit("false_positives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_negatives")) - ) - .then(pl.lit("true_negatives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_positives")) - ) - .then(pl.lit("false_negatives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_negatives")) - & (pl.col("reals_labels") == pl.lit("real_competing")) - & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative")) - ) - .then(pl.lit("true_negatives")) - .when( - (pl.col("prediction_label") == pl.lit("predicted_positives")) - & (pl.col("reals_labels") == pl.lit("real_competing")) - & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative")) - ) - .then(pl.lit("false_positives")) - .otherwise(pl.lit("excluded")) # or pl.lit(None) if you prefer nulls - .cast( - pl.Enum( - [ - "true_positives", - "false_positives", - "true_negatives", - "false_negatives", - "excluded", - ] - ) - ) - ).alias("classification_outcome") - ) - ) - return final_adjusted_data_polars - - -def _censored_count(df: pl.DataFrame) -> pl.DataFrame: - return ( - df.with_columns( - ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 0)) - .cast(pl.Float64) - .alias("is_censored") - ) - .group_by(["strata", "fixed_time_horizon"]) - .agg(pl.col("is_censored").sum().alias("real_censored_est")) - ) - - -def _competing_count(df: pl.DataFrame) -> pl.DataFrame: - return ( - df.with_columns( - ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 2)) - .cast(pl.Float64) - .alias("is_competing") - ) - .group_by(["strata", "fixed_time_horizon"]) - .agg(pl.col("is_competing").sum().alias("real_competing_est")) - ) - - -def _aj_estimates_by_cutoff_per_horizon( - df: pl.DataFrame, - horizons: list[float], - breaks: Sequence[float], - stratified_by: Sequence[str], -) -> pl.DataFrame: - return pl.concat( - [ - df.filter(pl.col("fixed_time_horizon") == h) - .group_by("strata") - .map_groups( - lambda group: extract_aj_estimate_by_cutoffs( - group, [h], breaks, stratified_by, full_event_table=False - ) - ) - for h in horizons - ], - how="vertical", - ) - - -def _aj_estimates_per_horizon( - df: pl.DataFrame, horizons: list[float], full_event_table: bool -) -> pl.DataFrame: - return pl.concat( - [ - df.filter(pl.col("fixed_time_horizon") == h) - .group_by("strata") - .map_groups( - lambda group: extract_aj_estimate_for_strata( - group, [h], full_event_table - ) - ) - for h in horizons - ], - how="vertical", - ) - - -def _aj_adjusted_events( - reference_group_data: pl.DataFrame, - breaks: Sequence[float], - exploded: pl.DataFrame, - censoring: str, - competing: str, - horizons: list[float], - stratified_by: Sequence[str], - full_event_table: bool = False, - risk_set_scope: Sequence[str] = ["within_stratum"], -) -> pl.DataFrame: - strata_enum_dtype = reference_group_data.schema["strata"] - - # Special-case: adjusted censoring + competing adjusted_as_negative supports pooled_by_cutoff - if censoring == "adjusted" and competing == "adjusted_as_negative": - if risk_set_scope == "within_stratum": - adjusted = ( - reference_group_data.group_by("strata") - .map_groups( - lambda group: extract_aj_estimate_for_strata( - group, horizons, full_event_table - ) - ) - .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - ) - # preserve the original enum dtype for 'strata' coming from reference_group_data - - adjusted = adjusted.with_columns( - [ - pl.col("strata").cast(strata_enum_dtype), - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope"), - ] - ) - - return adjusted - - elif risk_set_scope == "pooled_by_cutoff": - adjusted = extract_aj_estimate_by_cutoffs( - reference_group_data, horizons, breaks, stratified_by, full_event_table - ) - adjusted = adjusted.with_columns( - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope") - ) - return adjusted - - # Special-case: both excluded (faster branch in original) - if censoring == "excluded" and competing == "excluded": - non_censored_non_competing = exploded.filter( - (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") == 1) - ) - - adjusted = _aj_estimates_per_horizon( - non_censored_non_competing, horizons, full_event_table - ) - - adjusted = adjusted.with_columns( - [ - pl.col("strata").cast(strata_enum_dtype), - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope"), - ] - ).join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - - return adjusted - - # Special-case: competing excluded (handled by filtering out competing events) - if competing == "excluded": - # Use exploded to apply filters that depend on fixed_time_horizon consistently - non_competing = exploded.filter( - (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") != 2) - ).with_columns( - pl.when(pl.col("reals") == 2) - .then(pl.lit(0)) - .otherwise(pl.col("reals")) - .alias("reals") - ) - - if risk_set_scope == "within_stratum": - adjusted = ( - _aj_estimates_per_horizon(non_competing, horizons, full_event_table) - # .select(pl.exclude("real_competing_est")) - .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - ) - - elif risk_set_scope == "pooled_by_cutoff": - adjusted = extract_aj_estimate_by_cutoffs( - non_competing, horizons, breaks, stratified_by, full_event_table - ) - - adjusted = adjusted.with_columns( - [ - pl.col("strata").cast(strata_enum_dtype), - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope"), - ] - ) - return adjusted - - # For remaining cases, determine base dataframe depending on censoring rule: - # - "adjusted": use the full reference_group_data (events censored at horizon are kept/adjusted) - # - "excluded": remove administratively censored observations (use exploded with filter) - base_df = ( - exploded.filter( - (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") > 0) - ) - if censoring == "excluded" - else reference_group_data - ) - - # Apply competing-event transformation if required - if competing == "adjusted_as_censored": - base_df = base_df.with_columns( - pl.when(pl.col("reals") == 2) - .then(pl.lit(0)) - .otherwise(pl.col("reals")) - .alias("reals") - ) - elif competing == "adjusted_as_composite": - base_df = base_df.with_columns( - pl.when(pl.col("reals") == 2) - .then(pl.lit(1)) - .otherwise(pl.col("reals")) - .alias("reals") - ) - # competing == "adjusted_as_negative": keep reals as-is (no transform) - - # Finally choose aggregation strategy: per-stratum or horizon-wise - if censoring == "excluded": - # For excluded censoring we always evaluate per-horizon on the filtered (exploded) dataset - - if risk_set_scope == "within_stratum": - adjusted = _aj_estimates_per_horizon(base_df, horizons, full_event_table) - - adjusted = adjusted.join( - pl.DataFrame({"chosen_cutoff": breaks}), how="cross" - ) - - elif risk_set_scope == "pooled_by_cutoff": - adjusted = _aj_estimates_by_cutoff_per_horizon( - base_df, horizons, breaks, stratified_by - ) - - adjusted = adjusted.with_columns( - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope") - ) - - return adjusted.with_columns(pl.col("strata").cast(strata_enum_dtype)) - else: - # For adjusted censoring we aggregate within strata - - if risk_set_scope == "within_stratum": - adjusted = ( - base_df.group_by("strata") - .map_groups( - lambda group: extract_aj_estimate_for_strata( - group, horizons, full_event_table - ) - ) - .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") - ) - - elif risk_set_scope == "pooled_by_cutoff": - adjusted = extract_aj_estimate_by_cutoffs( - base_df, horizons, breaks, stratified_by, full_event_table - ) - - adjusted = adjusted.with_columns( - [ - pl.col("strata").cast(strata_enum_dtype), - pl.lit(risk_set_scope) - .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) - .alias("risk_set_scope"), - ] - ) - - return adjusted - - -def _calculate_cumulative_aj_data_binary(aj_data: pl.DataFrame) -> pl.DataFrame: - cumulative_aj_data = ( - aj_data.group_by( - [ - "reference_group", - "stratified_by", - "chosen_cutoff", - "classification_outcome", - ] - ) - .agg([pl.col("reals_estimate").sum()]) - .pivot(on="classification_outcome", values="reals_estimate") - .with_columns( - [ - pl.col(col).fill_null(0) - for col in [ - "true_positives", - "true_negatives", - "false_positives", - "false_negatives", - ] - ] - ) - .with_columns( - (pl.col("true_positives") + pl.col("false_positives")).alias( - "predicted_positives" - ), - (pl.col("true_negatives") + pl.col("false_negatives")).alias( - "predicted_negatives" - ), - (pl.col("true_positives") + pl.col("false_negatives")).alias( - "real_positives" - ), - (pl.col("false_positives") + pl.col("true_negatives")).alias( - "real_negatives" - ), - ( - pl.col("true_positives") - + pl.col("true_negatives") - + pl.col("false_positives") - + pl.col("false_negatives") - ) - .alias("n") - .sum(), - ) - .with_columns( - (pl.col("true_positives") + pl.col("false_positives")).alias( - "predicted_positives" - ), - (pl.col("true_negatives") + pl.col("false_negatives")).alias( - "predicted_negatives" - ), - (pl.col("true_positives") + pl.col("false_negatives")).alias( - "real_positives" - ), - (pl.col("false_positives") + pl.col("true_negatives")).alias( - "real_negatives" - ), - ( - pl.col("true_positives") - + pl.col("true_negatives") - + pl.col("false_positives") - + pl.col("false_negatives") - ).alias("n"), - ) - ) - - return cumulative_aj_data - - -def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame: - cumulative_aj_data = ( - aj_data.filter(pl.col("risk_set_scope") == "pooled_by_cutoff") - .group_by( - [ - "reference_group", - "fixed_time_horizon", - "censoring_heuristic", - "competing_heuristic", - "stratified_by", - "chosen_cutoff", - "classification_outcome", - ] - ) - .agg([pl.col("reals_estimate").sum()]) - .pivot(on="classification_outcome", values="reals_estimate") - .fill_null(0) - .with_columns( - (pl.col("true_positives") + pl.col("false_positives")).alias( - "predicted_positives" - ), - (pl.col("true_negatives") + pl.col("false_negatives")).alias( - "predicted_negatives" - ), - (pl.col("true_positives") + pl.col("false_negatives")).alias( - "real_positives" - ), - (pl.col("false_positives") + pl.col("true_negatives")).alias( - "real_negatives" - ), - ( - pl.col("true_positives") - + pl.col("true_negatives") - + pl.col("false_positives") - + pl.col("false_negatives") - ).alias("n"), - ) - .with_columns( - (pl.col("true_positives") + pl.col("false_positives")).alias( - "predicted_positives" - ), - (pl.col("true_negatives") + pl.col("false_negatives")).alias( - "predicted_negatives" - ), - (pl.col("true_positives") + pl.col("false_negatives")).alias( - "real_positives" - ), - (pl.col("false_positives") + pl.col("true_negatives")).alias( - "real_negatives" - ), - ( - pl.col("true_positives") - + pl.col("true_negatives") - + pl.col("false_positives") - + pl.col("false_negatives") - ).alias("n"), - ) - ) - - return cumulative_aj_data - - -def _turn_cumulative_aj_to_performance_data( - cumulative_aj_data: pl.DataFrame, -) -> pl.DataFrame: - performance_data = cumulative_aj_data.with_columns( - (pl.col("true_positives") / pl.col("real_positives")).alias("sensitivity"), - (pl.col("true_negatives") / pl.col("real_negatives")).alias("specificity"), - (pl.col("true_positives") / pl.col("predicted_positives")).alias("ppv"), - (pl.col("true_negatives") / pl.col("predicted_negatives")).alias("npv"), - (pl.col("false_positives") / pl.col("real_negatives")).alias( - "false_positive_rate" - ), - ( - (pl.col("true_positives") / pl.col("predicted_positives")) - / (pl.col("real_positives") / pl.col("n")) - ).alias("lift"), - pl.when(pl.col("stratified_by") == "probability_threshold") - .then( - (pl.col("true_positives") / pl.col("n")) - - (pl.col("false_positives") / pl.col("n")) - * pl.col("chosen_cutoff") - / (1 - pl.col("chosen_cutoff")) - ) - .otherwise(None) - .alias("net_benefit"), - pl.when(pl.col("stratified_by") == "probability_threshold") - .then( - 100 * (pl.col("true_negatives") / pl.col("n")) - - (pl.col("false_negatives") / pl.col("n")) - * (1 - pl.col("chosen_cutoff")) - / pl.col("chosen_cutoff") - ) - .otherwise(None) - .alias("net_benefit_interventions_avoided"), - pl.when(pl.col("stratified_by") == "probability_threshold") - .then(pl.col("predicted_positives") / pl.col("n")) - .otherwise(pl.col("chosen_cutoff")) - .alias("ppcr"), - ) - - return performance_data diff --git a/src/rtichoke/performance_data/performance_data.py b/src/rtichoke/performance_data/performance_data.py index 3922723..8fa2d30 100644 --- a/src/rtichoke/performance_data/performance_data.py +++ b/src/rtichoke/performance_data/performance_data.py @@ -5,13 +5,15 @@ from typing import Dict, Union import polars as pl from collections.abc import Sequence -from rtichoke.helpers.sandbox_observable_helpers import ( +from rtichoke.processing.adjustments import _create_adjusted_data_binary +from rtichoke.processing.combinations import ( _create_aj_data_combinations_binary, create_breaks_values, - _create_list_data_to_adjust_binary, - _create_adjusted_data_binary, - _cast_and_join_adjusted_data_binary, +) +from rtichoke.processing.transforms import ( _calculate_cumulative_aj_data_binary, + _cast_and_join_adjusted_data_binary, + _create_list_data_to_adjust_binary, _turn_cumulative_aj_to_performance_data, ) import numpy as np diff --git a/src/rtichoke/performance_data/performance_data_times.py b/src/rtichoke/performance_data/performance_data_times.py index 901bd59..d1629b4 100644 --- a/src/rtichoke/performance_data/performance_data_times.py +++ b/src/rtichoke/performance_data/performance_data_times.py @@ -5,14 +5,16 @@ from typing import Dict, Union import polars as pl from collections.abc import Sequence -from rtichoke.helpers.sandbox_observable_helpers import ( - create_breaks_values, +from rtichoke.processing.adjustments import create_adjusted_data +from rtichoke.processing.combinations import ( create_aj_data_combinations, - _create_list_data_to_adjust, - create_adjusted_data, - cast_and_join_adjusted_data, + create_breaks_values, +) +from rtichoke.processing.transforms import ( _calculate_cumulative_aj_data, + _create_list_data_to_adjust, _turn_cumulative_aj_to_performance_data, + cast_and_join_adjusted_data, ) import numpy as np diff --git a/src/rtichoke/helpers/__init__.py b/src/rtichoke/processing/__init__.py similarity index 100% rename from src/rtichoke/helpers/__init__.py rename to src/rtichoke/processing/__init__.py diff --git a/src/rtichoke/processing/adjustments.py b/src/rtichoke/processing/adjustments.py new file mode 100644 index 0000000..4bb982b --- /dev/null +++ b/src/rtichoke/processing/adjustments.py @@ -0,0 +1,743 @@ +import pandas as pd +import polars as pl +from polarstate import predict_aj_estimates +from polarstate import prepare_event_table +from collections.abc import Sequence +from rtichoke.processing.transforms import assign_and_explode_polars + + +def create_aj_data( + reference_group_data, + breaks, + censoring_heuristic, + competing_heuristic, + fixed_time_horizons, + stratified_by: Sequence[str], + full_event_table: bool = False, + risk_set_scope: Sequence[str] = ["within_stratum"], +): + """ + Create AJ estimates per strata based on censoring and competing heuristicss. + """ + + def aj_estimates_with_cross(df, extra_cols): + return df.join(pl.DataFrame(extra_cols), how="cross") + + exploded = assign_and_explode_polars(reference_group_data, fixed_time_horizons) + + event_table = prepare_event_table(reference_group_data) + + # TODO: solve strata in the pipeline + + excluded_events = _extract_excluded_events( + event_table, fixed_time_horizons, censoring_heuristic, competing_heuristic + ) + + aj_dfs = [] + for rscope in risk_set_scope: + aj_res = _aj_adjusted_events( + reference_group_data, + breaks, + exploded, + censoring_heuristic, + competing_heuristic, + fixed_time_horizons, + stratified_by, + full_event_table, + rscope, + ) + + aj_res = aj_res.select( + [ + "strata", + "times", + "chosen_cutoff", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + "estimate_origin", + "fixed_time_horizon", + "risk_set_scope", + ] + ) + + aj_dfs.append(aj_res) + + aj_df = pl.concat(aj_dfs, how="vertical") + + result = aj_df.join(excluded_events, on=["fixed_time_horizon"], how="left") + + return aj_estimates_with_cross( + result, + { + "censoring_heuristic": censoring_heuristic, + "competing_heuristic": competing_heuristic, + }, + ).select( + [ + "strata", + "chosen_cutoff", + "fixed_time_horizon", + "times", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + "real_censored_est", + "censoring_heuristic", + "competing_heuristic", + "estimate_origin", + "risk_set_scope", + ] + ) + + +def _extract_excluded_events( + event_table: pl.DataFrame, + fixed_time_horizons: list[float], + censoring_heuristic: str, + competing_heuristic: str, +) -> pl.DataFrame: + horizons_df = pl.DataFrame({"times": fixed_time_horizons}).sort("times") + + excluded_events = horizons_df.join_asof( + event_table.with_columns( + pl.col("count_0").cum_sum().cast(pl.Float64).alias("real_censored_est"), + pl.col("count_2").cum_sum().cast(pl.Float64).alias("real_competing_est"), + ).select( + pl.col("times"), + pl.col("real_censored_est"), + pl.col("real_competing_est"), + ), + left_on="times", + right_on="times", + ).with_columns([pl.col("times").alias("fixed_time_horizon")]) + + if censoring_heuristic != "excluded": + excluded_events = excluded_events.with_columns( + pl.lit(0.0).alias("real_censored_est") + ) + + if competing_heuristic != "excluded": + excluded_events = excluded_events.with_columns( + pl.lit(0.0).alias("real_competing_est") + ) + + return excluded_events + + +def extract_crude_estimate_polars(data: pl.DataFrame) -> pl.DataFrame: + all_combinations = data.select(["strata", "reals", "fixed_time_horizon"]).unique() + + counts = data.group_by(["strata", "reals", "fixed_time_horizon"]).agg( + pl.count().alias("reals_estimate") + ) + + return all_combinations.join( + counts, on=["strata", "reals", "fixed_time_horizon"], how="left" + ).with_columns([pl.col("reals_estimate").fill_null(0).cast(pl.Int64)]) + + +def extract_aj_estimate_by_cutoffs( + data_to_adjust, horizons, breaks, stratified_by, full_event_table: bool +): + # n = data_to_adjust.height + + counts_per_strata = ( + data_to_adjust.group_by( + ["strata", "stratified_by", "upper_bound", "lower_bound"] + ) + .len(name="strata_count") + .with_columns(pl.col("strata_count").cast(pl.Float64)) + ) + + aj_estimates_predicted_positives = pl.DataFrame() + aj_estimates_predicted_negatives = pl.DataFrame() + + for stratification_criteria in stratified_by: + for chosen_cutoff in breaks: + if stratification_criteria == "probability_threshold": + mask_predicted_positives = (pl.col("upper_bound") > chosen_cutoff) & ( + pl.col("stratified_by") == "probability_threshold" + ) + mask_predicted_negatives = (pl.col("upper_bound") <= chosen_cutoff) & ( + pl.col("stratified_by") == "probability_threshold" + ) + + elif stratification_criteria == "ppcr": + mask_predicted_positives = ( + pl.col("lower_bound") > 1 - chosen_cutoff + ) & (pl.col("stratified_by") == "ppcr") + mask_predicted_negatives = ( + pl.col("lower_bound") <= 1 - chosen_cutoff + ) & (pl.col("stratified_by") == "ppcr") + + predicted_positives = data_to_adjust.filter(mask_predicted_positives) + predicted_negatives = data_to_adjust.filter(mask_predicted_negatives) + + counts_per_strata_predicted_positives = counts_per_strata.filter( + mask_predicted_positives + ) + counts_per_strata_predicted_negatives = counts_per_strata.filter( + mask_predicted_negatives + ) + + event_table_predicted_positives = prepare_event_table(predicted_positives) + event_table_predicted_negatives = prepare_event_table(predicted_negatives) + + aj_estimate_predicted_positives = ( + ( + predict_aj_estimates( + event_table_predicted_positives, + pl.Series(horizons), + full_event_table, + ) + .with_columns( + pl.lit(chosen_cutoff).alias("chosen_cutoff"), + pl.lit(stratification_criteria) + .alias("stratified_by") + .cast(pl.Enum(["probability_threshold", "ppcr"])), + ) + .join( + counts_per_strata_predicted_positives, + on=["stratified_by"], + how="left", + ) + .with_columns( + [ + ( + pl.col("state_occupancy_probability_0") + * pl.col("strata_count") + ).alias("real_negatives_est"), + ( + pl.col("state_occupancy_probability_1") + * pl.col("strata_count") + ).alias("real_positives_est"), + ( + pl.col("state_occupancy_probability_2") + * pl.col("strata_count") + ).alias("real_competing_est"), + ] + ) + ) + .select( + [ + "strata", + # "stratified_by", + "times", + "chosen_cutoff", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + "estimate_origin", + ] + ) + .with_columns([pl.col("times").alias("fixed_time_horizon")]) + ) + + aj_estimate_predicted_negatives = ( + ( + predict_aj_estimates( + event_table_predicted_negatives, + pl.Series(horizons), + full_event_table, + ) + .with_columns( + pl.lit(chosen_cutoff).alias("chosen_cutoff"), + pl.lit(stratification_criteria) + .alias("stratified_by") + .cast(pl.Enum(["probability_threshold", "ppcr"])), + ) + .join( + counts_per_strata_predicted_negatives, + on=["stratified_by"], + how="left", + ) + .with_columns( + [ + ( + pl.col("state_occupancy_probability_0") + * pl.col("strata_count") + ).alias("real_negatives_est"), + ( + pl.col("state_occupancy_probability_1") + * pl.col("strata_count") + ).alias("real_positives_est"), + ( + pl.col("state_occupancy_probability_2") + * pl.col("strata_count") + ).alias("real_competing_est"), + ] + ) + ) + .select( + [ + "strata", + # "stratified_by", + "times", + "chosen_cutoff", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + "estimate_origin", + ] + ) + .with_columns([pl.col("times").alias("fixed_time_horizon")]) + ) + + aj_estimates_predicted_negatives = pl.concat( + [aj_estimates_predicted_negatives, aj_estimate_predicted_negatives], + how="vertical", + ) + + aj_estimates_predicted_positives = pl.concat( + [aj_estimates_predicted_positives, aj_estimate_predicted_positives], + how="vertical", + ) + + aj_estimate_by_cutoffs = pl.concat( + [aj_estimates_predicted_negatives, aj_estimates_predicted_positives], + how="vertical", + ) + + return aj_estimate_by_cutoffs + + +def extract_aj_estimate_for_strata(data_to_adjust, horizons, full_event_table: bool): + n = data_to_adjust.height + + event_table = prepare_event_table(data_to_adjust) + + aj_estimate_for_strata_polars = predict_aj_estimates( + event_table, pl.Series(horizons), full_event_table + ) + + if len(horizons) == 1: + aj_estimate_for_strata_polars = aj_estimate_for_strata_polars.with_columns( + pl.lit(horizons[0]).alias("fixed_time_horizon") + ) + + else: + fixed_df = aj_estimate_for_strata_polars.filter( + pl.col("estimate_origin") == "fixed_time_horizons" + ).with_columns([pl.col("times").alias("fixed_time_horizon")]) + + event_df = ( + aj_estimate_for_strata_polars.filter( + pl.col("estimate_origin") == "event_table" + ) + .with_columns([pl.lit(horizons).alias("fixed_time_horizon")]) + .explode("fixed_time_horizon") + ) + + aj_estimate_for_strata_polars = pl.concat( + [fixed_df, event_df], how="vertical" + ).sort("estimate_origin", "fixed_time_horizon", "times") + + return aj_estimate_for_strata_polars.with_columns( + [ + (pl.col("state_occupancy_probability_0") * n).alias("real_negatives_est"), + (pl.col("state_occupancy_probability_1") * n).alias("real_positives_est"), + (pl.col("state_occupancy_probability_2") * n).alias("real_competing_est"), + pl.col("fixed_time_horizon").cast(pl.Float64), + pl.lit(data_to_adjust["strata"][0]).alias("strata"), + ] + ).select( + [ + "strata", + "times", + "fixed_time_horizon", + "real_negatives_est", + "real_positives_est", + "real_competing_est", + pl.col("estimate_origin"), + ] + ) + + +def ensure_no_categorical(df: pd.DataFrame) -> pd.DataFrame: + df = df.copy() + for col in df.select_dtypes(include="category").columns: + df[col] = df[col].astype(str) + return df + + +def extract_aj_estimate_by_heuristics( + df: pl.DataFrame, + breaks: Sequence[float], + heuristics_sets: list[dict], + fixed_time_horizons: list[float], + stratified_by: Sequence[str], + risk_set_scope: Sequence[str] = ["within_stratum"], +) -> pl.DataFrame: + aj_dfs = [] + + for heuristic in heuristics_sets: + censoring = heuristic["censoring_heuristic"] + competing = heuristic["competing_heuristic"] + + aj_df = create_aj_data( + df, + breaks, + censoring, + competing, + fixed_time_horizons, + stratified_by=stratified_by, + full_event_table=False, + risk_set_scope=risk_set_scope, + ).with_columns( + [ + pl.lit(censoring).alias("censoring_heuristic"), + pl.lit(competing).alias("competing_heuristic"), + ] + ) + + aj_dfs.append(aj_df) + + aj_estimates_data = pl.concat(aj_dfs).drop(["estimate_origin", "times"]) + + aj_estimates_unpivoted = aj_estimates_data.unpivot( + index=[ + "strata", + "chosen_cutoff", + "fixed_time_horizon", + "censoring_heuristic", + "competing_heuristic", + "risk_set_scope", + ], + variable_name="reals_labels", + value_name="reals_estimate", + ) + + return aj_estimates_unpivoted + + +def _create_adjusted_data_binary( + list_data_to_adjust: dict[str, pl.DataFrame], + breaks: Sequence[float], + stratified_by: Sequence[str], +) -> pl.DataFrame: + long_df = pl.concat(list(list_data_to_adjust.values()), how="vertical") + + adjusted_data_binary = ( + long_df.group_by(["strata", "stratified_by", "reference_group", "reals_labels"]) + .agg(pl.count().alias("reals_estimate")) + .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + ) + + return adjusted_data_binary + + +def create_adjusted_data( + list_data_to_adjust: dict[str, pl.DataFrame], + heuristics_sets: list[dict[str, str]], + fixed_time_horizons: list[float], + breaks: Sequence[float], + stratified_by: Sequence[str], + risk_set_scope: Sequence[str] = ["within_stratum"], +) -> pl.DataFrame: + all_results = [] + + reference_groups = list(list_data_to_adjust.keys()) + reference_group_enum = pl.Enum(reference_groups) + + heuristics_df = pl.DataFrame(heuristics_sets) + censoring_heuristic_enum = pl.Enum( + heuristics_df["censoring_heuristic"].unique(maintain_order=True) + ) + competing_heuristic_enum = pl.Enum( + heuristics_df["competing_heuristic"].unique(maintain_order=True) + ) + + for reference_group, df in list_data_to_adjust.items(): + input_df = df.select( + ["strata", "reals", "times", "upper_bound", "lower_bound", "stratified_by"] + ) + + aj_result = extract_aj_estimate_by_heuristics( + input_df, + breaks, + heuristics_sets=heuristics_sets, + fixed_time_horizons=fixed_time_horizons, + stratified_by=stratified_by, + risk_set_scope=risk_set_scope, + ) + + aj_result_with_group = aj_result.with_columns( + [ + pl.lit(reference_group) + .cast(reference_group_enum) + .alias("reference_group") + ] + ) + + all_results.append(aj_result_with_group) + + reals_enum_dtype = pl.Enum( + [ + "real_negatives", + "real_positives", + "real_competing", + "real_censored", + ] + ) + + return ( + pl.concat(all_results) + .with_columns([pl.col("reference_group").cast(reference_group_enum)]) + .with_columns( + [ + pl.col("reals_labels").str.replace(r"_est$", "").cast(reals_enum_dtype), + pl.col("censoring_heuristic").cast(censoring_heuristic_enum), + pl.col("competing_heuristic").cast(competing_heuristic_enum), + ] + ) + ) + + +def _censored_count(df: pl.DataFrame) -> pl.DataFrame: + return ( + df.with_columns( + ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 0)) + .cast(pl.Float64) + .alias("is_censored") + ) + .group_by(["strata", "fixed_time_horizon"]) + .agg(pl.col("is_censored").sum().alias("real_censored_est")) + ) + + +def _competing_count(df: pl.DataFrame) -> pl.DataFrame: + return ( + df.with_columns( + ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 2)) + .cast(pl.Float64) + .alias("is_competing") + ) + .group_by(["strata", "fixed_time_horizon"]) + .agg(pl.col("is_competing").sum().alias("real_competing_est")) + ) + + +def _aj_estimates_by_cutoff_per_horizon( + df: pl.DataFrame, + horizons: list[float], + breaks: Sequence[float], + stratified_by: Sequence[str], +) -> pl.DataFrame: + return pl.concat( + [ + df.filter(pl.col("fixed_time_horizon") == h) + .group_by("strata") + .map_groups( + lambda group: extract_aj_estimate_by_cutoffs( + group, [h], breaks, stratified_by, full_event_table=False + ) + ) + for h in horizons + ], + how="vertical", + ) + + +def _aj_estimates_per_horizon( + df: pl.DataFrame, horizons: list[float], full_event_table: bool +) -> pl.DataFrame: + return pl.concat( + [ + df.filter(pl.col("fixed_time_horizon") == h) + .group_by("strata") + .map_groups( + lambda group: extract_aj_estimate_for_strata( + group, [h], full_event_table + ) + ) + for h in horizons + ], + how="vertical", + ) + + +def _aj_adjusted_events( + reference_group_data: pl.DataFrame, + breaks: Sequence[float], + exploded: pl.DataFrame, + censoring: str, + competing: str, + horizons: list[float], + stratified_by: Sequence[str], + full_event_table: bool = False, + risk_set_scope: Sequence[str] = ["within_stratum"], +) -> pl.DataFrame: + strata_enum_dtype = reference_group_data.schema["strata"] + + # Special-case: adjusted censoring + competing adjusted_as_negative supports pooled_by_cutoff + if censoring == "adjusted" and competing == "adjusted_as_negative": + if risk_set_scope == "within_stratum": + adjusted = ( + reference_group_data.group_by("strata") + .map_groups( + lambda group: extract_aj_estimate_for_strata( + group, horizons, full_event_table + ) + ) + .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + ) + # preserve the original enum dtype for 'strata' coming from reference_group_data + + adjusted = adjusted.with_columns( + [ + pl.col("strata").cast(strata_enum_dtype), + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope"), + ] + ) + + return adjusted + + elif risk_set_scope == "pooled_by_cutoff": + adjusted = extract_aj_estimate_by_cutoffs( + reference_group_data, horizons, breaks, stratified_by, full_event_table + ) + adjusted = adjusted.with_columns( + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope") + ) + return adjusted + + # Special-case: both excluded (faster branch in original) + if censoring == "excluded" and competing == "excluded": + non_censored_non_competing = exploded.filter( + (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") == 1) + ) + + adjusted = _aj_estimates_per_horizon( + non_censored_non_competing, horizons, full_event_table + ) + + adjusted = adjusted.with_columns( + [ + pl.col("strata").cast(strata_enum_dtype), + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope"), + ] + ).join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + + return adjusted + + # Special-case: competing excluded (handled by filtering out competing events) + if competing == "excluded": + # Use exploded to apply filters that depend on fixed_time_horizon consistently + non_competing = exploded.filter( + (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") != 2) + ).with_columns( + pl.when(pl.col("reals") == 2) + .then(pl.lit(0)) + .otherwise(pl.col("reals")) + .alias("reals") + ) + + if risk_set_scope == "within_stratum": + adjusted = ( + _aj_estimates_per_horizon(non_competing, horizons, full_event_table) + # .select(pl.exclude("real_competing_est")) + .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + ) + + elif risk_set_scope == "pooled_by_cutoff": + adjusted = extract_aj_estimate_by_cutoffs( + non_competing, horizons, breaks, stratified_by, full_event_table + ) + + adjusted = adjusted.with_columns( + [ + pl.col("strata").cast(strata_enum_dtype), + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope"), + ] + ) + return adjusted + + # For remaining cases, determine base dataframe depending on censoring rule: + # - "adjusted": use the full reference_group_data (events censored at horizon are kept/adjusted) + # - "excluded": remove administratively censored observations (use exploded with filter) + base_df = ( + exploded.filter( + (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") > 0) + ) + if censoring == "excluded" + else reference_group_data + ) + + # Apply competing-event transformation if required + if competing == "adjusted_as_censored": + base_df = base_df.with_columns( + pl.when(pl.col("reals") == 2) + .then(pl.lit(0)) + .otherwise(pl.col("reals")) + .alias("reals") + ) + elif competing == "adjusted_as_composite": + base_df = base_df.with_columns( + pl.when(pl.col("reals") == 2) + .then(pl.lit(1)) + .otherwise(pl.col("reals")) + .alias("reals") + ) + # competing == "adjusted_as_negative": keep reals as-is (no transform) + + # Finally choose aggregation strategy: per-stratum or horizon-wise + if censoring == "excluded": + # For excluded censoring we always evaluate per-horizon on the filtered (exploded) dataset + + if risk_set_scope == "within_stratum": + adjusted = _aj_estimates_per_horizon(base_df, horizons, full_event_table) + + adjusted = adjusted.join( + pl.DataFrame({"chosen_cutoff": breaks}), how="cross" + ) + + elif risk_set_scope == "pooled_by_cutoff": + adjusted = _aj_estimates_by_cutoff_per_horizon( + base_df, horizons, breaks, stratified_by + ) + + adjusted = adjusted.with_columns( + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope") + ) + + return adjusted.with_columns(pl.col("strata").cast(strata_enum_dtype)) + else: + # For adjusted censoring we aggregate within strata + + if risk_set_scope == "within_stratum": + adjusted = ( + base_df.group_by("strata") + .map_groups( + lambda group: extract_aj_estimate_for_strata( + group, horizons, full_event_table + ) + ) + .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") + ) + + elif risk_set_scope == "pooled_by_cutoff": + adjusted = extract_aj_estimate_by_cutoffs( + base_df, horizons, breaks, stratified_by, full_event_table + ) + + adjusted = adjusted.with_columns( + [ + pl.col("strata").cast(strata_enum_dtype), + pl.lit(risk_set_scope) + .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"])) + .alias("risk_set_scope"), + ] + ) + + return adjusted diff --git a/src/rtichoke/processing/combinations.py b/src/rtichoke/processing/combinations.py new file mode 100644 index 0000000..b790929 --- /dev/null +++ b/src/rtichoke/processing/combinations.py @@ -0,0 +1,218 @@ +import numpy as np +import polars as pl +from typing import Dict +from collections.abc import Sequence + + +def _enum_dataframe(column_name: str, values: Sequence[str]) -> pl.DataFrame: + """Create a single-column DataFrame with an enum dtype.""" + enum_values = list(dict.fromkeys(values)) + enum_dtype = pl.Enum(enum_values) + return pl.DataFrame({column_name: pl.Series(values, dtype=enum_dtype)}) + + +def create_strata_combinations(stratified_by: str, by: float, breaks) -> pl.DataFrame: + s_by = str(by) + decimals = len(s_by.split(".")[-1]) if "." in s_by else 0 + fmt = f"{{:.{decimals}f}}" + + if stratified_by == "probability_threshold": + upper_bound = breaks[1:] # breaks + lower_bound = breaks[:-1] # np.roll(upper_bound, 1) + # lower_bound[0] = 0.0 + mid_point = upper_bound - by / 2 + include_lower_bound = lower_bound > -0.1 + include_upper_bound = upper_bound == 1.0 # upper_bound != 0.0 + # chosen_cutoff = upper_bound + strata = format_strata_column( + lower_bound=lower_bound, + upper_bound=upper_bound, + include_lower_bound=include_lower_bound, + include_upper_bound=include_upper_bound, + decimals=2, + ) + + elif stratified_by == "ppcr": + strata_mid = breaks[1:] + lower_bound = strata_mid - by / 2 + upper_bound = strata_mid + by / 2 + mid_point = breaks[1:] + include_lower_bound = np.ones_like(strata_mid, dtype=bool) + include_upper_bound = np.zeros_like(strata_mid, dtype=bool) + # chosen_cutoff = strata_mid + strata = np.array([fmt.format(x) for x in strata_mid], dtype=object) + else: + raise ValueError(f"Unsupported stratified_by: {stratified_by}") + + bins_df = pl.DataFrame( + { + "strata": pl.Series(strata), + "lower_bound": lower_bound, + "upper_bound": upper_bound, + "mid_point": mid_point, + "include_lower_bound": include_lower_bound, + "include_upper_bound": include_upper_bound, + # "chosen_cutoff": chosen_cutoff, + "stratified_by": [stratified_by] * len(strata), + } + ) + + cutoffs_df = pl.DataFrame({"chosen_cutoff": breaks}) + + return bins_df.join(cutoffs_df, how="cross") + + +def format_strata_column( + lower_bound: list[float], + upper_bound: list[float], + include_lower_bound: list[bool], + include_upper_bound: list[bool], + decimals: int = 3, +) -> list[str]: + return [ + f"{'[' if ilb else '('}" + f"{round(lb, decimals):.{decimals}f}, " + f"{round(ub, decimals):.{decimals}f}" + f"{']' if iub else ')'}" + for lb, ub, ilb, iub in zip( + lower_bound, upper_bound, include_lower_bound, include_upper_bound + ) + ] + + +def format_strata_interval( + lower: float, upper: float, include_lower: bool, include_upper: bool +) -> str: + left = "[" if include_lower else "(" + right = "]" if include_upper else ")" + return f"{left}{lower:.3f}, {upper:.3f}{right}" + + +def create_breaks_values(probs_vec, stratified_by, by): + if stratified_by != "probability_threshold": + breaks = np.quantile(probs_vec, np.linspace(1, 0, int(1 / by) + 1)) + else: + breaks = np.round( + np.arange(0, 1 + by, by), decimals=len(str(by).split(".")[-1]) + ) + return breaks + + +def _create_aj_data_combinations_binary( + reference_groups: Sequence[str], + stratified_by: Sequence[str], + by: float, + breaks: Sequence[float], +) -> pl.DataFrame: + dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by] + + strata_combinations = pl.concat(dfs, how="vertical") + + strata_cats = ( + strata_combinations.select(pl.col("strata").unique(maintain_order=True)) + .to_series() + .to_list() + ) + + strata_enum = pl.Enum(strata_cats) + stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"]) + + strata_combinations = strata_combinations.with_columns( + [ + pl.col("strata").cast(strata_enum), + pl.col("stratified_by").cast(stratified_by_enum), + ] + ) + + # Define values for Cartesian product + reals_labels = ["real_negatives", "real_positives"] + + combinations_frames: list[pl.DataFrame] = [ + _enum_dataframe("reference_group", reference_groups), + strata_combinations, + _enum_dataframe("reals_labels", reals_labels), + ] + + result = combinations_frames[0] + for frame in combinations_frames[1:]: + result = result.join(frame, how="cross") + + return result + + +def create_aj_data_combinations( + reference_groups: Sequence[str], + heuristics_sets: list[Dict], + fixed_time_horizons: Sequence[float], + stratified_by: Sequence[str], + by: float, + breaks: Sequence[float], + risk_set_scope: Sequence[str] = ["within_stratum", "pooled_by_cutoff"], +) -> pl.DataFrame: + dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by] + strata_combinations = pl.concat(dfs, how="vertical") + + # strata_enum = pl.Enum(strata_combinations["strata"]) + + strata_cats = ( + strata_combinations.select(pl.col("strata").unique(maintain_order=True)) + .to_series() + .to_list() + ) + + strata_enum = pl.Enum(strata_cats) + stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"]) + + strata_combinations = strata_combinations.with_columns( + [ + pl.col("strata").cast(strata_enum), + pl.col("stratified_by").cast(stratified_by_enum), + ] + ) + + risk_set_scope_combinations = pl.DataFrame( + { + "risk_set_scope": pl.Series(risk_set_scope).cast( + pl.Enum(["within_stratum", "pooled_by_cutoff"]) + ) + } + ) + + # Define values for Cartesian product + reals_labels = [ + "real_negatives", + "real_positives", + "real_competing", + "real_censored", + ] + + heuristics_combinations = pl.DataFrame(heuristics_sets) + + censoring_heuristics_enum = pl.Enum( + heuristics_combinations["censoring_heuristic"].unique(maintain_order=True) + ) + competing_heuristics_enum = pl.Enum( + heuristics_combinations["competing_heuristic"].unique(maintain_order=True) + ) + + combinations_frames: list[pl.DataFrame] = [ + _enum_dataframe("reference_group", reference_groups), + pl.DataFrame( + {"fixed_time_horizon": pl.Series(fixed_time_horizons, dtype=pl.Float64)} + ), + heuristics_combinations.with_columns( + [ + pl.col("censoring_heuristic").cast(censoring_heuristics_enum), + pl.col("competing_heuristic").cast(competing_heuristics_enum), + ] + ), + strata_combinations, + risk_set_scope_combinations, + _enum_dataframe("reals_labels", reals_labels), + ] + + result = combinations_frames[0] + for frame in combinations_frames[1:]: + result = result.join(frame, how="cross") + + return result diff --git a/src/rtichoke/helpers/exported_functions.py b/src/rtichoke/processing/exported_functions.py similarity index 99% rename from src/rtichoke/helpers/exported_functions.py rename to src/rtichoke/processing/exported_functions.py index a273346..778ad91 100644 --- a/src/rtichoke/helpers/exported_functions.py +++ b/src/rtichoke/processing/exported_functions.py @@ -4,7 +4,7 @@ import plotly.graph_objects as go -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( create_non_interactive_curve, create_interactive_marker, create_reference_lines_for_plotly, diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/processing/plotly_helper_functions.py similarity index 99% rename from src/rtichoke/helpers/plotly_helper_functions.py rename to src/rtichoke/processing/plotly_helper_functions.py index ab475be..074fc52 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/processing/plotly_helper_functions.py @@ -681,6 +681,15 @@ def _htext(title: pl.Expr) -> pl.Expr: (pl.col("x") >= min_p_threshold) & (pl.col("x") <= max_p_threshold) ) + return pl.DataFrame( + schema={ + "reference_group": pl.Utf8, + "x": pl.Float64, + "y": pl.Float64, + "text": pl.Utf8, + } + ) + def create_non_interactive_curve_polars( performance_data_ready_for_curve, reference_group_color, reference_group @@ -1157,7 +1166,7 @@ def _add_hover_text_to_performance_data( ) return performance_data.with_columns( - [pl.col(pl.FLOAT_DTYPES).round(3), hover_text_expr.alias("text")] + [pl.col(pl.Float64).round(3), hover_text_expr.alias("text")] ) diff --git a/src/rtichoke/helpers/send_post_request_to_r_rtichoke.py b/src/rtichoke/processing/send_post_request_to_r_rtichoke.py similarity index 98% rename from src/rtichoke/helpers/send_post_request_to_r_rtichoke.py rename to src/rtichoke/processing/send_post_request_to_r_rtichoke.py index aaca6f3..f8254e6 100644 --- a/src/rtichoke/helpers/send_post_request_to_r_rtichoke.py +++ b/src/rtichoke/processing/send_post_request_to_r_rtichoke.py @@ -4,7 +4,7 @@ # import requests import pandas as pd -from rtichoke.helpers.exported_functions import create_plotly_curve +from rtichoke.processing.exported_functions import create_plotly_curve def send_requests_to_rtichoke_r(dictionary_to_send, url_api, endpoint): diff --git a/src/rtichoke/processing/transforms.py b/src/rtichoke/processing/transforms.py new file mode 100644 index 0000000..4e4339a --- /dev/null +++ b/src/rtichoke/processing/transforms.py @@ -0,0 +1,700 @@ +import numpy as np +import polars as pl +from typing import Dict, Union +from rtichoke.processing.combinations import create_breaks_values + + +def add_cutoff_strata(data: pl.DataFrame, by: float, stratified_by) -> pl.DataFrame: + def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame: + probs = group["probs"].to_numpy() + columns_to_add = [] + + breaks = create_breaks_values(probs, "probability_threshold", by) + if "probability_threshold" in stratified_by: + last_bin_index = len(breaks) - 2 + + bin_indices = np.digitize(probs, bins=breaks, right=False) - 1 + bin_indices = np.where(probs == 1.0, last_bin_index, bin_indices) + + lower_bounds = breaks[bin_indices] + upper_bounds = breaks[bin_indices + 1] + + include_upper_bounds = bin_indices == last_bin_index + + strata_prob_labels = np.where( + include_upper_bounds, + [f"[{lo:.2f}, {hi:.2f}]" for lo, hi in zip(lower_bounds, upper_bounds)], + [f"[{lo:.2f}, {hi:.2f})" for lo, hi in zip(lower_bounds, upper_bounds)], + ).astype(str) + + columns_to_add.append( + pl.Series("strata_probability_threshold", strata_prob_labels) + ) + + if "ppcr" in stratified_by: + # --- Compute strata_ppcr as equal-frequency quantile bins by rank --- + by = float(by) + q = int(round(1 / by)) # e.g. 0.2 -> 5 bins + + probs = np.asarray(probs, float) + + edges = np.quantile(probs, np.linspace(0.0, 1.0, q + 1), method="linear") + + edges = np.maximum.accumulate(edges) + + edges[0] = 0.0 + edges[-1] = 1.0 + + bin_idx = np.digitize(probs, bins=edges[1:-1], right=True) + + s = str(by) + decimals = len(s.split(".")[-1]) if "." in s else 0 + + labels = [f"{x:.{decimals}f}" for x in np.linspace(by, 1.0, q)] + + strata_labels = np.array(labels)[bin_idx] + + columns_to_add.append( + pl.Series("strata_ppcr", strata_labels).cast(pl.Enum(labels)) + ) + return group.with_columns(columns_to_add) + + # Apply per-group transformation + grouped = data.partition_by("reference_group", as_dict=True) + transformed_groups = [transform_group(group, by) for group in grouped.values()] + return pl.concat(transformed_groups) + + +def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame: + # Identify id_vars and value_vars + id_vars = [col for col in data.columns if not col.startswith("strata_")] + value_vars = [col for col in data.columns if col.startswith("strata_")] + + # Perform the melt (equivalent to pandas.melt) + data_long = data.melt( + id_vars=id_vars, + value_vars=value_vars, + variable_name="stratified_by", + value_name="strata", + ) + + stratified_by_labels = ["probability_threshold", "ppcr"] + stratified_by_enum = pl.Enum(stratified_by_labels) + + # Remove "strata_" prefix from the 'stratified_by' column + data_long = data_long.with_columns( + pl.col("stratified_by").str.replace("^strata_", "").cast(stratified_by_enum) + ) + + return data_long + + +def map_reals_to_labels_polars(data: pl.DataFrame) -> pl.DataFrame: + return data.with_columns( + [ + pl.when(pl.col("reals") == 0) + .then("real_negatives") + .when(pl.col("reals") == 1) + .then("real_positives") + .when(pl.col("reals") == 2) + .then("real_competing") + .otherwise("real_censored") + .alias("reals") + ] + ) + + +def update_administrative_censoring_polars(data: pl.DataFrame) -> pl.DataFrame: + data = data.with_columns( + [ + pl.when( + (pl.col("times") > pl.col("fixed_time_horizon")) + & (pl.col("reals_labels") == "real_positives") + ) + .then(pl.lit("real_negatives")) + .when( + (pl.col("times") < pl.col("fixed_time_horizon")) + & (pl.col("reals_labels") == "real_negatives") + ) + .then(pl.lit("real_censored")) + .otherwise(pl.col("reals_labels")) + .alias("reals_labels") + ] + ) + + return data + + +def assign_and_explode_polars( + data: pl.DataFrame, fixed_time_horizons: list[float] +) -> pl.DataFrame: + return ( + data.with_columns(pl.lit(fixed_time_horizons).alias("fixed_time_horizon")) + .explode("fixed_time_horizon") + .with_columns(pl.col("fixed_time_horizon").cast(pl.Float64)) + ) + + +def _create_list_data_to_adjust_binary( + aj_data_combinations: pl.DataFrame, + probs_dict: Dict[str, np.ndarray], + reals_dict: Union[np.ndarray, Dict[str, np.ndarray]], + stratified_by, + by, +) -> Dict[str, pl.DataFrame]: + reference_group_labels = list(probs_dict.keys()) + + if isinstance(reals_dict, dict): + num_keys_reals = len(reals_dict) + else: + num_keys_reals = 1 + + reference_group_enum = pl.Enum(reference_group_labels) + + strata_enum_dtype = aj_data_combinations.schema["strata"] + + if len(probs_dict) == 1: + probs_array = np.asarray(probs_dict[reference_group_labels[0]]) + + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, len(probs_array)), + "probs": probs_array, + "reals": reals_dict, + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif num_keys_reals == 1: + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, len(reals_dict)), + "probs": np.concatenate( + [probs_dict[group] for group in reference_group_labels] + ), + "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)), + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif isinstance(reals_dict, dict): + data_to_adjust = ( + pl.DataFrame( + { + "reference_group": list(probs_dict.keys()), + "probs": list(probs_dict.values()), + "reals": list(reals_dict.values()), + } + ) + .explode(["probs", "reals"]) + .with_columns(pl.col("reference_group").cast(reference_group_enum)) + ) + + data_to_adjust = add_cutoff_strata( + data_to_adjust, by=by, stratified_by=stratified_by + ) + + data_to_adjust = pivot_longer_strata(data_to_adjust) + + data_to_adjust = ( + data_to_adjust.with_columns([pl.col("strata")]) + .with_columns(pl.col("strata").cast(strata_enum_dtype)) + .join( + aj_data_combinations.select( + pl.col("strata"), + pl.col("stratified_by"), + pl.col("upper_bound"), + pl.col("lower_bound"), + ).unique(), + how="left", + on=["strata", "stratified_by"], + ) + ) + + reals_labels = ["real_negatives", "real_positives"] + + reals_enum = pl.Enum(reals_labels) + + reals_map = {0: "real_negatives", 1: "real_positives"} + + data_to_adjust = data_to_adjust.with_columns( + pl.col("reals") + .replace_strict(reals_map, return_dtype=reals_enum) + .alias("reals_labels") + ) + + list_data_to_adjust = { + group[0]: df + for group, df in data_to_adjust.partition_by( + "reference_group", as_dict=True + ).items() + } + + return list_data_to_adjust + + +def _create_list_data_to_adjust( + aj_data_combinations: pl.DataFrame, + probs_dict: Dict[str, np.ndarray], + reals_dict: Union[np.ndarray, Dict[str, np.ndarray]], + times_dict: Union[np.ndarray, Dict[str, np.ndarray]], + stratified_by, + by, +) -> Dict[str, pl.DataFrame]: + # reference_groups = list(probs_dict.keys()) + reference_group_labels = list(probs_dict.keys()) + + if isinstance(reals_dict, dict): + num_keys_reals = len(reals_dict) + else: + num_keys_reals = 1 + + # num_reals = len(reals_dict) + + reference_group_enum = pl.Enum(reference_group_labels) + + strata_enum_dtype = aj_data_combinations.schema["strata"] + + if len(probs_dict) == 1: + probs_array = np.asarray(probs_dict[reference_group_labels[0]]) + + if isinstance(reals_dict, dict): + reals_array = np.asarray(reals_dict[0]) + else: + reals_array = np.asarray(reals_dict) + + if isinstance(times_dict, dict): + times_array = np.asarray(times_dict[0]) + else: + times_array = np.asarray(times_dict) + + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, len(probs_array)), + "probs": probs_array, + "reals": reals_array, + "times": times_array, + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif num_keys_reals == 1: + reals_array = np.asarray(reals_dict) + times_array = np.asarray(times_dict) + n = len(reals_array) + + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, n), + "probs": np.concatenate( + [np.asarray(probs_dict[g]) for g in reference_group_labels] + ), + "reals": np.tile(reals_array, len(reference_group_labels)), + "times": np.tile(times_array, len(reference_group_labels)), + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif isinstance(reals_dict, dict) and isinstance(times_dict, dict): + data_to_adjust = ( + pl.DataFrame( + { + "reference_group": reference_group_labels, + "probs": list(probs_dict.values()), + "reals": list(reals_dict.values()), + "times": list(times_dict.values()), + } + ) + .explode(["probs", "reals", "times"]) + .with_columns(pl.col("reference_group").cast(reference_group_enum)) + ) + + data_to_adjust = add_cutoff_strata( + data_to_adjust, by=by, stratified_by=stratified_by + ) + + data_to_adjust = pivot_longer_strata(data_to_adjust) + + data_to_adjust = ( + data_to_adjust.with_columns([pl.col("strata")]) + .with_columns(pl.col("strata").cast(strata_enum_dtype)) + .join( + aj_data_combinations.select( + pl.col("strata"), + pl.col("stratified_by"), + pl.col("upper_bound"), + pl.col("lower_bound"), + ).unique(), + how="left", + on=["strata", "stratified_by"], + ) + ) + + reals_labels = [ + "real_negatives", + "real_positives", + "real_competing", + "real_censored", + ] + + reals_enum = pl.Enum(reals_labels) + + # Map reals values to strings + reals_map = {0: "real_negatives", 2: "real_competing", 1: "real_positives"} + + data_to_adjust = data_to_adjust.with_columns( + pl.col("reals") + .replace_strict(reals_map, return_dtype=reals_enum) + .alias("reals_labels") + ) + + # Partition by reference_group + list_data_to_adjust = { + group[0]: df + for group, df in data_to_adjust.partition_by( + "reference_group", as_dict=True + ).items() + } + + return list_data_to_adjust + + +def _cast_and_join_adjusted_data_binary( + aj_data_combinations: pl.DataFrame, aj_estimates_data: pl.DataFrame +) -> pl.DataFrame: + strata_enum_dtype = aj_data_combinations.schema["strata"] + + aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns( + pl.col("strata").cast(strata_enum_dtype) + ) + + final_adjusted_data_polars = ( + ( + aj_data_combinations.with_columns([pl.col("strata")]).join( + aj_estimates_data, + on=[ + "strata", + "stratified_by", + "reals_labels", + "reference_group", + "chosen_cutoff", + ], + how="left", + ) + ) + .with_columns( + pl.when( + ( + (pl.col("chosen_cutoff") >= pl.col("upper_bound")) + & (pl.col("stratified_by") == "probability_threshold") + ) + | ( + ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point")) + & (pl.col("stratified_by") == "ppcr") + ) + ) + .then(pl.lit("predicted_negatives")) + .otherwise(pl.lit("predicted_positives")) + .cast(pl.Enum(["predicted_negatives", "predicted_positives"])) + .alias("prediction_label") + ) + .with_columns( + ( + pl.when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_positives")) + ) + .then(pl.lit("true_positives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_negatives")) + ) + .then(pl.lit("false_positives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_negatives")) + ) + .then(pl.lit("true_negatives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_positives")) + ) + .then(pl.lit("false_negatives")) + .cast( + pl.Enum( + [ + "true_positives", + "false_positives", + "true_negatives", + "false_negatives", + ] + ) + ) + ).alias("classification_outcome") + ) + ).with_columns(pl.col("reals_estimate").fill_null(0)) + + return final_adjusted_data_polars + + +def cast_and_join_adjusted_data( + aj_data_combinations, aj_estimates_data +) -> pl.DataFrame: + strata_enum_dtype = aj_data_combinations.schema["strata"] + + aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns( + pl.col("strata").cast(strata_enum_dtype) + ) + + final_adjusted_data_polars = ( + aj_data_combinations.with_columns([pl.col("strata")]) + .join( + aj_estimates_data, + on=[ + "strata", + "fixed_time_horizon", + "censoring_heuristic", + "competing_heuristic", + "reals_labels", + "reference_group", + "chosen_cutoff", + "risk_set_scope", + ], + how="left", + ) + .with_columns( + pl.when( + ( + (pl.col("chosen_cutoff") >= pl.col("upper_bound")) + & (pl.col("stratified_by") == "probability_threshold") + ) + | ( + ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point")) + & (pl.col("stratified_by") == "ppcr") + ) + ) + .then(pl.lit("predicted_negatives")) + .otherwise(pl.lit("predicted_positives")) + .cast(pl.Enum(["predicted_negatives", "predicted_positives"])) + .alias("prediction_label") + ) + .with_columns( + ( + pl.when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_positives")) + ) + .then(pl.lit("true_positives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_negatives")) + ) + .then(pl.lit("false_positives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_negatives")) + ) + .then(pl.lit("true_negatives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_positives")) + ) + .then(pl.lit("false_negatives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_negatives")) + & (pl.col("reals_labels") == pl.lit("real_competing")) + & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative")) + ) + .then(pl.lit("true_negatives")) + .when( + (pl.col("prediction_label") == pl.lit("predicted_positives")) + & (pl.col("reals_labels") == pl.lit("real_competing")) + & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative")) + ) + .then(pl.lit("false_positives")) + .otherwise(pl.lit("excluded")) # or pl.lit(None) if you prefer nulls + .cast( + pl.Enum( + [ + "true_positives", + "false_positives", + "true_negatives", + "false_negatives", + "excluded", + ] + ) + ) + ).alias("classification_outcome") + ) + ) + return final_adjusted_data_polars + + +def _calculate_cumulative_aj_data_binary(aj_data: pl.DataFrame) -> pl.DataFrame: + cumulative_aj_data = ( + aj_data.group_by( + [ + "reference_group", + "stratified_by", + "chosen_cutoff", + "classification_outcome", + ] + ) + .agg([pl.col("reals_estimate").sum()]) + .pivot(on="classification_outcome", values="reals_estimate") + .with_columns( + [ + pl.col(col).fill_null(0) + for col in [ + "true_positives", + "true_negatives", + "false_positives", + "false_negatives", + ] + ] + ) + .with_columns( + (pl.col("true_positives") + pl.col("false_positives")).alias( + "predicted_positives" + ), + (pl.col("true_negatives") + pl.col("false_negatives")).alias( + "predicted_negatives" + ), + (pl.col("true_positives") + pl.col("false_negatives")).alias( + "real_positives" + ), + (pl.col("false_positives") + pl.col("true_negatives")).alias( + "real_negatives" + ), + ( + pl.col("true_positives") + + pl.col("true_negatives") + + pl.col("false_positives") + + pl.col("false_negatives") + ) + .alias("n") + .sum(), + ) + .with_columns( + (pl.col("true_positives") + pl.col("false_positives")).alias( + "predicted_positives" + ), + (pl.col("true_negatives") + pl.col("false_negatives")).alias( + "predicted_negatives" + ), + (pl.col("true_positives") + pl.col("false_negatives")).alias( + "real_positives" + ), + (pl.col("false_positives") + pl.col("true_negatives")).alias( + "real_negatives" + ), + ( + pl.col("true_positives") + + pl.col("true_negatives") + + pl.col("false_positives") + + pl.col("false_negatives") + ).alias("n"), + ) + ) + + return cumulative_aj_data + + +def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame: + cumulative_aj_data = ( + aj_data.filter(pl.col("risk_set_scope") == "pooled_by_cutoff") + .group_by( + [ + "reference_group", + "fixed_time_horizon", + "censoring_heuristic", + "competing_heuristic", + "stratified_by", + "chosen_cutoff", + "classification_outcome", + ] + ) + .agg([pl.col("reals_estimate").sum()]) + .pivot(on="classification_outcome", values="reals_estimate") + .fill_null(0) + .with_columns( + (pl.col("true_positives") + pl.col("false_positives")).alias( + "predicted_positives" + ), + (pl.col("true_negatives") + pl.col("false_negatives")).alias( + "predicted_negatives" + ), + (pl.col("true_positives") + pl.col("false_negatives")).alias( + "real_positives" + ), + (pl.col("false_positives") + pl.col("true_negatives")).alias( + "real_negatives" + ), + ( + pl.col("true_positives") + + pl.col("true_negatives") + + pl.col("false_positives") + + pl.col("false_negatives") + ).alias("n"), + ) + .with_columns( + (pl.col("true_positives") + pl.col("false_positives")).alias( + "predicted_positives" + ), + (pl.col("true_negatives") + pl.col("false_negatives")).alias( + "predicted_negatives" + ), + (pl.col("true_positives") + pl.col("false_negatives")).alias( + "real_positives" + ), + (pl.col("false_positives") + pl.col("true_negatives")).alias( + "real_negatives" + ), + ( + pl.col("true_positives") + + pl.col("true_negatives") + + pl.col("false_positives") + + pl.col("false_negatives") + ).alias("n"), + ) + ) + + return cumulative_aj_data + + +def _turn_cumulative_aj_to_performance_data( + cumulative_aj_data: pl.DataFrame, +) -> pl.DataFrame: + performance_data = cumulative_aj_data.with_columns( + (pl.col("true_positives") / pl.col("real_positives")).alias("sensitivity"), + (pl.col("true_negatives") / pl.col("real_negatives")).alias("specificity"), + (pl.col("true_positives") / pl.col("predicted_positives")).alias("ppv"), + (pl.col("true_negatives") / pl.col("predicted_negatives")).alias("npv"), + (pl.col("false_positives") / pl.col("real_negatives")).alias( + "false_positive_rate" + ), + ( + (pl.col("true_positives") / pl.col("predicted_positives")) + / (pl.col("real_positives") / pl.col("n")) + ).alias("lift"), + pl.when(pl.col("stratified_by") == "probability_threshold") + .then( + (pl.col("true_positives") / pl.col("n")) + - (pl.col("false_positives") / pl.col("n")) + * pl.col("chosen_cutoff") + / (1 - pl.col("chosen_cutoff")) + ) + .otherwise(None) + .alias("net_benefit"), + pl.when(pl.col("stratified_by") == "probability_threshold") + .then( + 100 * (pl.col("true_negatives") / pl.col("n")) + - (pl.col("false_negatives") / pl.col("n")) + * (1 - pl.col("chosen_cutoff")) + / pl.col("chosen_cutoff") + ) + .otherwise(None) + .alias("net_benefit_interventions_avoided"), + pl.when(pl.col("stratified_by") == "probability_threshold") + .then(pl.col("predicted_positives") / pl.col("n")) + .otherwise(pl.col("chosen_cutoff")) + .alias("ppcr"), + ) + + return performance_data diff --git a/src/rtichoke/summary_report/summary_report.py b/src/rtichoke/summary_report/summary_report.py index 9549260..8506794 100644 --- a/src/rtichoke/summary_report/summary_report.py +++ b/src/rtichoke/summary_report/summary_report.py @@ -2,8 +2,10 @@ A module for Summary Report """ -from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r -from rtichoke.helpers.sandbox_observable_helpers import ( +from rtichoke.processing.send_post_request_to_r_rtichoke import ( + send_requests_to_rtichoke_r, +) +from rtichoke.processing.transforms import ( _create_list_data_to_adjust, ) import subprocess diff --git a/src/rtichoke/utility/decision.py b/src/rtichoke/utility/decision.py index 50f0e6d..57fdf53 100644 --- a/src/rtichoke/utility/decision.py +++ b/src/rtichoke/utility/decision.py @@ -4,7 +4,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from rtichoke.helpers.plotly_helper_functions import ( +from rtichoke.processing.plotly_helper_functions import ( _create_rtichoke_plotly_curve_binary, _create_rtichoke_plotly_curve_times, _plot_rtichoke_curve_binary, diff --git a/tests/test_calibration.py b/tests/test_calibration.py new file mode 100644 index 0000000..4e79687 --- /dev/null +++ b/tests/test_calibration.py @@ -0,0 +1,28 @@ +import numpy as np +from rtichoke.calibration.calibration import create_calibration_curve + + +def test_create_calibration_curve_smooth(): + probs = {"model_1": np.linspace(0, 1, 100)} + reals = np.random.randint(0, 2, 100) + fig = create_calibration_curve(probs, reals, calibration_type="smooth") + + # Check if the figure has the correct number of traces (smooth curve, histogram, and reference line) + assert len(fig.data) == 3 + + # Check reference line data + reference_line = fig.data[0] + assert reference_line.name == "Perfectly Calibrated" + + +def test_create_calibration_curve_smooth_single_point(): + probs = {"model_1": np.array([0.5] * 100)} + reals = np.random.randint(0, 2, 100) + fig = create_calibration_curve(probs, reals, calibration_type="smooth") + + # Check that the plot mode is "lines+markers" + assert fig.data[1].mode == "lines+markers" + + # Check histogram data + histogram = fig.data[2] + assert histogram.type == "bar" diff --git a/tests/test_calibration_times.py b/tests/test_calibration_times.py new file mode 100644 index 0000000..b6570c9 --- /dev/null +++ b/tests/test_calibration_times.py @@ -0,0 +1,24 @@ +import numpy as np +from rtichoke.calibration import create_calibration_curve_times + + +def test_create_calibration_curve_times(): + probs = {"model_1": np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])} + reals = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + times = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + fixed_time_horizons = [5, 10] + heuristics_sets = [ + {"censoring_heuristic": "excluded", "competing_heuristic": "excluded"} + ] + + fig = create_calibration_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + ) + + assert fig is not None + assert len(fig.data) > 0 + assert len(fig.layout.sliders) > 0 diff --git a/tests/test_heuristics.py b/tests/test_heuristics.py new file mode 100644 index 0000000..6730c1c --- /dev/null +++ b/tests/test_heuristics.py @@ -0,0 +1,97 @@ +import pytest +import polars as pl +from polars.testing import assert_frame_equal +from rtichoke.calibration.calibration import _apply_heuristics_and_censoring + + +@pytest.fixture +def sample_data(): + return pl.DataFrame( + { + "real": [1, 0, 2, 1, 2, 0, 1], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + + +def test_competing_as_negative_logic(sample_data): + # Heuristics that shouldn't change data before horizon + result = _apply_heuristics_and_censoring( + sample_data, 15, "adjusted", "adjusted_as_negative" + ) + # Competing events at times 3 and 9 should become 0. + expected = pl.DataFrame( + { + "real": [1, 0, 0, 1, 0, 0, 1], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + assert_frame_equal(result, expected) + + +def test_admin_censoring(sample_data): + result = _apply_heuristics_and_censoring( + sample_data, 7, "adjusted", "adjusted_as_negative" + ) + # Admin censoring for times > 7. Competing event at time=3 becomes 0. + expected = pl.DataFrame( + { + "real": [1, 0, 0, 0, 0, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + assert_frame_equal(result, expected) + + +def test_censoring_excluded(sample_data): + result = _apply_heuristics_and_censoring( + sample_data, 10, "excluded", "adjusted_as_negative" + ) + # Excludes censored at times 2, 10. Admin censors time > 10. Competing at 3,9 -> 0. + expected = pl.DataFrame( + { + "real": [1, 0, 1, 0, 0], + "time": [1, 3, 8, 9, 12], + } + ) + assert_frame_equal(result.sort("time"), expected.sort("time")) + + +def test_competing_excluded(sample_data): + result = _apply_heuristics_and_censoring(sample_data, 10, "adjusted", "excluded") + # Excludes competing at 3, 9. Admin censors time > 10. + expected = pl.DataFrame( + { + "real": [1, 0, 1, 0, 0], + "time": [1, 2, 8, 10, 12], + } + ) + assert_frame_equal(result.sort("time"), expected.sort("time")) + + +def test_competing_as_negative(sample_data): + result = _apply_heuristics_and_censoring( + sample_data, 10, "adjusted", "adjusted_as_negative" + ) + # Competing at 3,9 -> 0. Admin censors time > 10. + expected = pl.DataFrame( + { + "real": [1, 0, 0, 1, 0, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + assert_frame_equal(result, expected) + + +def test_competing_as_composite(sample_data): + result = _apply_heuristics_and_censoring( + sample_data, 10, "adjusted", "adjusted_as_composite" + ) + # Competing at 3,9 -> 1. Admin censors time > 10. + expected = pl.DataFrame( + { + "real": [1, 0, 1, 1, 1, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + assert_frame_equal(result, expected) diff --git a/tests/test_rtichoke.py b/tests/test_rtichoke.py index 0ff1b91..1dd7916 100644 --- a/tests/test_rtichoke.py +++ b/tests/test_rtichoke.py @@ -2,7 +2,7 @@ A module for tests """ -from rtichoke.helpers.sandbox_observable_helpers import ( +from rtichoke.processing.adjustments import ( extract_aj_estimate_for_strata, ) diff --git a/uv.lock b/uv.lock index 7cf5135..a23d1da 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.13'", @@ -3888,7 +3888,7 @@ wheels = [ [[package]] name = "rtichoke" -version = "0.1.25" +version = "0.1.26" source = { editable = "." } dependencies = [ { name = "marimo", version = "0.17.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, @@ -3899,6 +3899,7 @@ dependencies = [ { name = "polarstate" }, { name = "pyarrow", version = "21.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pyarrow", version = "22.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "statsmodels" }, { name = "typing" }, ] @@ -3933,6 +3934,7 @@ requires-dist = [ { name = "polars", specifier = ">=1.28.0" }, { name = "polarstate", specifier = "==0.1.8" }, { name = "pyarrow", specifier = ">=21.0.0" }, + { name = "statsmodels", specifier = ">=0.14.0" }, { name = "typing", specifier = ">=3.7.4.3" }, ]