diff --git a/pyproject.toml b/pyproject.toml index da65996..d71c787 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "pyarrow>=21.0.0", ] name = "rtichoke" -version = "0.1.23" +version = "0.1.24" description = "interactive visualizations for performance of predictive models" readme = "README.md" diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 78630d5..38d1885 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -1126,8 +1126,6 @@ def _create_rtichoke_curve_list_times( performance_data ) - print("aj_estimates_from_performance_data", aj_estimates_from_performance_data) - multiple_populations = _check_if_multiple_populations_are_being_validated_times( aj_estimates_from_performance_data ) @@ -1154,14 +1152,6 @@ def _create_rtichoke_curve_list_times( performance_data_with_hover_text, x_metric, y_metric ) - reference_data = _create_reference_lines_data( - curve=curve, - aj_estimates_from_performance_data=aj_estimates_from_performance_data, - multiple_populations=multiple_populations, - min_p_threshold=min_p_threshold, - max_p_threshold=max_p_threshold, - ) - axes_ranges = extract_axes_ranges( performance_data_ready_for_curve, curve=curve, @@ -1171,6 +1161,55 @@ def _create_rtichoke_curve_list_times( reference_group_keys = performance_data["reference_group"].unique().to_list() + fixed_time_horizons = ( + performance_data_ready_for_curve.select(pl.col("fixed_time_horizon")) + .unique() + .sort("fixed_time_horizon") + .to_series() + .to_list() + ) + + if "fixed_time_horizon" in performance_data_ready_for_curve.columns: + reference_data_per_horizon = [] + + for fixed_time_horizon in fixed_time_horizons: + reference_lines = _create_reference_lines_data( + curve=curve, + aj_estimates_from_performance_data=( + aj_estimates_from_performance_data.filter( + pl.col("fixed_time_horizon") == fixed_time_horizon + ) + ), + multiple_populations=multiple_populations, + min_p_threshold=min_p_threshold, + max_p_threshold=max_p_threshold, + ).with_columns(pl.lit(fixed_time_horizon).alias("fixed_time_horizon")) + + reference_data_per_horizon.append(reference_lines) + + reference_data = ( + pl.concat(reference_data_per_horizon, how="vertical") + if reference_data_per_horizon + else pl.DataFrame() + ) + + else: + reference_data = _create_reference_lines_data( + curve=curve, + aj_estimates_from_performance_data=aj_estimates_from_performance_data, + multiple_populations=multiple_populations, + min_p_threshold=min_p_threshold, + max_p_threshold=max_p_threshold, + ) + + if ( + "fixed_time_horizon" in performance_data_ready_for_curve.columns + and "fixed_time_horizon" not in reference_data.columns + ): + reference_data = reference_data.join( + pl.DataFrame({"fixed_time_horizon": fixed_time_horizons}), how="cross" + ) + cutoffs = ( performance_data_ready_for_curve.select(pl.col("chosen_cutoff")) .drop_nulls() @@ -1234,6 +1273,7 @@ def _create_rtichoke_curve_list_times( "x_label": x_label, "y_label": y_label, "animation_slider_cutoff_prefix": animation_slider_cutoff_prefix, + "fixed_time_horizons": fixed_time_horizons, "reference_group_keys": reference_group_keys, "performance_data_ready_for_curve": performance_data_ready_for_curve, "reference_data": reference_data, @@ -1386,16 +1426,23 @@ def _create_rtichoke_curve_list_binary( def _select_and_rename_necessary_variables( performance_data: pl.DataFrame, x_perf_metric: str, y_perf_metric: str ) -> pl.DataFrame: - return performance_data.select( + selected_columns = [ pl.col("reference_group"), pl.col("chosen_cutoff"), pl.col(x_perf_metric).alias("x"), pl.col(y_perf_metric).alias("y"), pl.col("text"), - ) + ] + + if "fixed_time_horizon" in performance_data.columns: + selected_columns.append(pl.col("fixed_time_horizon")) + return performance_data.select(*selected_columns) -def _create_slider_dict(animation_slider_prefic: str, steps: dict) -> dict[str, Any]: + +def _create_slider_dict( + animation_slider_prefic: str, steps: list[dict[str, Any]] +) -> dict[str, Any]: slider_dict = { "active": 0, "yanchor": "top", @@ -1417,6 +1464,232 @@ def _create_slider_dict(animation_slider_prefic: str, steps: dict) -> dict[str, return slider_dict +def _create_plotly_curve_times(rtichoke_curve_list: dict[str, Any]) -> go.Figure: + initial_fixed_time_horizon = rtichoke_curve_list["fixed_time_horizons"][0] + initial_cutoff = ( + rtichoke_curve_list["cutoffs"][0] if rtichoke_curve_list["cutoffs"] else None + ) + + def _xy_for_curve( + group: str, fixed_time_horizon: float + ) -> tuple[list[Any], list[Any]]: + subset = rtichoke_curve_list["performance_data_ready_for_curve"].filter( + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == fixed_time_horizon) + ) + return subset["x"].to_list(), subset["y"].to_list() + + def _xy_for_reference( + group: str, fixed_time_horizon: float + ) -> tuple[list[Any], list[Any], list[Any]]: + subset = rtichoke_curve_list["reference_data"].filter( + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == fixed_time_horizon) + ) + return subset["x"].to_list(), subset["y"].to_list(), subset["text"].to_list() + + def _xy_at_cutoff( + group: str, cutoff: float, fixed_time_horizon: float + ) -> tuple[Any, Any]: + row = ( + rtichoke_curve_list["performance_data_ready_for_curve"] + .filter( + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == fixed_time_horizon) + & (pl.col("chosen_cutoff") == cutoff) + & pl.col("x").is_not_null() + & pl.col("y").is_not_null() + ) + .select(["x", "y"]) + .limit(1) + ) + if row.height == 0: + return None, None + r = row.row(0) + return r[0], r[1] + + non_interactive_curve = [] + for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"]: + for group in rtichoke_curve_list["reference_group_keys"]: + non_interactive_curve.append( + go.Scatter( + x=_xy_for_curve(group, fixed_time_horizon)[0], + y=_xy_for_curve(group, fixed_time_horizon)[1], + text=rtichoke_curve_list["performance_data_ready_for_curve"].filter( + (pl.col("reference_group") == group) + & (pl.col("fixed_time_horizon") == fixed_time_horizon) + )["text"], + mode="markers+lines", + name=group, + legendgroup=group, + line={ + "width": 2, + "color": rtichoke_curve_list["colors_dictionary"].get(group), + }, + hoverlabel=dict( + bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), + font_color="white", + ), + hoverinfo="text", + showlegend=fixed_time_horizon == initial_fixed_time_horizon, + visible=fixed_time_horizon == initial_fixed_time_horizon, + ) + ) + + marker_traces: list[go.Scatter] = [] + for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"]: + for group in rtichoke_curve_list["reference_group_keys"]: + x_val, y_val = ( + _xy_at_cutoff(group, initial_cutoff, fixed_time_horizon) + if initial_cutoff is not None + else (None, None) + ) + marker_traces.append( + go.Scatter( + x=[x_val] if x_val is not None else [], + y=[y_val] if y_val is not None else [], + mode="markers", + marker={ + "size": 12, + "color": ( + rtichoke_curve_list["colors_dictionary"].get(group) + if rtichoke_curve_list["multiple_reference_groups"] + else "#f6e3be", + ), + "line": {"width": 3, "color": "black"}, + }, + name=f"{group} @ cutoff", + legendgroup=group, + hoverlabel=dict( + bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), + font_color="white", + ), + showlegend=False, + hoverinfo="text", + visible=fixed_time_horizon == initial_fixed_time_horizon, + ) + ) + + reference_traces = [] + for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"]: + for group in rtichoke_curve_list["colors_dictionary"].keys(): + reference_traces.append( + go.Scatter( + x=_xy_for_reference(group, fixed_time_horizon)[0], + y=_xy_for_reference(group, fixed_time_horizon)[1], + mode="lines", + name=group, + legendgroup=group, + line=dict( + dash="dot", + color=rtichoke_curve_list["colors_dictionary"].get(group), + width=1.5, + ), + hoverlabel=dict( + bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), + font_color="white", + ), + hoverinfo="text", + text=_xy_for_reference(group, fixed_time_horizon)[2], + showlegend=False, + visible=fixed_time_horizon == initial_fixed_time_horizon, + ) + ) + + num_curve_traces = len(non_interactive_curve) + num_marker_traces = len(marker_traces) + cutoff_target_indices = list( + range( + num_curve_traces, + num_curve_traces + num_marker_traces, + ) + ) + + cutoff_steps = [ + { + "method": "restyle", + "args": [ + { + "x": [ + [xy[0]] if xy[0] is not None else [] + for xy in marker_points_at_cutoff + ], + "y": [ + [xy[1]] if xy[1] is not None else [] + for xy in marker_points_at_cutoff + ], + }, + cutoff_target_indices, + ], + "label": f"{cutoff:g}", + } + for cutoff in rtichoke_curve_list["cutoffs"] + for marker_points_at_cutoff in [ + [ + _xy_at_cutoff(group, cutoff, fixed_time_horizon) + if cutoff is not None + else (None, None) + for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"] + for group in rtichoke_curve_list["reference_group_keys"] + ] + ] + ] + + steps_fixed_time_horizon = [] + total_traces = num_curve_traces + num_marker_traces + len(reference_traces) + for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"]: + visibility: list[bool] = [] + for horizon in rtichoke_curve_list["fixed_time_horizons"]: + horizon_visible = horizon == fixed_time_horizon + visibility.extend( + [horizon_visible] * len(rtichoke_curve_list["reference_group_keys"]) + ) + for horizon in rtichoke_curve_list["fixed_time_horizons"]: + horizon_visible = horizon == fixed_time_horizon + visibility.extend( + [horizon_visible] * len(rtichoke_curve_list["reference_group_keys"]) + ) + for horizon in rtichoke_curve_list["fixed_time_horizons"]: + horizon_visible = horizon == fixed_time_horizon + visibility.extend( + [horizon_visible] * len(rtichoke_curve_list["colors_dictionary"].keys()) + ) + + steps_fixed_time_horizon.append( + { + "method": "restyle", + "args": [ + {"visible": visibility}, + list(range(total_traces)), + ], + "label": f"{fixed_time_horizon:g}", + } + ) + + slider_cutoff_dict = _create_slider_dict( + rtichoke_curve_list["animation_slider_cutoff_prefix"], cutoff_steps + ) + slider_fixed_time_horizon_dict = _create_slider_dict( + "Fixed Time Horizon: ", steps_fixed_time_horizon + ) + + curve_layout = _create_curve_layout( + size=rtichoke_curve_list["size"], + slider_dict=[slider_cutoff_dict, slider_fixed_time_horizon_dict], + axes_ranges=rtichoke_curve_list["axes_ranges"], + x_label=rtichoke_curve_list["x_label"], + y_label=rtichoke_curve_list["y_label"], + ) + + return go.Figure( + data=non_interactive_curve + marker_traces + reference_traces, + layout=curve_layout, + ) + + def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figure: non_interactive_curve = [ go.Scatter( @@ -1573,11 +1846,28 @@ def xy_at_cutoff(group, c): def _create_curve_layout( size: int, - slider_dict: dict, + slider_dict: dict | list[dict], axes_ranges: dict[str, list[float]] | None = None, x_label: str | None = None, y_label: str | None = None, ) -> dict[str, Any]: + sliders = slider_dict if isinstance(slider_dict, list) else [slider_dict] + + if len(sliders) > 1: + vertical_spacing = -0.4 + for idx, slider in enumerate(sliders): + slider_y = slider.get("y", 0) + base_y = slider_y if isinstance(slider_y, (int, float)) else -0.2 + slider["y"] = base_y + vertical_spacing * float(idx) + base_pad = ( + slider.get("pad", {}) if isinstance(slider.get("pad"), dict) else {} + ) + slider["pad"] = { + "t": max(120, base_pad.get("t", 0)), + "b": max(80, base_pad.get("b", 0)), + **base_pad, + } + curve_layout = { "xaxis": {"showgrid": False}, "yaxis": {"showgrid": False}, @@ -1594,7 +1884,7 @@ def _create_curve_layout( "bgcolor": "rgba(0, 0, 0, 0)", "bordercolor": "rgba(0, 0, 0, 0)", }, - "height": size + 50, + "height": size + 100, "width": size, # "hoverlabel": {"bgcolor": "rgba(0,0,0,0)", "bordercolor": "rgba(0,0,0,0)"}, "updatemenus": [ @@ -1610,7 +1900,7 @@ def _create_curve_layout( ], } ], - "sliders": [slider_dict], + "sliders": sliders, } if axes_ranges is not None: diff --git a/uv.lock b/uv.lock index 557790b..e8ffad5 100644 --- a/uv.lock +++ b/uv.lock @@ -3888,7 +3888,7 @@ wheels = [ [[package]] name = "rtichoke" -version = "0.1.23" +version = "0.1.24" source = { editable = "." } dependencies = [ { name = "marimo", version = "0.17.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },