11"""
2- A module for Gains Curves
2+ A module for Gains Curves using Plotly helpers
33"""
44
5- from typing import Dict , List , Optional
6- from pandas import DataFrame
5+ from typing import Dict , List , Sequence , Union
76from plotly .graph_objs ._figure import Figure
8- from rtichoke .helpers .send_post_request_to_r_rtichoke import create_rtichoke_curve
9- from rtichoke .helpers .send_post_request_to_r_rtichoke import plot_rtichoke_curve
7+ from rtichoke .helpers .plotly_helper_functions import (
8+ _create_rtichoke_plotly_curve_binary ,
9+ _plot_rtichoke_curve_binary ,
10+ )
11+ import numpy as np
12+ import polars as pl
1013
1114
1215def create_gains_curve (
13- probs : Dict [str , List [ float ] ],
14- reals : Dict [str , List [ int ]],
16+ probs : Dict [str , np . ndarray ],
17+ reals : Union [ np . ndarray , Dict [str , np . ndarray ]],
1518 by : float = 0.01 ,
16- stratified_by : str = "probability_threshold" ,
17- size : Optional [ int ] = None ,
19+ stratified_by : Sequence [ str ] = [ "probability_threshold" ] ,
20+ size : int = 600 ,
1821 color_values : List [str ] = [
1922 "#1b9e77" ,
2023 "#d95f02" ,
@@ -37,78 +40,86 @@ def create_gains_curve(
3740 "#D1603D" ,
3841 "#585123" ,
3942 ],
40- url_api : str = "http://localhost:4242/" ,
4143) -> Figure :
42- """Create Gains Curve
44+ """Create Gains Curve.
4345
44- Args:
45- probs (Dict[str, List[float]]): _description_
46- reals (Dict[str, List[int]]): _description_
47- by (float, optional): _description_. Defaults to 0.01.
48- stratified_by (str, optional): _description_. Defaults to "probability_threshold".
49- size (Optional[int], optional): _description_. Defaults to None.
50- color_values (List[str], optional): _description_. Defaults to None.
51- url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
46+ Parameters
47+ ----------
48+ probs : Dict[str, np.ndarray]
49+ Dictionary mapping a label or group name to an array of predicted
50+ probabilities for the positive class.
51+ reals : Union[np.ndarray, Dict[str, np.ndarray]]
52+ Ground-truth binary labels (0/1) as a single array, or a dictionary
53+ mapping the same label/group keys used in ``probs`` to arrays of
54+ ground-truth labels.
55+ by : float, optional
56+ Resolution for probability thresholds when computing the curve
57+ (step size). Default is 0.01.
58+ stratified_by : Sequence[str], optional
59+ Sequence of column names to stratify the performance data by.
60+ Default is ["probability_threshold"].
61+ size : int, optional
62+ Plot size in pixels (width and height). Default is 600.
63+ color_values : List[str], optional
64+ List of color hex strings to use for the plotted lines. If not
65+ provided, a default palette is used.
5266
53- Returns:
54- Figure: _description_
67+ Returns
68+ -------
69+ Figure
70+ A Plotly ``Figure`` containing the Gains curve(s).
71+
72+ Notes
73+ -----
74+ The function delegates computation and plotting to
75+ ``_create_rtichoke_plotly_curve_binary`` and returns the resulting
76+ Plotly figure.
5577 """
56- fig = create_rtichoke_curve (
78+ fig = _create_rtichoke_plotly_curve_binary (
5779 probs ,
5880 reals ,
5981 by = by ,
6082 stratified_by = stratified_by ,
6183 size = size ,
6284 color_values = color_values ,
63- url_api = url_api ,
6485 curve = "gains" ,
6586 )
6687 return fig
6788
6889
6990def plot_gains_curve (
70- performance_data : DataFrame ,
71- size : Optional [int ] = None ,
72- color_values : List [str ] = [
73- "#1b9e77" ,
74- "#d95f02" ,
75- "#7570b3" ,
76- "#e7298a" ,
77- "#07004D" ,
78- "#E6AB02" ,
79- "#FE5F55" ,
80- "#54494B" ,
81- "#006E90" ,
82- "#BC96E6" ,
83- "#52050A" ,
84- "#1F271B" ,
85- "#BE7C4D" ,
86- "#63768D" ,
87- "#08A045" ,
88- "#320A28" ,
89- "#82FF9E" ,
90- "#2176FF" ,
91- "#D1603D" ,
92- "#585123" ,
93- ],
94- url_api : str = "http://localhost:4242/" ,
91+ performance_data : pl .DataFrame ,
92+ stratified_by : Sequence [str ] = ["probability_threshold" ],
93+ size : int = 600 ,
9594) -> Figure :
96- """Plot Gains Curve
95+ """Plot Gains curve from performance data.
96+
97+ Parameters
98+ ----------
99+ performance_data : pl.DataFrame
100+ A Polars DataFrame containing performance metrics for the Gains curve.
101+ Expected columns include (but may not be limited to)
102+ ``probability_threshold`` and gains-related metrics, plus any
103+ stratification columns.
104+ stratified_by : Sequence[str], optional
105+ Sequence of column names used for stratification in the
106+ ``performance_data``. Default is ["probability_threshold"].
107+ size : int, optional
108+ Plot size in pixels (width and height). Default is 600.
97109
98- Args:
99- performance_data (DataFrame): _description_
100- size (Optional[int], optional): _description_. Defaults to None.
101- color_values (List[str], optional): _description_. Defaults to None.
102- url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
110+ Returns
111+ -------
112+ Figure
113+ A Plotly ``Figure`` containing the Gains plot.
103114
104- Returns:
105- Figure: _description_
115+ Notes
116+ -----
117+ This function wraps ``_plot_rtichoke_curve_binary`` to produce a
118+ ready-to-render Plotly figure from precomputed performance data.
106119 """
107- fig = plot_rtichoke_curve (
120+ fig = _plot_rtichoke_curve_binary (
108121 performance_data ,
109122 size = size ,
110- color_values = color_values ,
111- url_api = url_api ,
112123 curve = "gains" ,
113124 )
114125 return fig
0 commit comments