Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"pyarrow>=21.0.0",
]
name = "rtichoke"
version = "0.1.22"
version = "0.1.23"
description = "interactive visualizations for performance of predictive models"
readme = "README.md"

Expand Down
187 changes: 187 additions & 0 deletions src/rtichoke/helpers/plotly_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,38 @@ def _get_aj_estimates_from_performance_data(
)


def _get_aj_estimates_from_performance_data_times(
performance_data: pl.DataFrame,
) -> pl.DataFrame:
return (
performance_data.filter(
(pl.col("chosen_cutoff") == 0) | (pl.col("chosen_cutoff") == 1)
)
.select("reference_group", "fixed_time_horizon", "real_positives", "n")
.unique()
.with_columns((pl.col("real_positives") / pl.col("n")).alias("aj_estimate"))
.select(
pl.col("reference_group"),
pl.col("fixed_time_horizon"),
pl.col("aj_estimate"),
)
.sort(by=["reference_group", "fixed_time_horizon"])
)


def _check_if_multiple_populations_are_being_validated_times(
aj_estimates: pl.DataFrame,
) -> bool:
max_val = (
aj_estimates.group_by("fixed_time_horizon")
.agg(pl.col("aj_estimate").n_unique().alias("num_populations"))[
"num_populations"
]
.max()
)
return max_val is not None and max_val > 1


def _check_if_multiple_populations_are_being_validated(
aj_estimates: pl.DataFrame,
) -> bool:
Expand All @@ -908,6 +940,21 @@ def _check_if_multiple_models_are_being_validated(aj_estimates: pl.DataFrame) ->
return aj_estimates["reference_group"].unique().len() > 1


def _infer_performance_data_type_times(
aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool
) -> str:
multiple_models = _check_if_multiple_populations_are_being_validated_times(
aj_estimates_from_performance_data
)

if multiple_populations:
return "several populations"
elif multiple_models:
return "several models"
else:
return "single model"


def _infer_performance_data_type(
aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool
) -> str:
Expand Down Expand Up @@ -1058,6 +1105,146 @@ def _add_hover_text_to_performance_data(
)


def _create_rtichoke_curve_list_times(
performance_data: pl.DataFrame,
stratified_by: str,
size: int = 500,
color_value=None,
curve="roc",
min_p_threshold=0,
max_p_threshold=1,
) -> dict[str, Any]:
animation_slider_cutoff_prefix = (
"Prob. Threshold: "
if stratified_by == "probability_threshold"
else "Predicted Positives (Rate):"
)

x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve]

aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data_times(
performance_data
)

print("aj_estimates_from_performance_data", aj_estimates_from_performance_data)

multiple_populations = _check_if_multiple_populations_are_being_validated_times(
aj_estimates_from_performance_data
)

multiple_models = _check_if_multiple_models_are_being_validated(
aj_estimates_from_performance_data
)

perf_dat_type = _infer_performance_data_type_times(
aj_estimates_from_performance_data, multiple_populations
)

multiple_reference_groups = multiple_populations or multiple_models

performance_data_with_hover_text = _add_hover_text_to_performance_data(
performance_data.sort("chosen_cutoff"),
performance_metric_x=x_metric,
performance_metric_y=y_metric,
stratified_by=stratified_by,
perf_dat_type=perf_dat_type,
)

performance_data_ready_for_curve = _select_and_rename_necessary_variables(
performance_data_with_hover_text, x_metric, y_metric
)

reference_data = _create_reference_lines_data(
curve=curve,
aj_estimates_from_performance_data=aj_estimates_from_performance_data,
multiple_populations=multiple_populations,
min_p_threshold=min_p_threshold,
max_p_threshold=max_p_threshold,
)

axes_ranges = extract_axes_ranges(
performance_data_ready_for_curve,
curve=curve,
min_p_threshold=min_p_threshold,
max_p_threshold=max_p_threshold,
)

reference_group_keys = performance_data["reference_group"].unique().to_list()

cutoffs = (
performance_data_ready_for_curve.select(pl.col("chosen_cutoff"))
.drop_nulls()
.unique()
.sort("chosen_cutoff")
.to_series()
.to_list()
)

palette = [
"#1b9e77",
"#d95f02",
"#7570b3",
"#e7298a",
"#07004D",
"#E6AB02",
"#FE5F55",
"#54494B",
"#006E90",
"#BC96E6",
"#52050A",
"#1F271B",
"#BE7C4D",
"#63768D",
"#08A045",
"#320A28",
"#82FF9E",
"#2176FF",
"#D1603D",
"#585123",
]

colors_dictionary = {
**{
key: "#BEBEBE"
for key in [
"random_guess",
"perfect_model",
"treat_none",
"treat_all",
]
},
**{
variant_key: (
palette[group_index] if multiple_reference_groups else "#000000"
)
for group_index, reference_group in enumerate(reference_group_keys)
for variant_key in [
reference_group,
f"random_guess_{reference_group}",
f"perfect_model_{reference_group}",
f"treat_none_{reference_group}",
f"treat_all_{reference_group}",
]
},
}

rtichoke_curve_list = {
"size": size,
"axes_ranges": axes_ranges,
"x_label": x_label,
"y_label": y_label,
"animation_slider_cutoff_prefix": animation_slider_cutoff_prefix,
"reference_group_keys": reference_group_keys,
"performance_data_ready_for_curve": performance_data_ready_for_curve,
"reference_data": reference_data,
"cutoffs": cutoffs,
"colors_dictionary": colors_dictionary,
"multiple_reference_groups": multiple_reference_groups,
}

