diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index 0f6c886..56be1d2 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -48,10 +48,17 @@ jobs:
run: |
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
+
+ - name: Quartodoc build
+ working-directory: docs
+ run: uv run quartodoc build
+
+ - name: Clean any leftover Quarto publish worktree
+ run: rm -rf .quarto
- name: Render and Publish
working-directory: docs
- run: uv run quarto publish gh-pages --no-browser --token "${{ secrets.GITHUB_TOKEN }}"
+ run: uv run quarto publish gh-pages --no-browser --no-prompt --token "${{ secrets.GITHUB_TOKEN }}"
- name: Publish package
if: github.ref == 'refs/heads/main' && matrix.python-version == '3.10'
diff --git a/docs/.gitignore b/docs/.gitignore
index 075b254..3147a4d 100644
--- a/docs/.gitignore
+++ b/docs/.gitignore
@@ -1 +1,4 @@
/.quarto/
+
+**/*.quarto_ipynb
+_sidebar.yml
\ No newline at end of file
diff --git a/docs/Makefile b/docs/Makefile
deleted file mode 100644
index 5172566..0000000
--- a/docs/Makefile
+++ /dev/null
@@ -1,19 +0,0 @@
-# Minimal makefile for Sphinx documentation
-
-# You can set these variables from the command line.
-SPHINXOPTS =
-SPHINXBUILD = python -msphinx
-SPHINXPROJ = rtichoke
-SOURCEDIR = .
-BUILDDIR = _build
-
-# Put it first so that "make" without argument is like "make help".
-help:
- @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
-
-.PHONY: help Makefile
-
-# Catch-all target: route all unknown targets to Sphinx using the new
-# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
-%: Makefile
- @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
\ No newline at end of file
diff --git a/docs/_quarto.yml b/docs/_quarto.yml
index 57c7443..adc9a8e 100644
--- a/docs/_quarto.yml
+++ b/docs/_quarto.yml
@@ -16,6 +16,11 @@ quartodoc:
package: rtichoke
sidebar: "_sidebar.yml"
sections:
+ - title: Performance Data
+ desc: Functions for creating performance data.
+ contents:
+ - prepare_performance_data
+ - prepare_performance_data_times
- title: Calibration
desc: Functions for Calibration.
contents:
diff --git a/docs/_sidebar.yml b/docs/_sidebar.yml
deleted file mode 100644
index cc8acc1..0000000
--- a/docs/_sidebar.yml
+++ /dev/null
@@ -1,23 +0,0 @@
-website:
- sidebar:
- - contents:
- - reference/index.qmd
- - contents:
- - reference/create_calibration_curve.qmd
- section: Calibration
- - contents:
- - reference/create_roc_curve.qmd
- - reference/create_precision_recall_curve.qmd
- - reference/create_gains_curve.qmd
- - reference/create_lift_curve.qmd
- - reference/plot_roc_curve.qmd
- - reference/plot_precision_recall_curve.qmd
- - reference/plot_gains_curve.qmd
- - reference/plot_lift_curve.qmd
- section: Discrimination
- - contents:
- - reference/create_decision_curve.qmd
- - reference/plot_decision_curve.qmd
- section: Utility
- id: reference
- - id: dummy-sidebar
diff --git a/docs/_site/line_ppcr_04.svg b/docs/_site/line_ppcr_04.svg
deleted file mode 100644
index e8c9784..0000000
--- a/docs/_site/line_ppcr_04.svg
+++ /dev/null
@@ -1,766 +0,0 @@
-
-
-
-
diff --git a/docs/_site/mermaid_diagrams/p_threshold_trt_all.png b/docs/_site/mermaid_diagrams/p_threshold_trt_all.png
deleted file mode 100644
index 845a883..0000000
Binary files a/docs/_site/mermaid_diagrams/p_threshold_trt_all.png and /dev/null differ
diff --git a/docs/_site/mermaid_diagrams/p_threshold_trt_all.svg b/docs/_site/mermaid_diagrams/p_threshold_trt_all.svg
deleted file mode 100644
index 4074cac..0000000
--- a/docs/_site/mermaid_diagrams/p_threshold_trt_all.svg
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
\ No newline at end of file
diff --git a/docs/_site/robots.txt b/docs/_site/robots.txt
deleted file mode 100644
index f5c2a44..0000000
--- a/docs/_site/robots.txt
+++ /dev/null
@@ -1 +0,0 @@
-Sitemap: https://uriahf.github.io/rtichoke_python/sitemap.xml
diff --git a/docs/_site/sandbox_files/figure-html/cell-2-output-1.png b/docs/_site/sandbox_files/figure-html/cell-2-output-1.png
deleted file mode 100644
index d6ba60e..0000000
Binary files a/docs/_site/sandbox_files/figure-html/cell-2-output-1.png and /dev/null differ
diff --git a/docs/_site/site_libs/bootstrap/bootstrap-icons.woff b/docs/_site/site_libs/bootstrap/bootstrap-icons.woff
deleted file mode 100644
index dbeeb05..0000000
Binary files a/docs/_site/site_libs/bootstrap/bootstrap-icons.woff and /dev/null differ
diff --git a/docs/contributing.md b/docs/contributing.md
deleted file mode 100644
index 435d357..0000000
--- a/docs/contributing.md
+++ /dev/null
@@ -1,2 +0,0 @@
-```{include} ../CONTRIBUTING.md
-```
\ No newline at end of file
diff --git a/docs/example.ipynb b/docs/example.ipynb
deleted file mode 100644
index 117fc5f..0000000
--- a/docs/example.ipynb
+++ /dev/null
@@ -1,421 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Example usage\n",
- "\n",
- "To use `rtichoke` in a project:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "ename": "ModuleNotFoundError",
- "evalue": "No module named 'rtichoke'",
- "output_type": "error",
- "traceback": [
- "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[1;32mIn[2], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mrtichoke\u001b[39;00m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(rtichoke\u001b[38;5;241m.\u001b[39m__version__)\n",
- "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'rtichoke'"
- ]
- }
- ],
- "source": [
- "import rtichoke\n",
- "\n",
- "print(rtichoke.__version__)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "from sklearn.linear_model import LogisticRegression\n",
- "import numpy as np\n",
- "\n",
- "lr = LogisticRegression()\n",
- "x = np.arange(10).reshape(-1, 1)\n",
- "y = np.array([0, 1, 0, 0, 1, 1, 1, 0, 0, 1])\n",
- "\n",
- "x_test = np.arange(7).reshape(-1, 1)\n",
- "y_test = np.array([1, 0, 1, 0, 1, 0, 0])\n",
- "\n",
- "model = LogisticRegression(solver=\"liblinear\", random_state=0)\n",
- "lasso = LogisticRegression(solver=\"liblinear\", penalty=\"l1\", random_state=0)\n",
- "\n",
- "model.fit(x, y)\n",
- "lasso.fit(x_test, y_test)\n",
- "\n",
- "probs_dict_for_examples = {\n",
- " \"One Model\": {\"Logistic Regression\": model.predict_proba(x)[:, 1].tolist()},\n",
- " \"Multiple Models\": {\n",
- " \"Logistic Regression\": model.predict_proba(x)[:, 1].tolist(),\n",
- " \"Lasso\": lasso.predict_proba(x)[:, 1].tolist(),\n",
- " },\n",
- " \"Multiple Populations\": {\n",
- " \"Train\": model.predict_proba(x)[:, 1].tolist(),\n",
- " \"Test\": model.predict_proba(x_test)[:, 1].tolist(),\n",
- " },\n",
- "}\n",
- "\n",
- "reals_dict_for_examples = {\n",
- " \"One Model\": {\"Logistic Regression\": y.tolist()},\n",
- " \"Multiple Models\": {\"Reals\": y.tolist()},\n",
- " \"Multiple Populations\": {\"Train\": y.tolist(), \"Test\": y_test.tolist()},\n",
- "}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "dict_keys(['probability_threshold', 'TP', 'TN', 'FN', 'FP', 'sensitivity', 'FPR', 'specificity', 'PPV', 'NPV', 'lift', 'predicted_positives', 'NB', 'ppcr'])\n",
- "dict_keys(['probability_threshold', 'ppcr', 'TP', 'TN', 'FN', 'FP', 'sensitivity', 'FPR', 'specificity', 'PPV', 'NPV', 'lift', 'predicted_positives'])\n",
- "dict_keys(['model', 'probability_threshold', 'TP', 'TN', 'FN', 'FP', 'sensitivity', 'FPR', 'specificity', 'PPV', 'NPV', 'lift', 'predicted_positives', 'NB', 'ppcr'])\n",
- "dict_keys(['model', 'probability_threshold', 'ppcr', 'TP', 'TN', 'FN', 'FP', 'sensitivity', 'FPR', 'specificity', 'PPV', 'NPV', 'lift', 'predicted_positives'])\n",
- "dict_keys(['population', 'probability_threshold', 'TP', 'TN', 'FN', 'FP', 'sensitivity', 'FPR', 'specificity', 'PPV', 'NPV', 'lift', 'predicted_positives', 'NB', 'ppcr'])\n",
- "dict_keys(['population', 'probability_threshold', 'ppcr', 'TP', 'TN', 'FN', 'FP', 'sensitivity', 'FPR', 'specificity', 'PPV', 'NPV', 'lift', 'predicted_positives'])\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " probability_threshold | \n",
- " TP | \n",
- " TN | \n",
- " FN | \n",
- " FP | \n",
- " sensitivity | \n",
- " FPR | \n",
- " specificity | \n",
- " PPV | \n",
- " NPV | \n",
- " lift | \n",
- " predicted_positives | \n",
- " NB | \n",
- " ppcr | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | 0 | \n",
- " 0.00 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0 | \n",
- " 5 | \n",
- " 1.0 | \n",
- " 1.0 | \n",
- " 0.0 | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 1.0 | \n",
- " 10 | \n",
- " 0.5000 | \n",
- " 1.0 | \n",
- "
\n",
- " \n",
- " | 1 | \n",
- " 0.01 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0 | \n",
- " 5 | \n",
- " 1.0 | \n",
- " 1.0 | \n",
- " 0.0 | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 1.0 | \n",
- " 10 | \n",
- " 0.4949 | \n",
- " 1.0 | \n",
- "
\n",
- " \n",
- " | 2 | \n",
- " 0.02 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0 | \n",
- " 5 | \n",
- " 1.0 | \n",
- " 1.0 | \n",
- " 0.0 | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 1.0 | \n",
- " 10 | \n",
- " 0.4898 | \n",
- " 1.0 | \n",
- "
\n",
- " \n",
- " | 3 | \n",
- " 0.03 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0 | \n",
- " 5 | \n",
- " 1.0 | \n",
- " 1.0 | \n",
- " 0.0 | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 1.0 | \n",
- " 10 | \n",
- " 0.4845 | \n",
- " 1.0 | \n",
- "
\n",
- " \n",
- " | 4 | \n",
- " 0.04 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0 | \n",
- " 5 | \n",
- " 1.0 | \n",
- " 1.0 | \n",
- " 0.0 | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 1.0 | \n",
- " 10 | \n",
- " 0.4792 | \n",
- " 1.0 | \n",
- "
\n",
- " \n",
- " | ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- "
\n",
- " \n",
- " | 96 | \n",
- " 0.96 | \n",
- " 0 | \n",
- " 5 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 1.0 | \n",
- " NaN | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 0 | \n",
- " 0.0000 | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- " | 97 | \n",
- " 0.97 | \n",
- " 0 | \n",
- " 5 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 1.0 | \n",
- " NaN | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 0 | \n",
- " 0.0000 | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- " | 98 | \n",
- " 0.98 | \n",
- " 0 | \n",
- " 5 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 1.0 | \n",
- " NaN | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 0 | \n",
- " 0.0000 | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- " | 99 | \n",
- " 0.99 | \n",
- " 0 | \n",
- " 5 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 1.0 | \n",
- " NaN | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 0 | \n",
- " 0.0000 | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- " | 100 | \n",
- " 1.00 | \n",
- " 0 | \n",
- " 5 | \n",
- " 5 | \n",
- " 0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 1.0 | \n",
- " NaN | \n",
- " 0.5 | \n",
- " NaN | \n",
- " 0 | \n",
- " NaN | \n",
- " 0.0 | \n",
- "
\n",
- " \n",
- "
\n",
- "
101 rows ร 14 columns
\n",
- "
"
- ],
- "text/plain": [
- " probability_threshold TP TN FN FP sensitivity FPR specificity \\\n",
- "0 0.00 5 0 0 5 1.0 1.0 0.0 \n",
- "1 0.01 5 0 0 5 1.0 1.0 0.0 \n",
- "2 0.02 5 0 0 5 1.0 1.0 0.0 \n",
- "3 0.03 5 0 0 5 1.0 1.0 0.0 \n",
- "4 0.04 5 0 0 5 1.0 1.0 0.0 \n",
- ".. ... .. .. .. .. ... ... ... \n",
- "96 0.96 0 5 5 0 0.0 0.0 1.0 \n",
- "97 0.97 0 5 5 0 0.0 0.0 1.0 \n",
- "98 0.98 0 5 5 0 0.0 0.0 1.0 \n",
- "99 0.99 0 5 5 0 0.0 0.0 1.0 \n",
- "100 1.00 0 5 5 0 0.0 0.0 1.0 \n",
- "\n",
- " PPV NPV lift predicted_positives NB ppcr \n",
- "0 0.5 NaN 1.0 10 0.5000 1.0 \n",
- "1 0.5 NaN 1.0 10 0.4949 1.0 \n",
- "2 0.5 NaN 1.0 10 0.4898 1.0 \n",
- "3 0.5 NaN 1.0 10 0.4845 1.0 \n",
- "4 0.5 NaN 1.0 10 0.4792 1.0 \n",
- ".. ... ... ... ... ... ... \n",
- "96 NaN 0.5 NaN 0 0.0000 0.0 \n",
- "97 NaN 0.5 NaN 0 0.0000 0.0 \n",
- "98 NaN 0.5 NaN 0 0.0000 0.0 \n",
- "99 NaN 0.5 NaN 0 0.0000 0.0 \n",
- "100 NaN 0.5 NaN 0 NaN 0.0 \n",
- "\n",
- "[101 rows x 14 columns]"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "performance_datas = [\n",
- " rtichoke.prepare_performance_data(\n",
- " probs=probs_dict_for_examples[x],\n",
- " reals=reals_dict_for_examples[x],\n",
- " stratified_by=stratified_by,\n",
- " url_api=\"http://127.0.0.1:7644/\",\n",
- " )\n",
- " for x in probs_dict_for_examples.keys()\n",
- " for stratified_by in [\"probability_threshold\", \"ppcr\"]\n",
- "]\n",
- "\n",
- "performance_datas[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "roc_curves = [\n",
- " rtichoke.create_roc_curve(\n",
- " probs=probs_dict_for_examples[x],\n",
- " reals=reals_dict_for_examples[x],\n",
- " size=600,\n",
- " stratified_by=stratified_by,\n",
- " url_api=\"http://127.0.0.1:7644/\",\n",
- " )\n",
- " for x in probs_dict_for_examples.keys()\n",
- " for stratified_by in [\"probability_threshold\", \"ppcr\"]\n",
- "]\n",
- "\n",
- "# roc_curves[0].show(config={'displayModeBar': False})\n",
- "# roc_curves[1].show(config={'displayModeBar': False})\n",
- "# roc_curves[2].show(config={'displayModeBar': False})\n",
- "# roc_curves[3].show(config={'displayModeBar': False})\n",
- "roc_curves[4].show(config={\"displayModeBar\": False})\n",
- "# roc_curves[5].show(config={'displayModeBar': False})"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "rtichoke",
- "language": "python",
- "name": "rtichoke"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.13.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/docs/reference/.gitignore b/docs/reference/.gitignore
new file mode 100644
index 0000000..8a05323
--- /dev/null
+++ b/docs/reference/.gitignore
@@ -0,0 +1 @@
+**/*.qmd
diff --git a/docs/small_data_example.py b/docs/small_data_example.py
index 50573a1..ec5143f 100644
--- a/docs/small_data_example.py
+++ b/docs/small_data_example.py
@@ -32,7 +32,7 @@ def _():
@app.cell
def _(np, pl):
- probs_test = {
+ probs_dict_test = {
"small_data_set": np.array(
[0.9, 0.85, 0.95, 0.88, 0.6, 0.7, 0.51, 0.2, 0.1, 0.33]
)
@@ -50,7 +50,7 @@ def _(np, pl):
)
data_to_adjust
- return probs_test, reals_dict_test, times_dict_test
+ return probs_dict_test, reals_dict_test, times_dict_test
@app.cell
@@ -60,6 +60,8 @@ def _(create_aj_data_combinations, create_breaks_values):
stratified_by = ["probability_threshold", "ppcr"]
# stratified_by = ["probability_threshold"]
+ fixed_time_horizons = [10.0, 20.0, 30.0, 40.0, 50.0]
+
# stratified_by = ["ppcr"]
heuristics_sets = [
@@ -100,7 +102,7 @@ def _(create_aj_data_combinations, create_breaks_values):
aj_data_combinations = create_aj_data_combinations(
["small_data_set"],
heuristics_sets=heuristics_sets,
- fixed_time_horizons=[10.0, 20.0, 30.0, 40.0, 50.0],
+ fixed_time_horizons=fixed_time_horizons,
stratified_by=stratified_by,
by=by,
breaks=breaks,
@@ -110,7 +112,14 @@ def _(create_aj_data_combinations, create_breaks_values):
# aj_data_combinations
aj_data_combinations
- return aj_data_combinations, breaks, by, heuristics_sets, stratified_by
+ return (
+ aj_data_combinations,
+ breaks,
+ by,
+ fixed_time_horizons,
+ heuristics_sets,
+ stratified_by,
+ )
@app.cell
@@ -118,14 +127,14 @@ def _(
aj_data_combinations,
by,
create_list_data_to_adjust,
- probs_test,
+ probs_dict_test,
reals_dict_test,
stratified_by,
times_dict_test,
):
list_data_to_adjust_polars_probability_threshold = create_list_data_to_adjust(
aj_data_combinations,
- probs_test,
+ probs_dict_test,
reals_dict_test,
times_dict_test,
stratified_by=stratified_by,
@@ -140,6 +149,7 @@ def _(
def _(
breaks,
create_adjusted_data,
+ fixed_time_horizons,
heuristics_sets,
list_data_to_adjust_polars_probability_threshold,
stratified_by,
@@ -147,7 +157,7 @@ def _(
adjusted_data = create_adjusted_data(
list_data_to_adjust_polars_probability_threshold,
heuristics_sets=heuristics_sets,
- fixed_time_horizons=[10.0, 20.0, 30.0, 40.0, 50.0],
+ fixed_time_horizons=fixed_time_horizons,
breaks=breaks,
stratified_by=stratified_by,
# risk_set_scope = ["pooled_by_cutoff"]
@@ -192,7 +202,50 @@ def _(cumulative_aj_data):
return
-@app.cell(column=1, hide_code=True)
+@app.cell(column=1)
+def _():
+ from rtichoke.performance_data.performance_data_times import (
+ prepare_performance_data_times,
+ )
+
+ return (prepare_performance_data_times,)
+
+
+@app.cell
+def _(
+ fixed_time_horizons,
+ prepare_performance_data_times,
+ probs_dict_test,
+ reals_dict_test,
+ times_dict_test,
+):
+ prepare_performance_data_times(
+ probs_dict_test, reals_dict_test, times_dict_test, fixed_time_horizons, by=0.1
+ )
+ return
+
+
+@app.cell
+def _(np):
+ probs_dict_test = {
+ "small_data_set": np.array(
+ [0.9, 0.85, 0.95, 0.88, 0.6, 0.7, 0.51, 0.2, 0.1, 0.33]
+ )
+ }
+ reals_dict_test = [1, 1, 1, 1, 0, 2, 1, 2, 0, 1]
+ times_dict_test = [24.1, 9.7, 49.9, 18.6, 34.8, 14.2, 39.2, 46.0, 31.5, 4.3]
+
+ fixed_time_horizons = [10.0, 20.0, 30.0, 40.0, 50.0]
+
+ return (
+ fixed_time_horizons,
+ probs_dict_test,
+ reals_dict_test,
+ times_dict_test,
+ )
+
+
+@app.cell(column=2, hide_code=True)
def _(mo):
fill_color_radio = mo.ui.radio(
options=["classification_outcome", "reals_labels"],
@@ -275,7 +328,7 @@ def _(mo):
return (competing_heuristic_radio,)
-@app.cell(column=2, hide_code=True)
+@app.cell(column=3, hide_code=True)
def _(
by,
censoring_heuristic_radio,
diff --git a/src/rtichoke/__init__.py b/src/rtichoke/__init__.py
index 4a32461..c2e04b7 100644
--- a/src/rtichoke/__init__.py
+++ b/src/rtichoke/__init__.py
@@ -31,6 +31,10 @@
prepare_performance_data as prepare_performance_data,
)
+from rtichoke.performance_data.performance_data_times import (
+ prepare_performance_data_times as prepare_performance_data_times,
+)
+
from rtichoke.summary_report.summary_report import (
create_summary_report as create_summary_report,
)
@@ -48,5 +52,6 @@
"create_decision_curve",
"plot_decision_curve",
"prepare_performance_data",
+ "prepare_performance_data_times",
"create_summary_report",
]
diff --git a/src/rtichoke/helpers/sandbox_observable_helpers.py b/src/rtichoke/helpers/sandbox_observable_helpers.py
index 730712f..629bc32 100644
--- a/src/rtichoke/helpers/sandbox_observable_helpers.py
+++ b/src/rtichoke/helpers/sandbox_observable_helpers.py
@@ -137,32 +137,22 @@ def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame:
q = int(round(1 / by)) # e.g. 0.2 -> 5 bins
probs = np.asarray(probs, float)
- n = probs.size
- print(f"q = {q}, n = {n}")
- print("probs:", probs)
edges = np.quantile(probs, np.linspace(0.0, 1.0, q + 1), method="linear")
- print("edges before accumulating:", edges)
edges = np.maximum.accumulate(edges)
- print("edges after accumulating:", edges)
edges[0] = 0.0
edges[-1] = 1.0
- print("edges after setting 0 and 1:", edges)
-
bin_idx = np.digitize(probs, bins=edges[1:-1], right=True)
- print("bin_idx:", bin_idx)
s = str(by)
decimals = len(s.split(".")[-1]) if "." in s else 0
labels = [f"{x:.{decimals}f}" for x in np.linspace(by, 1.0, q)]
- print("bin_labels", labels)
strata_labels = np.array([labels[i] for i in bin_idx], dtype=object)
- print("strata_labels:", strata_labels)
columns_to_add.append(
pl.Series("strata_ppcr", strata_labels).cast(pl.Enum(labels))
@@ -205,7 +195,6 @@ def create_strata_combinations(stratified_by: str, by: float, breaks) -> pl.Data
include_upper_bound = np.zeros_like(strata_mid, dtype=bool)
# chosen_cutoff = strata_mid
strata = np.array([fmt.format(x) for x in strata_mid], dtype=object)
- print("strata", strata)
else:
raise ValueError(f"Unsupported stratified_by: {stratified_by}")
@@ -263,6 +252,48 @@ def create_breaks_values(probs_vec, stratified_by, by):
return breaks
+def _create_aj_data_combinations_binary(
+ reference_groups: Sequence[str],
+ stratified_by: Sequence[str],
+ by: float,
+ breaks: Sequence[float],
+) -> pl.DataFrame:
+ dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by]
+
+ strata_combinations = pl.concat(dfs, how="vertical")
+
+ strata_cats = (
+ strata_combinations.select(pl.col("strata").unique(maintain_order=True))
+ .to_series()
+ .to_list()
+ )
+
+ strata_enum = pl.Enum(strata_cats)
+ stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"])
+
+ strata_combinations = strata_combinations.with_columns(
+ [
+ pl.col("strata").cast(strata_enum),
+ pl.col("stratified_by").cast(stratified_by_enum),
+ ]
+ )
+
+ # Define values for Cartesian product
+ reals_labels = ["real_negatives", "real_positives"]
+
+ combinations_frames: list[pl.DataFrame] = [
+ _enum_dataframe("reference_group", reference_groups),
+ strata_combinations,
+ _enum_dataframe("reals_labels", reals_labels),
+ ]
+
+ result = combinations_frames[0]
+ for frame in combinations_frames[1:]:
+ result = result.join(frame, how="cross")
+
+ return result
+
+
def create_aj_data_combinations(
reference_groups: Sequence[str],
heuristics_sets: list[Dict],
@@ -309,8 +340,6 @@ def create_aj_data_combinations(
"real_censored",
]
- print("heuristics_sets", pl.DataFrame(heuristics_sets))
-
heuristics_combinations = pl.DataFrame(heuristics_sets)
censoring_heuristics_enum = pl.Enum(
@@ -411,13 +440,11 @@ def create_aj_data(
fixed_time_horizons,
stratified_by: Sequence[str],
full_event_table: bool = False,
- risk_set_scope: Sequence[str] = "within_stratum",
+ risk_set_scope: Sequence[str] = ["within_stratum"],
):
"""
Create AJ estimates per strata based on censoring and competing heuristicss.
"""
- print("stratified_by", stratified_by)
- print("Creating aj data")
def aj_estimates_with_cross(df, extra_cols):
return df.join(pl.DataFrame(extra_cols), how="cross")
@@ -432,8 +459,6 @@ def aj_estimates_with_cross(df, extra_cols):
event_table, fixed_time_horizons, censoring_heuristic, competing_heuristic
)
- print("stratified_by before _aj_adjusted_events", stratified_by)
-
aj_dfs = []
for rscope in risk_set_scope:
aj_res = _aj_adjusted_events(
@@ -448,9 +473,6 @@ def aj_estimates_with_cross(df, extra_cols):
rscope,
)
- print("aj_res before select", aj_res.columns)
- print("aj_res", aj_res)
-
aj_res = aj_res.select(
[
"strata",
@@ -465,18 +487,10 @@ def aj_estimates_with_cross(df, extra_cols):
]
)
- print("aj_res columns", aj_res.columns)
- print("aj_res", aj_res)
-
aj_dfs.append(aj_res)
aj_df = pl.concat(aj_dfs, how="vertical")
- print("aj_df columns", aj_df.columns)
-
- # print("aj_df")
- # print(aj_df)
-
result = aj_df.join(excluded_events, on=["fixed_time_horizon"], how="left")
return aj_estimates_with_cross(
@@ -737,8 +751,6 @@ def extract_aj_estimate_by_cutoffs(
how="vertical",
)
- print("aj_estimate_by_cutoffs", aj_estimate_by_cutoffs)
-
return aj_estimate_by_cutoffs
@@ -773,9 +785,6 @@ def extract_aj_estimate_for_strata(data_to_adjust, horizons, full_event_table: b
[fixed_df, event_df], how="vertical"
).sort("estimate_origin", "fixed_time_horizon", "times")
- # print("aj_estimate_for_strata_polars")
- # print(aj_estimate_for_strata_polars)
-
return aj_estimate_for_strata_polars.with_columns(
[
(pl.col("state_occupancy_probability_0") * n).alias("real_negatives_est"),
@@ -807,6 +816,74 @@ def assign_and_explode_polars(
)
+def _create_list_data_to_adjust_binary(
+ aj_data_combinations: pl.DataFrame,
+ probs_dict: Dict[str, np.ndarray],
+ reals_dict: Union[np.ndarray, Dict[str, np.ndarray]],
+ stratified_by,
+ by,
+) -> Dict[str, pl.DataFrame]:
+ reference_group_labels = list(probs_dict.keys())
+ num_reals = len(reals_dict)
+
+ reference_group_enum = pl.Enum(reference_group_labels)
+
+ strata_enum_dtype = aj_data_combinations.schema["strata"]
+
+ data_to_adjust = pl.DataFrame(
+ {
+ "reference_group": np.repeat(reference_group_labels, num_reals),
+ "probs": np.concatenate(
+ [probs_dict[group] for group in reference_group_labels]
+ ),
+ "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)),
+ }
+ ).with_columns(pl.col("reference_group").cast(reference_group_enum))
+
+ data_to_adjust = add_cutoff_strata(
+ data_to_adjust, by=by, stratified_by=stratified_by
+ )
+
+ data_to_adjust = pivot_longer_strata(data_to_adjust)
+
+ data_to_adjust = (
+ data_to_adjust.with_columns([pl.col("strata")])
+ .with_columns(pl.col("strata").cast(strata_enum_dtype))
+ .join(
+ aj_data_combinations.select(
+ pl.col("strata"),
+ pl.col("stratified_by"),
+ pl.col("upper_bound"),
+ pl.col("lower_bound"),
+ ).unique(),
+ how="left",
+ on=["strata", "stratified_by"],
+ )
+ )
+
+ reals_labels = ["real_negatives", "real_positives"]
+
+ reals_enum = pl.Enum(reals_labels)
+
+ reals_map = {0: "real_negatives", 1: "real_positives"}
+
+ data_to_adjust = data_to_adjust.with_columns(
+ pl.col("reals")
+ .replace_strict(reals_map, return_dtype=reals_enum)
+ .alias("reals_labels")
+ )
+
+ # Partition by reference_group
+ list_data_to_adjust = {
+ group[0]: df
+ for group, df in data_to_adjust.partition_by(
+ "reference_group", as_dict=True
+ ).items()
+ }
+
+ return list_data_to_adjust
+
+
def create_list_data_to_adjust(
aj_data_combinations: pl.DataFrame,
probs_dict: Dict[str, np.ndarray],
@@ -814,7 +891,7 @@ def create_list_data_to_adjust(
times_dict: Union[np.ndarray, Dict[str, np.ndarray]],
stratified_by,
by,
-):
+) -> Dict[str, pl.DataFrame]:
# reference_groups = list(probs_dict.keys())
reference_group_labels = list(probs_dict.keys())
num_reals = len(reals_dict)
@@ -863,6 +940,7 @@ def create_list_data_to_adjust(
"real_competing",
"real_censored",
]
+
reals_enum = pl.Enum(reals_labels)
# Map reals values to strings
@@ -898,22 +976,14 @@ def extract_aj_estimate_by_heuristics(
heuristics_sets: list[dict],
fixed_time_horizons: list[float],
stratified_by: Sequence[str],
- risk_set_scope: str = "within_stratum",
+ risk_set_scope: Sequence[str] = ["within_stratum"],
) -> pl.DataFrame:
aj_dfs = []
- print("stratified_by", stratified_by)
-
for heuristic in heuristics_sets:
censoring = heuristic["censoring_heuristic"]
competing = heuristic["competing_heuristic"]
- print("stratified_by", stratified_by)
-
- print("df before create_aj_data")
- print(df.columns)
- print(df.schema)
-
aj_df = create_aj_data(
df,
breaks,
@@ -932,12 +1002,8 @@ def extract_aj_estimate_by_heuristics(
aj_dfs.append(aj_df)
- # print("aj_dfs", aj_dfs)
-
aj_estimates_data = pl.concat(aj_dfs).drop(["estimate_origin", "times"])
- print("aj_estimates_data", aj_estimates_data)
-
aj_estimates_unpivoted = aj_estimates_data.unpivot(
index=[
"strata",
@@ -951,22 +1017,36 @@ def extract_aj_estimate_by_heuristics(
value_name="reals_estimate",
)
- print("aj_estimates_unpivoted", aj_estimates_unpivoted)
-
return aj_estimates_unpivoted
+def _create_adjusted_data_binary(
+ list_data_to_adjust: dict[str, pl.DataFrame],
+ breaks: Sequence[float],
+ stratified_by: Sequence[str],
+) -> pl.DataFrame:
+ long_df = pl.concat(list(list_data_to_adjust.values()), how="vertical")
+
+ adjusted_data_binary = (
+ long_df.group_by(["strata", "stratified_by", "reference_group", "reals_labels"])
+ .agg(pl.sum("reals").alias("reals_estimate"))
+ .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
+ )
+
+ return adjusted_data_binary
+
+
def create_adjusted_data(
- list_data_to_adjust_polars: dict[str, pl.DataFrame],
+ list_data_to_adjust: dict[str, pl.DataFrame],
heuristics_sets: list[dict[str, str]],
fixed_time_horizons: list[float],
breaks: Sequence[float],
stratified_by: Sequence[str],
- risk_set_scope: str = "within_stratum",
+ risk_set_scope: Sequence[str] = ["within_stratum"],
) -> pl.DataFrame:
all_results = []
- reference_groups = list(list_data_to_adjust_polars.keys())
+ reference_groups = list(list_data_to_adjust.keys())
reference_group_enum = pl.Enum(reference_groups)
heuristics_df = pl.DataFrame(heuristics_sets)
@@ -977,13 +1057,11 @@ def create_adjusted_data(
heuristics_df["competing_heuristic"].unique(maintain_order=True)
)
- for reference_group, df in list_data_to_adjust_polars.items():
+ for reference_group, df in list_data_to_adjust.items():
input_df = df.select(
["strata", "reals", "times", "upper_bound", "lower_bound", "stratified_by"]
)
- print("stratified_by", stratified_by)
-
aj_result = extract_aj_estimate_by_heuristics(
input_df,
breaks,
@@ -1003,8 +1081,6 @@ def create_adjusted_data(
all_results.append(aj_result_with_group)
- print("all_results", all_results)
-
reals_enum_dtype = pl.Enum(
[
"real_negatives",
@@ -1027,7 +1103,86 @@ def create_adjusted_data(
)
-def cast_and_join_adjusted_data(aj_data_combinations, aj_estimates_data):
+def _cast_and_join_adjusted_data_binary(
+ aj_data_combinations: pl.DataFrame, aj_estimates_data: pl.DataFrame
+) -> pl.DataFrame:
+ strata_enum_dtype = aj_data_combinations.schema["strata"]
+
+ aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns(
+ pl.col("strata").cast(strata_enum_dtype)
+ )
+
+ final_adjusted_data_polars = (
+ (
+ aj_data_combinations.with_columns([pl.col("strata")]).join(
+ aj_estimates_data,
+ on=[
+ "strata",
+ "stratified_by",
+ "reals_labels",
+ "reference_group",
+ "chosen_cutoff",
+ ],
+ how="left",
+ )
+ )
+ .with_columns(
+ pl.when(
+ (
+ (pl.col("chosen_cutoff") >= pl.col("upper_bound"))
+ & (pl.col("stratified_by") == "probability_threshold")
+ )
+ | (
+ ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point"))
+ & (pl.col("stratified_by") == "ppcr")
+ )
+ )
+ .then(pl.lit("predicted_negatives"))
+ .otherwise(pl.lit("predicted_positives"))
+ .cast(pl.Enum(["predicted_negatives", "predicted_positives"]))
+ .alias("prediction_label")
+ )
+ .with_columns(
+ (
+ pl.when(
+ (pl.col("prediction_label") == pl.lit("predicted_positives"))
+ & (pl.col("reals_labels") == pl.lit("real_positives"))
+ )
+ .then(pl.lit("true_positives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_positives"))
+ & (pl.col("reals_labels") == pl.lit("real_negatives"))
+ )
+ .then(pl.lit("false_positives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_negatives"))
+ & (pl.col("reals_labels") == pl.lit("real_negatives"))
+ )
+ .then(pl.lit("true_negatives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_negatives"))
+ & (pl.col("reals_labels") == pl.lit("real_positives"))
+ )
+ .then(pl.lit("false_negatives"))
+ .cast(
+ pl.Enum(
+ [
+ "true_positives",
+ "false_positives",
+ "true_negatives",
+ "false_negatives",
+ ]
+ )
+ )
+ ).alias("classification_outcome")
+ )
+ )
+ return final_adjusted_data_polars
+
+
+def cast_and_join_adjusted_data(
+ aj_data_combinations, aj_estimates_data
+) -> pl.DataFrame:
strata_enum_dtype = aj_data_combinations.schema["strata"]
aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns(
@@ -1190,18 +1345,13 @@ def _aj_adjusted_events(
horizons: list[float],
stratified_by: Sequence[str],
full_event_table: bool = False,
- risk_set_scope: str = "within_stratum",
+ risk_set_scope: Sequence[str] = ["within_stratum"],
) -> pl.DataFrame:
- print("reference_group_data")
- print(reference_group_data)
-
strata_enum_dtype = reference_group_data.schema["strata"]
# Special-case: adjusted censoring + competing adjusted_as_negative supports pooled_by_cutoff
if censoring == "adjusted" and competing == "adjusted_as_negative":
if risk_set_scope == "within_stratum":
- print("reference_group_data", reference_group_data)
-
adjusted = (
reference_group_data.group_by("strata")
.map_groups(
@@ -1225,8 +1375,6 @@ def _aj_adjusted_events(
return adjusted
elif risk_set_scope == "pooled_by_cutoff":
- print("reference_group_data", reference_group_data)
-
adjusted = extract_aj_estimate_by_cutoffs(
reference_group_data, horizons, breaks, stratified_by, full_event_table
)
@@ -1260,8 +1408,6 @@ def _aj_adjusted_events(
# Special-case: competing excluded (handled by filtering out competing events)
if competing == "excluded":
- print("running for censoring adjusted and competing excluded")
-
# Use exploded to apply filters that depend on fixed_time_horizon consistently
non_competing = exploded.filter(
(pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") != 2)
@@ -1272,8 +1418,6 @@ def _aj_adjusted_events(
.alias("reals")
)
- print("non_competing data", non_competing)
-
if risk_set_scope == "within_stratum":
adjusted = (
_aj_estimates_per_horizon(non_competing, horizons, full_event_table)
@@ -1286,8 +1430,6 @@ def _aj_adjusted_events(
non_competing, horizons, breaks, stratified_by, full_event_table
)
- print("adjusted after join cutoffs", adjusted)
-
adjusted = adjusted.with_columns(
[
pl.col("strata").cast(strata_enum_dtype),
@@ -1337,8 +1479,6 @@ def _aj_adjusted_events(
pl.DataFrame({"chosen_cutoff": breaks}), how="cross"
)
- print("adjusted after join", adjusted)
-
elif risk_set_scope == "pooled_by_cutoff":
adjusted = _aj_estimates_by_cutoff_per_horizon(
base_df, horizons, breaks, stratified_by
@@ -1382,6 +1522,65 @@ def _aj_adjusted_events(
return adjusted
+def _calculate_cumulative_aj_data_binary(aj_data: pl.DataFrame) -> pl.DataFrame:
+ cumulative_aj_data = (
+ aj_data.group_by(
+ [
+ "reference_group",
+ "stratified_by",
+ "chosen_cutoff",
+ "classification_outcome",
+ ]
+ )
+ .agg([pl.col("reals_estimate").sum()])
+ .pivot(on="classification_outcome", values="reals_estimate")
+ .with_columns(
+ (pl.col("true_positives") + pl.col("false_positives")).alias(
+ "predicted_positives"
+ ),
+ (pl.col("true_negatives") + pl.col("false_negatives")).alias(
+ "predicted_negatives"
+ ),
+ (pl.col("true_positives") + pl.col("false_negatives")).alias(
+ "real_positives"
+ ),
+ (pl.col("false_positives") + pl.col("true_negatives")).alias(
+ "real_negatives"
+ ),
+ (
+ pl.col("true_positives")
+ + pl.col("true_negatives")
+ + pl.col("false_positives")
+ + pl.col("false_negatives")
+ )
+ .alias("n")
+ .sum(),
+ )
+ .with_columns(
+ (pl.col("true_positives") + pl.col("false_positives")).alias(
+ "predicted_positives"
+ ),
+ (pl.col("true_negatives") + pl.col("false_negatives")).alias(
+ "predicted_negatives"
+ ),
+ (pl.col("true_positives") + pl.col("false_negatives")).alias(
+ "real_positives"
+ ),
+ (pl.col("false_positives") + pl.col("true_negatives")).alias(
+ "real_negatives"
+ ),
+ (
+ pl.col("true_positives")
+ + pl.col("true_negatives")
+ + pl.col("false_positives")
+ + pl.col("false_negatives")
+ ).alias("n"),
+ )
+ )
+
+ return cumulative_aj_data
+
+
def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame:
cumulative_aj_data = (
aj_data.filter(pl.col("risk_set_scope") == "pooled_by_cutoff")
diff --git a/src/rtichoke/performance_data/performance_data.py b/src/rtichoke/performance_data/performance_data.py
index 5a8f72d..772de6e 100644
--- a/src/rtichoke/performance_data/performance_data.py
+++ b/src/rtichoke/performance_data/performance_data.py
@@ -2,40 +2,87 @@
A module for Performance Data
"""
-from typing import Dict, List
-from pandas.core.frame import DataFrame
-import pandas as pd
-from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r
+from typing import Dict, Union
+import polars as pl
+from collections.abc import Sequence
+from rtichoke.helpers.sandbox_observable_helpers import (
+ _create_aj_data_combinations_binary,
+ create_breaks_values,
+ _create_list_data_to_adjust_binary,
+ _create_adjusted_data_binary,
+ _cast_and_join_adjusted_data_binary,
+ _calculate_cumulative_aj_data_binary,
+ _turn_cumulative_aj_to_performance_data,
+)
+import numpy as np
def prepare_performance_data(
- probs: Dict[str, List[float]],
- reals: Dict[str, List[int]],
- stratified_by: str = "probability_threshold",
- url_api: str = "http://localhost:4242/",
-) -> DataFrame:
- """Prepare Performance Data
-
- Args:
- probs (Dict[str, List[float]]): _description_
- reals (Dict[str, List[int]]): _description_
- stratified_by (str, optional): _description_. Defaults to "probability_threshold".
- url_api (_type_, optional): _description_. Defaults to "http://localhost:4242/".
-
- Returns:
- DataFrame: _description_
+ probs: Dict[str, np.ndarray],
+ reals: Union[np.ndarray, Dict[str, np.ndarray]],
+ stratified_by: Sequence[str] = ["probability_threshold"],
+ by: float = 0.01,
+) -> pl.DataFrame:
+ """Prepare performance data for binary classification.
+
+ Parameters
+ ----------
+ probs : Dict[str, np.ndarray]
+ Mapping from dataset name to predicted probabilities (1-D numpy arrays).
+ reals : Union[np.ndarray, Dict[str, np.ndarray]]
+ True event labels. Can be a single array aligned to pooled probabilities
+ or a dictionary mapping each dataset name to its true-label array. Labels
+ are expected to be binary integers (0/1).
+ stratified_by : Sequence[str], optional
+ Stratification variables used to create combinations/breaks. Defaults to
+ ``["probability_threshold"]``.
+ by : float, optional
+ Step width for probability-threshold breaks (used to create the grid of
+ cutoffs). Defaults to ``0.01``.
+
+ Returns
+ -------
+ pl.DataFrame
+ A Polars DataFrame containing performance metrics computed across probability
+ thresholds. Columns include the probability cutoff and performance measures.
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> probs_dict_test = {
+ ... "small_data_set": np.array(
+ ... [0.9, 0.85, 0.95, 0.88, 0.6, 0.7, 0.51, 0.2, 0.1, 0.33]
+ ... )
+ ... }
+ >>> reals_dict_test = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1]
+
+ >>> prepare_performance_data(
+ ... probs_dict_test,
+ ... reals_dict_test,
+ ... by = 0.1
+ ... )
"""
- rtichoke_response = send_requests_to_rtichoke_r(
- dictionary_to_send={
- "probs": probs,
- "reals": reals,
- "stratified_by": stratified_by,
- },
- url_api=url_api,
- endpoint="prepare_performance_data",
+
+ breaks = create_breaks_values(None, "probability_threshold", by)
+
+ aj_data_combinations = _create_aj_data_combinations_binary(
+ list(probs.keys()), stratified_by=stratified_by, by=by, breaks=breaks
)
- performance_data = pd.DataFrame(
- rtichoke_response.json(), columns=list(rtichoke_response.json()[0].keys())
+ list_data_to_adjust = _create_list_data_to_adjust_binary(
+ aj_data_combinations, probs, reals, stratified_by=stratified_by, by=by
)
+
+ adjusted_data = _create_adjusted_data_binary(
+ list_data_to_adjust, breaks=breaks, stratified_by=stratified_by
+ )
+
+ final_adjusted_data = _cast_and_join_adjusted_data_binary(
+ aj_data_combinations, adjusted_data
+ )
+
+ cumulative_aj_data = _calculate_cumulative_aj_data_binary(final_adjusted_data)
+
+ performance_data = _turn_cumulative_aj_to_performance_data(cumulative_aj_data)
+
return performance_data
diff --git a/src/rtichoke/performance_data/performance_data_times.py b/src/rtichoke/performance_data/performance_data_times.py
new file mode 100644
index 0000000..0cc0b9d
--- /dev/null
+++ b/src/rtichoke/performance_data/performance_data_times.py
@@ -0,0 +1,123 @@
+"""
+A module for Performance Data with Time Dimension
+"""
+
+from typing import Dict, Union
+import polars as pl
+from collections.abc import Sequence
+from rtichoke.helpers.sandbox_observable_helpers import (
+ create_breaks_values,
+ create_aj_data_combinations,
+ create_list_data_to_adjust,
+ create_adjusted_data,
+ cast_and_join_adjusted_data,
+ _calculate_cumulative_aj_data,
+ _turn_cumulative_aj_to_performance_data,
+)
+
+import numpy as np
+
+
+def prepare_performance_data_times(
+ probs: Dict[str, np.ndarray],
+ reals: Union[np.ndarray, Dict[str, np.ndarray]],
+ times: Union[np.ndarray, Dict[str, np.ndarray]],
+ fixed_time_horizons: list[float],
+ heuristics_sets: list[Dict] = [
+ {
+ "censoring_heuristic": "adjusted",
+ "competing_heuristic": "adjusted_as_negative",
+ }
+ ],
+ stratified_by: Sequence[str] = ["probability_threshold"],
+ by: float = 0.01,
+) -> pl.DataFrame:
+ """Prepare performance data with a time dimension.
+
+ Parameters
+ ----------
+ probs : Dict[str, np.ndarray]
+ Mapping from dataset name to predicted probabilities (1-D numpy arrays).
+ reals : Union[np.ndarray, Dict[str, np.ndarray]]
+ True event labels. Can be a single array aligned to pooled probabilities
+ or a dictionary mapping each dataset name to its true-label array. Labels
+ are expected to be integers (e.g., 0/1 for binary, or competing event codes).
+ times : Union[np.ndarray, Dict[str, np.ndarray]]
+ Event or censoring times corresponding to `reals`. Either a single array
+ or a dictionary keyed like `probs` when multiple datasets are provided.
+ fixed_time_horizons : list[float]
+ Fixed time horizons (same units as `times`) at which to evaluate performance.
+ heuristics_sets : list[Dict], optional
+ List of heuristic dictionaries controlling censoring/competing-event handling.
+ Default is a single heuristic set:
+ ``[{"censoring_heuristic": "adjusted", "competing_heuristic": "adjusted_as_negative"}]``
+ stratified_by : Sequence[str], optional
+ Stratification variables used to create combinations/breaks. Defaults to
+ ``["probability_threshold"]``.
+ by : float, optional
+ Step width for probability-threshold breaks (used to create the grid of
+ cutoffs). Defaults to ``0.01``.
+
+ Returns
+ -------
+ pl.DataFrame
+ A Polars DataFrame containing performance metrics computed across probability
+ thresholds and fixed time horizons. Columns include the probability cutoff,
+ fixed time horizon, heuristic identifiers, and AJ-derived performance measures.
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> probs_dict_test = {
+ ... "small_data_set": np.array(
+ ... [0.9, 0.85, 0.95, 0.88, 0.6, 0.7, 0.51, 0.2, 0.1, 0.33]
+ ... )
+ ... }
+ >>> reals_dict_test = [1, 1, 1, 1, 0, 2, 1, 2, 0, 1]
+ >>> times_dict_test = [24.1, 9.7, 49.9, 18.6, 34.8, 14.2, 39.2, 46.0, 31.5, 4.3]
+ >>> fixed_time_horizons = [10.0, 20.0, 30.0, 40.0, 50.0]
+
+ >>> prepare_performance_data_times(
+ ... probs_dict_test,
+ ... reals_dict_test,
+ ... times_dict_test,
+ ... fixed_time_horizons,
+ ... by = 0.1
+ ... )
+ """
+
+ breaks = create_breaks_values(None, "probability_threshold", by)
+ risk_set_scope = ["pooled_by_cutoff"]
+
+ aj_data_combinations = create_aj_data_combinations(
+ list(probs.keys()),
+ heuristics_sets=heuristics_sets,
+ fixed_time_horizons=fixed_time_horizons,
+ stratified_by=stratified_by,
+ by=by,
+ breaks=breaks,
+ risk_set_scope=risk_set_scope,
+ )
+
+ list_data_to_adjust = create_list_data_to_adjust(
+ aj_data_combinations, probs, reals, times, stratified_by=stratified_by, by=by
+ )
+
+ adjusted_data = create_adjusted_data(
+ list_data_to_adjust,
+ heuristics_sets=heuristics_sets,
+ fixed_time_horizons=fixed_time_horizons,
+ breaks=breaks,
+ stratified_by=stratified_by,
+ risk_set_scope=risk_set_scope,
+ )
+
+ final_adjusted_data = cast_and_join_adjusted_data(
+ aj_data_combinations, adjusted_data
+ )
+
+ cumulative_aj_data = _calculate_cumulative_aj_data(final_adjusted_data)
+
+ performance_data = _turn_cumulative_aj_to_performance_data(cumulative_aj_data)
+
+ return performance_data