Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,14 @@ license = {text = "MIT"}
requires-python = ">=3.9"
dependencies = [
"plotly<6.0.0,>=5.13.1",
"ipython<9.0.0,>=8.11.0",
"nbformat<6.0.0,>=5.7.3",
"sphinx-autoapi<3.0.0,>=2.1.0",
"sphinx-rtd-theme<2.0.0,>=1.2.0",
"importlib>=1.0.4",
"quartodoc>=0.9.1",
"papermill>=2.6.0",
"polars>=1.28.0",
"pyarrow>=20.0.0",
"ty>=0.0.1a5",
"pandas>=2.2.3",
"typing>=3.7.4.3",
"polarstate==0.1.8",
"marimo>=0.17.0",
]
name = "rtichoke"
version = "0.1.18"
version = "0.1.20"
description = "interactive visualizations for performance of predictive models"
readme = "README.md"

Expand All @@ -44,6 +35,7 @@ dev = [
"ty>=0.0.1a12",
"scikit-learn>=1.6.1",
"polarstate>=0.1.6",
"quartodoc>=0.11.1",
]

[tool.uv.workspace]
Expand Down
2 changes: 2 additions & 0 deletions src/rtichoke/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@

from rtichoke.performance_data.performance_data import (
prepare_performance_data as prepare_performance_data,
prepare_binned_classification_data as prepare_binned_classification_data,
)

from rtichoke.performance_data.performance_data_times import (
prepare_performance_data_times as prepare_performance_data_times,
prepare_binned_classification_data_times as prepare_binned_classification_data_times,
)

