diff --git a/src/rtichoke/__init__.py b/src/rtichoke/__init__.py index c60435b..cc5b47a 100644 --- a/src/rtichoke/__init__.py +++ b/src/rtichoke/__init__.py @@ -4,27 +4,40 @@ __version__ = version("rtichoke") -from rtichoke.discrimination.roc import create_roc_curve as create_roc_curve +from rtichoke.discrimination.roc import ( + create_roc_curve as create_roc_curve, + create_roc_curve_times as create_roc_curve_times, +) from rtichoke.discrimination.roc import plot_roc_curve as plot_roc_curve -from rtichoke.discrimination.lift import create_lift_curve as create_lift_curve +from rtichoke.discrimination.lift import ( + create_lift_curve as create_lift_curve, + create_lift_curve_times as create_lift_curve_times, +) from rtichoke.discrimination.lift import plot_lift_curve as plot_lift_curve from rtichoke.discrimination.precision_recall import ( create_precision_recall_curve as create_precision_recall_curve, + create_precision_recall_curve_times as create_precision_recall_curve_times, ) from rtichoke.discrimination.precision_recall import ( plot_precision_recall_curve as plot_precision_recall_curve, ) -from rtichoke.discrimination.gains import create_gains_curve as create_gains_curve +from rtichoke.discrimination.gains import ( + create_gains_curve as create_gains_curve, + create_gains_curve_times as create_gains_curve_times, +) from rtichoke.discrimination.gains import plot_gains_curve as plot_gains_curve # from rtichoke.calibration.calibration import ( # create_calibration_curve as create_calibration_curve, # ) -from rtichoke.utility.decision import create_decision_curve as create_decision_curve +from rtichoke.utility.decision import ( + create_decision_curve as create_decision_curve, + create_decision_curve_times as create_decision_curve_times, +) from rtichoke.utility.decision import plot_decision_curve as plot_decision_curve from rtichoke.performance_data.performance_data import ( diff --git a/src/rtichoke/discrimination/gains.py b/src/rtichoke/discrimination/gains.py index 39334ef..59366f4 100644 --- a/src/rtichoke/discrimination/gains.py +++ b/src/rtichoke/discrimination/gains.py @@ -5,6 +5,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, ) @@ -123,3 +124,58 @@ def plot_gains_curve( curve="gains", ) return fig + + +def create_gains_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Lift Curve.""" + + fig = _create_rtichoke_plotly_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve="gains", + ) + + return fig diff --git a/src/rtichoke/discrimination/lift.py b/src/rtichoke/discrimination/lift.py index 65f2553..5f358af 100644 --- a/src/rtichoke/discrimination/lift.py +++ b/src/rtichoke/discrimination/lift.py @@ -5,6 +5,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, ) @@ -123,3 +124,58 @@ def plot_lift_curve( curve="lift", ) return fig + + +def create_lift_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Lift Curve.""" + + fig = _create_rtichoke_plotly_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve="lift", + ) + + return fig diff --git a/src/rtichoke/discrimination/precision_recall.py b/src/rtichoke/discrimination/precision_recall.py index 06314d7..565cf5c 100644 --- a/src/rtichoke/discrimination/precision_recall.py +++ b/src/rtichoke/discrimination/precision_recall.py @@ -5,6 +5,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, ) @@ -123,3 +124,58 @@ def plot_precision_recall_curve( curve="precision recall", ) return fig + + +def create_precision_recall_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Lift Curve.""" + + fig = _create_rtichoke_plotly_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve="precision recall", + ) + + return fig diff --git a/src/rtichoke/discrimination/roc.py b/src/rtichoke/discrimination/roc.py index ae2d7e4..d8a8ed0 100644 --- a/src/rtichoke/discrimination/roc.py +++ b/src/rtichoke/discrimination/roc.py @@ -5,6 +5,7 @@ from typing import Dict, List, Union, Sequence from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, ) @@ -124,3 +125,58 @@ def plot_roc_curve( ) return fig + + +def create_roc_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Lift Curve.""" + + fig = _create_rtichoke_plotly_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve="roc", + ) + + return fig diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 38d1885..ab475be 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -8,6 +8,9 @@ from typing import Any, Dict, Union, Sequence import numpy as np from rtichoke.performance_data.performance_data import prepare_performance_data +from rtichoke.performance_data.performance_data_times import ( + prepare_performance_data_times, +) _HOVER_LABELS = { "false_positive_rate": "1 - Specificity (FPR)", @@ -22,6 +25,21 @@ "ppcr": "Predicted Positives", } +DEFAULT_MODEBAR_BUTTONS_TO_REMOVE = [ + "zoom2d", + "pan2d", + "select2d", + "lasso2d", + "zoomIn2d", + "zoomOut2d", + "autoScale2d", + "resetScale2d", + "hoverClosestCartesian", + "hoverCompareCartesian", + "toggleSpikelines", + "toImage", +] + def _create_rtichoke_plotly_curve_binary( probs: Dict[str, np.ndarray], @@ -51,6 +69,44 @@ def _create_rtichoke_plotly_curve_binary( return fig +def _create_rtichoke_plotly_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + min_p_threshold: float = 0, + max_p_threshold: float = 1, + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values=None, + curve: str = "roc", +) -> go.Figure: + performance_data = prepare_performance_data_times( + probs, + reals, + times, + by=by, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + stratified_by=stratified_by, + ) + + rtichoke_curve_list_times = _create_rtichoke_curve_list_times( + performance_data, stratified_by=stratified_by[0], curve=curve + ) + + fig = _create_plotly_curve_times(rtichoke_curve_list_times) + + return fig + + def _plot_rtichoke_curve_binary( performance_data: pl.DataFrame, stratified_by: str = "probability_threshold", @@ -1490,7 +1546,7 @@ def _xy_for_reference( def _xy_at_cutoff( group: str, cutoff: float, fixed_time_horizon: float - ) -> tuple[Any, Any]: + ) -> tuple[Any, Any, Any]: row = ( rtichoke_curve_list["performance_data_ready_for_curve"] .filter( @@ -1500,13 +1556,13 @@ def _xy_at_cutoff( & pl.col("x").is_not_null() & pl.col("y").is_not_null() ) - .select(["x", "y"]) + .select(["x", "y", "text"]) .limit(1) ) if row.height == 0: - return None, None + return None, None, None r = row.row(0) - return r[0], r[1] + return r[0], r[1], r[2] non_interactive_curve = [] for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"]: @@ -1532,7 +1588,7 @@ def _xy_at_cutoff( font_color="white", ), hoverinfo="text", - showlegend=fixed_time_horizon == initial_fixed_time_horizon, + showlegend=rtichoke_curve_list["multiple_reference_groups"], visible=fixed_time_horizon == initial_fixed_time_horizon, ) ) @@ -1540,10 +1596,10 @@ def _xy_at_cutoff( marker_traces: list[go.Scatter] = [] for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"]: for group in rtichoke_curve_list["reference_group_keys"]: - x_val, y_val = ( + x_val, y_val, text_val = ( _xy_at_cutoff(group, initial_cutoff, fixed_time_horizon) if initial_cutoff is not None - else (None, None) + else (None, None, None) ) marker_traces.append( go.Scatter( @@ -1555,19 +1611,26 @@ def _xy_at_cutoff( "color": ( rtichoke_curve_list["colors_dictionary"].get(group) if rtichoke_curve_list["multiple_reference_groups"] - else "#f6e3be", + else "#f6e3be" ), "line": {"width": 3, "color": "black"}, }, name=f"{group} @ cutoff", legendgroup=group, hoverlabel=dict( - bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), - bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), - font_color="white", + bgcolor="#f6e3be" + if not rtichoke_curve_list["multiple_reference_groups"] + else rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor="#f6e3be" + if not rtichoke_curve_list["multiple_reference_groups"] + else rtichoke_curve_list["colors_dictionary"].get(group), + font_color="black" + if not rtichoke_curve_list["multiple_reference_groups"] + else "white", ), showlegend=False, hoverinfo="text", + text=text_val, visible=fixed_time_horizon == initial_fixed_time_horizon, ) ) @@ -1608,35 +1671,40 @@ def _xy_at_cutoff( ) ) - cutoff_steps = [ - { - "method": "restyle", - "args": [ - { - "x": [ - [xy[0]] if xy[0] is not None else [] - for xy in marker_points_at_cutoff - ], - "y": [ - [xy[1]] if xy[1] is not None else [] - for xy in marker_points_at_cutoff - ], - }, - cutoff_target_indices, - ], - "label": f"{cutoff:g}", - } - for cutoff in rtichoke_curve_list["cutoffs"] - for marker_points_at_cutoff in [ - [ - _xy_at_cutoff(group, cutoff, fixed_time_horizon) - if cutoff is not None - else (None, None) - for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"] - for group in rtichoke_curve_list["reference_group_keys"] - ] + def marker_values_for_cutoff( + cutoff: float, + ) -> tuple[list[list], list[list], list[list]]: + marker_values = [ + _xy_at_cutoff(group, cutoff, fixed_time_horizon) + if cutoff is not None + else (None, None, None) + for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"] + for group in rtichoke_curve_list["reference_group_keys"] ] - ] + + xs = [[x] if x is not None else [] for x, _, _ in marker_values] + ys = [[y] if y is not None else [] for _, y, _ in marker_values] + texts = [[text] if text is not None else [] for _, _, text in marker_values] + + return xs, ys, texts + + cutoff_steps = [] + for cutoff in rtichoke_curve_list["cutoffs"]: + xs, ys, texts = marker_values_for_cutoff(cutoff) + cutoff_steps.append( + { + "method": "restyle", + "args": [ + { + "x": xs, + "y": ys, + "text": texts, + }, + cutoff_target_indices, + ], + "label": f"{cutoff:g}", + } + ) steps_fixed_time_horizon = [] total_traces = num_curve_traces + num_marker_traces + len(reference_traces) @@ -1682,6 +1750,7 @@ def _xy_at_cutoff( axes_ranges=rtichoke_curve_list["axes_ranges"], x_label=rtichoke_curve_list["x_label"], y_label=rtichoke_curve_list["y_label"], + show_legend=rtichoke_curve_list["multiple_reference_groups"], ) return go.Figure( @@ -1691,6 +1760,10 @@ def _xy_at_cutoff( def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figure: + initial_cutoff = ( + rtichoke_curve_list["cutoffs"][0] if rtichoke_curve_list["cutoffs"] else None + ) + non_interactive_curve = [ go.Scatter( x=rtichoke_curve_list["performance_data_ready_for_curve"] @@ -1715,15 +1788,58 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur font_color="white", ), hoverinfo="text", - showlegend=True, + showlegend=rtichoke_curve_list["multiple_reference_groups"], ) for group in rtichoke_curve_list["reference_group_keys"] ] + def xy_at_cutoff(group, c): + row = ( + rtichoke_curve_list["performance_data_ready_for_curve"] + .filter( + (pl.col("reference_group") == group) + & (pl.col("chosen_cutoff") == c) + & pl.col("x").is_not_null() + & pl.col("y").is_not_null() + ) + .select(["x", "y", "text"]) + .limit(1) + ) + if row.height == 0: + return None, None, None + r = row.row(0) + return r[0], r[1], r[2] + + def marker_values_for_cutoff( + cutoff: float, + ) -> tuple[list[list], list[list], list[list]]: + marker_values = [ + xy_at_cutoff(group, cutoff) + for group in rtichoke_curve_list["reference_group_keys"] + ] + + xs = [[x] if x is not None else [] for x, _, _ in marker_values] + ys = [[y] if y is not None else [] for _, y, _ in marker_values] + texts = [[text] if text is not None else [] for _, _, text in marker_values] + + return xs, ys, texts + + initial_xs, initial_ys, initial_texts = ( + marker_values_for_cutoff(initial_cutoff) + if initial_cutoff is not None + else ( + [[] for _ in rtichoke_curve_list["reference_group_keys"]], + [[] for _ in rtichoke_curve_list["reference_group_keys"]], + [[] for _ in rtichoke_curve_list["reference_group_keys"]], + ) + ) + initial_interactive_markers = [ go.Scatter( - x=[], - y=[], + x=initial_xs[idx], + y=initial_ys[idx], + text=initial_texts[idx], + # hovertext=initial_texts[idx], mode="markers", marker={ "size": 12, @@ -1737,14 +1853,20 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur name=f"{group} @ cutoff", legendgroup=group, hoverlabel=dict( - bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), - bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), - font_color="white", + bgcolor="#f6e3be" + if not rtichoke_curve_list["multiple_reference_groups"] + else rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor="#f6e3be" + if not rtichoke_curve_list["multiple_reference_groups"] + else rtichoke_curve_list["colors_dictionary"].get(group), + font_color="black" + if not rtichoke_curve_list["multiple_reference_groups"] + else "white", ), showlegend=False, hoverinfo="text", ) - for group in rtichoke_curve_list["reference_group_keys"] + for idx, group in enumerate(rtichoke_curve_list["reference_group_keys"]) ] reference_traces = [ @@ -1784,47 +1906,24 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur ) ) - def xy_at_cutoff(group, c): - row = ( - rtichoke_curve_list["performance_data_ready_for_curve"] - .filter( - (pl.col("reference_group") == group) - & (pl.col("chosen_cutoff") == c) - & pl.col("x").is_not_null() - & pl.col("y").is_not_null() - ) - .select(["x", "y"]) - .limit(1) + steps = [] + for cutoff in rtichoke_curve_list["cutoffs"]: + xs, ys, texts = marker_values_for_cutoff(cutoff) + steps.append( + { + "method": "restyle", + "args": [ + { + "x": xs, + "y": ys, + "text": texts, + # "hovertext": texts, + }, + dyn_idx, + ], + "label": f"{cutoff:g}", + } ) - if row.height == 0: - return None, None - r = row.row(0) # (x, y) - return r[0], r[1] - - steps = [ - { - "method": "restyle", - "args": [ - { - "x": [ - [xy_at_cutoff(group, cutoff)[0]] - if xy_at_cutoff(group, cutoff)[0] is not None - else [] - for group in rtichoke_curve_list["reference_group_keys"] - ], - "y": [ - [xy_at_cutoff(group, cutoff)[1]] - if xy_at_cutoff(group, cutoff)[1] is not None - else [] - for group in rtichoke_curve_list["reference_group_keys"] - ], - }, - dyn_idx, - ], - "label": f"{cutoff:g}", - } - for cutoff in rtichoke_curve_list["cutoffs"] - ] slider_dict = _create_slider_dict( rtichoke_curve_list["animation_slider_prefix"], steps @@ -1836,6 +1935,7 @@ def xy_at_cutoff(group, c): axes_ranges=rtichoke_curve_list["axes_ranges"], x_label=rtichoke_curve_list["x_label"], y_label=rtichoke_curve_list["y_label"], + show_legend=rtichoke_curve_list["multiple_reference_groups"], ) return go.Figure( @@ -1850,6 +1950,7 @@ def _create_curve_layout( axes_ranges: dict[str, list[float]] | None = None, x_label: str | None = None, y_label: str | None = None, + show_legend: bool = True, ) -> dict[str, Any]: sliders = slider_dict if isinstance(slider_dict, list) else [slider_dict] @@ -1901,6 +2002,7 @@ def _create_curve_layout( } ], "sliders": sliders, + "modebar": {"remove": list(DEFAULT_MODEBAR_BUTTONS_TO_REMOVE)}, } if axes_ranges is not None: diff --git a/src/rtichoke/utility/decision.py b/src/rtichoke/utility/decision.py index ca2b71c..50f0e6d 100644 --- a/src/rtichoke/utility/decision.py +++ b/src/rtichoke/utility/decision.py @@ -6,6 +6,7 @@ from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( _create_rtichoke_plotly_curve_binary, + _create_rtichoke_plotly_curve_times, _plot_rtichoke_curve_binary, ) import numpy as np @@ -162,3 +163,68 @@ def plot_decision_curve( max_p_threshold=max_p_threshold, ) return fig + + +def create_decision_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + decision_type: str = "conventional", + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + min_p_threshold: float = 0, + max_p_threshold: float = 1, + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Decision Curve.""" + + if decision_type == "conventional": + curve = "decision" + else: + curve = "interventions avoided" + + fig = _create_rtichoke_plotly_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve=curve, + min_p_threshold=min_p_threshold, + max_p_threshold=max_p_threshold, + ) + + return fig