Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"polarstate==0.1.8",
"marimo>=0.17.0",
"pyarrow>=21.0.0",
"statsmodels>=0.14.0",
]
name = "rtichoke"
version = "0.1.25"
Expand Down
268 changes: 225 additions & 43 deletions src/rtichoke/calibration/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,205 @@

from typing import Any, Dict, List, Optional

# import pandas as pd
import polars as pl
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.graph_objs._figure import Figure
# from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r
from rtichoke.utility.check_performance_type import check_performance_type_by_probs_and_reals
from rtichoke.utility.create_reference_group_color_vector import create_reference_group_color_vector
import statsmodels.api as sm
import numpy as np


def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float]:
if deciles_dat.height == 1:
l, u = 0, 1
else:
l = max(0, min(deciles_dat["x"].min(), deciles_dat["y"].min()))
u = max(deciles_dat["x"].max(), deciles_dat["y"].max())

return [l - (u - l) * 0.05, u + (u - l) * 0.05]


def _create_calibration_curve_list(
probs: Dict[str, List[float]],
reals: Dict[str, List[int]],
color_values: List[str],
size: Optional[int],
) -> Dict[str, Any]:
if not probs:
return {}

performance_type = check_performance_type_by_probs_and_reals(probs, reals)
reference_groups = list(probs.keys())
group_colors_vec = create_reference_group_color_vector(
reference_groups, performance_type, color_values
)

deciles_dfs = []
smooth_dfs = []

if performance_type == "several populations":
for group in reference_groups:
deciles_df = _make_deciles_dat(probs[group], reals[group])
deciles_df = deciles_df.with_columns(
pl.lit(group).alias("reference_group")
)
deciles_dfs.append(deciles_df)

if len(set(probs[group])) == 1:
smooth_df = pl.DataFrame(
{
"x": [probs[group][0]],
"y": [np.mean(reals[group])],
"reference_group": [group],
}
)
else:
lowess = sm.nonparametric.lowess(
reals[group], probs[group], it=0
)
xout = np.linspace(0, 1, 101)
smooth_df = pl.DataFrame(
{
"x": xout,
"y": np.interp(xout, lowess[:, 0], lowess[:, 1]),
"reference_group": group,
}
)
smooth_dfs.append(smooth_df)
else:
real_values = next(iter(reals.values()))
for group in reference_groups:
deciles_df = _make_deciles_dat(probs[group], real_values)
deciles_df = deciles_df.with_columns(
pl.lit(group).alias("reference_group")
)
deciles_dfs.append(deciles_df)

if len(set(probs[group])) == 1:
smooth_df = pl.DataFrame(
{
"x": [probs[group][0]],
"y": [np.mean(real_values)],
"reference_group": [group],
}
)
else:
lowess = sm.nonparametric.lowess(
real_values, probs[group], it=0
)
xout = np.linspace(0, 1, 101)
smooth_df = pl.DataFrame(
{
"x": xout,
"y": np.interp(xout, lowess[:, 0], lowess[:, 1]),
"reference_group": group,
}
)
smooth_dfs.append(smooth_df)

deciles_dat = pl.concat(deciles_dfs)
smooth_dat = pl.concat(smooth_dfs).drop_nulls()

hover_text_discrete = "Predicted: {x:.3f}<br>Observed: {y:.3f} ({sum_reals} / {total_obs})"
hover_text_smooth = "Predicted: {x:.3f}<br>Observed: {y:.3f}"
if performance_type != "one model":
hover_text_discrete = "<b>{reference_group}</b><br>" + hover_text_discrete
hover_text_smooth = "<b>{reference_group}</b><br>" + hover_text_smooth

deciles_dat = deciles_dat.with_columns(
pl.struct(deciles_dat.columns)
.apply(lambda row: hover_text_discrete.format(**row))
.alias("text")
)
smooth_dat = smooth_dat.with_columns(
pl.struct(smooth_dat.columns)
.apply(lambda row: hover_text_smooth.format(**row))
.alias("text")
)

limits = _define_limits_for_calibration_plot(deciles_dat)
axes_ranges = {"xaxis": limits, "yaxis": limits}

x_ref = np.linspace(0, 1, 101)
reference_data = pl.DataFrame({"x": x_ref, "y": x_ref})
reference_data = reference_data.with_columns(
pl.lit(
"<b>Perfectly Calibrated</b><br>Predicted: "
+ reference_data["x"].round(3).cast(str)
+ "<br>Observed: "
+ reference_data["y"].round(3).cast(str)
).alias("text")
)

hist_dfs = []
for group, prob_values in probs.items():
counts, mids = np.histogram(prob_values, bins=np.arange(0, 1.01, 0.01))
hist_df = pl.DataFrame(
{"mids": mids[:-1] + 0.005, "counts": counts, "reference_group": group}
)
hist_df = hist_df.with_columns(
(
pl.col("counts").cast(str)
+ " observations in ["
+ (pl.col("mids") - 0.005).round(3).cast(str)
+ ", "
+ (pl.col("mids") + 0.005).round(3).cast(str)
+ "]"
).alias("text")
)
hist_dfs.append(hist_df)

histogram_for_calibration = pl.concat(hist_dfs)

return {
"performance_type": [performance_type],
"size": [[size]],
"deciles_dat": deciles_dat,
"smooth_dat": smooth_dat,
"group_colors_vec": group_colors_vec,
"axes_ranges": axes_ranges,
"reference_data": reference_data,
"histogram_for_calibration": histogram_for_calibration,
"histogram_opacity": [1 / len(probs)],
}


