From 838c033cb0477ed9cf1c55840a46830cc9448e5a Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Wed, 5 Nov 2025 08:10:12 +0200 Subject: [PATCH 1/2] fix: close #198 --- .../helpers/sandbox_observable_helpers.py | 52 ++++++++++++++----- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/src/rtichoke/helpers/sandbox_observable_helpers.py b/src/rtichoke/helpers/sandbox_observable_helpers.py index 629bc32..aee95e6 100644 --- a/src/rtichoke/helpers/sandbox_observable_helpers.py +++ b/src/rtichoke/helpers/sandbox_observable_helpers.py @@ -824,21 +824,50 @@ def _create_list_data_to_adjust_binary( by, ) -> Dict[str, pl.DataFrame]: reference_group_labels = list(probs_dict.keys()) - num_reals = len(reals_dict) + + if isinstance(reals_dict, dict): + num_keys_reals = len(reals_dict) + else: + num_keys_reals = 1 reference_group_enum = pl.Enum(reference_group_labels) strata_enum_dtype = aj_data_combinations.schema["strata"] - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, num_reals), - "probs": np.concatenate( - [probs_dict[group] for group in reference_group_labels] - ), - "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)), - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + if len(probs_dict) == 1: + probs_array = np.asarray(probs_dict[reference_group_labels[0]]) + + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, len(probs_array)), + "probs": probs_array, + "reals": reals_dict, + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif num_keys_reals == 1: + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, len(reals_dict)), + "probs": np.concatenate( + [probs_dict[group] for group in reference_group_labels] + ), + "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)), + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif isinstance(reals_dict, dict): + data_to_adjust = ( + pl.DataFrame( + { + "reference_group": list(probs_dict.keys()), + "probs": list(probs_dict.values()), + "reals": list(reals_dict.values()), + } + ) + .explode(["probs", "reals"]) + .with_columns(pl.col("reference_group").cast(reference_group_enum)) + ) data_to_adjust = add_cutoff_strata( data_to_adjust, by=by, stratified_by=stratified_by @@ -873,7 +902,6 @@ def _create_list_data_to_adjust_binary( .alias("reals_labels") ) - # Partition by reference_group list_data_to_adjust = { group[0]: df for group, df in data_to_adjust.partition_by( @@ -1029,7 +1057,7 @@ def _create_adjusted_data_binary( adjusted_data_binary = ( long_df.group_by(["strata", "stratified_by", "reference_group", "reals_labels"]) - .agg(pl.sum("reals").alias("reals_estimate")) + .agg(pl.count().alias("reals_estimate")) .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross") ) From 96d24c9696bbc33f17eed6afe7b93860937a2293 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 10 Nov 2025 03:04:44 +0200 Subject: [PATCH 2/2] fix: close #197 --- pyproject.toml | 2 +- src/rtichoke/discrimination/gains.py | 129 +- src/rtichoke/discrimination/lift.py | 129 +- .../discrimination/precision_recall.py | 129 +- src/rtichoke/discrimination/roc.py | 128 +- src/rtichoke/helpers/exported_functions.py | 5 + .../helpers/plotly_helper_functions.py | 1156 ++++++++++++++++- uv.lock | 2 +- 8 files changed, 1431 insertions(+), 249 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 059e19f..055c1da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "polarstate==0.1.8", ] name = "rtichoke" -version = "0.1.13" +version = "0.1.14" description = "interactive visualizations for performance of predictive models" readme = "README.md" diff --git a/src/rtichoke/discrimination/gains.py b/src/rtichoke/discrimination/gains.py index e2a552b..39334ef 100644 --- a/src/rtichoke/discrimination/gains.py +++ b/src/rtichoke/discrimination/gains.py @@ -1,20 +1,23 @@ """ -A module for Gains Curves +A module for Gains Curves using Plotly helpers """ -from typing import Dict, List, Optional -from pandas import DataFrame +from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from rtichoke.helpers.send_post_request_to_r_rtichoke import create_rtichoke_curve -from rtichoke.helpers.send_post_request_to_r_rtichoke import plot_rtichoke_curve +from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_binary, + _plot_rtichoke_curve_binary, +) +import numpy as np +import polars as pl def create_gains_curve( - probs: Dict[str, List[float]], - reals: Dict[str, List[int]], + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], by: float = 0.01, - stratified_by: str = "probability_threshold", - size: Optional[int] = None, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, color_values: List[str] = [ "#1b9e77", "#d95f02", @@ -37,78 +40,86 @@ def create_gains_curve( "#D1603D", "#585123", ], - url_api: str = "http://localhost:4242/", ) -> Figure: - """Create Gains Curve + """Create Gains Curve. - Args: - probs (Dict[str, List[float]]): _description_ - reals (Dict[str, List[int]]): _description_ - by (float, optional): _description_. Defaults to 0.01. - stratified_by (str, optional): _description_. Defaults to "probability_threshold". - size (Optional[int], optional): _description_. Defaults to None. - color_values (List[str], optional): _description_. Defaults to None. - url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/". + Parameters + ---------- + probs : Dict[str, np.ndarray] + Dictionary mapping a label or group name to an array of predicted + probabilities for the positive class. + reals : Union[np.ndarray, Dict[str, np.ndarray]] + Ground-truth binary labels (0/1) as a single array, or a dictionary + mapping the same label/group keys used in ``probs`` to arrays of + ground-truth labels. + by : float, optional + Resolution for probability thresholds when computing the curve + (step size). Default is 0.01. + stratified_by : Sequence[str], optional + Sequence of column names to stratify the performance data by. + Default is ["probability_threshold"]. + size : int, optional + Plot size in pixels (width and height). Default is 600. + color_values : List[str], optional + List of color hex strings to use for the plotted lines. If not + provided, a default palette is used. - Returns: - Figure: _description_ + Returns + ------- + Figure + A Plotly ``Figure`` containing the Gains curve(s). + + Notes + ----- + The function delegates computation and plotting to + ``_create_rtichoke_plotly_curve_binary`` and returns the resulting + Plotly figure. """ - fig = create_rtichoke_curve( + fig = _create_rtichoke_plotly_curve_binary( probs, reals, by=by, stratified_by=stratified_by, size=size, color_values=color_values, - url_api=url_api, curve="gains", ) return fig def plot_gains_curve( - performance_data: DataFrame, - size: Optional[int] = None, - color_values: List[str] = [ - "#1b9e77", - "#d95f02", - "#7570b3", - "#e7298a", - "#07004D", - "#E6AB02", - "#FE5F55", - "#54494B", - "#006E90", - "#BC96E6", - "#52050A", - "#1F271B", - "#BE7C4D", - "#63768D", - "#08A045", - "#320A28", - "#82FF9E", - "#2176FF", - "#D1603D", - "#585123", - ], - url_api: str = "http://localhost:4242/", + performance_data: pl.DataFrame, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, ) -> Figure: - """Plot Gains Curve + """Plot Gains curve from performance data. + + Parameters + ---------- + performance_data : pl.DataFrame + A Polars DataFrame containing performance metrics for the Gains curve. + Expected columns include (but may not be limited to) + ``probability_threshold`` and gains-related metrics, plus any + stratification columns. + stratified_by : Sequence[str], optional + Sequence of column names used for stratification in the + ``performance_data``. Default is ["probability_threshold"]. + size : int, optional + Plot size in pixels (width and height). Default is 600. - Args: - performance_data (DataFrame): _description_ - size (Optional[int], optional): _description_. Defaults to None. - color_values (List[str], optional): _description_. Defaults to None. - url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/". + Returns + ------- + Figure + A Plotly ``Figure`` containing the Gains plot. - Returns: - Figure: _description_ + Notes + ----- + This function wraps ``_plot_rtichoke_curve_binary`` to produce a + ready-to-render Plotly figure from precomputed performance data. """ - fig = plot_rtichoke_curve( + fig = _plot_rtichoke_curve_binary( performance_data, size=size, - color_values=color_values, - url_api=url_api, curve="gains", ) return fig diff --git a/src/rtichoke/discrimination/lift.py b/src/rtichoke/discrimination/lift.py index a796c29..65f2553 100644 --- a/src/rtichoke/discrimination/lift.py +++ b/src/rtichoke/discrimination/lift.py @@ -1,20 +1,23 @@ """ -A module for Lift Curves +A module for Lift Curves using Plotly helpers """ -from typing import Dict, List, Optional +from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from pandas import DataFrame -from rtichoke.helpers.send_post_request_to_r_rtichoke import create_rtichoke_curve -from rtichoke.helpers.send_post_request_to_r_rtichoke import plot_rtichoke_curve +from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_binary, + _plot_rtichoke_curve_binary, +) +import numpy as np +import polars as pl def create_lift_curve( - probs: Dict[str, List[float]], - reals: Dict[str, List[int]], + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], by: float = 0.01, - stratified_by: str = "probability_threshold", - size: Optional[int] = None, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, color_values: List[str] = [ "#1b9e77", "#d95f02", @@ -37,78 +40,86 @@ def create_lift_curve( "#D1603D", "#585123", ], - url_api: str = "http://localhost:4242/", ) -> Figure: - """Create Lift Curve + """Create Lift Curve. - Args: - probs (Dict[str, List[float]]): _description_ - reals (Dict[str, List[int]]): _description_ - by (float, optional): _description_. Defaults to 0.01. - stratified_by (str, optional): _description_. Defaults to "probability_threshold". - size (Optional[int], optional): _description_. Defaults to None. - color_values (List[str], optional): _description_. Defaults to None. - url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/". + Parameters + ---------- + probs : Dict[str, np.ndarray] + Dictionary mapping a label or group name to an array of predicted + probabilities for the positive class. + reals : Union[np.ndarray, Dict[str, np.ndarray]] + Ground-truth binary labels (0/1) as a single array, or a dictionary + mapping the same label/group keys used in ``probs`` to arrays of + ground-truth labels. + by : float, optional + Resolution for probability thresholds when computing the curve + (step size). Default is 0.01. + stratified_by : Sequence[str], optional + Sequence of column names to stratify the performance data by. + Default is ["probability_threshold"]. + size : int, optional + Plot size in pixels (width and height). Default is 600. + color_values : List[str], optional + List of color hex strings to use for the plotted lines. If not + provided, a default palette is used. - Returns: - Figure: _description_ + Returns + ------- + Figure + A Plotly ``Figure`` containing the Lift curve(s). + + Notes + ----- + The function delegates computation and plotting to + ``_create_rtichoke_plotly_curve_binary`` and returns the resulting + Plotly figure. """ - fig = create_rtichoke_curve( + fig = _create_rtichoke_plotly_curve_binary( probs, reals, by=by, stratified_by=stratified_by, size=size, color_values=color_values, - url_api=url_api, curve="lift", ) return fig def plot_lift_curve( - performance_data: DataFrame, - size: Optional[int] = None, - color_values: List[str] = [ - "#1b9e77", - "#d95f02", - "#7570b3", - "#e7298a", - "#07004D", - "#E6AB02", - "#FE5F55", - "#54494B", - "#006E90", - "#BC96E6", - "#52050A", - "#1F271B", - "#BE7C4D", - "#63768D", - "#08A045", - "#320A28", - "#82FF9E", - "#2176FF", - "#D1603D", - "#585123", - ], - url_api: str = "http://localhost:4242/", + performance_data: pl.DataFrame, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, ) -> Figure: - """Plot Lift Curve + """Plot Lift curve from performance data. + + Parameters + ---------- + performance_data : pl.DataFrame + A Polars DataFrame containing performance metrics for the Lift curve. + Expected columns include (but may not be limited to) + ``probability_threshold`` and lift-related metrics, plus any + stratification columns. + stratified_by : Sequence[str], optional + Sequence of column names used for stratification in the + ``performance_data``. Default is ["probability_threshold"]. + size : int, optional + Plot size in pixels (width and height). Default is 600. - Args: - performance_data (DataFrame): _description_ - size (Optional[int], optional): _description_. Defaults to None. - color_values (List[str], optional): _description_. Defaults to None. - url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/". + Returns + ------- + Figure + A Plotly ``Figure`` containing the Lift plot. - Returns: - Figure: _description_ + Notes + ----- + This function wraps ``_plot_rtichoke_curve_binary`` to produce a + ready-to-render Plotly figure from precomputed performance data. """ - fig = plot_rtichoke_curve( + fig = _plot_rtichoke_curve_binary( performance_data, size=size, - color_values=color_values, - url_api=url_api, curve="lift", ) return fig diff --git a/src/rtichoke/discrimination/precision_recall.py b/src/rtichoke/discrimination/precision_recall.py index 274e36e..06314d7 100644 --- a/src/rtichoke/discrimination/precision_recall.py +++ b/src/rtichoke/discrimination/precision_recall.py @@ -1,20 +1,23 @@ """ -A module for Precision Recall Curves +A module for Precision-Recall Curves using Plotly helpers """ -from typing import Dict, List, Optional +from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure -from pandas import DataFrame -from rtichoke.helpers.send_post_request_to_r_rtichoke import create_rtichoke_curve -from rtichoke.helpers.send_post_request_to_r_rtichoke import plot_rtichoke_curve +from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_binary, + _plot_rtichoke_curve_binary, +) +import numpy as np +import polars as pl def create_precision_recall_curve( - probs: Dict[str, List[float]], - reals: Dict[str, List[int]], + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], by: float = 0.01, - stratified_by: str = "probability_threshold", - size: Optional[int] = None, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, color_values: List[str] = [ "#1b9e77", "#d95f02", @@ -37,78 +40,86 @@ def create_precision_recall_curve( "#D1603D", "#585123", ], - url_api: str = "http://localhost:4242/", ) -> Figure: - """Create Precision Recall Curve + """Create Precision-Recall Curve. - Args: - probs (Dict[str, List[float]]): _description_ - reals (Dict[str, List[int]]): _description_ - by (float, optional): _description_. Defaults to 0.01. - stratified_by (str, optional): _description_. Defaults to "probability_threshold". - size (Optional[int], optional): _description_. Defaults to None. - color_values (List[str], optional): _description_. Defaults to None. - url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/". + Parameters + ---------- + probs : Dict[str, np.ndarray] + Dictionary mapping a label or group name to an array of predicted + probabilities for the positive class. + reals : Union[np.ndarray, Dict[str, np.ndarray]] + Ground-truth binary labels (0/1) as a single array, or a dictionary + mapping the same label/group keys used in ``probs`` to arrays of + ground-truth labels. + by : float, optional + Resolution for probability thresholds when computing the curve + (step size). Default is 0.01. + stratified_by : Sequence[str], optional + Sequence of column names to stratify the performance data by. + Default is ["probability_threshold"]. + size : int, optional + Plot size in pixels (width and height). Default is 600. + color_values : List[str], optional + List of color hex strings to use for the plotted lines. If not + provided, a default palette is used. - Returns: - Figure: _description_ + Returns + ------- + Figure + A Plotly ``Figure`` containing the Precision-Recall curve(s). + + Notes + ----- + The function delegates computation and plotting to + ``_create_rtichoke_plotly_curve_binary`` and returns the resulting + Plotly figure. """ - fig = create_rtichoke_curve( + fig = _create_rtichoke_plotly_curve_binary( probs, reals, by=by, stratified_by=stratified_by, size=size, color_values=color_values, - url_api=url_api, curve="precision recall", ) return fig def plot_precision_recall_curve( - performance_data: DataFrame, - size: Optional[int] = None, - color_values: List[str] = [ - "#1b9e77", - "#d95f02", - "#7570b3", - "#e7298a", - "#07004D", - "#E6AB02", - "#FE5F55", - "#54494B", - "#006E90", - "#BC96E6", - "#52050A", - "#1F271B", - "#BE7C4D", - "#63768D", - "#08A045", - "#320A28", - "#82FF9E", - "#2176FF", - "#D1603D", - "#585123", - ], - url_api: str = "http://localhost:4242/", + performance_data: pl.DataFrame, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, ) -> Figure: - """Plot Precision Recall Curve + """Plot Precision-Recall curve from performance data. + + Parameters + ---------- + performance_data : pl.DataFrame + A Polars DataFrame containing performance metrics for the + Precision-Recall curve. Expected columns include (but may not be + limited to) ``probability_threshold``, precision and recall values, + plus any stratification columns. + stratified_by : Sequence[str], optional + Sequence of column names used for stratification in the + ``performance_data``. Default is ["probability_threshold"]. + size : int, optional + Plot size in pixels (width and height). Default is 600. - Args: - performance_data (DataFrame): _description_ - size (Optional[int], optional): _description_. Defaults to None. - color_values (List[str], optional): _description_. Defaults to None. - url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/". + Returns + ------- + Figure + A Plotly ``Figure`` containing the Precision-Recall plot. - Returns: - Figure: _description_ + Notes + ----- + This function wraps ``_plot_rtichoke_curve_binary`` to produce a + ready-to-render Plotly figure from precomputed performance data. """ - fig = plot_rtichoke_curve( + fig = _plot_rtichoke_curve_binary( performance_data, size=size, - color_values=color_values, - url_api=url_api, curve="precision recall", ) return fig diff --git a/src/rtichoke/discrimination/roc.py b/src/rtichoke/discrimination/roc.py index 4c1c3bf..ae2d7e4 100644 --- a/src/rtichoke/discrimination/roc.py +++ b/src/rtichoke/discrimination/roc.py @@ -2,19 +2,22 @@ A module for ROC Curves """ -from typing import Dict, List, Optional +from typing import Dict, List, Union, Sequence from plotly.graph_objs._figure import Figure -from pandas import DataFrame -from rtichoke.helpers.send_post_request_to_r_rtichoke import create_rtichoke_curve -from rtichoke.helpers.send_post_request_to_r_rtichoke import plot_rtichoke_curve +from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_binary, + _plot_rtichoke_curve_binary, +) +import numpy as np +import polars as pl def create_roc_curve( - probs: Dict[str, List[float]], - reals: Dict[str, List[int]], + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], by: float = 0.01, - stratified_by: str = "probability_threshold", - size: Optional[int] = None, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, color_values: List[str] = [ "#1b9e77", "#d95f02", @@ -37,78 +40,87 @@ def create_roc_curve( "#D1603D", "#585123", ], - url_api: str = "http://localhost:4242/", ) -> Figure: - """Create ROC Curve + """Create ROC Curve. - Args: - probs (Dict[str, List[float]]): _description_ - reals (Dict[str, List[int]]): _description_ - by (float, optional): _description_. Defaults to 0.01. - stratified_by (str, optional): _description_. Defaults to "probability_threshold". - size (Optional[int], optional): _description_. Defaults to None. - color_values (List[str], optional): _description_. Defaults to None. - url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/". + Parameters + ---------- + probs : Dict[str, np.ndarray] + Dictionary mapping a label or group name to an array of predicted + probabilities for the positive class. + reals : Union[np.ndarray, Dict[str, np.ndarray]] + Ground-truth binary labels (0/1) as a single array, or a dictionary + mapping the same label/group keys used in ``probs`` to arrays of + ground-truth labels. + by : float, optional + Resolution for probability thresholds when computing the curve + (step size). Default is 0.01. + stratified_by : Sequence[str], optional + Sequence of column names to stratify the performance data by. + Default is ["probability_threshold"]. + size : int, optional + Plot size in pixels (width and height). Default is 600. + color_values : List[str], optional + List of color hex strings to use for the plotted lines. If not + provided, a default palette is used. - Returns: - Figure: _description_ + Returns + ------- + Figure + A Plotly ``Figure`` containing the ROC curve(s). + + Notes + ----- + The function delegates computation and plotting to + ``_create_rtichoke_plotly_curve_binary`` and returns the resulting + Plotly figure. """ - fig = create_rtichoke_curve( + fig = _create_rtichoke_plotly_curve_binary( probs, reals, by=by, stratified_by=stratified_by, size=size, color_values=color_values, - url_api=url_api, curve="roc", ) return fig def plot_roc_curve( - performance_data: DataFrame, - size: Optional[int] = None, - color_values: List[str] = [ - "#1b9e77", - "#d95f02", - "#7570b3", - "#e7298a", - "#07004D", - "#E6AB02", - "#FE5F55", - "#54494B", - "#006E90", - "#BC96E6", - "#52050A", - "#1F271B", - "#BE7C4D", - "#63768D", - "#08A045", - "#320A28", - "#82FF9E", - "#2176FF", - "#D1603D", - "#585123", - ], - url_api: str = "http://localhost:4242/", + performance_data: pl.DataFrame, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, ) -> Figure: - """Plot ROC Curve + """Plot ROC curve from performance data. - Args: - performance_data (DataFrame): _description_ - size (Optional[int], optional): _description_. Defaults to None. - color_values (List[str], optional): _description_. Defaults to None. - url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/". + Parameters + ---------- + performance_data : pl.DataFrame + A Polars DataFrame containing performance metrics for the ROC curve. + Expected columns include (but may not be limited to) ``probability_threshold``, + true positive rate (TPR) and false positive rate (FPR), plus any + stratification columns. + stratified_by : Sequence[str], optional + Sequence of column names used for stratification in the + ``performance_data``. Default is ["probability_threshold"]. + size : int, optional + Plot size in pixels (width and height). Default is 600. - Returns: - Figure: _description_ + Returns + ------- + Figure + A Plotly ``Figure`` containing the ROC plot. + + Notes + ----- + This function wraps ``_plot_rtichoke_curve_binary`` to produce a + ready-to-render Plotly figure from precomputed performance data. """ - fig = plot_rtichoke_curve( + fig = _plot_rtichoke_curve_binary( performance_data, size=size, - color_values=color_values, - url_api=url_api, curve="roc", ) + return fig diff --git a/src/rtichoke/helpers/exported_functions.py b/src/rtichoke/helpers/exported_functions.py index 5afcd9a..a273346 100644 --- a/src/rtichoke/helpers/exported_functions.py +++ b/src/rtichoke/helpers/exported_functions.py @@ -13,6 +13,11 @@ # TODO: Fix zoom for plotly curves +def create_plotly_curve_polars(rtichoke_curve_dict): + # non_interactive_curve_list = [] + return None + + def create_plotly_curve(rtichoke_curve_dict): """ diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 9cfe9b1..9b375d9 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -3,6 +3,640 @@ """ import plotly.graph_objects as go +import polars as pl +import math +from typing import Any, Dict, Union, Sequence +import numpy as np +from rtichoke.performance_data.performance_data import prepare_performance_data + + +def _create_rtichoke_plotly_curve_binary( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values=None, + curve: str = "roc", + min_p_threshold: float = 0, + max_p_threshold: float = 1, +) -> go.Figure: + performance_data = prepare_performance_data( + probs=probs, + reals=reals, + stratified_by=stratified_by, + by=by, + ) + + fig = _plot_rtichoke_curve_binary( + performance_data=performance_data, + stratified_by=stratified_by[0], + curve=curve, + size=size, + ) + + return fig + + +def _plot_rtichoke_curve_binary( + performance_data: pl.DataFrame, + stratified_by: Sequence[str] = ["probability_threshold"], + curve: str = "roc", + size: int = 600, +) -> go.Figure: + rtichoke_curve_list = _create_rtichoke_curve_list_binary( + performance_data=performance_data, + stratified_by=stratified_by[0], + curve=curve, + size=size, + ) + + fig = _create_plotly_curve_binary(rtichoke_curve_list) + + return fig + + +def _grid(start: float, stop: float, step: float) -> pl.Series: + """Like R seq(start, stop, by=step).""" + n = int(round((stop - start) / step)) + 1 + return pl.Series(np.round(np.linspace(start, stop, n), 10)) + + +def _is_several_populations(perf_dat_type: str) -> bool: + return perf_dat_type.strip().lower() == "several populations" + + +# Perfect/strategy reference formulas (vectorized) +def _perfect_gains_y(x: pl.Series, p: float) -> pl.Series: + # Gains perfect: y = min(x/p, 1); if p<=0 => 0 + xa = x.to_numpy() + if p <= 0: + return pl.Series(np.zeros_like(xa, dtype=float)) + return pl.Series(np.minimum(xa / p, 1.0)) + + +def _perfect_lift_y_series(x: pl.Series, p: float) -> pl.Series: + if not (0.0 <= p <= 1.0): + raise ValueError(f"p must be in [0,1], got {p}") + if p == 0.0: + return pl.Series(np.full(x.len(), np.nan), dtype=pl.Float64) + if p == 1.0: + return pl.Series(np.full(x.len(), 1.0), dtype=pl.Float64) + + # materialize via DataFrame -> Series to avoid Expr return + df = pl.DataFrame({"x": x.cast(pl.Float64)}) + y = df.select( + pl.when(pl.col("x") <= p) + .then(1.0 / p) + .otherwise(1.0 / p + ((1.0 - 1.0 / p) / (1.0 - p)) * (pl.col("x") - p)) + .cast(pl.Float64) + .alias("y") + ).to_series() + return y + + +def _perfect_lift_expr(x_col: str = "x", p_col: str = "p") -> pl.Expr: + # piecewise: p<=0 -> NaN, p>=1 -> 1, x<=p -> 1/p, else linear to 1 at x=1 + x = pl.col(x_col) + p = pl.col(p_col) + m = (1.0 - 1.0 / p) / (1.0 - p) + return ( + pl.when(p <= 0) + .then(pl.lit(np.nan)) + .when(p >= 1) + .then(pl.lit(1.0)) + .when(x <= p) + .then(1.0 / p) + .otherwise(1.0 / p + m * (x - p)) + .cast(pl.Float64) + ) + + +def _random_guess_pr_expr(p_col: str = "p", x_col: str = "x") -> pl.Expr: + # Baseline precision equals prevalence p; undefined (NaN) if p <= 0 + p = pl.col(p_col) + return pl.when(p <= 0).then(pl.lit(np.nan)).otherwise(p).cast(pl.Float64) + + +def _treat_all_nb_y(x: pl.Series, p: float) -> pl.Series: + # Decision curve "treat all" NB = p - (1-p)*x/(1-x) + xa = x.to_numpy() + y = p - (1 - p) * (xa / (1 - xa)) + return pl.Series(y) + + +def _treat_none_interventions_avoided_y(x: pl.Series, p: float) -> pl.Series: + # Interventions avoided (per 100) for "treat none": 100*(1 - p - p*(1-x)/x) + xa = x.to_numpy() + y = 100.0 * (1 - p - p * (1 - xa) / xa) + return pl.Series(y) + + +def _htext_pr(title_expr: pl.Expr) -> pl.Expr: + return pl.format( + "{}
PPV: {}
Sensitivity: {}", + title_expr, + pl.col("y").round(3), + pl.col("x").round(2), + ).alias("text") + + +def _ensure_series_float(x) -> pl.Series: + return ( + x.cast(pl.Float64) + if isinstance(x, pl.Series) + else pl.Series(x, dtype=pl.Float64) + ) + + +def _perfect_gains_expr(x_col: str = "x", p_col: str = "p") -> pl.Expr: + x, p = pl.col(x_col), pl.col(p_col) + return ( + pl.when(p == 0) + .then(pl.lit(np.nan)) # no positives -> undefined + .when(p == 1) + .then(x) # all positives -> y=x + .when(x < p) + .then(x / p) # linear up to full recall at x=p + .otherwise(pl.lit(1.0)) # plateau at 1 afterwards + .cast(pl.Float64) + ) + + +def _perfect_gains_series(x: pl.Series, p: float) -> pl.Series: + if p <= 0.0: + return pl.Series(np.full(len(x), np.nan), dtype=pl.Float64) + if p >= 1.0: + return x.cast(pl.Float64) + df = pl.DataFrame({"x": x}) + return df.select( + pl.when(pl.col("x") <= p) + .then(pl.col("x") / p) + .otherwise(1.0) + .cast(pl.Float64) + .alias("y") + ).to_series() + + +def _odds_expr(x_col: str = "x") -> pl.Expr: + x = pl.col(x_col) + return pl.when(x == 0).then(pl.lit("∞")).otherwise(((1 - x) / x).round(2)) + + +def _treat_all_nb_expr(x_col: str = "x", p_col: str = "p") -> pl.Expr: + # Net Benefit (treat-all): NB = p - (1 - p) * (pt / (1 - pt)) + # Guard x=1 (division by zero) and invalid p + x, p = pl.col(x_col), pl.col(p_col) + w = x / (1 - x) + return ( + pl.when((p < 0) | (p > 1) | (x >= 1)) + .then(pl.lit(np.nan)) + .otherwise(p - (1 - p) * w) + .cast(pl.Float64) + ) + + +def _htext_nb(title_expr: pl.Expr, x_col: str = "x") -> pl.Expr: + return pl.format( + "{}
NB: {}
Probability Threshold: {}
Odds of Prob. Threshold: 1:{}", + title_expr, + pl.col("y").round(3), + pl.col(x_col), + _odds_expr(x_col), + ).alias("text") + + +def _htext_ia(title_expr: pl.Expr, x_col: str = "x") -> pl.Expr: + return pl.format( + "{}
Interventions Avoided (per 100): {}" + "
Probability Threshold: {}
Odds of Prob. Threshold: 1:{}", + title_expr, + pl.col("y").round(3), + pl.col(x_col), + ((1 - pl.col(x_col)) / pl.col(x_col)).round(2), + ).alias("text") + + +def _create_reference_lines_data( + curve: str, + aj_estimates_from_performance_data: pl.DataFrame, + multiple_populations: bool, + min_p_threshold: float = 0.0, + max_p_threshold: float = 1.0, +) -> pl.DataFrame: + curve = curve.strip().lower() + # --- ROC --- + if curve == "roc": + x = _grid(0.0, 1.0, 0.01) + return pl.DataFrame( + { + "reference_group": pl.Series(["random_guess"] * len(x)), + "x": x, + "y": x, + } + ).with_columns( + pl.format( + "Random Guess
Sensitivity: {}
1 - Specificity: {}", + pl.col("y"), + pl.col("x"), + ).alias("text") + ) + + # --- Lift --- + if curve == "lift": + x = _grid(0.01, 1.0, 0.01) + x_s = pl.Series("x", x, dtype=pl.Float64) + + if multiple_populations: + aj_df = aj_estimates_from_performance_data.select( + pl.col("reference_group"), + pl.col("aj_estimate").cast(pl.Float64).alias("p"), + ) + + # random-guess (y=1 unless all p==0 -> NaN) + all_zero = ( + aj_df["p"].len() > 0 + and float(aj_df["p"].max()) == 0.0 + and float(aj_df["p"].min()) == 0.0 + ) + rand_y = pl.Series( + np.full(len(x_s), np.nan) if all_zero else np.ones(len(x_s)), + dtype=pl.Float64, + ) + + random_guess = pl.DataFrame( + {"reference_group": ["random_guess"] * len(x_s), "x": x_s, "y": rand_y} + ).with_columns( + pl.format( + "Random Guess
Lift: {}
Predicted Positives: {}%", + pl.col("y").round(3), + (100 * pl.col("x")).round(0), + ).alias("text") + ) + + # perfect per population (cross-join x with per-group p) + perfect_df = ( + pl.DataFrame({"x": x_s}) + .join(aj_df, how="cross") + .with_columns( + _perfect_lift_expr("x", "p").alias("y"), + pl.format("perfect_model_{}", pl.col("reference_group")).alias( + "reference_group_fmt" + ), + ) + .with_columns( + pl.format( + "Perfect Prediction ({})
Lift: {}
Predicted Positives: {}%", + pl.col("reference_group"), + pl.col("y").round(3), + (100 * pl.col("x")).round(0), + ).alias("text") + ) + .select( + [ + pl.col("reference_group_fmt").alias("reference_group"), + "x", + "y", + "text", + ] + ) + ) + + return pl.concat([random_guess, perfect_df], how="vertical") + + else: + # single population + p = float( + aj_estimates_from_performance_data.select( + pl.col("aj_estimate").cast(pl.Float64).first() + ).item() + ) + + rand_y = ( + pl.Series(np.ones(len(x_s)), dtype=pl.Float64) + if p > 0.0 + else pl.Series(np.full(len(x_s), np.nan), dtype=pl.Float64) + ) + + perfect_y = _perfect_lift_y_series(x_s, p) + + return pl.concat( + [ + pl.DataFrame( + { + "reference_group": ["random_guess"] * len(x_s), + "x": x_s, + "y": rand_y, + } + ).with_columns( + pl.format( + "Random Guess
Lift: {}
Predicted Positives: {}%", + pl.col("y").round(3), + (100 * pl.col("x")).round(0), + ).alias("text") + ), + pl.DataFrame( + { + "reference_group": ["perfect_model"] * len(x_s), + "x": x_s, + "y": perfect_y, + } + ).with_columns( + pl.format( + "Perfect Prediction
Lift: {}
Predicted Positives: {}%", + pl.col("y").round(3), + (100 * pl.col("x")).round(0), + ).alias("text") + ), + ], + how="vertical", + ) + + # --- Precision–Recall --- + if curve == "precision recall": + x = pl.Series("x", _grid(0.01, 1.0, 0.01), dtype=pl.Float64) + + def _htext(title_expr: pl.Expr) -> pl.Expr: + return pl.format( + "{}
PPV: {}
Sensitivity: {}", + title_expr, + pl.col("y").round(3), + pl.col("x").round(2), + ).alias("text") + + if multiple_populations: + # Expect aj_estimates_from_performance_data: [reference_group, aj_estimate] + aj_df = aj_estimates_from_performance_data.select( + pl.col("reference_group"), + pl.col("aj_estimate").cast(pl.Float64).alias("p"), + ) + + base = pl.DataFrame({"x": x}) + + # Random guess per population + random_guess = ( + base.join(aj_df, how="cross") + .with_columns( + _random_guess_pr_expr("p", "x").alias("y"), + pl.format("random_guess_{}", pl.col("reference_group")).alias( + "reference_group" + ), + ) + .with_columns( + _htext(pl.format("Random Guess ({})", pl.col("reference_group"))) + ) + .select(["reference_group", "x", "y", "text"]) + ) + + return random_guess + + else: + # Single population + p = float( + aj_estimates_from_performance_data.select( + pl.col("aj_estimate").cast(pl.Float64).first() + ).item() + ) + + n = len(x) + y_baseline = ( + pl.Series(np.full(n, np.nan), dtype=pl.Float64) + if p <= 0.0 + else pl.Series(np.full(n, p), dtype=pl.Float64) + ) + + return pl.DataFrame( + { + "reference_group": ["random_guess"] * n, + "x": x, + "y": y_baseline, + } + ).with_columns(_htext(pl.lit("Random Guess"))) + + # --- Gains --- + if curve == "gains": + x = pl.Series("x", _grid(0.0, 1.0, 0.01), dtype=pl.Float64) + base = pl.DataFrame({"x": x}) + + def _htext(title: pl.Expr) -> pl.Expr: + return pl.format( + "{}
Sensitivity: {}
Predicted Positives: {}%", + title, + pl.col("y").round(3), + (100 * pl.col("x")).round(0), + ).alias("text") + + random_guess = ( + base.with_columns( + pl.lit("random_guess").alias("reference_group"), + pl.col("x").alias("y"), + ) + .with_columns(_htext(pl.lit("Random Guess"))) + .select(["reference_group", "x", "y", "text"]) + ) + + if multiple_populations: + # Expect DF with columns: reference_group, aj_estimate + aj_df = aj_estimates_from_performance_data.select( + pl.col("reference_group"), + pl.col("aj_estimate").cast(pl.Float64).alias("p"), + ) + + perfect_df = ( + base.join(aj_df, how="cross") + .with_columns( + _perfect_gains_expr("x", "p").alias("y"), + pl.format("perfect_model_{}", pl.col("reference_group")).alias( + "reference_group" + ), + ) + .with_columns( + _htext( + pl.format("Perfect Prediction ({})", pl.col("reference_group")) + ) + ) + .select(["reference_group", "x", "y", "text"]) + ) + + return pl.concat([random_guess, perfect_df], how="vertical") + + else: + # Single population: take first aj_estimate + p = float( + aj_estimates_from_performance_data.select( + pl.col("aj_estimate").cast(pl.Float64).first() + ).item() + ) + perfect_y = _perfect_gains_series(x, p) + perfect_df = pl.DataFrame( + {"reference_group": ["perfect_model"] * len(x), "x": x, "y": perfect_y} + ).with_columns(_htext(pl.lit("Perfect Prediction"))) + return pl.concat([random_guess, perfect_df], how="vertical") + + # ===== Decision Curve ===== + if curve == "decision": + x = pl.Series("x", _grid(0.0, 0.99, 0.01), dtype=pl.Float64) + base = pl.DataFrame({"x": x}) + + # Treat-none (reference line, NB=0) + treat_none = ( + base.with_columns( + pl.lit("treat_none").alias("reference_group"), + pl.lit(0.0, dtype=pl.Float64).alias("y"), + ) + .with_columns(_htext_nb(pl.lit("Treat None"))) + .select(["reference_group", "x", "y", "text"]) + ) + + if multiple_populations: + # expect aj_estimates_from_performance_data: [reference_group, aj_estimate] + aj_df = aj_estimates_from_performance_data.select( + pl.col("reference_group"), + pl.col("aj_estimate").cast(pl.Float64).alias("p"), + ) + + treat_all = ( + base.join(aj_df, how="cross") + .with_columns( + _treat_all_nb_expr("x", "p").alias("y"), + pl.format("treat_all_{}", pl.col("reference_group")).alias( + "reference_group" + ), + ) + .with_columns( + _htext_nb(pl.format("Treat All ({})", pl.col("reference_group"))) + ) + .select(["reference_group", "x", "y", "text"]) + ) + + df = pl.concat([treat_none, treat_all], how="vertical") + + else: + # single population + p = float( + aj_estimates_from_performance_data.select( + pl.col("aj_estimate").cast(pl.Float64).first() + ).item() + ) + + treat_all = ( + base.with_columns( + _treat_all_nb_expr("x", p_col=None).map_elements( + lambda v: v, return_dtype=pl.Float64 + ) # no-op cast + if False + # (keeps linter happy; we inline p below) + else pl.when((p < 0) | (p > 1) | (pl.col("x") >= 1)) + .then(pl.lit(np.nan)) + .otherwise( + pl.lit(p) - (1 - pl.lit(p)) * (pl.col("x") / (1 - pl.col("x"))) + ) + .cast(pl.Float64) + .alias("y") + ) + .with_columns(pl.lit("treat_all").alias("reference_group")) + .with_columns(_htext_nb(pl.lit("Treat All"))) + .select(["reference_group", "x", "y", "text"]) + ) + + df = pl.concat([treat_none, treat_all], how="vertical") + + # clamp thresholds post-build + return df.filter( + (pl.col("x") >= min_p_threshold) & (pl.col("x") <= max_p_threshold) + ) + + # ===== Interventions Avoided (reference lines) ===== + if curve == "interventions avoided": + x = pl.Series( + "x", _grid(0.01, 0.99, 0.01), dtype=pl.Float64 + ) # avoid x=0,1 divisions + base = pl.DataFrame({"x": x}) + + # Treat-all reference (0 per 100) + treat_all_ref = ( + base.with_columns( + pl.lit("treat_all").alias("reference_group"), + pl.lit(0.0, dtype=pl.Float64).alias("y"), + ) + .with_columns(_htext_ia(pl.lit("Treat All"))) + .select(["reference_group", "x", "y", "text"]) + ) + + if multiple_populations: + # expect aj_estimates_from_performance_data: [reference_group, aj_estimate] + aj_df = aj_estimates_from_performance_data.select( + pl.col("reference_group"), + pl.col("aj_estimate").cast(pl.Float64).alias("p"), + ) + + # Use your existing helper for correctness of the IA math + parts = [treat_all_ref] + for row in aj_df.iter_rows(named=True): + name, p = row["reference_group"], float(row["p"]) + y = _treat_none_interventions_avoided_y(x, p) # your helper + parts.append( + pl.DataFrame( + { + "reference_group": [f"treat_none_{name}"] * len(x), + "x": x, + "y": y, + } + ).with_columns(_htext_ia(pl.lit(f"Treat None ({name})"))) + ) + df = pl.concat(parts, how="vertical") + + else: + p = float( + aj_estimates_from_performance_data.select( + pl.col("aj_estimate").cast(pl.Float64).first() + ).item() + ) + y = _treat_none_interventions_avoided_y(x, p) # your helper + + df = pl.concat( + [ + treat_all_ref, + pl.DataFrame( + {"reference_group": ["treat_none"] * len(x), "x": x, "y": y} + ).with_columns(_htext_ia(pl.lit("Treat None"))), + ], + how="vertical", + ) + + return df.filter( + (pl.col("x") >= min_p_threshold) & (pl.col("x") <= max_p_threshold) + ) + + +def create_non_interactive_curve_polars( + performance_data_ready_for_curve, reference_group_color, reference_group +): + """ + + Parameters + ---------- + performance_data_ready_for_curve : + + reference_group_color : + + + Returns + ------- + + """ + + non_interactive_curve = go.Scatter( + x=performance_data_ready_for_curve["x"], + y=performance_data_ready_for_curve["y"], + mode="markers+lines", + hoverinfo="text", + # hovertext=performance_data_ready_for_curve["text"], + name=reference_group, + legendgroup=reference_group, + line={"width": 2, "color": reference_group_color}, + ) + return non_interactive_curve def create_non_interactive_curve( @@ -22,10 +656,7 @@ def create_non_interactive_curve( """ performance_data_ready_for_curve = performance_data_ready_for_curve.dropna() - # print("Print y values non interactive") - # print(performance_data_ready_for_curve['y'].values) - # print("Done Printing non interactive") - print(reference_group) + non_interactive_curve = go.Scatter( x=performance_data_ready_for_curve["x"].values.tolist(), y=performance_data_ready_for_curve["y"].values.tolist(), @@ -61,14 +692,6 @@ def create_interactive_marker( column_name=performance_data_ready_for_curve.loc[:, "y"].fillna(-1) ) - # print("Print y values in k") - # print(performance_data_ready_for_curve["x"].values.tolist()[k]) - # print("Done Printing") - - # print("Print y values") - # print(performance_data_ready_for_curve['y'].values) - # print("Done Printing") - interactive_marker = go.Scatter( x=[performance_data_ready_for_curve["x"].values.tolist()[k]], y=[performance_data_ready_for_curve["y"].values.tolist()[k]], @@ -112,3 +735,512 @@ def create_reference_lines_for_plotly(reference_data, reference_line_color): showlegend=False, ) return reference_lines + + +_CURVE_CONFIG = { + "roc": ( + "false_positive_rate", + "sensitivity", + "1 - Specificity", + "Sensitivity", + ), + "precision recall": ( + "sensitivity", + "ppv", + "Sensitivity", + "PPV", + ), + "lift": ( + "ppcr", + "lift", + "Predicted Positives (Rate)", + "Lift", + ), + "gains": ( + "ppcr", + "sensitivity", + "Predicted Positives (Rate)", + "Sensitivity", + ), + "decision": ( + "chosen_cutoff", + "net_benefit", + "Probability Threshold", + "Net Benefit", + ), + "interventions avoided": ( + "chosen_cutoff", + "net_benefit_interventions_avoided", + "Probability Threshold", + "Interventions Avoided (per 100)", + ), +} + + +def _finite_vals(series: pl.Series) -> list[float]: + vals = series.to_list() + out = [] + for v in vals: + if v is None: + continue + # Allow ints too + if ( + isinstance(v, (int, float)) + and math.isfinite(v) + and not math.isnan(float(v)) + ): + out.append(float(v)) + return out + + +def extract_axes_ranges( + performance_data_ready: pl.DataFrame, + curve: str, + min_p_threshold: float = 0.0, + max_p_threshold: float = 1.0, +) -> dict[str, list[float]]: + y_vals = _finite_vals(performance_data_ready["y"]) + + if curve == "roc": + rng = {"xaxis": [0.0, 1.0], "yaxis": [0.0, 1.0]} + + elif curve == "precision recall": + rng = {"xaxis": [0.0, 1.0], "yaxis": [0.0, 1.0]} + + elif curve == "gains": + rng = {"xaxis": [0.0, 1.0], "yaxis": [0.0, 1.0]} + + elif curve == "lift": + max_y = max([1.0] + y_vals) if y_vals else 1.0 + rng = {"xaxis": [0.0, 1.0], "yaxis": [0.0, max_y]} + + elif curve == "decision": + max_y = max(y_vals) if y_vals else 0.0 + min_y = min(y_vals) if y_vals else 0.0 + rng = { + "xaxis": [float(min_p_threshold), float(max_p_threshold)], + "yaxis": [min(min_y, 0.0), max_y], + } + + elif curve == "interventions avoided": + min_y = min(y_vals) if y_vals else 0.0 + rng = { + "xaxis": [float(min_p_threshold), float(max_p_threshold)], + "yaxis": [min(0.0, min_y), 100.0], + } + + else: + # Sensible default + rng = {"xaxis": [0.0, 1.0], "yaxis": [0.0, 1.0]} + + # Match the R post-step: purrr::map(~ extend_axis_ranges(.x)) + rng["xaxis"] = _extend_axis_ranges(rng["xaxis"]) + rng["yaxis"] = _extend_axis_ranges(rng["yaxis"]) + return rng + + +def _extend_axis_ranges(bounds, pad_frac=0.02): + lo, hi = bounds + # Handle None or identical values + if lo is None or hi is None: + return bounds + span = hi - lo + if span <= 0: + pad = 1e-6 + return [lo - pad, hi + pad] + pad = span * pad_frac + return [lo - pad, hi + pad] + + +def _get_prevalence_from_performance_data( + performance_data: pl.DataFrame, +) -> dict[str, float]: + cols_to_keep = [ + c for c in ["model", "population", "ppv"] if c in performance_data.columns + ] + if "ppcr" not in performance_data.columns or "ppv" not in performance_data.columns: + raise ValueError("performance_data must include 'ppcr' and 'PPV' columns") + + df = performance_data.filter(pl.col("ppcr") == 1).select(cols_to_keep).unique() + + if len(df.columns) == 1: + return {"single": float(df["ppv"][0])} + + key_col = df.columns[0] + return dict(zip(df[key_col].to_list(), df["ppv"].to_list())) + + +def _get_aj_estimates_from_performance_data( + performance_data: pl.DataFrame, +) -> pl.DataFrame: + return ( + performance_data.select("reference_group", "real_positives", "n") + .unique() + .with_columns((pl.col("real_positives") / pl.col("n")).alias("aj_estimate")) + .select(pl.col("reference_group"), pl.col("aj_estimate")) + ) + + +def _check_if_multiple_populations_are_being_validated( + aj_estimates: pl.DataFrame, +) -> bool: + return aj_estimates["aj_estimate"].unique().len() > 1 + + +def _create_rtichoke_curve_list_binary( + performance_data: pl.DataFrame, + stratified_by: str, + size: int = 500, + color_values=None, + curve="roc", + min_p_threshold=0, + max_p_threshold=1, +) -> dict[str, Any]: + animation_slider_prefix = ( + "Prob. Threshold: " + if stratified_by == "probability_threshold" + else "Predicted Positives (Rate):" + ) + + x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve] + + performance_data_ready_for_curve = _select_and_rename_necessary_variables( + performance_data.sort("chosen_cutoff"), x_metric, y_metric + ) + + aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data( + performance_data + ) + + multiple_populations = _check_if_multiple_populations_are_being_validated( + aj_estimates_from_performance_data + ) + + reference_data = _create_reference_lines_data( + curve=curve, + aj_estimates_from_performance_data=aj_estimates_from_performance_data, + multiple_populations=multiple_populations, + min_p_threshold=min_p_threshold, + max_p_threshold=max_p_threshold, + ) + + axes_ranges = extract_axes_ranges( + performance_data_ready_for_curve, + curve=curve, + min_p_threshold=min_p_threshold, + max_p_threshold=max_p_threshold, + ) + + reference_group_keys = performance_data["reference_group"].unique().to_list() + + cutoffs = ( + performance_data_ready_for_curve.select(pl.col("chosen_cutoff")) + .drop_nulls() + .unique() + .sort("chosen_cutoff") + .to_series() + .to_list() + ) + + palette = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ] + + colors_dictionary = { + **{ + key: "#BEBEBE" + for key in [ + "random_guess", + "perfect_model", + "treat_none", + "treat_all", + ] + }, + **{ + variant_key: (palette[group_index] if multiple_populations else "#000000") + for group_index, reference_group in enumerate(reference_group_keys) + for variant_key in [ + reference_group, + f"random_guess_{reference_group}", + f"perfect_model_{reference_group}", + f"treat_none_{reference_group}", + f"treat_all_{reference_group}", + ] + }, + } + + rtichoke_curve_list = { + "size": size, + "axes_ranges": axes_ranges, + "x_label": x_label, + "y_label": y_label, + "animation_slider_prefix": animation_slider_prefix, + "reference_group_keys": reference_group_keys, + "performance_data_ready_for_curve": performance_data_ready_for_curve, + "reference_data": reference_data, + "cutoffs": cutoffs, + "colors_dictionary": colors_dictionary, + "multiple_populations": multiple_populations, + } + + return rtichoke_curve_list + + +def _select_and_rename_necessary_variables( + performance_data: pl.DataFrame, x_perf_metric: str, y_perf_metric: str +) -> pl.DataFrame: + return performance_data.select( + pl.col("reference_group"), + pl.col("chosen_cutoff"), + pl.col(x_perf_metric).alias("x"), + pl.col(y_perf_metric).alias("y"), + ) + + +def _create_slider_dict(animation_slider_prefic: str, steps: dict) -> dict[str, Any]: + slider_dict = { + "active": 0, + "yanchor": "top", + "xanchor": "left", + "currentvalue": { + "font": {"size": 20}, + "prefix": animation_slider_prefic, + "visible": True, + "xanchor": "left", + }, + "transition": {"duration": 300, "easing": "linear"}, + "pad": {"b": 10, "t": 50}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": steps, + } + + return slider_dict + + +def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figure: + non_interactive_curve = [ + go.Scatter( + x=rtichoke_curve_list["performance_data_ready_for_curve"] + .filter(pl.col("reference_group") == group)["x"] + .to_list(), + y=rtichoke_curve_list["performance_data_ready_for_curve"] + .filter(pl.col("reference_group") == group)["y"] + .to_list(), + mode="markers+lines", + name=group, + legendgroup=group, + line={ + "width": 2, + "color": rtichoke_curve_list["colors_dictionary"].get(group), + }, + showlegend=True, + ) + for group in rtichoke_curve_list["reference_group_keys"] + ] + + initial_interactive_markers = [ + go.Scatter( + x=[], + y=[], + mode="markers", + marker={ + "size": 12, + "color": ( + rtichoke_curve_list["colors_dictionary"].get(group) + if rtichoke_curve_list["multiple_populations"] + else "#f6e3be" + ), + "line": {"width": 3, "color": "black"}, + }, + name=f"{group} @ cutoff", + legendgroup=group, + showlegend=False, + hovertemplate=f"{group}
x=%{{x:.4f}}
y=%{{y:.4f}}", + ) + for group in rtichoke_curve_list["reference_group_keys"] + ] + + reference_traces = [ + go.Scatter( + x=rtichoke_curve_list["reference_data"] + .filter(pl.col("reference_group") == group)["x"] + .to_list(), + y=rtichoke_curve_list["reference_data"] + .filter(pl.col("reference_group") == group)["y"] + .to_list(), + mode="lines", + name=group, + legendgroup=group, + line=dict( + dash="dot", + color=rtichoke_curve_list["colors_dictionary"].get(group), + width=1.5, + ), + hoverinfo="text", + text=rtichoke_curve_list["reference_data"] + .filter(pl.col("reference_group") == group)["text"] + .to_list(), + showlegend=False, + ) + for group in rtichoke_curve_list["colors_dictionary"].keys() + ] + + dyn_idx = list( + range( + len(rtichoke_curve_list["reference_group_keys"]), + len(rtichoke_curve_list["reference_group_keys"]) * 2, + ) + ) + + def xy_at_cutoff(group, c): + row = ( + rtichoke_curve_list["performance_data_ready_for_curve"] + .filter( + (pl.col("reference_group") == group) + & (pl.col("chosen_cutoff") == c) + & pl.col("x").is_not_null() + & pl.col("y").is_not_null() + ) + .select(["x", "y"]) + .limit(1) + ) + if row.height == 0: + return None, None + r = row.row(0) # (x, y) + return r[0], r[1] + + steps = [ + { + "method": "restyle", + "args": [ + { + "x": [ + [xy_at_cutoff(group, cutoff)[0]] + if xy_at_cutoff(group, cutoff)[0] is not None + else [] + for group in rtichoke_curve_list["reference_group_keys"] + ], + "y": [ + [xy_at_cutoff(group, cutoff)[1]] + if xy_at_cutoff(group, cutoff)[1] is not None + else [] + for group in rtichoke_curve_list["reference_group_keys"] + ], + }, + dyn_idx, + ], + "label": f"{cutoff:g}", + } + for cutoff in rtichoke_curve_list["cutoffs"] + ] + + slider_dict = _create_slider_dict( + rtichoke_curve_list["animation_slider_prefix"], steps + ) + + curve_layout = _create_curve_layout( + size=rtichoke_curve_list["size"], + slider_dict=slider_dict, + axes_ranges=rtichoke_curve_list["axes_ranges"], + x_label=rtichoke_curve_list["x_label"], + y_label=rtichoke_curve_list["y_label"], + ) + + return go.Figure( + data=non_interactive_curve + initial_interactive_markers + reference_traces, + layout=curve_layout, + ) + + +def _create_curve_layout( + size: int, + slider_dict: dict, + axes_ranges: dict[str, list[float]] | None = None, + x_label: str | None = None, + y_label: str | None = None, +) -> dict[str, Any]: + curve_layout = { + "xaxis": {"showgrid": False}, + "yaxis": {"showgrid": False}, + "plot_bgcolor": "rgba(0, 0, 0, 0)", + "showlegend": True, + "legend": { + "orientation": "h", + "xanchor": "center", + "yanchor": "top", + "x": 0.5, + "y": 1.3, + }, + "height": size, + "width": size, + "updatemenus": [ + { + "type": "buttons", + "buttons": [ + { + "label": "Play", + "method": "animate", + "visible": False, + "args": [None, {"frame": {"duration": 500, "redraw": False}}], + } + ], + } + ], + "sliders": [slider_dict], + } + + if axes_ranges is not None: + curve_layout["xaxis"]["range"] = axes_ranges["xaxis"] + curve_layout["yaxis"]["range"] = axes_ranges["yaxis"] + + if x_label: + curve_layout["xaxis"]["title"] = {"text": x_label} + if y_label: + curve_layout["yaxis"]["title"] = {"text": y_label} + + return curve_layout + + +def _create_interactive_marker_polars( + performance_data_ready_for_curve: pl.DataFrame, + interactive_marker_color: str, + k: int, + reference_group: str, +): + interactive_marker = go.Scatter( + x=[performance_data_ready_for_curve["x"][k]], + y=[performance_data_ready_for_curve["y"][k]], + mode="markers", + # hoverinfo="text", + # hovertext=[performance_data_ready_for_curve["text"].values.tolist()[k]], + name=reference_group, + legendgroup=reference_group, + showlegend=False, + marker={ + "size": 12, + "color": interactive_marker_color, + "line": {"width": 2, "color": "black"}, + }, + ) + return interactive_marker diff --git a/uv.lock b/uv.lock index ad22b84..ea3c64d 100644 --- a/uv.lock +++ b/uv.lock @@ -4054,7 +4054,7 @@ wheels = [ [[package]] name = "rtichoke" -version = "0.1.13" +version = "0.1.14" source = { editable = "." } dependencies = [ { name = "importlib" },