Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ dependencies = [
"typing>=3.7.4.3",
"polarstate==0.1.8",
"marimo>=0.17.0",
"pyarrow>=21.0.0",
]
name = "rtichoke"
version = "0.1.21"
version = "0.1.22"
description = "interactive visualizations for performance of predictive models"
readme = "README.md"

Expand Down
4 changes: 4 additions & 0 deletions src/rtichoke/helpers/plotly_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,16 @@ def _plot_rtichoke_curve_binary(
stratified_by: str = "probability_threshold",
curve: str = "roc",
size: int = 600,
min_p_threshold: float = 0,
max_p_threshold: float = 1,
) -> go.Figure:
rtichoke_curve_list = _create_rtichoke_curve_list_binary(
performance_data=performance_data,
stratified_by=stratified_by,
curve=curve,
size=size,
min_p_threshold=min_p_threshold,
max_p_threshold=max_p_threshold,
)

fig = _create_plotly_curve_binary(rtichoke_curve_list)
Expand Down
160 changes: 92 additions & 68 deletions src/rtichoke/utility/decision.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
"""
A module for Summary Report
A module for Decision Curves using Plotly helpers
"""

from typing import Dict, List, Optional
from pandas.core.frame import DataFrame
from typing import Dict, List, Sequence, Union
from plotly.graph_objs._figure import Figure
from rtichoke.helpers.send_post_request_to_r_rtichoke import create_rtichoke_curve
from rtichoke.helpers.send_post_request_to_r_rtichoke import plot_rtichoke_curve
from rtichoke.helpers.plotly_helper_functions import (
_create_rtichoke_plotly_curve_binary,
_plot_rtichoke_curve_binary,
)
import numpy as np
import polars as pl


