Skip to content

Commit d9c402f

Browse files
authored
Merge pull request #201 from uriahf/observable_report
Observable report
2 parents 8beb473 + a8e0f54 commit d9c402f

File tree

9 files changed

+1471
-261
lines changed

9 files changed

+1471
-261
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ dependencies = [
2121
"polarstate==0.1.8",
2222
]
2323
name = "rtichoke"
24-
version = "0.1.13"
24+
version = "0.1.14"
2525
description = "interactive visualizations for performance of predictive models"
2626
readme = "README.md"
2727

Lines changed: 70 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
"""
2-
A module for Gains Curves
2+
A module for Gains Curves using Plotly helpers
33
"""
44

5-
from typing import Dict, List, Optional
6-
from pandas import DataFrame
5+
from typing import Dict, List, Sequence, Union
76
from plotly.graph_objs._figure import Figure
8-
from rtichoke.helpers.send_post_request_to_r_rtichoke import create_rtichoke_curve
9-
from rtichoke.helpers.send_post_request_to_r_rtichoke import plot_rtichoke_curve
7+
from rtichoke.helpers.plotly_helper_functions import (
8+
_create_rtichoke_plotly_curve_binary,
9+
_plot_rtichoke_curve_binary,
10+
)
11+
import numpy as np
12+
import polars as pl
1013

1114