from rtichoke.summary_report.summary_report import (
Expand Down
4 changes: 2 additions & 2 deletions src/rtichoke/helpers/plotly_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def _create_rtichoke_plotly_curve_binary(

def _plot_rtichoke_curve_binary(
performance_data: pl.DataFrame,
stratified_by: Sequence[str] = ["probability_threshold"],
stratified_by: str = "probability_threshold",
curve: str = "roc",
size: int = 600,
) -> go.Figure:
rtichoke_curve_list = _create_rtichoke_curve_list_binary(
performance_data=performance_data,
stratified_by=stratified_by[0],
stratified_by=stratified_by,
curve=curve,
size=size,
)
Expand Down
94 changes: 74 additions & 20 deletions src/rtichoke/performance_data/performance_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,76 @@
import numpy as np


def prepare_binned_classification_data(
probs: Dict[str, np.ndarray],
reals: Union[np.ndarray, Dict[str, np.ndarray]],
stratified_by: Sequence[str] = ("probability_threshold",),
by: float = 0.01,
) -> pl.DataFrame:
"""
Prepare probability-binned classification data for binary outcomes.

This constructs the underlying, binned data across probability thresholds
(and any additional stratification variables). It returns the adjusted data
before cumulative Aalen–Johansen and performance computations.

Parameters
----------
probs : Dict[str, np.ndarray]
Mapping from dataset name to predicted probabilities (1-D numpy arrays).
reals : Union[np.ndarray, Dict[str, np.ndarray]]
True event labels. Can be a single array aligned to pooled probabilities
or a dictionary mapping each dataset name to its true-label array. Labels
are expected to be binary integers (0/1).
stratified_by : Sequence[str], optional
Stratification variables used to create combinations/breaks. Defaults to
``("probability_threshold",)``.
by : float, optional
Step width for probability-threshold breaks (used to create the grid of
cutoffs). Defaults to ``0.01``.

Returns
-------
pl.DataFrame
A Polars DataFrame containing probability-binned classification data
(one row per combination of dataset / bin / strata). This is the basis
for histograms, calibration diagnostics, and performance curves.
"""
breaks = create_breaks_values(None, "probability_threshold", by)

aj_data_combinations = _create_aj_data_combinations_binary(
list(probs.keys()),
stratified_by=stratified_by,
by=by,
breaks=breaks,
)

list_data_to_adjust = _create_list_data_to_adjust_binary(
aj_data_combinations,
probs,
reals,
stratified_by=stratified_by,
by=by,
)

adjusted_data = _create_adjusted_data_binary(
list_data_to_adjust,
breaks=breaks,
stratified_by=stratified_by,
)

final_adjusted_data = _cast_and_join_adjusted_data_binary(
aj_data_combinations,
adjusted_data,
)

return final_adjusted_data


def prepare_performance_data(
probs: Dict[str, np.ndarray],
reals: Union[np.ndarray, Dict[str, np.ndarray]],
stratified_by: Sequence[str] = ["probability_threshold"],
stratified_by: Sequence[str] = ("probability_threshold",),
by: float = 0.01,
) -> pl.DataFrame:
"""Prepare performance data for binary classification.
Expand All @@ -35,7 +101,7 @@ def prepare_performance_data(
are expected to be binary integers (0/1).
stratified_by : Sequence[str], optional
Stratification variables used to create combinations/breaks. Defaults to
``["probability_threshold"]``.
``("probability_threshold",)``.
by : float, optional
Step width for probability-threshold breaks (used to create the grid of
cutoffs). Defaults to ``0.01``.
Expand All @@ -59,26 +125,14 @@ def prepare_performance_data(
>>> prepare_performance_data(
... probs_dict_test,
... reals_dict_test,
... by = 0.1
... by=0.1
... )
"""

breaks = create_breaks_values(None, "probability_threshold", by)

aj_data_combinations = _create_aj_data_combinations_binary(
list(probs.keys()), stratified_by=stratified_by, by=by, breaks=breaks
)

list_data_to_adjust = _create_list_data_to_adjust_binary(
aj_data_combinations, probs, reals, stratified_by=stratified_by, by=by
)

adjusted_data = _create_adjusted_data_binary(
list_data_to_adjust, breaks=breaks, stratified_by=stratified_by
)

final_adjusted_data = _cast_and_join_adjusted_data_binary(
aj_data_combinations, adjusted_data
final_adjusted_data = prepare_binned_classification_data(
probs=probs,
reals=reals,
stratified_by=stratified_by,
by=by,
)

cumulative_aj_data = _calculate_cumulative_aj_data_binary(final_adjusted_data)
Expand Down
117 changes: 87 additions & 30 deletions src/rtichoke/performance_data/performance_data_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def prepare_performance_data_times(
"competing_heuristic": "adjusted_as_negative",
}
],
stratified_by: Sequence[str] = ["probability_threshold"],
stratified_by: Sequence[str] = ("probability_threshold",),
by: float = 0.01,
) -> pl.DataFrame:
"""Prepare performance data with a time dimension.
Expand All @@ -50,10 +50,11 @@ def prepare_performance_data_times(
heuristics_sets : list[Dict], optional
List of heuristic dictionaries controlling censoring/competing-event handling.
Default is a single heuristic set:
``[{"censoring_heuristic": "adjusted", "competing_heuristic": "adjusted_as_negative"}]``
``[{"censoring_heuristic": "adjusted",
"competing_heuristic": "adjusted_as_negative"}]``
stratified_by : Sequence[str], optional
Stratification variables used to create combinations/breaks. Defaults to
``["probability_threshold"]``.
``("probability_threshold",)``.
by : float, optional
Step width for probability-threshold breaks (used to create the grid of
cutoffs). Defaults to ``0.01``.
Expand All @@ -64,30 +65,84 @@ def prepare_performance_data_times(
A Polars DataFrame containing performance metrics computed across probability
thresholds and fixed time horizons. Columns include the probability cutoff,
fixed time horizon, heuristic identifiers, and AJ-derived performance measures.
"""
# 1. Get the underlying binned time-dependent classification data
final_adjusted_data = prepare_binned_classification_data_times(
probs=probs,
reals=reals,
times=times,
fixed_time_horizons=fixed_time_horizons,
heuristics_sets=heuristics_sets,
stratified_by=stratified_by,
by=by,
risk_set_scope=["pooled_by_cutoff"],
)

# 2. Apply AJ cumulative machinery
cumulative_aj_data = _calculate_cumulative_aj_data(final_adjusted_data)

Examples
--------
>>> import numpy as np
>>> probs_dict_test = {
... "small_data_set": np.array(
... [0.9, 0.85, 0.95, 0.88, 0.6, 0.7, 0.51, 0.2, 0.1, 0.33]
... )
... }
>>> reals_dict_test = [1, 1, 1, 1, 0, 2, 1, 2, 0, 1]
>>> times_dict_test = [24.1, 9.7, 49.9, 18.6, 34.8, 14.2, 39.2, 46.0, 31.5, 4.3]
>>> fixed_time_horizons = [10.0, 20.0, 30.0, 40.0, 50.0]

>>> prepare_performance_data_times(
... probs_dict_test,
... reals_dict_test,
... times_dict_test,
... fixed_time_horizons,
... by = 0.1
... )
# 3. Turn AJ output into performance metrics
performance_data = _turn_cumulative_aj_to_performance_data(cumulative_aj_data)

return performance_data


def prepare_binned_classification_data_times(
probs: Dict[str, np.ndarray],
reals: Union[np.ndarray, Dict[str, np.ndarray]],
times: Union[np.ndarray, Dict[str, np.ndarray]],
fixed_time_horizons: list[float],
heuristics_sets: list[Dict] = [
{
"censoring_heuristic": "adjusted",
"competing_heuristic": "adjusted_as_negative",
}
],
stratified_by: Sequence[str] = ("probability_threshold",),
by: float = 0.01,
risk_set_scope: Sequence[str] = ["pooled_by_cutoff", "within_stratum"],
) -> pl.DataFrame:
"""
Prepare probability-binned, time-dependent classification data.

This constructs the underlying, binned data across probability thresholds,
fixed time horizons, and heuristic sets. It returns the adjusted data
before the cumulative Aalen–Johansen and performance-transformation steps.

Parameters
----------
probs : Dict[str, np.ndarray]
Mapping from dataset name to predicted probabilities (1-D numpy arrays).
reals : Union[np.ndarray, Dict[str, np.ndarray]]
True event labels. Can be a single array aligned to pooled probabilities
or a dictionary mapping each dataset name to its true-label array. Labels
are expected to be integers (e.g., 0/1 for binary, or competing event codes).
times : Union[np.ndarray, Dict[str, np.ndarray]]
Event or censoring times corresponding to `reals`. Either a single array
or a dictionary keyed like `probs` when multiple datasets are provided.
fixed_time_horizons : list[float]
Fixed time horizons (same units as `times`) at which to evaluate performance.
heuristics_sets : list[Dict], optional
List of heuristic dictionaries controlling censoring/competing-event handling.
Default is a single heuristic set:
``[{"censoring_heuristic": "adjusted",
"competing_heuristic": "adjusted_as_negative"}]``
stratified_by : Sequence[str], optional
Stratification variables used to create combinations/breaks. Defaults to
``("probability_threshold",)``.
by : float, optional
Step width for probability-threshold breaks (used to create the grid of
cutoffs). Defaults to ``0.01``.

Returns
-------
pl.DataFrame
A Polars DataFrame containing probability-binned, time-dependent
classification data (one row per combination of dataset / bin /
time horizon / heuristic / strata). This is the basis for histograms,
calibration diagnostics, and time-dependent performance curves.
"""
breaks = create_breaks_values(None, "probability_threshold", by)
risk_set_scope = ["pooled_by_cutoff"]

aj_data_combinations = create_aj_data_combinations(
list(probs.keys()),
Expand All @@ -100,7 +155,12 @@ def prepare_performance_data_times(
)

list_data_to_adjust = create_list_data_to_adjust(
aj_data_combinations, probs, reals, times, stratified_by=stratified_by, by=by
aj_data_combinations,
probs,
reals,
times,
stratified_by=stratified_by,
by=by,
)

adjusted_data = create_adjusted_data(
Expand All @@ -113,11 +173,8 @@ def prepare_performance_data_times(
)

final_adjusted_data = cast_and_join_adjusted_data(
aj_data_combinations, adjusted_data
aj_data_combinations,
adjusted_data,
)

cumulative_aj_data = _calculate_cumulative_aj_data(final_adjusted_data)

performance_data = _turn_cumulative_aj_to_performance_data(cumulative_aj_data)

return performance_data
return final_adjusted_data
Loading