def _make_deciles_dat(probs: List[float], reals: List[int]) -> pl.DataFrame:
"""
Creates a DataFrame with deciles for the calibration curve.
"""
if len(set(probs)) == 1:
return pl.DataFrame(
{
"quintile": [1],
"x": [probs[0]],
"y": [sum(reals) / len(reals)],
"sum_reals": [sum(reals)],
"total_obs": [len(reals)],
}
)
else:
df = pl.DataFrame({"probs": probs, "reals": reals})
# Replicating dplyr's ntile(10)
df = df.with_columns(
(
(pl.col("probs").rank("ordinal", seed=1) * 10) / (pl.count() + 1)
).floor().cast(pl.Int64).alias("quintile")
)

quintile_df = (
df.group_by("quintile")
.agg(
(pl.col("reals").sum() / pl.count()).alias("y"),
pl.col("probs").mean().alias("x"),
pl.col("reals").sum().alias("sum_reals"),
pl.count().alias("total_obs"),
)
.sort("quintile")
)
return quintile_df


def create_calibration_curve(
Expand Down Expand Up @@ -38,54 +232,42 @@ def create_calibration_curve(
"#D1603D",
"#585123",
],
url_api: str = "http://localhost:4242/",
) -> Figure:
"""Creates Calibration Curve

Args:
probs (Dict[str, List[float]]): _description_
reals (Dict[str, List[int]]): _description_
calibration_type (str, optional): _description_. Defaults to "discrete".
size (Optional[int], optional): _description_. Defaults to None.
color_values (List[str], optional): _description_. Defaults to None.
url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
probs (Dict[str, List[float]]): A dictionary where keys are model names and values are lists of predicted probabilities.
reals (Dict[str, List[int]]): A dictionary where keys are population names and values are lists of actual outcomes (0 or 1).
calibration_type (str, optional): The type of calibration curve to create, either "discrete" or "smooth". Defaults to "discrete".
size (Optional[int], optional): The size of the plot. Defaults to None.
color_values (List[str], optional): A list of hex color codes for the plot. Defaults to a predefined list.

Returns:
Figure: _description_
Figure: A Plotly Figure object representing the calibration curve.
"""
pass

# rtichoke_response = send_requests_to_rtichoke_r(
# dictionary_to_send={
# "probs": probs,
# "reals": reals,
# "size": size,
# "color_values ": color_values,
# },
# url_api=url_api,
# endpoint="create_calibration_curve_list",
# )

# calibration_curve_list = rtichoke_response.json()

# calibration_curve_list["deciles_dat"] = pd.DataFrame.from_dict(
# calibration_curve_list["deciles_dat"]
# )
# calibration_curve_list["smooth_dat"] = pd.DataFrame.from_dict(
# calibration_curve_list["smooth_dat"]
# )
# calibration_curve_list["reference_data"] = pd.DataFrame.from_dict(
# calibration_curve_list["reference_data"]
# )
# calibration_curve_list["histogram_for_calibration"] = pd.DataFrame.from_dict(
# calibration_curve_list["histogram_for_calibration"]
# )

# calibration_curve = create_plotly_curve_from_calibration_curve_list(
# calibration_curve_list=calibration_curve_list, calibration_type=calibration_type
# )

# return calibration_curve
calibration_curve_list = _create_calibration_curve_list(
probs=probs, reals=reals, color_values=color_values, size=size
)

calibration_curve_list["deciles_dat"] = calibration_curve_list[
"deciles_dat"
].to_pandas()
calibration_curve_list["smooth_dat"] = calibration_curve_list[
"smooth_dat"
].to_pandas()
calibration_curve_list["reference_data"] = calibration_curve_list[
"reference_data"
].to_pandas()
calibration_curve_list["histogram_for_calibration"] = calibration_curve_list[
"histogram_for_calibration"
].to_pandas()

calibration_curve = create_plotly_curve_from_calibration_curve_list(
calibration_curve_list=calibration_curve_list,
calibration_type=calibration_type,
)

return calibration_curve


def create_plotly_curve_from_calibration_curve_list(
Expand Down
69 changes: 69 additions & 0 deletions tests/test_calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import polars as pl
from polars.testing import assert_frame_equal
from rtichoke.calibration.calibration import (
_make_deciles_dat,
_create_calibration_curve_list,
)


def test_make_deciles_dat():
probs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
reals = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
result = _make_deciles_dat(probs, reals)
expected = pl.DataFrame(
{
"quintile": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
"y": [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
"x": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
"sum_reals": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
"total_obs": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
}
)
# The quintile calculation is not exactly the same as R's ntile,
# so we will only check the other columns
assert_frame_equal(
result.drop("quintile"), expected.drop("quintile"), check_row_order=False
)


def test_create_calibration_curve_list_single_population():
probs = {"model_1": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]}
reals = {"pop_1": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]}
result = _create_calibration_curve_list(probs, reals, [], 500)

assert result["performance_type"][0] == "one model"
assert len(result["deciles_dat"]) > 0
assert len(result["smooth_dat"]) > 0
assert len(result["histogram_for_calibration"]) > 0


def test_create_calibration_curve_list_multiple_populations():
probs = {
"model_1": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
"model_2": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
}
reals = {
"pop_1": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
"pop_2": [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
}
result = _create_calibration_curve_list(probs, reals, [], 500)

assert result["performance_type"][0] == "several populations"
assert len(result["deciles_dat"]) > 0
assert len(result["smooth_dat"]) > 0
assert len(result["histogram_for_calibration"]) > 0
# Check that the data is correctly grouped
assert (
len(
pl.DataFrame(result["deciles_dat"])
.filter(pl.col("reference_group") == "model_1")
)
> 0
)
assert (
len(
pl.DataFrame(result["deciles_dat"])
.filter(pl.col("reference_group") == "model_2")
)
> 0
)
Loading