diff --git a/pyproject.toml b/pyproject.toml
index 33dfc03..12be92f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,7 +22,7 @@ dependencies = [
"marimo>=0.17.0",
]
name = "rtichoke"
-version = "0.1.17"
+version = "0.1.18"
description = "interactive visualizations for performance of predictive models"
readme = "README.md"
diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py
index 83b6e08..9a65e04 100644
--- a/src/rtichoke/helpers/plotly_helper_functions.py
+++ b/src/rtichoke/helpers/plotly_helper_functions.py
@@ -9,6 +9,19 @@
import numpy as np
from rtichoke.performance_data.performance_data import prepare_performance_data
+_HOVER_LABELS = {
+ "false_positive_rate": "1 - Specificity (FPR)",
+ "sensitivity": "Sensitivity",
+ "specificity": "Specificity",
+ "lift": "Lift",
+ "ppv": "PPV",
+ "npv": "NPV",
+ "net_benefit": "NB",
+ "net_benefit_interventions_avoided": "Interventions Avoided (per 100)",
+ "chosen_cutoff": "Prob. Threshold",
+ "ppcr": "Predicted Positives",
+}
+
def _create_rtichoke_plotly_curve_binary(
probs: Dict[str, np.ndarray],
@@ -887,6 +900,160 @@ def _check_if_multiple_populations_are_being_validated(
return aj_estimates["aj_estimate"].unique().len() > 1
+def _check_if_multiple_models_are_being_validated(aj_estimates: pl.DataFrame) -> bool:
+ return aj_estimates["reference_group"].unique().len() > 1
+
+
+def _infer_performance_data_type(
+ aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool
+) -> str:
+ multiple_models = _check_if_multiple_models_are_being_validated(
+ aj_estimates_from_performance_data
+ )
+
+ if multiple_populations:
+ return "several populations"
+ elif multiple_models:
+ return "several models"
+ else:
+ return "single model"
+
+
+def _bold_hover_metrics(text: str, metrics: Sequence[str]) -> str:
+ lines = text.split("
")
+ for metric in metrics:
+ label = _HOVER_LABELS.get(metric, metric)
+ lines = [
+ f"{line}" if label in line and "" not in line else line
+ for line in lines
+ ]
+ return "
".join(lines)
+
+
+def _add_model_population_text(text: str, row: dict, perf_dat_type: str) -> str:
+ if perf_dat_type == "several models" and "model" in row:
+ text = f"Model: {row['model']}
{text}"
+ if perf_dat_type == "several populations" and "population" in row:
+ text = f"Population: {row['population']}
{text}"
+ return text
+
+
+def _round_val(value: Any, digits: int = 3):
+ try:
+ if value is None:
+ return ""
+ if isinstance(value, (int, float, np.floating)):
+ return round(float(value), digits)
+ except (TypeError, ValueError):
+ pass
+ return value
+
+
+def _build_hover_text(
+ row: dict,
+ performance_metric_x: str,
+ performance_metric_y: str,
+ stratified_by: str,
+ perf_dat_type: str,
+) -> str:
+ interventions_avoided = performance_metric_y == "net_benefit_interventions_avoided"
+
+ raw_probability_threshold = row.get("chosen_cutoff")
+ probability_threshold = _round_val(raw_probability_threshold)
+ sensitivity = _round_val(row.get("sensitivity"))
+ fpr = _round_val(row.get("false_positive_rate"))
+ specificity = _round_val(row.get("specificity"))
+ lift = _round_val(row.get("lift"))
+ ppv = _round_val(row.get("ppv"))
+ npv = _round_val(row.get("npv"))
+ net_benefit = _round_val(row.get("net_benefit"))
+ nb_interventions_avoided = _round_val(row.get("net_benefit_interventions_avoided"))
+ predicted_positives = _round_val(row.get("predicted_positives"))
+ raw_ppcr = row.get("ppcr")
+ ppcr_percent = (
+ _round_val(100 * raw_ppcr)
+ if isinstance(raw_ppcr, (int, float, np.floating))
+ else ""
+ )
+ tp = _round_val(row.get("true_positives"))
+ tn = _round_val(row.get("true_negatives"))
+ fp = _round_val(row.get("false_positives"))
+ fn = _round_val(row.get("false_negatives"))
+
+ if (
+ isinstance(raw_probability_threshold, (int, float, np.floating))
+ and raw_probability_threshold != 0
+ ):
+ odds = _round_val(
+ (1 - raw_probability_threshold) / raw_probability_threshold, 2
+ )
+ else:
+ odds = None
+
+ if not interventions_avoided:
+ text_lines = [
+ f"Prob. Threshold: {probability_threshold}",
+ f"Sensitivity: {sensitivity}",
+ f"1 - Specificity (FPR): {fpr}",
+ f"Specificity: {specificity}",
+ f"Lift: {lift}",
+ f"PPV: {ppv}",
+ f"NPV: {npv}",
+ ]
+ if stratified_by == "probability_threshold":
+ text_lines.append(f"NB: {net_benefit}")
+ if odds is not None and math.isfinite(float(odds)):
+ text_lines.append(f"Odds of Prob. Threshold: 1:{odds}")
+ text_lines.extend(
+ [
+ f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)",
+ f"TP: {tp}",
+ f"TN: {tn}",
+ f"FP: {fp}",
+ f"FN: {fn}",
+ ]
+ )
+ else:
+ text_lines = [
+ f"Prob. Threshold: {probability_threshold}",
+ f"Interventions Avoided (per 100): {nb_interventions_avoided}",
+ f"NB: {net_benefit}",
+ f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)",
+ f"TN: {tn}",
+ f"FN: {fn}",
+ ]
+ if odds is not None and math.isfinite(float(odds)):
+ text_lines.insert(1, f"Odds of Prob. Threshold: 1:{odds}")
+
+ text = "
".join(text_lines)
+ text = _bold_hover_metrics(text, [performance_metric_x, performance_metric_y])
+ text = _add_model_population_text(text, row, perf_dat_type)
+ return text.replace("NaN", "").replace("nan", "")
+
+
+def _add_hover_text_to_performance_data(
+ performance_data: pl.DataFrame,
+ performance_metric_x: str,
+ performance_metric_y: str,
+ stratified_by: str,
+ perf_dat_type: str,
+) -> pl.DataFrame:
+ hover_text_expr = pl.struct(performance_data.columns).map_elements(
+ lambda row: _build_hover_text(
+ row,
+ performance_metric_x=performance_metric_x,
+ performance_metric_y=performance_metric_y,
+ stratified_by=stratified_by,
+ perf_dat_type=perf_dat_type,
+ ),
+ return_dtype=pl.Utf8,
+ )
+
+ return performance_data.with_columns(
+ [pl.col(pl.FLOAT_DTYPES).round(3), hover_text_expr.alias("text")]
+ )
+
+
def _create_rtichoke_curve_list_binary(
performance_data: pl.DataFrame,
stratified_by: str,
@@ -904,10 +1071,6 @@ def _create_rtichoke_curve_list_binary(
x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve]
- performance_data_ready_for_curve = _select_and_rename_necessary_variables(
- performance_data.sort("chosen_cutoff"), x_metric, y_metric
- )
-
aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data(
performance_data
)
@@ -916,6 +1079,28 @@ def _create_rtichoke_curve_list_binary(
aj_estimates_from_performance_data
)
+ multiple_models = _check_if_multiple_models_are_being_validated(
+ aj_estimates_from_performance_data
+ )
+
+ perf_dat_type = _infer_performance_data_type(
+ aj_estimates_from_performance_data, multiple_populations
+ )
+
+ multiple_reference_groups = multiple_populations or multiple_models
+
+ performance_data_with_hover_text = _add_hover_text_to_performance_data(
+ performance_data.sort("chosen_cutoff"),
+ performance_metric_x=x_metric,
+ performance_metric_y=y_metric,
+ stratified_by=stratified_by,
+ perf_dat_type=perf_dat_type,
+ )
+
+ performance_data_ready_for_curve = _select_and_rename_necessary_variables(
+ performance_data_with_hover_text, x_metric, y_metric
+ )
+
reference_data = _create_reference_lines_data(
curve=curve,
aj_estimates_from_performance_data=aj_estimates_from_performance_data,
@@ -976,7 +1161,9 @@ def _create_rtichoke_curve_list_binary(
]
},
**{
- variant_key: (palette[group_index] if multiple_populations else "#000000")
+ variant_key: (
+ palette[group_index] if multiple_reference_groups else "#000000"
+ )
for group_index, reference_group in enumerate(reference_group_keys)
for variant_key in [
reference_group,
@@ -999,7 +1186,7 @@ def _create_rtichoke_curve_list_binary(
"reference_data": reference_data,
"cutoffs": cutoffs,
"colors_dictionary": colors_dictionary,
- "multiple_populations": multiple_populations,
+ "multiple_reference_groups": multiple_reference_groups,
}
return rtichoke_curve_list
@@ -1013,6 +1200,7 @@ def _select_and_rename_necessary_variables(
pl.col("chosen_cutoff"),
pl.col(x_perf_metric).alias("x"),
pl.col(y_perf_metric).alias("y"),
+ pl.col("text"),
)
@@ -1047,6 +1235,9 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
y=rtichoke_curve_list["performance_data_ready_for_curve"]
.filter(pl.col("reference_group") == group)["y"]
.to_list(),
+ text=rtichoke_curve_list["performance_data_ready_for_curve"].filter(
+ pl.col("reference_group") == group
+ )["text"],
mode="markers+lines",
name=group,
legendgroup=group,
@@ -1054,6 +1245,12 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
"width": 2,
"color": rtichoke_curve_list["colors_dictionary"].get(group),
},
+ hoverlabel=dict(
+ bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
+ bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
+ font_color="white",
+ ),
+ hoverinfo="text",
showlegend=True,
)
for group in rtichoke_curve_list["reference_group_keys"]
@@ -1068,15 +1265,20 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
"size": 12,
"color": (
rtichoke_curve_list["colors_dictionary"].get(group)
- if rtichoke_curve_list["multiple_populations"]
+ if rtichoke_curve_list["multiple_reference_groups"]
else "#f6e3be"
),
"line": {"width": 3, "color": "black"},
},
name=f"{group} @ cutoff",
legendgroup=group,
+ hoverlabel=dict(
+ bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
+ bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
+ font_color="white",
+ ),
showlegend=False,
- hovertemplate=f"{group}
x=%{{x:.4f}}
y=%{{y:.4f}}",
+ hoverinfo="text",
)
for group in rtichoke_curve_list["reference_group_keys"]
]
@@ -1097,6 +1299,11 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
color=rtichoke_curve_list["colors_dictionary"].get(group),
width=1.5,
),
+ hoverlabel=dict(
+ bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
+ bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
+ font_color="white",
+ ),
hoverinfo="text",
text=rtichoke_curve_list["reference_data"]
.filter(pl.col("reference_group") == group)["text"]
@@ -1183,7 +1390,7 @@ def _create_curve_layout(
curve_layout = {
"xaxis": {"showgrid": False},
"yaxis": {"showgrid": False},
- "template": "none",
+ "template": "plotly",
"plot_bgcolor": "rgba(0, 0, 0, 0)",
"paper_bgcolor": "rgba(0, 0, 0, 0)",
"showlegend": True,
@@ -1198,7 +1405,7 @@ def _create_curve_layout(
},
"height": size + 50,
"width": size,
- "hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"},
+ # "hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"},
"updatemenus": [
{
"type": "buttons",
diff --git a/uv.lock b/uv.lock
index 7e2793a..9cccb56 100644
--- a/uv.lock
+++ b/uv.lock
@@ -4148,7 +4148,7 @@ wheels = [
[[package]]
name = "rtichoke"
-version = "0.1.17"
+version = "0.1.18"
source = { editable = "." }
dependencies = [
{ name = "importlib" },