1215
def create_gains_curve(
13-
probs: Dict[str, List[float]],
14-
reals: Dict[str, List[int]],
16+
probs: Dict[str, np.ndarray],
17+
reals: Union[np.ndarray, Dict[str, np.ndarray]],
1518
by: float = 0.01,
16-
stratified_by: str = "probability_threshold",
17-
size: Optional[int] = None,
19+
stratified_by: Sequence[str] = ["probability_threshold"],
20+
size: int = 600,
1821
color_values: List[str] = [
1922
"#1b9e77",
2023
"#d95f02",
@@ -37,78 +40,86 @@ def create_gains_curve(
3740
"#D1603D",
3841
"#585123",
3942
],
40-
url_api: str = "http://localhost:4242/",
4143
) -> Figure:
42-
"""Create Gains Curve
44+
"""Create Gains Curve.
4345
44-
Args:
45-
probs (Dict[str, List[float]]): _description_
46-
reals (Dict[str, List[int]]): _description_
47-
by (float, optional): _description_. Defaults to 0.01.
48-
stratified_by (str, optional): _description_. Defaults to "probability_threshold".
49-
size (Optional[int], optional): _description_. Defaults to None.
50-
color_values (List[str], optional): _description_. Defaults to None.
51-
url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
46+
Parameters
47+
----------
48+
probs : Dict[str, np.ndarray]
49+
Dictionary mapping a label or group name to an array of predicted
50+
probabilities for the positive class.
51+
reals : Union[np.ndarray, Dict[str, np.ndarray]]
52+
Ground-truth binary labels (0/1) as a single array, or a dictionary
53+
mapping the same label/group keys used in ``probs`` to arrays of
54+
ground-truth labels.
55+
by : float, optional
56+
Resolution for probability thresholds when computing the curve
57+
(step size). Default is 0.01.
58+
stratified_by : Sequence[str], optional
59+
Sequence of column names to stratify the performance data by.
60+
Default is ["probability_threshold"].
61+
size : int, optional
62+
Plot size in pixels (width and height). Default is 600.
63+
color_values : List[str], optional
64+
List of color hex strings to use for the plotted lines. If not
65+
provided, a default palette is used.
5266
53-
Returns:
54-
Figure: _description_
67+
Returns
68+
-------
69+
Figure
70+
A Plotly ``Figure`` containing the Gains curve(s).
71+
72+
Notes
73+
-----
74+
The function delegates computation and plotting to
75+
``_create_rtichoke_plotly_curve_binary`` and returns the resulting
76+
Plotly figure.
5577
"""
56-
fig = create_rtichoke_curve(
78+
fig = _create_rtichoke_plotly_curve_binary(
5779
probs,
5880
reals,
5981
by=by,
6082
stratified_by=stratified_by,
6183
size=size,
6284
color_values=color_values,
63-
url_api=url_api,
6485
curve="gains",
6586
)
6687
return fig
6788

6889

6990
def plot_gains_curve(
70-
performance_data: DataFrame,
71-
size: Optional[int] = None,
72-
color_values: List[str] = [
73-
"#1b9e77",
74-
"#d95f02",
75-
"#7570b3",
76-
"#e7298a",
77-
"#07004D",
78-
"#E6AB02",
79-
"#FE5F55",
80-
"#54494B",
81-
"#006E90",
82-
"#BC96E6",
83-
"#52050A",
84-
"#1F271B",
85-
"#BE7C4D",
86-
"#63768D",
87-
"#08A045",
88-
"#320A28",
89-
"#82FF9E",
90-
"#2176FF",
91-
"#D1603D",
92-
"#585123",
93-
],
94-
url_api: str = "http://localhost:4242/",
91+
performance_data: pl.DataFrame,
92+
stratified_by: Sequence[str] = ["probability_threshold"],
93+
size: int = 600,
9594
) -> Figure:
96-
"""Plot Gains Curve
95+
"""Plot Gains curve from performance data.
96+
97+
Parameters
98+
----------
99+
performance_data : pl.DataFrame
100+
A Polars DataFrame containing performance metrics for the Gains curve.
101+
Expected columns include (but may not be limited to)
102+
``probability_threshold`` and gains-related metrics, plus any
103+
stratification columns.
104+
stratified_by : Sequence[str], optional
105+
Sequence of column names used for stratification in the
106+
``performance_data``. Default is ["probability_threshold"].
107+
size : int, optional
108+
Plot size in pixels (width and height). Default is 600.
97109
98-
Args:
99-
performance_data (DataFrame): _description_
100-
size (Optional[int], optional): _description_. Defaults to None.
101-
color_values (List[str], optional): _description_. Defaults to None.
102-
url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
110+
Returns
111+
-------
112+
Figure
113+
A Plotly ``Figure`` containing the Gains plot.
103114
104-
Returns:
105-
Figure: _description_
115+
Notes
116+
-----
117+
This function wraps ``_plot_rtichoke_curve_binary`` to produce a
118+
ready-to-render Plotly figure from precomputed performance data.
106119
"""
107-
fig = plot_rtichoke_curve(
120+
fig = _plot_rtichoke_curve_binary(
108121
performance_data,
109122
size=size,
110-
color_values=color_values,
111-
url_api=url_api,
112123
curve="gains",
113124
)
114125
return fig
Lines changed: 70 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
"""
2-
A module for Lift Curves
2+
A module for Lift Curves using Plotly helpers
33
"""
44

5-
from typing import Dict, List, Optional
5+
from typing import Dict, List, Sequence, Union
66
from plotly.graph_objs._figure import Figure
7-
from pandas import DataFrame
8-
from rtichoke.helpers.send_post_request_to_r_rtichoke import create_rtichoke_curve
9-
from rtichoke.helpers.send_post_request_to_r_rtichoke import plot_rtichoke_curve
7+
from rtichoke.helpers.plotly_helper_functions import (
8+
_create_rtichoke_plotly_curve_binary,
9+
_plot_rtichoke_curve_binary,
10+
)
11+
import numpy as np
12+
import polars as pl
1013

1114

1215
def create_lift_curve(
13-
probs: Dict[str, List[float]],
14-
reals: Dict[str, List[int]],
16+
probs: Dict[str, np.ndarray],
17+
reals: Union[np.ndarray, Dict[str, np.ndarray]],
1518
by: float = 0.01,
16-
stratified_by: str = "probability_threshold",
17-
size: Optional[int] = None,
19+
stratified_by: Sequence[str] = ["probability_threshold"],
20+
size: int = 600,
1821
color_values: List[str] = [
1922
"#1b9e77",
2023
"#d95f02",
@@ -37,78 +40,86 @@ def create_lift_curve(
3740
"#D1603D",
3841
"#585123",
3942
],
40-
url_api: str = "http://localhost:4242/",
4143
) -> Figure:
42-
"""Create Lift Curve
44+
"""Create Lift Curve.
4345
44-
Args:
45-
probs (Dict[str, List[float]]): _description_
46-
reals (Dict[str, List[int]]): _description_
47-
by (float, optional): _description_. Defaults to 0.01.
48-
stratified_by (str, optional): _description_. Defaults to "probability_threshold".
49-
size (Optional[int], optional): _description_. Defaults to None.
50-
color_values (List[str], optional): _description_. Defaults to None.
51-
url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
46+
Parameters
47+
----------
48+
probs : Dict[str, np.ndarray]
49+
Dictionary mapping a label or group name to an array of predicted
50+
probabilities for the positive class.
51+
reals : Union[np.ndarray, Dict[str, np.ndarray]]
52+
Ground-truth binary labels (0/1) as a single array, or a dictionary
53+
mapping the same label/group keys used in ``probs`` to arrays of
54+
ground-truth labels.
55+
by : float, optional
56+
Resolution for probability thresholds when computing the curve
57+
(step size). Default is 0.01.
58+
stratified_by : Sequence[str], optional
59+
Sequence of column names to stratify the performance data by.
60+
Default is ["probability_threshold"].
61+
size : int, optional
62+
Plot size in pixels (width and height). Default is 600.
63+
color_values : List[str], optional
64+
List of color hex strings to use for the plotted lines. If not
65+
provided, a default palette is used.
5266
53-
Returns:
54-
Figure: _description_
67+
Returns
68+
-------
69+
Figure
70+
A Plotly ``Figure`` containing the Lift curve(s).
71+
72+
Notes
73+
-----
74+
The function delegates computation and plotting to
75+
``_create_rtichoke_plotly_curve_binary`` and returns the resulting
76+
Plotly figure.
5577
"""
56-
fig = create_rtichoke_curve(
78+
fig = _create_rtichoke_plotly_curve_binary(
5779
probs,
5880
reals,
5981
by=by,
6082
stratified_by=stratified_by,
6183
size=size,
6284
color_values=color_values,
63-
url_api=url_api,
6485
curve="lift",
6586
)
6687
return fig
6788

6889

6990
def plot_lift_curve(
70-
performance_data: DataFrame,
71-
size: Optional[int] = None,
72-
color_values: List[str] = [
73-
"#1b9e77",
74-
"#d95f02",
75-
"#7570b3",
76-
"#e7298a",
77-
"#07004D",
78-
"#E6AB02",
79-
"#FE5F55",
80-
"#54494B",
81-
"#006E90",
82-
"#BC96E6",
83-
"#52050A",
84-
"#1F271B",
85-
"#BE7C4D",
86-
"#63768D",
87-
"#08A045",
88-
"#320A28",
89-
"#82FF9E",
90-
"#2176FF",
91-
"#D1603D",
92-
"#585123",
93-
],
94-
url_api: str = "http://localhost:4242/",
91+
performance_data: pl.DataFrame,
92+
stratified_by: Sequence[str] = ["probability_threshold"],
93+
size: int = 600,
9594
) -> Figure:
96-
"""Plot Lift Curve
95+
"""Plot Lift curve from performance data.
96+
97+
Parameters
98+
----------
99+
performance_data : pl.DataFrame
100+
A Polars DataFrame containing performance metrics for the Lift curve.
101+
Expected columns include (but may not be limited to)
102+
``probability_threshold`` and lift-related metrics, plus any
103+
stratification columns.
104+
stratified_by : Sequence[str], optional
105+
Sequence of column names used for stratification in the
106+
``performance_data``. Default is ["probability_threshold"].
107+
size : int, optional
108+
Plot size in pixels (width and height). Default is 600.
97109
98-
Args:
99-
performance_data (DataFrame): _description_
100-
size (Optional[int], optional): _description_. Defaults to None.
101-
color_values (List[str], optional): _description_. Defaults to None.
102-
url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
110+
Returns
111+
-------
112+
Figure
113+
A Plotly ``Figure`` containing the Lift plot.
103114
104-
Returns:
105-
Figure: _description_
115+
Notes
116+
-----
117+
This function wraps ``_plot_rtichoke_curve_binary`` to produce a
118+
ready-to-render Plotly figure from precomputed performance data.
106119
"""
107-
fig = plot_rtichoke_curve(
120+
fig = _plot_rtichoke_curve_binary(
108121
performance_data,
109122
size=size,
110-
color_values=color_values,
111-
url_api=url_api,
112123
curve="lift",
113124
)
114125
return fig

0 commit comments

Comments
 (0)