Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"marimo>=0.17.0",
]
name = "rtichoke"
version = "0.1.17"
version = "0.1.18"
description = "interactive visualizations for performance of predictive models"
readme = "README.md"

Expand Down
227 changes: 217 additions & 10 deletions src/rtichoke/helpers/plotly_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@
import numpy as np
from rtichoke.performance_data.performance_data import prepare_performance_data

_HOVER_LABELS = {
"false_positive_rate": "1 - Specificity (FPR)",
"sensitivity": "Sensitivity",
"specificity": "Specificity",
"lift": "Lift",
"ppv": "PPV",
"npv": "NPV",
"net_benefit": "NB",
"net_benefit_interventions_avoided": "Interventions Avoided (per 100)",
"chosen_cutoff": "Prob. Threshold",
"ppcr": "Predicted Positives",
}


def _create_rtichoke_plotly_curve_binary(
probs: Dict[str, np.ndarray],
Expand Down Expand Up @@ -887,6 +900,160 @@ def _check_if_multiple_populations_are_being_validated(
return aj_estimates["aj_estimate"].unique().len() > 1


def _check_if_multiple_models_are_being_validated(aj_estimates: pl.DataFrame) -> bool:
return aj_estimates["reference_group"].unique().len() > 1


def _infer_performance_data_type(
aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool
) -> str:
multiple_models = _check_if_multiple_models_are_being_validated(
aj_estimates_from_performance_data
)

if multiple_populations:
return "several populations"
elif multiple_models:
return "several models"
else:
return "single model"


def _bold_hover_metrics(text: str, metrics: Sequence[str]) -> str:
lines = text.split("<br>")
for metric in metrics:
label = _HOVER_LABELS.get(metric, metric)
lines = [
f"<b>{line}</b>" if label in line and "<b>" not in line else line
for line in lines
]
return "<br>".join(lines)


def _add_model_population_text(text: str, row: dict, perf_dat_type: str) -> str:
if perf_dat_type == "several models" and "model" in row:
text = f"<b>Model: {row['model']}</b><br>{text}"
if perf_dat_type == "several populations" and "population" in row:
text = f"<b>Population: {row['population']}</b><br>{text}"
return text


def _round_val(value: Any, digits: int = 3):
try:
if value is None:
return ""
if isinstance(value, (int, float, np.floating)):
return round(float(value), digits)
except (TypeError, ValueError):
pass
return value


def _build_hover_text(
row: dict,
performance_metric_x: str,
performance_metric_y: str,
stratified_by: str,
perf_dat_type: str,
) -> str:
interventions_avoided = performance_metric_y == "net_benefit_interventions_avoided"

raw_probability_threshold = row.get("chosen_cutoff")
probability_threshold = _round_val(raw_probability_threshold)
sensitivity = _round_val(row.get("sensitivity"))
fpr = _round_val(row.get("false_positive_rate"))
specificity = _round_val(row.get("specificity"))
lift = _round_val(row.get("lift"))
ppv = _round_val(row.get("ppv"))
npv = _round_val(row.get("npv"))
net_benefit = _round_val(row.get("net_benefit"))
nb_interventions_avoided = _round_val(row.get("net_benefit_interventions_avoided"))
predicted_positives = _round_val(row.get("predicted_positives"))
raw_ppcr = row.get("ppcr")
ppcr_percent = (
_round_val(100 * raw_ppcr)
if isinstance(raw_ppcr, (int, float, np.floating))
else ""
)
tp = _round_val(row.get("true_positives"))
tn = _round_val(row.get("true_negatives"))
fp = _round_val(row.get("false_positives"))
fn = _round_val(row.get("false_negatives"))

if (
isinstance(raw_probability_threshold, (int, float, np.floating))
and raw_probability_threshold != 0
):
odds = _round_val(
(1 - raw_probability_threshold) / raw_probability_threshold, 2
)
else:
odds = None

if not interventions_avoided:
text_lines = [
f"Prob. Threshold: {probability_threshold}",
f"Sensitivity: {sensitivity}",
f"1 - Specificity (FPR): {fpr}",
f"Specificity: {specificity}",
f"Lift: {lift}",
f"PPV: {ppv}",
f"NPV: {npv}",
]
if stratified_by == "probability_threshold":
text_lines.append(f"NB: {net_benefit}")
if odds is not None and math.isfinite(float(odds)):
text_lines.append(f"Odds of Prob. Threshold: 1:{odds}")
text_lines.extend(
[
f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)",
f"TP: {tp}",
f"TN: {tn}",
f"FP: {fp}",
f"FN: {fn}",
]
)
else:
text_lines = [
f"Prob. Threshold: {probability_threshold}",
f"Interventions Avoided (per 100): {nb_interventions_avoided}",
f"NB: {net_benefit}",
f"Predicted Positives: {predicted_positives} ({ppcr_percent}%)",
f"TN: {tn}",
f"FN: {fn}",
]
if odds is not None and math.isfinite(float(odds)):
text_lines.insert(1, f"Odds of Prob. Threshold: 1:{odds}")

text = "<br>".join(text_lines)
text = _bold_hover_metrics(text, [performance_metric_x, performance_metric_y])
text = _add_model_population_text(text, row, perf_dat_type)
return text.replace("NaN", "").replace("nan", "")


def _add_hover_text_to_performance_data(
performance_data: pl.DataFrame,
performance_metric_x: str,
performance_metric_y: str,
stratified_by: str,
perf_dat_type: str,
) -> pl.DataFrame:
hover_text_expr = pl.struct(performance_data.columns).map_elements(
lambda row: _build_hover_text(
row,
performance_metric_x=performance_metric_x,
performance_metric_y=performance_metric_y,
stratified_by=stratified_by,
perf_dat_type=perf_dat_type,
),
return_dtype=pl.Utf8,
)

return performance_data.with_columns(
[pl.col(pl.FLOAT_DTYPES).round(3), hover_text_expr.alias("text")]
)


def _create_rtichoke_curve_list_binary(
performance_data: pl.DataFrame,
stratified_by: str,
Expand All @@ -904,10 +1071,6 @@ def _create_rtichoke_curve_list_binary(

x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve]

performance_data_ready_for_curve = _select_and_rename_necessary_variables(
performance_data.sort("chosen_cutoff"), x_metric, y_metric
)

aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data(
performance_data
)
Expand All @@ -916,6 +1079,28 @@ def _create_rtichoke_curve_list_binary(
aj_estimates_from_performance_data
)

multiple_models = _check_if_multiple_models_are_being_validated(
aj_estimates_from_performance_data
)

perf_dat_type = _infer_performance_data_type(
aj_estimates_from_performance_data, multiple_populations
)

multiple_reference_groups = multiple_populations or multiple_models

performance_data_with_hover_text = _add_hover_text_to_performance_data(
performance_data.sort("chosen_cutoff"),
performance_metric_x=x_metric,
performance_metric_y=y_metric,
stratified_by=stratified_by,
perf_dat_type=perf_dat_type,
)

performance_data_ready_for_curve = _select_and_rename_necessary_variables(
performance_data_with_hover_text, x_metric, y_metric
)

reference_data = _create_reference_lines_data(
curve=curve,
aj_estimates_from_performance_data=aj_estimates_from_performance_data,
Expand Down Expand Up @@ -976,7 +1161,9 @@ def _create_rtichoke_curve_list_binary(
]
},
**{
variant_key: (palette[group_index] if multiple_populations else "#000000")
variant_key: (
palette[group_index] if multiple_reference_groups else "#000000"
)
for group_index, reference_group in enumerate(reference_group_keys)
for variant_key in [
reference_group,
Expand All @@ -999,7 +1186,7 @@ def _create_rtichoke_curve_list_binary(
"reference_data": reference_data,
"cutoffs": cutoffs,
"colors_dictionary": colors_dictionary,
"multiple_populations": multiple_populations,
"multiple_reference_groups": multiple_reference_groups,
}

return rtichoke_curve_list
Expand All @@ -1013,6 +1200,7 @@ def _select_and_rename_necessary_variables(
pl.col("chosen_cutoff"),
pl.col(x_perf_metric).alias("x"),
pl.col(y_perf_metric).alias("y"),
pl.col("text"),
)


Expand Down Expand Up @@ -1047,13 +1235,22 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
y=rtichoke_curve_list["performance_data_ready_for_curve"]
.filter(pl.col("reference_group") == group)["y"]
.to_list(),
text=rtichoke_curve_list["performance_data_ready_for_curve"].filter(
pl.col("reference_group") == group
)["text"],
mode="markers+lines",
name=group,
legendgroup=group,
line={
"width": 2,
"color": rtichoke_curve_list["colors_dictionary"].get(group),
},
hoverlabel=dict(
bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
font_color="white",
),
hoverinfo="text",
showlegend=True,
)
for group in rtichoke_curve_list["reference_group_keys"]
Expand All @@ -1068,15 +1265,20 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
"size": 12,
"color": (
rtichoke_curve_list["colors_dictionary"].get(group)
if rtichoke_curve_list["multiple_populations"]
if rtichoke_curve_list["multiple_reference_groups"]
else "#f6e3be"
),
"line": {"width": 3, "color": "black"},
},
name=f"{group} @ cutoff",
legendgroup=group,
hoverlabel=dict(
bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
font_color="white",
),
showlegend=False,
hovertemplate=f"{group}<br>x=%{{x:.4f}}<br>y=%{{y:.4f}}<extra></extra>",
hoverinfo="text",
)
for group in rtichoke_curve_list["reference_group_keys"]
]
Expand All @@ -1097,6 +1299,11 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur
color=rtichoke_curve_list["colors_dictionary"].get(group),
width=1.5,
),
hoverlabel=dict(
bgcolor=rtichoke_curve_list["colors_dictionary"].get(group),
bordercolor=rtichoke_curve_list["colors_dictionary"].get(group),
font_color="white",
),
hoverinfo="text",
text=rtichoke_curve_list["reference_data"]
.filter(pl.col("reference_group") == group)["text"]
Expand Down Expand Up @@ -1183,7 +1390,7 @@ def _create_curve_layout(
curve_layout = {
"xaxis": {"showgrid": False},
"yaxis": {"showgrid": False},
"template": "none",
"template": "plotly",
"plot_bgcolor": "rgba(0, 0, 0, 0)",
"paper_bgcolor": "rgba(0, 0, 0, 0)",
"showlegend": True,
Expand All @@ -1198,7 +1405,7 @@ def _create_curve_layout(
},
"height": size + 50,
"width": size,
"hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"},
# "hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"},
"updatemenus": [
{
"type": "buttons",
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.