From 6eeac7679400230ff1f80912ef32308a0b050339 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 16:15:28 +0000 Subject: [PATCH] fix(calibration): Handle mixed input types in create_calibration_curve_times The `_build_initial_df_for_times` function now correctly handles cases where `probs` is a dictionary and `reals` and `times` are polars Series. This is achieved by relaxing the type check to allow for both NumPy arrays and polars Series. This change fixes a `TypeError` that occurred when calling `create_calibration_curve_times` with mixed input types, as reported by the user. An additional test case has been added to cover this scenario and prevent future regressions. --- AGENTS.md | 69 ++ new_test.py | 28 + pyproject.toml | 1 + src/rtichoke/__init__.py | 6 +- src/rtichoke/calibration/__init__.py | 4 + src/rtichoke/calibration/calibration.py | 1015 ++++++++++++++++++++--- tests/test_calibration.py | 28 + tests/test_calibration_times.py | 24 + tests/test_heuristics.py | 97 +++ tests/test_rtichoke.py | 30 + uv.lock | 4 +- 11 files changed, 1165 insertions(+), 141 deletions(-) create mode 100644 AGENTS.md create mode 100644 new_test.py create mode 100644 tests/test_calibration.py create mode 100644 tests/test_calibration_times.py create mode 100644 tests/test_heuristics.py diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..aa1968c --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,69 @@ +# rtichoke Agent Information + +This document provides guidance for AI agents working on the `rtichoke` repository. + +## Development Environment + +To set up the development environment, follow these steps: + +1. **Install `uv`**: If you don't have `uv` installed, please follow the official installation instructions. +2. **Create a virtual environment**: Use `uv venv` to create a virtual environment. +3. **Install dependencies**: Install the project dependencies, including the `dev` dependencies, with the following command: + + ```bash + uv pip install -e .[dev] + ``` + +## Running Tests + +The test suite is run using `pytest`. To run the tests, use the following command: + +```bash +uv run pytest +``` + +## Coding Conventions + +### Functional Programming + +Strive to use a functional programming style as much as possible. Avoid side effects and mutable state where practical. + +### Docstrings + +All exported functions must have NumPy-style docstrings. This is to ensure that the documentation is clear, consistent, and can be easily parsed by tools like `quartodoc`. + +Example of a NumPy-style docstring: + +```python +def my_function(param1, param2): + """Summary of the function's purpose. + + Parameters + ---------- + param1 : int + Description of the first parameter. + param2 : str + Description of the second parameter. + + Returns + ------- + bool + Description of the return value. + """ + # function body + return True +``` + +## Pre-commit Hooks + +This repository uses pre-commit hooks to ensure code quality and consistency. The following hooks are configured: + +* **`ruff-check`**: A linter to check for common errors and style issues. +* **`ruff-format`**: A code formatter to ensure a consistent code style. +* **`uv-lock`**: A hook to keep the `uv.lock` file up to date. + +Before committing, please ensure that the pre-commit hooks pass. You can run them manually on all files with `pre-commit run --all-files`. + +## Documentation + +The documentation for this project is built using `quartodoc`. The documentation is automatically built and deployed via GitHub Actions. There is no need to build the documentation manually. diff --git a/new_test.py b/new_test.py new file mode 100644 index 0000000..f41e6b7 --- /dev/null +++ b/new_test.py @@ -0,0 +1,28 @@ +def test_create_calibration_curve_times_mixed_inputs(): + from rtichoke.calibration.calibration import create_calibration_curve_times + import polars as pl + + probs_dict = { + "full": pl.Series( + "pr_failure18", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + ) + } + reals_series = pl.Series("cancer_cr", [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) + times_series = pl.Series( + "ttcancer", [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ) + + try: + create_calibration_curve_times( + probs_dict, + reals_series, + times_series, + fixed_time_horizons=[1.0, 2.0, 3.0, 4.0, 5.0], + heuristics_sets=[ + {"censoring_heuristic": "excluded", "competing_heuristic": "excluded"} + ], + ) + except TypeError: + assert False, ( + "create_calibration_curve_times raised a TypeError with mixed input types" + ) diff --git a/pyproject.toml b/pyproject.toml index 7ff976a..d9c3f98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "polarstate==0.1.8", "marimo>=0.17.0", "pyarrow>=21.0.0", + "statsmodels>=0.14.0", ] name = "rtichoke" version = "0.1.25" diff --git a/src/rtichoke/__init__.py b/src/rtichoke/__init__.py index cc5b47a..97fbc14 100644 --- a/src/rtichoke/__init__.py +++ b/src/rtichoke/__init__.py @@ -30,9 +30,9 @@ ) from rtichoke.discrimination.gains import plot_gains_curve as plot_gains_curve -# from rtichoke.calibration.calibration import ( -# create_calibration_curve as create_calibration_curve, -# ) +from rtichoke.calibration.calibration import ( + create_calibration_curve as create_calibration_curve, +) from rtichoke.utility.decision import ( create_decision_curve as create_decision_curve, diff --git a/src/rtichoke/calibration/__init__.py b/src/rtichoke/calibration/__init__.py index 4267999..190e74e 100644 --- a/src/rtichoke/calibration/__init__.py +++ b/src/rtichoke/calibration/__init__.py @@ -1,3 +1,7 @@ """ Subpackage for Calibration """ + +from .calibration import create_calibration_curve, create_calibration_curve_times + +__all__ = ["create_calibration_curve", "create_calibration_curve_times"] diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index f9d32f3..1328456 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -2,21 +2,24 @@ A module for Calibration Curves """ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Union # import pandas as pd import plotly.graph_objects as go from plotly.subplots import make_subplots from plotly.graph_objs._figure import Figure +import polars as pl +import numpy as np + # from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r def create_calibration_curve( - probs: Dict[str, List[float]], - reals: Dict[str, List[int]], + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], calibration_type: str = "discrete", - size: Optional[int] = None, - color_values: Optional[List[str]] = [ + size: int = 600, + color_values: List[str] = [ "#1b9e77", "#d95f02", "#7570b3", @@ -38,7 +41,6 @@ def create_calibration_curve( "#D1603D", "#585123", ], - url_api: str = "http://localhost:4242/", ) -> Figure: """Creates Calibration Curve @@ -55,40 +57,191 @@ def create_calibration_curve( """ pass - # rtichoke_response = send_requests_to_rtichoke_r( - # dictionary_to_send={ - # "probs": probs, - # "reals": reals, - # "size": size, - # "color_values ": color_values, - # }, - # url_api=url_api, - # endpoint="create_calibration_curve_list", - # ) - - # calibration_curve_list = rtichoke_response.json() - - # calibration_curve_list["deciles_dat"] = pd.DataFrame.from_dict( - # calibration_curve_list["deciles_dat"] - # ) - # calibration_curve_list["smooth_dat"] = pd.DataFrame.from_dict( - # calibration_curve_list["smooth_dat"] - # ) - # calibration_curve_list["reference_data"] = pd.DataFrame.from_dict( - # calibration_curve_list["reference_data"] - # ) - # calibration_curve_list["histogram_for_calibration"] = pd.DataFrame.from_dict( - # calibration_curve_list["histogram_for_calibration"] - # ) - - # calibration_curve = create_plotly_curve_from_calibration_curve_list( - # calibration_curve_list=calibration_curve_list, calibration_type=calibration_type - # ) - - # return calibration_curve - - -def create_plotly_curve_from_calibration_curve_list( + calibration_curve_list = _create_calibration_curve_list( + probs, reals, size=size, color_values=color_values + ) + + calibration_curve = _create_plotly_curve_from_calibration_curve_list( + calibration_curve_list, calibration_type=calibration_type + ) + + return calibration_curve + + +def create_calibration_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: List[float], + heuristics_sets: List[Dict[str, str]], + calibration_type: str = "discrete", + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Creates a time-dependent Calibration Curve with a slider for different time horizons.""" + + calibration_curve_list_times = _create_calibration_curve_list_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + size=size, + color_values=color_values, + ) + + fig = _create_plotly_curve_from_calibration_curve_list_times( + calibration_curve_list_times, calibration_type=calibration_type + ) + + return fig + + +def _create_plotly_curve_from_calibration_curve_list_times( + calibration_curve_list: Dict[str, Any], calibration_type: str = "discrete" +) -> Figure: + """ + Creates a plotly figure for time-dependent calibration curves. + """ + fig = make_subplots( + rows=2, cols=1, shared_xaxes=True, x_title="Predicted", row_heights=[0.8, 0.2] + ) + + initial_horizon = calibration_curve_list["fixed_time_horizons"][0] + + # Add traces for each horizon, initially visible only for the first horizon + for horizon in calibration_curve_list["fixed_time_horizons"]: + visible = horizon == initial_horizon + + # Reference Line + fig.add_trace( + go.Scatter( + x=calibration_curve_list["reference_data"]["x"], + y=calibration_curve_list["reference_data"]["y"], + hovertext=calibration_curve_list["reference_data"]["text"], + name="Perfectly Calibrated", + legendgroup="Perfectly Calibrated", + hoverinfo="text", + line={"width": 2, "dash": "dot", "color": "#BEBEBE"}, + showlegend=False, + visible=visible, + ), + row=1, col=1, + ) + + for group in calibration_curve_list["reference_group_keys"]: + color = calibration_curve_list["colors_dictionary"][group][0] + + # Calibration curve (discrete or smooth) + if calibration_type == "discrete": + data_subset = calibration_curve_list["deciles_dat"].filter( + (pl.col("reference_group") == group) & (pl.col("fixed_time_horizon") == horizon) + ) + mode = "lines+markers" + else: # smooth + data_subset = calibration_curve_list["smooth_dat"].filter( + (pl.col("reference_group") == group) & (pl.col("fixed_time_horizon") == horizon) + ) + mode = "lines+markers" if data_subset.height == 1 else "lines" + + fig.add_trace( + go.Scatter( + x=data_subset["x"], + y=data_subset["y"], + hovertext=data_subset["text"], + name=group, + legendgroup=group, + hoverinfo="text", + mode=mode, + marker={"size": 10, "color": color}, + visible=visible, + ), + row=1, col=1, + ) + + # Histogram + hist_subset = calibration_curve_list["histogram_for_calibration"].filter( + (pl.col("reference_group") == group) & (pl.col("fixed_time_horizon") == horizon) + ) + fig.add_trace( + go.Bar( + x=hist_subset["mids"], + y=hist_subset["counts"], + hovertext=hist_subset["text"], + name=group, + width=0.01, + legendgroup=group, + hoverinfo="text", + marker_color=color, + showlegend=False, + opacity=0.4, + visible=visible, + ), + row=2, col=1, + ) + + # Create slider + steps = [] + num_traces_per_horizon = 1 + 2 * len(calibration_curve_list["reference_group_keys"]) + + for i, horizon in enumerate(calibration_curve_list["fixed_time_horizons"]): + step = dict( + method="restyle", + args=[{"visible": [False] * (num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"]))}], + label=str(horizon), + ) + for j in range(num_traces_per_horizon): + step["args"][0]["visible"][i * num_traces_per_horizon + j] = True + steps.append(step) + + sliders = [dict( + active=0, + currentvalue={"prefix": "Time Horizon: "}, + pad={"t": 50}, + steps=steps, + )] + + # Layout + fig.update_layout( + sliders=sliders, + xaxis={"showgrid": False, "range": calibration_curve_list["axes_ranges"]["xaxis"]}, + yaxis={"showgrid": False, "range": calibration_curve_list["axes_ranges"]["yaxis"], "title": "Observed"}, + barmode="overlay", + plot_bgcolor="rgba(0, 0, 0, 0)", + legend={ + "orientation": "h", "xanchor": "center", "yanchor": "top", + "x": 0.5, "y": 1.3, "bgcolor": "rgba(0, 0, 0, 0)", + }, + showlegend=calibration_curve_list["performance_type"][0] != "one model", + width=calibration_curve_list["size"][0][0], + height=calibration_curve_list["size"][0][0], + ) + + return fig + + +def _create_plotly_curve_from_calibration_curve_list( calibration_curve_list: Dict[str, Any], calibration_type: str = "discrete" ) -> Figure: """Create plotly curve from calibration curve list @@ -124,16 +277,16 @@ def create_plotly_curve_from_calibration_curve_list( calibration_curve.add_trace( go.Scatter( - x=calibration_curve_list["reference_data"]["x"].values.tolist(), - y=calibration_curve_list["reference_data"]["y"].values.tolist(), - hovertext=calibration_curve_list["reference_data"]["text"].values.tolist(), + x=calibration_curve_list["reference_data"]["x"], + y=calibration_curve_list["reference_data"]["y"], + hovertext=calibration_curve_list["reference_data"]["text"], name="Perfectly Calibrated", legendgroup="Perfectly Calibrated", hoverinfo="text", line={ "width": 2, "dash": "dot", - "color": calibration_curve_list["group_colors_vec"]["reference_line"][ + "color": calibration_curve_list["colors_dictionary"]["reference_line"][ 0 ], }, @@ -144,111 +297,120 @@ def create_plotly_curve_from_calibration_curve_list( ) if calibration_type == "discrete": - for reference_group in list(calibration_curve_list["group_colors_vec"].keys()): - if any( - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ): - calibration_curve.add_trace( - go.Scatter( - x=calibration_curve_list["deciles_dat"]["x"][ - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ].values.tolist(), - y=calibration_curve_list["deciles_dat"]["y"][ - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ].values.tolist(), - hovertext=calibration_curve_list["deciles_dat"]["text"][ - calibration_curve_list["deciles_dat"]["reference_group"] - == reference_group - ].values.tolist(), - name=reference_group, - legendgroup=reference_group, - hoverinfo="text", - mode="lines+markers", - marker={ - "size": 10, - "color": calibration_curve_list["group_colors_vec"][ - reference_group - ][0], - }, - ), - row=1, - col=1, - ) + reference_groups = [ + k + for k in calibration_curve_list["colors_dictionary"].keys() + if k != "reference_line" + ] + for reference_group in reference_groups: + dec_sub = calibration_curve_list["deciles_dat"].filter( + pl.col("reference_group") == reference_group + ) + + print(dec_sub) + + calibration_curve.add_trace( + go.Scatter( + x=dec_sub.get_column("x").to_list(), + y=dec_sub.get_column("y").to_list(), + hovertext=dec_sub.get_column("text").to_list(), + name=reference_group, + legendgroup=reference_group, + hoverinfo="text", + mode="lines+markers", + marker={ + "size": 10, + "color": calibration_curve_list["colors_dictionary"][ + reference_group + ][0], + }, + ), + row=1, + col=1, + ) + + hist = calibration_curve_list["histogram_for_calibration"] + + for reference_group in reference_groups: + hist_sub = hist.filter(pl.col("reference_group") == reference_group) + if hist_sub.height == 0: + continue + + calibration_curve.add_trace( + go.Bar( + x=hist_sub.get_column("mids").to_list(), + y=hist_sub.get_column("counts").to_list(), + hovertext=hist_sub.get_column("text").to_list(), + name=reference_group, + width=0.01, + legendgroup=reference_group, + hoverinfo="text", + marker_color=calibration_curve_list["colors_dictionary"][ + reference_group + ][0], + showlegend=False, + opacity=0.4, + ), + row=2, + col=1, + ) if calibration_type == "smooth": - for reference_group in list(calibration_curve_list["group_colors_vec"].keys()): - if any( - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ): - calibration_curve.add_trace( - go.Scatter( - x=calibration_curve_list["smooth_dat"]["x"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ].values.tolist(), - y=calibration_curve_list["smooth_dat"]["y"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ].values.tolist(), - hovertext=calibration_curve_list["smooth_dat"]["text"][ - calibration_curve_list["smooth_dat"]["reference_group"] - == reference_group - ].values.tolist(), - name=reference_group, - legendgroup=reference_group, - hoverinfo="text", - mode="lines", - marker={ - "size": 10, - "color": calibration_curve_list["group_colors_vec"][ - reference_group - ][0], - }, - ), - row=1, - col=1, - ) + smooth_dat = calibration_curve_list["smooth_dat"] + reference_groups = [ + k + for k in calibration_curve_list["colors_dictionary"].keys() + if k != "reference_line" + ] + + for reference_group in reference_groups: + smooth_sub = smooth_dat.filter(pl.col("reference_group") == reference_group) + if smooth_sub.height == 0: + continue + + mode = "lines+markers" if smooth_sub.height == 1 else "lines" + + calibration_curve.add_trace( + go.Scatter( + x=smooth_sub.get_column("x").to_list(), + y=smooth_sub.get_column("y").to_list(), + hovertext=smooth_sub.get_column("text").to_list(), + name=reference_group, + legendgroup=reference_group, + hoverinfo="text", + mode=mode, + marker={ + "size": 10, + "color": calibration_curve_list["colors_dictionary"][ + reference_group + ][0], + }, + ), + row=1, + col=1, + ) + + hist = calibration_curve_list["histogram_for_calibration"] + + for reference_group in reference_groups: + hist_sub = hist.filter(pl.col("reference_group") == reference_group) + if hist_sub.height == 0: + continue - for reference_group in list(calibration_curve_list["group_colors_vec"].keys()): - if any( - calibration_curve_list["histogram_for_calibration"]["reference_group"] - == reference_group - ): calibration_curve.add_trace( go.Bar( - x=calibration_curve_list["histogram_for_calibration"]["mids"][ - calibration_curve_list["histogram_for_calibration"][ - "reference_group" - ] - == reference_group - ].values.tolist(), - y=calibration_curve_list["histogram_for_calibration"]["counts"][ - calibration_curve_list["histogram_for_calibration"][ - "reference_group" - ] - == reference_group - ].values.tolist(), - hovertext=calibration_curve_list["histogram_for_calibration"][ - "text" - ][ - calibration_curve_list["histogram_for_calibration"][ - "reference_group" - ] - == reference_group - ].values.tolist(), + x=hist_sub.get_column("mids").to_list(), + y=hist_sub.get_column("counts").to_list(), + hovertext=hist_sub.get_column("text").to_list(), name=reference_group, width=0.01, legendgroup=reference_group, hoverinfo="text", - marker_color=calibration_curve_list["group_colors_vec"][ + marker_color=calibration_curve_list["colors_dictionary"][ reference_group ][0], showlegend=False, - opacity=calibration_curve_list["histogram_opacity"][0], + opacity=0.4, ), row=2, col=1, @@ -284,3 +446,582 @@ def create_plotly_curve_from_calibration_curve_list( ) return calibration_curve + + +def _make_deciles_dat_binary( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + n_bins: int = 10, +) -> pl.DataFrame: + if isinstance(reals, dict): + reference_groups_keys = list(reals.keys()) + y_list = [ + np.asarray(reals[reference_group]).ravel() + for reference_group in reference_groups_keys + ] + lengths = np.array([len(y) for y in y_list], dtype=np.int64) + offsets = np.concatenate([np.array([0], dtype=np.int64), np.cumsum(lengths)]) + n_total = int(offsets[-1]) + + frames: list[pl.DataFrame] = [] + for model, p_all in probs.items(): + p_all = np.asarray(p_all).ravel() + if p_all.shape[0] != n_total: + raise ValueError( + f"probs['{model}'] length={p_all.shape[0]} does not match " + f"sum of population sizes={n_total}." + ) + + for i, pop in enumerate(reference_groups_keys): + start = int(offsets[i]) + end = int(offsets[i + 1]) + + frames.append( + pl.DataFrame( + { + "reference_group": pop, + "model": model, + "prob": p_all[start:end].astype(float, copy=False), + "real": y_list[i].astype(float, copy=False), + } + ) + ) + + df = pl.concat(frames, how="vertical") + + else: + y = np.asarray(reals).ravel() + n = y.shape[0] + frames = [] + for model, p in probs.items(): + p = np.asarray(p).ravel() + if p.shape[0] != n: + raise ValueError( + f"probs['{model}'] length={p.shape[0]} does not match reals length={n}." + ) + frames.append( + pl.DataFrame( + { + "reference_group": model, + "model": model, + "prob": p.astype(float, copy=False), + "real": y.astype(float, copy=False), + } + ) + ) + + df = pl.concat(frames, how="vertical") + + labels = [str(i) for i in range(1, n_bins + 1)] + + df = df.with_columns( + [ + pl.col("prob").cast(pl.Float64), + pl.col("real").cast(pl.Float64), + pl.col("prob") + .qcut(n_bins, labels=labels, allow_duplicates=True) + .over(["reference_group", "model"]) + .alias("decile"), + ] + ).with_columns(pl.col("decile").cast(pl.Int32)) + + deciles_data = ( + df.group_by(["reference_group", "model", "decile"]) + .agg( + [ + pl.len().alias("n"), + pl.mean("prob").alias("x"), + pl.mean("real").alias("y"), + pl.sum("real").alias("n_reals"), + ] + ) + .sort(["reference_group", "model", "decile"]) + ) + + return deciles_data + + +def _check_performance_type_by_probs_and_reals( + probs: Dict[str, np.ndarray], reals: Union[np.ndarray, Dict[str, np.ndarray]] +) -> str: + if isinstance(reals, dict) and len(reals) > 1: + return "multiple populations" + if len(probs) > 1: + return "multiple models" + return "one model" + + +def _create_calibration_curve_list( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Dict[str, Any]: + deciles_data = _make_deciles_dat_binary(probs, reals) + performance_type = _check_performance_type_by_probs_and_reals(probs, reals) + smooth_dat = _calculate_smooth_curve(probs, reals, performance_type) + + deciles_data, smooth_dat = _add_hover_text_to_calibration_data( + deciles_data, smooth_dat, performance_type + ) + + reference_data = _create_reference_data_for_calibration_curve() + + reference_groups = deciles_data["reference_group"].unique().to_list() + + colors_dictionary = _create_colors_dictionary_for_calibration( + reference_groups, color_values, performance_type + ) + + print("histogram for calibration") + + histogram_for_calibration = _create_histogram_for_calibration(probs) + + print(histogram_for_calibration) + + limits = _define_limits_for_calibration_plot(deciles_data) + axes_ranges = {"xaxis": limits, "yaxis": limits} + + smooth_dat = _calculate_smooth_curve(probs, reals, performance_type) + + calibration_curve_list = { + "deciles_dat": deciles_data, + "smooth_dat": smooth_dat, + "reference_data": reference_data, + "histogram_for_calibration": histogram_for_calibration, + # "histogram_opacity": [0.4], + "axes_ranges": axes_ranges, + "colors_dictionary": colors_dictionary, + "performance_type": [performance_type], + "size": [(size, size)], + } + + return calibration_curve_list + + +def _create_reference_data_for_calibration_curve() -> pl.DataFrame: + x_ref = np.linspace(0, 1, 101) + reference_data = pl.DataFrame({"x": x_ref, "y": x_ref}) + reference_data = reference_data.with_columns( + pl.concat_str( + [ + pl.lit("Perfectly Calibrated
Predicted: "), + pl.col("x").map_elements(lambda x: f"{x:.3f}", return_dtype=pl.Utf8), + pl.lit("
Observed: "), + pl.col("y").map_elements(lambda y: f"{y:.3f}", return_dtype=pl.Utf8), + ] + ).alias("text") + ) + return reference_data + + +def _calculate_smooth_curve( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + performance_type: str, +) -> pl.DataFrame: + """ + Calculate the smoothed calibration curve using lowess. + """ + from statsmodels.nonparametric.smoothers_lowess import lowess + + smooth_frames = [] + + # Helper function to process a single probability and real array + def process_single_array(p, r, group_name): + if len(np.unique(p)) == 1: + return pl.DataFrame( + {"x": [np.unique(p)[0]], "y": [np.mean(r)], "reference_group": [group_name]} + ) + else: + # lowess returns a 2D array where the first column is x and the second is y + smoothed = lowess(r, p, it=0) + xout = np.linspace(0, 1, 101) + yout = np.interp(xout, smoothed[:, 0], smoothed[:, 1]) + return pl.DataFrame( + {"x": xout, "y": yout, "reference_group": [group_name] * len(xout)} + ) + + if isinstance(reals, dict): + for model_name, prob_array in probs.items(): + # This logic assumes that for multiple populations, one model's probs are evaluated against multiple real outcomes. + # This might need adjustment based on the exact structure for multiple models and populations. + if len(probs) == 1 and len(reals) > 1: # One model, multiple populations + for pop_name, real_array in reals.items(): + frame = process_single_array(prob_array, real_array, pop_name) + smooth_frames.append(frame) + else: # Multiple models, potentially multiple populations + for group_name in reals.keys(): + if group_name in probs: + frame = process_single_array(probs[group_name], reals[group_name], group_name) + smooth_frames.append(frame) + + else: # reals is a single numpy array + for group_name, prob_array in probs.items(): + frame = process_single_array(prob_array, reals, group_name) + smooth_frames.append(frame) + + if not smooth_frames: + return pl.DataFrame(schema={"x": pl.Float64, "y": pl.Float64, "reference_group": pl.Utf8, "text": pl.Utf8}) + + smooth_dat = pl.concat(smooth_frames) + + if performance_type != "one model": + smooth_dat = smooth_dat.with_columns( + pl.concat_str( + [ + pl.lit(""), + pl.col("reference_group"), + pl.lit("
Predicted: "), + pl.col("x").map_elements( + lambda x: f"{x:.3f}", return_dtype=pl.Utf8 + ), + pl.lit("
Observed: "), + pl.col("y").map_elements( + lambda y: f"{y:.3f}", return_dtype=pl.Utf8 + ), + ] + ).alias("text") + ) + else: + smooth_dat = smooth_dat.with_columns( + pl.concat_str( + [ + pl.lit("Predicted: "), + pl.col("x").map_elements( + lambda x: f"{x:.3f}", return_dtype=pl.Utf8 + ), + pl.lit("
Observed: "), + pl.col("y").map_elements( + lambda y: f"{y:.3f}", return_dtype=pl.Utf8 + ), + ] + ).alias("text") + ) + return smooth_dat + + +def _add_hover_text_to_calibration_data( + deciles_dat: pl.DataFrame, + smooth_dat: pl.DataFrame, + performance_type: str, +) -> (pl.DataFrame, pl.DataFrame): + """Adds hover text to the deciles and smooth dataframes.""" + if performance_type != "one model": + deciles_dat = deciles_dat.with_columns( + pl.concat_str([ + pl.lit(""), pl.col("reference_group"), pl.lit("
Predicted: "), + pl.col("x").round(3), pl.lit("
Observed: "), pl.col("y").round(3), + pl.lit(" ( "), pl.col("n_reals"), pl.lit(" / "), pl.col("n"), pl.lit(" )"), + ]).alias("text") + ) + smooth_dat = smooth_dat.with_columns( + pl.concat_str([ + pl.lit(""), pl.col("reference_group"), pl.lit("
Predicted: "), + pl.col("x").round(3), pl.lit("
Observed: "), pl.col("y").round(3), + ]).alias("text") + ) + else: + deciles_dat = deciles_dat.with_columns( + pl.concat_str([ + pl.lit("Predicted: "), pl.col("x").round(3), pl.lit("
Observed: "), + pl.col("y").round(3), pl.lit(" ( "), pl.col("n_reals"), + pl.lit(" / "), pl.col("n"), pl.lit(" )"), + ]).alias("text") + ) + smooth_dat = smooth_dat.with_columns( + pl.concat_str([ + pl.lit("Predicted: "), pl.col("x").round(3), + pl.lit("
Observed: "), pl.col("y").round(3), + ]).alias("text") + ) + return deciles_dat, smooth_dat + + +def _create_colors_dictionary_for_calibration( + reference_groups: List[str], + color_values: List[str], + performance_type: str = "one model", +) -> Dict[str, List[str]]: + if performance_type == "one model": + colors = ["black"] + else: + colors = color_values[: len(reference_groups)] + + return { + "reference_line": ["#BEBEBE"], + **{ + group: [colors[i % len(colors)]] for i, group in enumerate(reference_groups) + }, + } + + +def _create_histogram_for_calibration(probs: Dict[str, np.ndarray]) -> pl.DataFrame: + hist_dfs = [] + for group, prob_values in probs.items(): + counts, mids = np.histogram(prob_values, bins=np.arange(0, 1.01, 0.01)) + hist_df = pl.DataFrame( + {"mids": mids[:-1] + 0.005, "counts": counts, "reference_group": group} + ) + hist_df = hist_df.with_columns( + ( + pl.col("counts").cast(str) + + " observations in [" + + (pl.col("mids") - 0.005).round(3).cast(str) + + ", " + + (pl.col("mids") + 0.005).round(3).cast(str) + + "]" + ).alias("text") + ) + hist_dfs.append(hist_df) + + histogram_for_calibration = pl.concat(hist_dfs) + + return histogram_for_calibration + + +def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float]: + if deciles_dat.height == 1: + lower_bound, upper_bound = 0.0, 1.0 + else: + lower_bound = float(max(0, min(deciles_dat["x"].min(), deciles_dat["y"].min()))) + upper_bound = float(max(deciles_dat["x"].max(), deciles_dat["y"].max())) + + return [ + lower_bound - (upper_bound - lower_bound) * 0.05, + upper_bound + (upper_bound - lower_bound) * 0.05, + ] + + +def _build_initial_df_for_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], +) -> pl.DataFrame: + """Builds the initial DataFrame for time-dependent calibration curves.""" + + # Case 1: Multiple populations (reals is a dict) + if isinstance(reals, dict): + if not isinstance(times, dict): + raise TypeError("If reals is a dict, times must also be a dict.") + + # Unnest reals and times dictionaries into a long DataFrame + reals_df = pl.DataFrame( + [ + pl.Series("reference_group", list(reals.keys())), + pl.Series("real", list(reals.values())), + pl.Series("time", list(times.values())), + ] + ).explode(["real", "time"]) + + # Unnest probs and join + probs_df = pl.DataFrame( + [ + pl.Series("model", list(probs.keys())), + pl.Series("prob", list(probs.values())), + ] + ).explode("prob") + + # If one model for many populations, cross join + if len(probs) == 1 and len(reals) > 1: + return reals_df.join(probs_df, how="cross") + else: # otherwise, assume a 1-to-1 mapping on reference_group/model + return reals_df.join(probs_df, left_on="reference_group", right_on="model") + + # Case 2: Single population (reals is an array) + else: + if not isinstance(times, (np.ndarray, pl.Series)): + raise TypeError("If reals is an array or Series, times must also be an array or Series.") + + base_df = pl.DataFrame({"real": reals, "time": times}) + + prob_frames = [] + for model_name, prob_array in probs.items(): + prob_frames.append( + base_df.with_columns( + pl.Series("prob", prob_array), + pl.lit(model_name).alias("reference_group") + ) + ) + return pl.concat(prob_frames) + + +def _apply_heuristics_and_censoring( + df: pl.DataFrame, + horizon: float, + censoring_heuristic: str, + competing_heuristic: str, +) -> pl.DataFrame: + """ + Applies censoring and competing risk heuristics to the data for a given time horizon. + """ + # Administrative censoring: outcomes after horizon are negative + df_adj = df.with_columns( + pl.when(pl.col("time") > horizon).then(0).otherwise(pl.col("real")).alias("real") + ) + + # Heuristics for events before or at horizon + if censoring_heuristic == "excluded": + df_adj = df_adj.filter(~((pl.col("real") == 0) & (pl.col("time") <= horizon))) + + if competing_heuristic == "excluded": + df_adj = df_adj.filter(~((pl.col("real") == 2) & (pl.col("time") <= horizon))) + elif competing_heuristic == "adjusted_as_negative": + df_adj = df_adj.with_columns( + pl.when((pl.col("real") == 2) & (pl.col("time") <= horizon)) + .then(0) + .otherwise(pl.col("real")) + .alias("real") + ) + elif competing_heuristic == "adjusted_as_composite": + df_adj = df_adj.with_columns( + pl.when((pl.col("real") == 2) & (pl.col("time") <= horizon)) + .then(1) + .otherwise(pl.col("real")) + .alias("real") + ) + + return df_adj + + +def _create_calibration_curve_list_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: List[float], + heuristics_sets: List[Dict[str, str]], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Dict[str, Any]: + """ + Creates the data structures needed for a time-dependent calibration curve plot. + """ + # Part 1: Prepare initial dataframe from inputs + initial_df = _build_initial_df_for_times(probs, reals, times) + + # Part 2: Iterate and generate calibration data for each horizon/heuristic + all_deciles = [] + all_smooth = [] + all_histograms = [] + + performance_type = _check_performance_type_by_probs_and_reals(probs, reals) + + for horizon in fixed_time_horizons: + for heuristics in heuristics_sets: + censoring_heuristic = heuristics["censoring_heuristic"] + competing_heuristic = heuristics["competing_heuristic"] + + if censoring_heuristic == "adjusted" or competing_heuristic == "adjusted_as_censored": + continue + + df_adj = _apply_heuristics_and_censoring( + initial_df, horizon, censoring_heuristic, competing_heuristic + ) + + if df_adj.height == 0: + continue + + # Re-create probs and reals dicts for helpers + probs_adj = { + group[0]: group_df["prob"].to_numpy() + for group, group_df in df_adj.group_by("reference_group") + } + reals_adj = { + group[0]: group_df["real"].to_numpy() + for group, group_df in df_adj.group_by("reference_group") + } + # If single population initially, reals_adj should be an array + if not isinstance(reals, dict) and len(probs) == 1: + reals_adj = next(iter(reals_adj.values())) + + + # Deciles + deciles_data = _make_deciles_dat_binary(probs_adj, reals_adj) + all_deciles.append(deciles_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))) + + # Smooth curve + smooth_data = _calculate_smooth_curve(probs_adj, reals_adj, performance_type) + all_smooth.append(smooth_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))) + + # Histogram + hist_data = _create_histogram_for_calibration(probs_adj) + all_histograms.append(hist_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))) + + + # Part 3: Combine results and create final dictionary + if not all_deciles: + raise ValueError("No data remaining after applying heuristics and time horizons.") + deciles_dat_final = pl.concat(all_deciles) + smooth_dat_final = pl.concat(all_smooth) + histogram_final = pl.concat(all_histograms) + + # Add hover text + deciles_dat_final, smooth_dat_final = _add_hover_text_to_calibration_data( + deciles_dat_final, smooth_dat_final, performance_type + ) + + + reference_data = _create_reference_data_for_calibration_curve() + reference_groups = deciles_dat_final["reference_group"].unique().to_list() + colors_dictionary = _create_colors_dictionary_for_calibration( + reference_groups, color_values, performance_type + ) + limits = _define_limits_for_calibration_plot(deciles_dat_final) + axes_ranges = {"xaxis": limits, "yaxis": limits} + + calibration_curve_list = { + "deciles_dat": deciles_dat_final, + "smooth_dat": smooth_dat_final, + "reference_data": reference_data, + "histogram_for_calibration": histogram_final, + "axes_ranges": axes_ranges, + "colors_dictionary": colors_dictionary, + "performance_type": [performance_type], + "size": [(size, size)], + "fixed_time_horizons": fixed_time_horizons, + "reference_group_keys": reference_groups + } + + return calibration_curve_list diff --git a/tests/test_calibration.py b/tests/test_calibration.py new file mode 100644 index 0000000..4e79687 --- /dev/null +++ b/tests/test_calibration.py @@ -0,0 +1,28 @@ +import numpy as np +from rtichoke.calibration.calibration import create_calibration_curve + + +def test_create_calibration_curve_smooth(): + probs = {"model_1": np.linspace(0, 1, 100)} + reals = np.random.randint(0, 2, 100) + fig = create_calibration_curve(probs, reals, calibration_type="smooth") + + # Check if the figure has the correct number of traces (smooth curve, histogram, and reference line) + assert len(fig.data) == 3 + + # Check reference line data + reference_line = fig.data[0] + assert reference_line.name == "Perfectly Calibrated" + + +def test_create_calibration_curve_smooth_single_point(): + probs = {"model_1": np.array([0.5] * 100)} + reals = np.random.randint(0, 2, 100) + fig = create_calibration_curve(probs, reals, calibration_type="smooth") + + # Check that the plot mode is "lines+markers" + assert fig.data[1].mode == "lines+markers" + + # Check histogram data + histogram = fig.data[2] + assert histogram.type == "bar" diff --git a/tests/test_calibration_times.py b/tests/test_calibration_times.py new file mode 100644 index 0000000..b6570c9 --- /dev/null +++ b/tests/test_calibration_times.py @@ -0,0 +1,24 @@ +import numpy as np +from rtichoke.calibration import create_calibration_curve_times + + +def test_create_calibration_curve_times(): + probs = {"model_1": np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])} + reals = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + times = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + fixed_time_horizons = [5, 10] + heuristics_sets = [ + {"censoring_heuristic": "excluded", "competing_heuristic": "excluded"} + ] + + fig = create_calibration_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + ) + + assert fig is not None + assert len(fig.data) > 0 + assert len(fig.layout.sliders) > 0 diff --git a/tests/test_heuristics.py b/tests/test_heuristics.py new file mode 100644 index 0000000..6730c1c --- /dev/null +++ b/tests/test_heuristics.py @@ -0,0 +1,97 @@ +import pytest +import polars as pl +from polars.testing import assert_frame_equal +from rtichoke.calibration.calibration import _apply_heuristics_and_censoring + + +@pytest.fixture +def sample_data(): + return pl.DataFrame( + { + "real": [1, 0, 2, 1, 2, 0, 1], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + + +def test_competing_as_negative_logic(sample_data): + # Heuristics that shouldn't change data before horizon + result = _apply_heuristics_and_censoring( + sample_data, 15, "adjusted", "adjusted_as_negative" + ) + # Competing events at times 3 and 9 should become 0. + expected = pl.DataFrame( + { + "real": [1, 0, 0, 1, 0, 0, 1], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + assert_frame_equal(result, expected) + + +def test_admin_censoring(sample_data): + result = _apply_heuristics_and_censoring( + sample_data, 7, "adjusted", "adjusted_as_negative" + ) + # Admin censoring for times > 7. Competing event at time=3 becomes 0. + expected = pl.DataFrame( + { + "real": [1, 0, 0, 0, 0, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + assert_frame_equal(result, expected) + + +def test_censoring_excluded(sample_data): + result = _apply_heuristics_and_censoring( + sample_data, 10, "excluded", "adjusted_as_negative" + ) + # Excludes censored at times 2, 10. Admin censors time > 10. Competing at 3,9 -> 0. + expected = pl.DataFrame( + { + "real": [1, 0, 1, 0, 0], + "time": [1, 3, 8, 9, 12], + } + ) + assert_frame_equal(result.sort("time"), expected.sort("time")) + + +def test_competing_excluded(sample_data): + result = _apply_heuristics_and_censoring(sample_data, 10, "adjusted", "excluded") + # Excludes competing at 3, 9. Admin censors time > 10. + expected = pl.DataFrame( + { + "real": [1, 0, 1, 0, 0], + "time": [1, 2, 8, 10, 12], + } + ) + assert_frame_equal(result.sort("time"), expected.sort("time")) + + +def test_competing_as_negative(sample_data): + result = _apply_heuristics_and_censoring( + sample_data, 10, "adjusted", "adjusted_as_negative" + ) + # Competing at 3,9 -> 0. Admin censors time > 10. + expected = pl.DataFrame( + { + "real": [1, 0, 0, 1, 0, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + assert_frame_equal(result, expected) + + +def test_competing_as_composite(sample_data): + result = _apply_heuristics_and_censoring( + sample_data, 10, "adjusted", "adjusted_as_composite" + ) + # Competing at 3,9 -> 1. Admin censors time > 10. + expected = pl.DataFrame( + { + "real": [1, 0, 1, 1, 1, 0, 0], + "time": [1, 2, 3, 8, 9, 10, 12], + } + ) + assert_frame_equal(result, expected) diff --git a/tests/test_rtichoke.py b/tests/test_rtichoke.py index 0ff1b91..5c2c5b6 100644 --- a/tests/test_rtichoke.py +++ b/tests/test_rtichoke.py @@ -158,3 +158,33 @@ def _expected_aj_df(neg, pos, comp, include_comp=True): cols.append("estimate_origin") return pl.DataFrame(data)[cols] + + +def test_create_calibration_curve_times_mixed_inputs(): + from rtichoke.calibration.calibration import create_calibration_curve_times + import polars as pl + + probs_dict = { + "full": pl.Series( + "pr_failure18", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + ) + } + reals_series = pl.Series("cancer_cr", [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) + times_series = pl.Series( + "ttcancer", [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ) + + try: + create_calibration_curve_times( + probs_dict, + reals_series, + times_series, + fixed_time_horizons=[1.0, 2.0, 3.0, 4.0, 5.0], + heuristics_sets=[ + {"censoring_heuristic": "excluded", "competing_heuristic": "excluded"} + ], + ) + except TypeError: + assert False, ( + "create_calibration_curve_times raised a TypeError with mixed input types" + ) diff --git a/uv.lock b/uv.lock index 7cf5135..1a9f3e2 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.13'", @@ -3899,6 +3899,7 @@ dependencies = [ { name = "polarstate" }, { name = "pyarrow", version = "21.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "pyarrow", version = "22.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "statsmodels" }, { name = "typing" }, ] @@ -3933,6 +3934,7 @@ requires-dist = [ { name = "polars", specifier = ">=1.28.0" }, { name = "polarstate", specifier = "==0.1.8" }, { name = "pyarrow", specifier = ">=21.0.0" }, + { name = "statsmodels", specifier = ">=0.14.0" }, { name = "typing", specifier = ">=3.7.4.3" }, ]