def create_decision_curve(
probs: Dict[str, List[float]],
reals: Dict[str, List[int]],
probs: Dict[str, np.ndarray],
reals: Union[np.ndarray, Dict[str, np.ndarray]],
decision_type: str = "conventional",
min_p_threshold: float = 0,
max_p_threshold: float = 1,
by: float = 0.01,
stratified_by: str = "probability_threshold",
size: Optional[int] = None,
stratified_by: Sequence[str] = ["probability_threshold"],
size: int = 600,
color_values: List[str] = [
"#1b9e77",
"#d95f02",
Expand All @@ -40,38 +43,63 @@ def create_decision_curve(
"#D1603D",
"#585123",
],
url_api: str = "http://localhost:4242/",
) -> Figure:
"""Create Decision Curve
"""Create Decision Curve.

Args:
probs (Dict[str, List[float]]): _description_
reals (Dict[str, List[int]]): _description_
decision_type (str, optional): _description_. Defaults to "conventional".
min_p_threshold (float, optional): _description_. Defaults to 0.
max_p_threshold (float, optional): _description_. Defaults to 1.
by (float, optional): _description_. Defaults to 0.01.
stratified_by (str, optional): _description_. Defaults to "probability_threshold".
size (Optional[int], optional): _description_. Defaults to None.
color_values (List[str], optional): _description_. Defaults to None.
url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
Parameters
----------
probs : Dict[str, np.ndarray]
Dictionary mapping a label or group name to an array of predicted
probabilities for the positive class.
reals : Union[np.ndarray, Dict[str, np.ndarray]]
Ground-truth binary labels (0/1) as a single array, or a dictionary
mapping the same label/group keys used in ``probs`` to arrays of
ground-truth labels.
decision_type : str, optional
Either ``"conventional"`` (decision curve) or another value that
implies the "interventions avoided" variant. Default is
``"conventional"``.
min_p_threshold : float, optional
Minimum probability threshold to include in the curve. Default is 0.
max_p_threshold : float, optional
Maximum probability threshold to include in the curve. Default is 1.
by : float, optional
Resolution for probability thresholds when computing the curve
(step size). Default is 0.01.
stratified_by : Sequence[str], optional
Sequence of column names to stratify the performance data by.
Default is ["probability_threshold"].
size : int, optional
Plot size in pixels (width and height). Default is 600.
color_values : List[str], optional
List of color hex strings to use for the plotted lines. If not
provided, a default palette is used.

Returns:
Figure: _description_
Returns
-------
Figure
A Plotly ``Figure`` containing the Decision curve.

Notes
-----
The function selects the appropriate curve name based on
``decision_type`` and delegates computation and plotting to
``_create_rtichoke_plotly_curve_binary``. Additional keyword arguments
(like ``min_p_threshold`` and ``max_p_threshold``) are forwarded to
the helper.
"""
if decision_type == "conventional":
curve = "decision"
else:
curve = "interventions avoided"

fig = create_rtichoke_curve(
fig = _create_rtichoke_plotly_curve_binary(
probs,
reals,
by=by,
stratified_by=stratified_by,
size=size,
color_values=color_values,
url_api=url_api,
curve=curve,
min_p_threshold=min_p_threshold,
max_p_threshold=max_p_threshold,
Expand All @@ -80,59 +108,55 @@ def create_decision_curve(


def plot_decision_curve(
performance_data: DataFrame,
decision_type: str,
min_p_threshold: int = 0,
max_p_threshold: int = 1,
size: Optional[int] = None,
color_values: List[str] = [
"#1b9e77",
"#d95f02",
"#7570b3",
"#e7298a",
"#07004D",
"#E6AB02",
"#FE5F55",
"#54494B",
"#006E90",
"#BC96E6",
"#52050A",
"#1F271B",
"#BE7C4D",
"#63768D",
"#08A045",
"#320A28",
"#82FF9E",
"#2176FF",
"#D1603D",
"#585123",
],
url_api: str = "http://localhost:4242/",
performance_data: pl.DataFrame,
decision_type: str = "conventional",
min_p_threshold: float = 0,
max_p_threshold: float = 1,
stratified_by: Sequence[str] = ["probability_threshold"],
size: int = 600,
) -> Figure:
"""Plot Decision Curve
"""Plot Decision Curve from performance data.

Parameters
----------
performance_data : pl.DataFrame
A Polars DataFrame containing performance metrics for the Decision
curve. Expected columns include (but may not be limited to)
``probability_threshold`` and decision-curve metrics, plus any
stratification columns.
decision_type : str
``"conventional"`` for decision curves, otherwise the
"interventions avoided" variant will be used.
min_p_threshold : float, optional
Minimum probability threshold to include in the curve. Default is 0.
max_p_threshold : float, optional
Maximum probability threshold to include in the curve. Default is 1.
stratified_by : Sequence[str], optional
Sequence of column names used for stratification in the
``performance_data``. Default is ["probability_threshold"].
size : int, optional
Plot size in pixels (width and height). Default is 600.

Args:
performance_data (DataFrame): _description_
decision_type (str): _description_
min_p_threshold (int, optional): _description_. Defaults to 0.
max_p_threshold (int, optional): _description_. Defaults to 1.
size (Optional[int], optional): _description_. Defaults to None.
color_values (List[str], optional): _description_. Defaults to None.
url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
Returns
-------
Figure
A Plotly ``Figure`` containing the Decision plot.

Returns:
Figure: _description_
Notes
-----
This function wraps ``_plot_rtichoke_curve_binary`` to produce a
ready-to-render Plotly figure from precomputed performance data.
Additional keyword arguments (``min_p_threshold``, ``max_p_threshold``)
are forwarded to the helper.
"""
if decision_type == "conventional":
curve = "decision"
else:
curve = "interventions avoided"

fig = plot_rtichoke_curve(
fig = _plot_rtichoke_curve_binary(
performance_data,
size=size,
color_values=color_values,
url_api=url_api,
curve=curve,
min_p_threshold=min_p_threshold,
max_p_threshold=max_p_threshold,
Expand Down
Loading