From c064bb2e8343427623b4d8656f36030630962a01 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 7 Dec 2025 12:04:03 +0200 Subject: [PATCH 1/4] fix: close #224 --- .../helpers/plotly_helper_functions.py | 60 +++++++++++++++ .../helpers/sandbox_observable_helpers.py | 74 +++++++++++++++---- .../performance_data_times.py | 6 +- src/rtichoke/summary_report/summary_report.py | 6 +- 4 files changed, 126 insertions(+), 20 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 5999534..86a42e8 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -898,6 +898,24 @@ def _get_aj_estimates_from_performance_data( ) +def _get_aj_estimates_from_performance_data_times( + performance_data: pl.DataFrame, +) -> pl.DataFrame: + return ( + performance_data.select( + "reference_group", "fixed_time_horizon", "real_positives", "n" + ) + .unique() + .with_columns((pl.col("real_positives") / pl.col("n")).alias("aj_estimate")) + .select( + pl.col("reference_group"), + pl.col("fixed_time_horizon"), + pl.col("aj_estimate"), + ) + .sort(by=["reference_group", "fixed_time_horizon"]) + ) + + def _check_if_multiple_populations_are_being_validated( aj_estimates: pl.DataFrame, ) -> bool: @@ -1058,6 +1076,48 @@ def _add_hover_text_to_performance_data( ) +def _create_rtichoke_curve_list_times( + performance_data: pl.DataFrame, + stratified_by: str, + size: int = 500, + color_value=None, + curve="roc", + min_p_threshold=0, + max_p_threshold=1, +) -> dict[str, Any]: + animation_slider_cutoff_prefix = ( + "Prob. Threshold: " + if stratified_by == "probability_threshold" + else "Predicted Positives (Rate):" + ) + + x_metric, y_metric, x_label, y_label = _CURVE_CONFIG[curve] + + aj_estimates_from_performance_data = _get_aj_estimates_from_performance_data_times( + performance_data + ) + + print("aj_estimates_from_performance_data", aj_estimates_from_performance_data) + + reference_group_keys = performance_data["reference_group"].unique().to_list() + + rtichoke_curve_list = { + "size": size, + # "axes_ranges": axes_ranges, + "x_label": x_label, + "y_label": y_label, + "animation_slider_cutoff_prefix": animation_slider_cutoff_prefix, + "reference_group_keys": reference_group_keys, + # "performance_data_ready_for_curve": performance_data_ready_for_curve, + # "reference_data": reference_data, + # "cutoffs": cutoffs, + # "colors_dictionary": colors_dictionary, + # "multiple_reference_groups": multiple_reference_groups, + } + + return rtichoke_curve_list + + def _create_rtichoke_curve_list_binary( performance_data: pl.DataFrame, stratified_by: str, diff --git a/src/rtichoke/helpers/sandbox_observable_helpers.py b/src/rtichoke/helpers/sandbox_observable_helpers.py index 6a2d3c5..9bd7499 100644 --- a/src/rtichoke/helpers/sandbox_observable_helpers.py +++ b/src/rtichoke/helpers/sandbox_observable_helpers.py @@ -912,7 +912,7 @@ def _create_list_data_to_adjust_binary( return list_data_to_adjust -def create_list_data_to_adjust( +def _create_list_data_to_adjust( aj_data_combinations: pl.DataFrame, probs_dict: Dict[str, np.ndarray], reals_dict: Union[np.ndarray, Dict[str, np.ndarray]], @@ -922,25 +922,70 @@ def create_list_data_to_adjust( ) -> Dict[str, pl.DataFrame]: # reference_groups = list(probs_dict.keys()) reference_group_labels = list(probs_dict.keys()) - num_reals = len(reals_dict) + + if isinstance(reals_dict, dict): + num_keys_reals = len(reals_dict) + else: + num_keys_reals = 1 + + # num_reals = len(reals_dict) reference_group_enum = pl.Enum(reference_group_labels) strata_enum_dtype = aj_data_combinations.schema["strata"] - # Flatten and ensure list format - data_to_adjust = pl.DataFrame( - { - "reference_group": np.repeat(reference_group_labels, num_reals), - "probs": np.concatenate( - [probs_dict[group] for group in reference_group_labels] - ), - "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)), - "times": np.tile(np.asarray(times_dict), len(reference_group_labels)), - } - ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + if len(probs_dict) == 1: + probs_array = np.asarray(probs_dict[reference_group_labels[0]]) + + if isinstance(reals_dict, dict): + reals_array = np.asarray(reals_dict[0]) + else: + reals_array = np.asarray(reals_dict) + + if isinstance(times_dict, dict): + times_array = np.asarray(times_dict[0]) + else: + times_array = np.asarray(times_dict) + + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, len(probs_array)), + "probs": probs_array, + "reals": reals_array, + "times": times_array, + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif num_keys_reals == 1: + reals_array = np.asarray(reals_dict) + times_array = np.asarray(times_dict) + n = len(reals_array) + + data_to_adjust = pl.DataFrame( + { + "reference_group": np.repeat(reference_group_labels, n), + "probs": np.concatenate( + [np.asarray(probs_dict[g]) for g in reference_group_labels] + ), + "reals": np.tile(reals_array, len(reference_group_labels)), + "times": np.tile(times_array, len(reference_group_labels)), + } + ).with_columns(pl.col("reference_group").cast(reference_group_enum)) + + elif isinstance(reals_dict, dict) and isinstance(times_dict, dict): + data_to_adjust = ( + pl.DataFrame( + { + "reference_group": reference_group_labels, + "probs": list(probs_dict.values()), + "reals": list(reals_dict.values()), + "times": list(times_dict.values()), + } + ) + .explode(["probs", "reals", "times"]) + .with_columns(pl.col("reference_group").cast(reference_group_enum)) + ) - # Apply strata data_to_adjust = add_cutoff_strata( data_to_adjust, by=by, stratified_by=stratified_by ) @@ -1637,6 +1682,7 @@ def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame: ) .agg([pl.col("reals_estimate").sum()]) .pivot(on="classification_outcome", values="reals_estimate") + .fill_null(0) .with_columns( (pl.col("true_positives") + pl.col("false_positives")).alias( "predicted_positives" diff --git a/src/rtichoke/performance_data/performance_data_times.py b/src/rtichoke/performance_data/performance_data_times.py index 55a462d..901bd59 100644 --- a/src/rtichoke/performance_data/performance_data_times.py +++ b/src/rtichoke/performance_data/performance_data_times.py @@ -8,7 +8,7 @@ from rtichoke.helpers.sandbox_observable_helpers import ( create_breaks_values, create_aj_data_combinations, - create_list_data_to_adjust, + _create_list_data_to_adjust, create_adjusted_data, cast_and_join_adjusted_data, _calculate_cumulative_aj_data, @@ -154,7 +154,7 @@ def prepare_binned_classification_data_times( risk_set_scope=risk_set_scope, ) - list_data_to_adjust = create_list_data_to_adjust( + list_data_to_adjust = _create_list_data_to_adjust( aj_data_combinations, probs, reals, @@ -175,6 +175,6 @@ def prepare_binned_classification_data_times( final_adjusted_data = cast_and_join_adjusted_data( aj_data_combinations, adjusted_data, - ) + ).with_columns(pl.col("reals_estimate").fill_null(0.0)) return final_adjusted_data diff --git a/src/rtichoke/summary_report/summary_report.py b/src/rtichoke/summary_report/summary_report.py index 9fc6fd3..9549260 100644 --- a/src/rtichoke/summary_report/summary_report.py +++ b/src/rtichoke/summary_report/summary_report.py @@ -4,7 +4,7 @@ from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r from rtichoke.helpers.sandbox_observable_helpers import ( - create_list_data_to_adjust, + _create_list_data_to_adjust, ) import subprocess @@ -67,8 +67,8 @@ def create_data_for_summary_report(probs, reals, times, fixed_time_horizons): stratified_by = ["probability_threshold", "ppcr"] by = 0.1 - list_data_to_adjust_polars = create_list_data_to_adjust( - probs, reals, times, stratified_by=stratified_by, by=by + list_data_to_adjust_polars = _create_list_data_to_adjust( + probs, reals, times, stratified_by=stratified_by, by=by, times_dict={} ) return list_data_to_adjust_polars From a02621377c2cb1c31ab63a6644a4f59c1e58ca2a Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 7 Dec 2025 12:19:36 +0200 Subject: [PATCH 2/4] fix: close #225 --- src/rtichoke/helpers/plotly_helper_functions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 86a42e8..1ccbd65 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -902,9 +902,10 @@ def _get_aj_estimates_from_performance_data_times( performance_data: pl.DataFrame, ) -> pl.DataFrame: return ( - performance_data.select( - "reference_group", "fixed_time_horizon", "real_positives", "n" + performance_data.filter( + (pl.col("chosen_cutoff") == 0) | (pl.col("chosen_cutoff") == 1) ) + .select("reference_group", "fixed_time_horizon", "real_positives", "n") .unique() .with_columns((pl.col("real_positives") / pl.col("n")).alias("aj_estimate")) .select( From ab9bff8748fbe04df80de1a79f4ef98b164b0a0e Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 7 Dec 2025 14:59:11 +0200 Subject: [PATCH 3/4] fix: close #227 --- .../helpers/plotly_helper_functions.py | 138 +++++++++++++++++- 1 file changed, 132 insertions(+), 6 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 1ccbd65..78630d5 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -917,6 +917,19 @@ def _get_aj_estimates_from_performance_data_times( ) +def _check_if_multiple_populations_are_being_validated_times( + aj_estimates: pl.DataFrame, +) -> bool: + max_val = ( + aj_estimates.group_by("fixed_time_horizon") + .agg(pl.col("aj_estimate").n_unique().alias("num_populations"))[ + "num_populations" + ] + .max() + ) + return max_val is not None and max_val > 1 + + def _check_if_multiple_populations_are_being_validated( aj_estimates: pl.DataFrame, ) -> bool: @@ -927,6 +940,21 @@ def _check_if_multiple_models_are_being_validated(aj_estimates: pl.DataFrame) -> return aj_estimates["reference_group"].unique().len() > 1 +def _infer_performance_data_type_times( + aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool +) -> str: + multiple_models = _check_if_multiple_populations_are_being_validated_times( + aj_estimates_from_performance_data + ) + + if multiple_populations: + return "several populations" + elif multiple_models: + return "several models" + else: + return "single model" + + def _infer_performance_data_type( aj_estimates_from_performance_data: pl.DataFrame, multiple_populations: bool ) -> str: @@ -1100,20 +1128,118 @@ def _create_rtichoke_curve_list_times( print("aj_estimates_from_performance_data", aj_estimates_from_performance_data) + multiple_populations = _check_if_multiple_populations_are_being_validated_times( + aj_estimates_from_performance_data + ) + + multiple_models = _check_if_multiple_models_are_being_validated( + aj_estimates_from_performance_data + ) + + perf_dat_type = _infer_performance_data_type_times( + aj_estimates_from_performance_data, multiple_populations + ) + + multiple_reference_groups = multiple_populations or multiple_models + + performance_data_with_hover_text = _add_hover_text_to_performance_data( + performance_data.sort("chosen_cutoff"), + performance_metric_x=x_metric, + performance_metric_y=y_metric, + stratified_by=stratified_by, + perf_dat_type=perf_dat_type, + ) + + performance_data_ready_for_curve = _select_and_rename_necessary_variables( + performance_data_with_hover_text, x_metric, y_metric + ) + + reference_data = _create_reference_lines_data( + curve=curve, + aj_estimates_from_performance_data=aj_estimates_from_performance_data, + multiple_populations=multiple_populations, + min_p_threshold=min_p_threshold, + max_p_threshold=max_p_threshold, + ) + + axes_ranges = extract_axes_ranges( + performance_data_ready_for_curve, + curve=curve, + min_p_threshold=min_p_threshold, + max_p_threshold=max_p_threshold, + ) + reference_group_keys = performance_data["reference_group"].unique().to_list() + cutoffs = ( + performance_data_ready_for_curve.select(pl.col("chosen_cutoff")) + .drop_nulls() + .unique() + .sort("chosen_cutoff") + .to_series() + .to_list() + ) + + palette = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ] + + colors_dictionary = { + **{ + key: "#BEBEBE" + for key in [ + "random_guess", + "perfect_model", + "treat_none", + "treat_all", + ] + }, + **{ + variant_key: ( + palette[group_index] if multiple_reference_groups else "#000000" + ) + for group_index, reference_group in enumerate(reference_group_keys) + for variant_key in [ + reference_group, + f"random_guess_{reference_group}", + f"perfect_model_{reference_group}", + f"treat_none_{reference_group}", + f"treat_all_{reference_group}", + ] + }, + } + rtichoke_curve_list = { "size": size, - # "axes_ranges": axes_ranges, + "axes_ranges": axes_ranges, "x_label": x_label, "y_label": y_label, "animation_slider_cutoff_prefix": animation_slider_cutoff_prefix, "reference_group_keys": reference_group_keys, - # "performance_data_ready_for_curve": performance_data_ready_for_curve, - # "reference_data": reference_data, - # "cutoffs": cutoffs, - # "colors_dictionary": colors_dictionary, - # "multiple_reference_groups": multiple_reference_groups, + "performance_data_ready_for_curve": performance_data_ready_for_curve, + "reference_data": reference_data, + "cutoffs": cutoffs, + "colors_dictionary": colors_dictionary, + "multiple_reference_groups": multiple_reference_groups, } return rtichoke_curve_list From 6836add50f1fe30090eb672f373165dd3f5843c2 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Thu, 11 Dec 2025 16:00:10 +0200 Subject: [PATCH 4/4] chore: update version number --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 10ba1ce..da65996 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "pyarrow>=21.0.0", ] name = "rtichoke" -version = "0.1.22" +version = "0.1.23" description = "interactive visualizations for performance of predictive models" readme = "README.md" diff --git a/uv.lock b/uv.lock index daa9a08..557790b 100644 --- a/uv.lock +++ b/uv.lock @@ -3888,7 +3888,7 @@ wheels = [ [[package]] name = "rtichoke" -version = "0.1.22" +version = "0.1.23" source = { editable = "." } dependencies = [ { name = "marimo", version = "0.17.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },