From dbe8891456ed0a82f78d9a660d486237347e6776 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sat, 29 Nov 2025 12:48:35 +0200 Subject: [PATCH 1/4] fix: close #207 --- .../helpers/plotly_helper_functions.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 83b6e08..c35c7cd 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -1054,6 +1054,15 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur "width": 2, "color": rtichoke_curve_list["colors_dictionary"].get(group), }, + hoverlabel=dict( + bgcolor=rtichoke_curve_list["colors_dictionary"].get( + group + ), # <-- background = trace color + bordercolor=rtichoke_curve_list["colors_dictionary"].get( + group + ), # <-- border = trace color + font_color="white", # <-- or "black" if your colors are light + ), showlegend=True, ) for group in rtichoke_curve_list["reference_group_keys"] @@ -1075,6 +1084,15 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur }, name=f"{group} @ cutoff", legendgroup=group, + hoverlabel=dict( + bgcolor=rtichoke_curve_list["colors_dictionary"].get( + group + ), # <-- background = trace color + bordercolor=rtichoke_curve_list["colors_dictionary"].get( + group + ), # <-- border = trace color + font_color="white", # <-- or "black" if your colors are light + ), showlegend=False, hovertemplate=f"{group}
x=%{{x:.4f}}
y=%{{y:.4f}}", ) @@ -1097,6 +1115,11 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur color=rtichoke_curve_list["colors_dictionary"].get(group), width=1.5, ), + hoverlabel=dict( + bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), + font_color="white", + ), hoverinfo="text", text=rtichoke_curve_list["reference_data"] .filter(pl.col("reference_group") == group)["text"] @@ -1183,7 +1206,7 @@ def _create_curve_layout( curve_layout = { "xaxis": {"showgrid": False}, "yaxis": {"showgrid": False}, - "template": "none", + "template": "plotly", "plot_bgcolor": "rgba(0, 0, 0, 0)", "paper_bgcolor": "rgba(0, 0, 0, 0)", "showlegend": True, @@ -1198,7 +1221,7 @@ def _create_curve_layout( }, "height": size + 50, "width": size, - "hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"}, + # "hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"}, "updatemenus": [ { "type": "buttons", From 127c9e287654fa98ec57052cf7fa4cfde8d9a48b Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sat, 29 Nov 2025 17:24:15 +0200 Subject: [PATCH 2/4] fix: close #210 --- .../helpers/plotly_helper_functions.py | 195 ++++++++++++++++-- 1 file changed, 179 insertions(+), 16 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index c35c7cd..1befd05 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -9,6 +9,19 @@ import numpy as np from rtichoke.performance_data.performance_data import prepare_performance_data +_HOVER_LABELS = { + "false_positive_rate": "1 - Specificity (FPR)", + "sensitivity": "Sensitivity", + "specificity": "Specificity", + "lift": "Lift", + "ppv": "PPV", + "npv": "NPV", + "net_benefit": "NB", + "net_benefit_interventions_avoided": "Interventions Avoided (per 100)", + "chosen_cutoff": "Prob. Threshold", + "ppcr": "Predicted Positives", +} + def _create_rtichoke_plotly_curve_binary( probs: Dict[str, np.ndarray], @@ -887,6 +900,149 @@ def _check_if_multiple_populations_are_being_validated( return aj_estimates["aj_estimate"].unique().len() > 1 +def _infer_performance_data_type(performance_data: pl.DataFrame) -> str: + if "model" in performance_data.columns: + return "several models" + if "population" in performance_data.columns: + return "several populations" + return "single model" + + +def _bold_hover_metrics(text: str, metrics: Sequence[str]) -> str: + lines = text.split("
") + for metric in metrics: + label = _HOVER_LABELS.get(metric, metric) + lines = [ + f"{line}" if label in line and "" not in line else line + for line in lines + ] + return "
".join(lines) + + +def _add_model_population_text(text: str, row: dict, perf_dat_type: str) -> str: + if perf_dat_type == "several models" and "model" in row: + text = f"Model: {row['model']}
{text}" + if perf_dat_type == "several populations" and "population" in row: + text = f"Population: {row['population']}
{text}" + return text + + +def _round_val(value: Any, digits: int = 3): + try: + if value is None: + return "" + if isinstance(value, (int, float, np.floating)): + return round(float(value), digits) + except (TypeError, ValueError): + pass + return value + + +def _build_hover_text( + row: dict, + performance_metric_x: str, + performance_metric_y: str, + stratified_by: str, + perf_dat_type: str, +) -> str: + interventions_avoided = performance_metric_y == "net_benefit_interventions_avoided" + + raw_probability_threshold = row.get("chosen_cutoff") + probability_threshold = _round_val(raw_probability_threshold) + sensitivity = _round_val(row.get("sensitivity")) + fpr = _round_val(row.get("false_positive_rate")) + specificity = _round_val(row.get("specificity")) + lift = _round_val(row.get("lift")) + ppv = _round_val(row.get("ppv")) + npv = _round_val(row.get("npv")) + net_benefit = _round_val(row.get("net_benefit")) + nb_interventions_avoided = _round_val(row.get("net_benefit_interventions_avoided")) + predicted_positives = _round_val(row.get("predicted_positives")) + raw_ppcr = row.get("ppcr") + ppcr_percent = ( + _round_val(100 * raw_ppcr) + if isinstance(raw_ppcr, (int, float, np.floating)) + else "" + ) + tp = _round_val(row.get("true_positives")) + tn = _round_val(row.get("true_negatives")) + fp = _round_val(row.get("false_positives")) + fn = _round_val(row.get("false_negatives")) + + if ( + isinstance(raw_probability_threshold, (int, float, np.floating)) + and raw_probability_threshold != 0 + ): + odds = _round_val( + (1 - raw_probability_threshold) / raw_probability_threshold, 2 + ) + else: + odds = None + + if not interventions_avoided: + text_lines = [ + f"Prob. Threshold: {probability_threshold}", + f"Sensitivity: {sensitivity}", + f"1 - Specificity (FPR): {fpr}", + f"Specificity: {specificity}", + f"Lift: {lift}", + f"PPV: {ppv}", + f"NPV: {npv}", + ] + if stratified_by == "probability_threshold": + text_lines.append(f"NB: {net_benefit}") + if odds is not None and math.isfinite(float(odds)): + text_lines.append(f"Odds of Prob. Threshold: 1:{odds}") + text_lines.extend( + [ + f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)", + f"TP: {tp}", + f"TN: {tn}", + f"FP: {fp}", + f"FN: {fn}", + ] + ) + else: + text_lines = [ + f"Prob. Threshold: {probability_threshold}", + f"Interventions Avoided (per 100): {nb_interventions_avoided}", + f"NB: {net_benefit}", + f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)", + f"TN: {tn}", + f"FN: {fn}", + ] + if odds is not None and math.isfinite(float(odds)): + text_lines.insert(1, f"Odds of Prob. Threshold: 1:{odds}") + + text = "
".join(text_lines) + text = _bold_hover_metrics(text, [performance_metric_x, performance_metric_y]) + text = _add_model_population_text(text, row, perf_dat_type) + return text.replace("NaN", "").replace("nan", "") + + +def _add_hover_text_to_performance_data( + performance_data: pl.DataFrame, + performance_metric_x: str, + performance_metric_y: str, + stratified_by: str, + perf_dat_type: str, +) -> pl.DataFrame: + hover_text_expr = pl.struct(performance_data.columns).map_elements( + lambda row: _build_hover_text( + row, + performance_metric_x=performance_metric_x, + performance_metric_y=performance_metric_y, + stratified_by=stratified_by, + perf_dat_type=perf_dat_type, + ), + return_dtype=pl.Utf8, + ) + + return performance_data.with_columns( + [pl.col(pl.FLOAT_DTYPES).round(3), hover_text_expr.alias("text")] + ) + + def _create_rtichoke_curve_list_binary( performance_data: pl.DataFrame, stratified_by: str, @@ -904,8 +1060,18 @@ def _create_rtichoke_curve_list_binary( x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve] + perf_dat_type = _infer_performance_data_type(performance_data) + + performance_data_with_hover_text = _add_hover_text_to_performance_data( + performance_data.sort("chosen_cutoff"), + performance_metric_x=x_metric, + performance_metric_y=y_metric, + stratified_by=stratified_by, + perf_dat_type=perf_dat_type, + ) + performance_data_ready_for_curve = _select_and_rename_necessary_variables( - performance_data.sort("chosen_cutoff"), x_metric, y_metric + performance_data_with_hover_text, x_metric, y_metric ) aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data( @@ -1013,6 +1179,7 @@ def _select_and_rename_necessary_variables( pl.col("chosen_cutoff"), pl.col(x_perf_metric).alias("x"), pl.col(y_perf_metric).alias("y"), + pl.col("text"), ) @@ -1047,6 +1214,9 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur y=rtichoke_curve_list["performance_data_ready_for_curve"] .filter(pl.col("reference_group") == group)["y"] .to_list(), + text=rtichoke_curve_list["performance_data_ready_for_curve"].filter( + pl.col("reference_group") == group + )["text"], mode="markers+lines", name=group, legendgroup=group, @@ -1055,14 +1225,11 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur "color": rtichoke_curve_list["colors_dictionary"].get(group), }, hoverlabel=dict( - bgcolor=rtichoke_curve_list["colors_dictionary"].get( - group - ), # <-- background = trace color - bordercolor=rtichoke_curve_list["colors_dictionary"].get( - group - ), # <-- border = trace color - font_color="white", # <-- or "black" if your colors are light + bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), + font_color="white", ), + hoverinfo="text", showlegend=True, ) for group in rtichoke_curve_list["reference_group_keys"] @@ -1085,16 +1252,12 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur name=f"{group} @ cutoff", legendgroup=group, hoverlabel=dict( - bgcolor=rtichoke_curve_list["colors_dictionary"].get( - group - ), # <-- background = trace color - bordercolor=rtichoke_curve_list["colors_dictionary"].get( - group - ), # <-- border = trace color - font_color="white", # <-- or "black" if your colors are light + bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), + font_color="white", ), showlegend=False, - hovertemplate=f"{group}
x=%{{x:.4f}}
y=%{{y:.4f}}", + hoverinfo="text", ) for group in rtichoke_curve_list["reference_group_keys"] ] From a23d67104658a05f7084288c6c5ed4e7b183bbb6 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sat, 29 Nov 2025 19:33:41 +0200 Subject: [PATCH 3/4] fix: close #209 --- .../helpers/plotly_helper_functions.py | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 1befd05..9a65e04 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -900,12 +900,23 @@ def _check_if_multiple_populations_are_being_validated( return aj_estimates["aj_estimate"].unique().len() > 1 -def _infer_performance_data_type(performance_data: pl.DataFrame) -> str: - if "model" in performance_data.columns: - return "several models" - if "population" in performance_data.columns: +def _check_if_multiple_models_are_being_validated(aj_estimates: pl.DataFrame) -> bool: + return aj_estimates["reference_group"].unique().len() > 1 + + +def _infer_performance_data_type( + aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool +) -> str: + multiple_models = _check_if_multiple_models_are_being_validated( + aj_estimates_from_performance_data + ) + + if multiple_populations: return "several populations" - return "single model" + elif multiple_models: + return "several models" + else: + return "single model" def _bold_hover_metrics(text: str, metrics: Sequence[str]) -> str: @@ -1060,7 +1071,23 @@ def _create_rtichoke_curve_list_binary( x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve] - perf_dat_type = _infer_performance_data_type(performance_data) + aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data( + performance_data + ) + + multiple_populations = _check_if_multiple_populations_are_being_validated( + aj_estimates_from_performance_data + ) + + multiple_models = _check_if_multiple_models_are_being_validated( + aj_estimates_from_performance_data + ) + + perf_dat_type = _infer_performance_data_type( + aj_estimates_from_performance_data, multiple_populations + ) + + multiple_reference_groups = multiple_populations or multiple_models performance_data_with_hover_text = _add_hover_text_to_performance_data( performance_data.sort("chosen_cutoff"), @@ -1074,14 +1101,6 @@ def _create_rtichoke_curve_list_binary( performance_data_with_hover_text, x_metric, y_metric ) - aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data( - performance_data - ) - - multiple_populations = _check_if_multiple_populations_are_being_validated( - aj_estimates_from_performance_data - ) - reference_data = _create_reference_lines_data( curve=curve, aj_estimates_from_performance_data=aj_estimates_from_performance_data, @@ -1142,7 +1161,9 @@ def _create_rtichoke_curve_list_binary( ] }, **{ - variant_key: (palette[group_index] if multiple_populations else "#000000") + variant_key: ( + palette[group_index] if multiple_reference_groups else "#000000" + ) for group_index, reference_group in enumerate(reference_group_keys) for variant_key in [ reference_group, @@ -1165,7 +1186,7 @@ def _create_rtichoke_curve_list_binary( "reference_data": reference_data, "cutoffs": cutoffs, "colors_dictionary": colors_dictionary, - "multiple_populations": multiple_populations, + "multiple_reference_groups": multiple_reference_groups, } return rtichoke_curve_list @@ -1244,7 +1265,7 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur "size": 12, "color": ( rtichoke_curve_list["colors_dictionary"].get(group) - if rtichoke_curve_list["multiple_populations"] + if rtichoke_curve_list["multiple_reference_groups"] else "#f6e3be" ), "line": {"width": 3, "color": "black"}, From 676a3bc6337f41458dc14d687b7109ba587f1725 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sat, 29 Nov 2025 19:36:45 +0200 Subject: [PATCH 4/4] build: update rtichoke version --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 33dfc03..12be92f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "marimo>=0.17.0", ] name = "rtichoke" -version = "0.1.17" +version = "0.1.18" description = "interactive visualizations for performance of predictive models" readme = "README.md" diff --git a/uv.lock b/uv.lock index 7e2793a..9cccb56 100644 --- a/uv.lock +++ b/uv.lock @@ -4148,7 +4148,7 @@ wheels = [ [[package]] name = "rtichoke" -version = "0.1.17" +version = "0.1.18" source = { editable = "." } dependencies = [ { name = "importlib" },