return rtichoke_curve_list


def _create_rtichoke_curve_list_binary(
performance_data: pl.DataFrame,
stratified_by: str,
Expand Down
74 changes: 60 additions & 14 deletions src/rtichoke/helpers/sandbox_observable_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def _create_list_data_to_adjust_binary(
return list_data_to_adjust


def create_list_data_to_adjust(
def _create_list_data_to_adjust(
aj_data_combinations: pl.DataFrame,
probs_dict: Dict[str, np.ndarray],
reals_dict: Union[np.ndarray, Dict[str, np.ndarray]],
Expand All @@ -922,25 +922,70 @@ def create_list_data_to_adjust(
) -> Dict[str, pl.DataFrame]:
# reference_groups = list(probs_dict.keys())
reference_group_labels = list(probs_dict.keys())
num_reals = len(reals_dict)

if isinstance(reals_dict, dict):
num_keys_reals = len(reals_dict)
else:
num_keys_reals = 1

# num_reals = len(reals_dict)

reference_group_enum = pl.Enum(reference_group_labels)

strata_enum_dtype = aj_data_combinations.schema["strata"]

# Flatten and ensure list format
data_to_adjust = pl.DataFrame(
{
"reference_group": np.repeat(reference_group_labels, num_reals),
"probs": np.concatenate(
[probs_dict[group] for group in reference_group_labels]
),
"reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)),
"times": np.tile(np.asarray(times_dict), len(reference_group_labels)),
}
).with_columns(pl.col("reference_group").cast(reference_group_enum))
if len(probs_dict) == 1:
probs_array = np.asarray(probs_dict[reference_group_labels[0]])

if isinstance(reals_dict, dict):
reals_array = np.asarray(reals_dict[0])
else:
reals_array = np.asarray(reals_dict)

if isinstance(times_dict, dict):
times_array = np.asarray(times_dict[0])
else:
times_array = np.asarray(times_dict)

data_to_adjust = pl.DataFrame(
{
"reference_group": np.repeat(reference_group_labels, len(probs_array)),
"probs": probs_array,
"reals": reals_array,
"times": times_array,
}
).with_columns(pl.col("reference_group").cast(reference_group_enum))

elif num_keys_reals == 1:
reals_array = np.asarray(reals_dict)
times_array = np.asarray(times_dict)
n = len(reals_array)

data_to_adjust = pl.DataFrame(
{
"reference_group": np.repeat(reference_group_labels, n),
"probs": np.concatenate(
[np.asarray(probs_dict[g]) for g in reference_group_labels]
),
"reals": np.tile(reals_array, len(reference_group_labels)),
"times": np.tile(times_array, len(reference_group_labels)),
}
).with_columns(pl.col("reference_group").cast(reference_group_enum))

elif isinstance(reals_dict, dict) and isinstance(times_dict, dict):
data_to_adjust = (
pl.DataFrame(
{
"reference_group": reference_group_labels,
"probs": list(probs_dict.values()),
"reals": list(reals_dict.values()),
"times": list(times_dict.values()),
}
)
.explode(["probs", "reals", "times"])
.with_columns(pl.col("reference_group").cast(reference_group_enum))
)

# Apply strata
data_to_adjust = add_cutoff_strata(
data_to_adjust, by=by, stratified_by=stratified_by
)
Expand Down Expand Up @@ -1637,6 +1682,7 @@ def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame:
)
.agg([pl.col("reals_estimate").sum()])
.pivot(on="classification_outcome", values="reals_estimate")
.fill_null(0)
.with_columns(
(pl.col("true_positives") + pl.col("false_positives")).alias(
"predicted_positives"
Expand Down
6 changes: 3 additions & 3 deletions src/rtichoke/performance_data/performance_data_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rtichoke.helpers.sandbox_observable_helpers import (
create_breaks_values,
create_aj_data_combinations,
create_list_data_to_adjust,
_create_list_data_to_adjust,
create_adjusted_data,
cast_and_join_adjusted_data,
_calculate_cumulative_aj_data,
Expand Down Expand Up @@ -154,7 +154,7 @@ def prepare_binned_classification_data_times(
risk_set_scope=risk_set_scope,
)

list_data_to_adjust = create_list_data_to_adjust(
list_data_to_adjust = _create_list_data_to_adjust(
aj_data_combinations,
probs,
reals,
Expand All @@ -175,6 +175,6 @@ def prepare_binned_classification_data_times(
final_adjusted_data = cast_and_join_adjusted_data(
aj_data_combinations,
adjusted_data,
)
).with_columns(pl.col("reals_estimate").fill_null(0.0))

return final_adjusted_data
6 changes: 3 additions & 3 deletions src/rtichoke/summary_report/summary_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r
from rtichoke.helpers.sandbox_observable_helpers import (
create_list_data_to_adjust,
_create_list_data_to_adjust,
)
import subprocess

Expand Down Expand Up @@ -67,8 +67,8 @@ def create_data_for_summary_report(probs, reals, times, fixed_time_horizons):
stratified_by = ["probability_threshold", "ppcr"]
by = 0.1

list_data_to_adjust_polars = create_list_data_to_adjust(
probs, reals, times, stratified_by=stratified_by, by=by
list_data_to_adjust_polars = _create_list_data_to_adjust(
probs, reals, times, stratified_by=stratified_by, by=by, times_dict={}
)

return list_data_to_adjust_polars
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.