From 7838d62561a7b97a027f544e17b84495dd0951ca Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 21 Feb 2026 16:17:36 -0500 Subject: [PATCH 1/9] Add ContinuousDiD estimator for continuous treatment dose-response Implement Callaway, Goodman-Bacon & Sant'Anna (2024) continuous treatment DiD estimator with B-spline dose-response curves, ACRT derivatives, staggered adoption support, and multiplier bootstrap inference. Validated against R contdid v0.1.0 across 6 benchmarks. Also extract shared bootstrap utilities to bootstrap_utils.py and fix plt.show() blocking in test suite via non-interactive backend. Co-Authored-By: Claude Opus 4.6 --- diff_diff/__init__.py | 10 + diff_diff/bootstrap_utils.py | 265 ++++++ diff_diff/continuous_did.py | 976 +++++++++++++++++++++++ diff_diff/continuous_did_bspline.py | 188 +++++ diff_diff/continuous_did_results.py | 326 ++++++++ diff_diff/prep.py | 1 + diff_diff/prep_dgp.py | 157 +++- diff_diff/staggered_bootstrap.py | 274 +------ docs/methodology/REGISTRY.md | 51 ++ docs/methodology/continuous-did.md | 431 ++++++++++ tests/conftest.py | 4 + tests/test_continuous_did.py | 506 ++++++++++++ tests/test_methodology_continuous_did.py | 654 +++++++++++++++ 13 files changed, 3591 insertions(+), 252 deletions(-) create mode 100644 diff_diff/bootstrap_utils.py create mode 100644 diff_diff/continuous_did.py create mode 100644 diff_diff/continuous_did_bspline.py create mode 100644 diff_diff/continuous_did_results.py create mode 100644 docs/methodology/continuous-did.md create mode 100644 tests/test_continuous_did.py create mode 100644 tests/test_methodology_continuous_did.py diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index 39545fd..b0e49b1 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -70,6 +70,7 @@ aggregate_to_cohorts, balance_panel, create_event_time, + generate_continuous_did_data, generate_did_data, generate_ddd_data, generate_event_study_data, @@ -122,6 +123,11 @@ TripleDifferenceResults, triple_difference, ) +from diff_diff.continuous_did import ( + ContinuousDiD, + ContinuousDiDResults, + DoseResponseCurve, +) from diff_diff.trop import ( TROP, TROPResults, @@ -161,6 +167,7 @@ "MultiPeriodDiD", "SyntheticDiD", "CallawaySantAnna", + "ContinuousDiD", "SunAbraham", "ImputationDiD", "TwoStageDiD", @@ -181,6 +188,8 @@ "CallawaySantAnnaResults", "CSBootstrapResults", "GroupTimeEffect", + "ContinuousDiDResults", + "DoseResponseCurve", "SunAbrahamResults", "SABootstrapResults", "ImputationDiDResults", @@ -228,6 +237,7 @@ "generate_ddd_data", "generate_panel_data", "generate_event_study_data", + "generate_continuous_did_data", "create_event_time", "aggregate_to_cohorts", "rank_control_units", diff --git a/diff_diff/bootstrap_utils.py b/diff_diff/bootstrap_utils.py new file mode 100644 index 0000000..eb6837b --- /dev/null +++ b/diff_diff/bootstrap_utils.py @@ -0,0 +1,265 @@ +""" +Shared bootstrap utilities for multiplier bootstrap inference. + +Provides weight generation, percentile CI, and p-value helpers used by +both CallawaySantAnna and ContinuousDiD estimators. +""" + +import warnings +from typing import Optional, Tuple + +import numpy as np + +from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights + +__all__ = [ + "generate_bootstrap_weights", + "generate_bootstrap_weights_batch", + "generate_bootstrap_weights_batch_numpy", + "compute_percentile_ci", + "compute_bootstrap_pvalue", + "compute_effect_bootstrap_stats", +] + + +def generate_bootstrap_weights( + n_units: int, + weight_type: str, + rng: np.random.Generator, +) -> np.ndarray: + """ + Generate bootstrap weights for multiplier bootstrap. + + Parameters + ---------- + n_units : int + Number of units (clusters) to generate weights for. + weight_type : str + Type of weights: "rademacher", "mammen", or "webb". + rng : np.random.Generator + Random number generator. + + Returns + ------- + np.ndarray + Array of bootstrap weights with shape (n_units,). + """ + if weight_type == "rademacher": + return rng.choice([-1.0, 1.0], size=n_units) + elif weight_type == "mammen": + sqrt5 = np.sqrt(5) + val1 = -(sqrt5 - 1) / 2 + val2 = (sqrt5 + 1) / 2 + p1 = (sqrt5 + 1) / (2 * sqrt5) + return rng.choice([val1, val2], size=n_units, p=[p1, 1 - p1]) + elif weight_type == "webb": + values = np.array([ + -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2), + np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2) + ]) + return rng.choice(values, size=n_units) + else: + raise ValueError( + f"weight_type must be 'rademacher', 'mammen', or 'webb', " + f"got '{weight_type}'" + ) + + +def generate_bootstrap_weights_batch( + n_bootstrap: int, + n_units: int, + weight_type: str, + rng: np.random.Generator, +) -> np.ndarray: + """ + Generate all bootstrap weights at once (vectorized). + + Uses Rust backend if available for parallel generation. + + Parameters + ---------- + n_bootstrap : int + Number of bootstrap iterations. + n_units : int + Number of units (clusters) to generate weights for. + weight_type : str + Type of weights: "rademacher", "mammen", or "webb". + rng : np.random.Generator + Random number generator. + + Returns + ------- + np.ndarray + Array of bootstrap weights with shape (n_bootstrap, n_units). + """ + if HAS_RUST_BACKEND and _rust_bootstrap_weights is not None: + seed = rng.integers(0, 2**63 - 1) + return _rust_bootstrap_weights(n_bootstrap, n_units, weight_type, seed) + return generate_bootstrap_weights_batch_numpy(n_bootstrap, n_units, weight_type, rng) + + +def generate_bootstrap_weights_batch_numpy( + n_bootstrap: int, + n_units: int, + weight_type: str, + rng: np.random.Generator, +) -> np.ndarray: + """ + NumPy fallback implementation of :func:`generate_bootstrap_weights_batch`. + + Parameters + ---------- + n_bootstrap : int + Number of bootstrap iterations. + n_units : int + Number of units (clusters) to generate weights for. + weight_type : str + Type of weights: "rademacher", "mammen", or "webb". + rng : np.random.Generator + Random number generator. + + Returns + ------- + np.ndarray + Array of bootstrap weights with shape (n_bootstrap, n_units). + """ + if weight_type == "rademacher": + return rng.choice([-1.0, 1.0], size=(n_bootstrap, n_units)) + elif weight_type == "mammen": + sqrt5 = np.sqrt(5) + val1 = -(sqrt5 - 1) / 2 + val2 = (sqrt5 + 1) / 2 + p1 = (sqrt5 + 1) / (2 * sqrt5) + return rng.choice([val1, val2], size=(n_bootstrap, n_units), p=[p1, 1 - p1]) + elif weight_type == "webb": + values = np.array([ + -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2), + np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2) + ]) + return rng.choice(values, size=(n_bootstrap, n_units)) + else: + raise ValueError( + f"weight_type must be 'rademacher', 'mammen', or 'webb', " + f"got '{weight_type}'" + ) + + +def compute_percentile_ci( + boot_dist: np.ndarray, + alpha: float, +) -> Tuple[float, float]: + """ + Compute percentile confidence interval from bootstrap distribution. + + Parameters + ---------- + boot_dist : np.ndarray + Bootstrap distribution (1-D array). + alpha : float + Significance level (e.g., 0.05 for 95% CI). + + Returns + ------- + tuple of float + ``(lower, upper)`` confidence interval bounds. + """ + lower = float(np.percentile(boot_dist, alpha / 2 * 100)) + upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100)) + return (lower, upper) + + +def compute_bootstrap_pvalue( + original_effect: float, + boot_dist: np.ndarray, + n_valid: Optional[int] = None, +) -> float: + """ + Compute two-sided bootstrap p-value using the percentile method. + + Parameters + ---------- + original_effect : float + Original point estimate. + boot_dist : np.ndarray + Bootstrap distribution of the effect. + n_valid : int, optional + Number of valid bootstrap samples for p-value floor. + If None, uses ``len(boot_dist)``. + + Returns + ------- + float + Two-sided bootstrap p-value. + """ + if original_effect >= 0: + p_one_sided = np.mean(boot_dist <= 0) + else: + p_one_sided = np.mean(boot_dist >= 0) + + p_value = min(2 * p_one_sided, 1.0) + n_for_floor = n_valid if n_valid is not None else len(boot_dist) + p_value = max(p_value, 1 / (n_for_floor + 1)) + return float(p_value) + + +def compute_effect_bootstrap_stats( + original_effect: float, + boot_dist: np.ndarray, + alpha: float = 0.05, + context: str = "bootstrap distribution", +) -> Tuple[float, Tuple[float, float], float]: + """ + Compute bootstrap statistics for a single effect. + + Filters non-finite samples, returning NaN for all statistics if + fewer than 50% of samples are valid. + + Parameters + ---------- + original_effect : float + Original point estimate. + boot_dist : np.ndarray + Bootstrap distribution of the effect. + alpha : float, default=0.05 + Significance level. + context : str, optional + Description for warning messages. + + Returns + ------- + se : float + Bootstrap standard error. + ci : tuple of float + Percentile confidence interval. + p_value : float + Bootstrap p-value. + """ + finite_mask = np.isfinite(boot_dist) + n_valid = np.sum(finite_mask) + n_total = len(boot_dist) + + if n_valid < n_total: + n_nonfinite = n_total - n_valid + warnings.warn( + f"Dropping {n_nonfinite}/{n_total} non-finite bootstrap samples " + f"in {context}. Bootstrap estimates based on remaining valid samples.", + RuntimeWarning, + stacklevel=3, + ) + + if n_valid < n_total * 0.5: + warnings.warn( + f"Too few valid bootstrap samples ({n_valid}/{n_total}) in {context}. " + "Returning NaN for SE/CI/p-value to signal invalid inference.", + RuntimeWarning, + stacklevel=3, + ) + return np.nan, (np.nan, np.nan), np.nan + + valid_dist = boot_dist[finite_mask] + se = float(np.std(valid_dist, ddof=1)) + ci = compute_percentile_ci(valid_dist, alpha) + p_value = compute_bootstrap_pvalue( + original_effect, valid_dist, n_valid=len(valid_dist) + ) + return se, ci, p_value diff --git a/diff_diff/continuous_did.py b/diff_diff/continuous_did.py new file mode 100644 index 0000000..f8f9851 --- /dev/null +++ b/diff_diff/continuous_did.py @@ -0,0 +1,976 @@ +""" +Continuous Difference-in-Differences estimator. + +Implements Callaway, Goodman-Bacon & Sant'Anna (2024), +"Difference-in-Differences with a Continuous Treatment" (NBER WP 32117). + +Estimates dose-response curves ATT(d) and ACRT(d), as well as summary +parameters ATT^{glob} and ACRT^{glob}, with optional multiplier bootstrap +inference. +""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.bootstrap_utils import ( + compute_effect_bootstrap_stats, + generate_bootstrap_weights_batch, +) +from diff_diff.continuous_did_bspline import ( + bspline_derivative_design_matrix, + bspline_design_matrix, + build_bspline_basis, + default_dose_grid, +) +from diff_diff.continuous_did_results import ( + ContinuousDiDResults, + DoseResponseCurve, +) +from diff_diff.linalg import solve_ols +from diff_diff.utils import safe_inference + +__all__ = ["ContinuousDiD", "ContinuousDiDResults", "DoseResponseCurve"] + + +class ContinuousDiD: + """ + Continuous Difference-in-Differences estimator. + + Implements the methodology from Callaway, Goodman-Bacon & Sant'Anna (2024) + for estimating dose-response curves when treatment has a continuous intensity. + + Parameters + ---------- + degree : int, default=3 + B-spline degree (3 = cubic). + num_knots : int, default=0 + Number of interior knots for the B-spline basis. + dvals : array-like, optional + Custom dose evaluation grid. If None, uses quantile-based default. + control_group : str, default="never_treated" + ``"never_treated"`` or ``"not_yet_treated"``. + anticipation : int, default=0 + Number of periods of treatment anticipation. + base_period : str, default="varying" + ``"varying"`` or ``"universal"``. + alpha : float, default=0.05 + Significance level for confidence intervals. + n_bootstrap : int, default=0 + Number of multiplier bootstrap iterations. 0 for analytical SEs only. + bootstrap_weights : str, default="rademacher" + Bootstrap weight type: ``"rademacher"``, ``"mammen"``, or ``"webb"``. + seed : int, optional + Random seed for reproducibility. + rank_deficient_action : str, default="warn" + Action for rank-deficient B-spline OLS: ``"warn"``, ``"error"``, or ``"silent"``. + + Examples + -------- + >>> from diff_diff import ContinuousDiD, generate_continuous_did_data + >>> data = generate_continuous_did_data(n_units=200, seed=42) + >>> est = ContinuousDiD(n_bootstrap=199, seed=42) + >>> results = est.fit(data, outcome="outcome", unit="unit", + ... time="period", first_treat="first_treat", + ... dose="dose", aggregate="dose") + >>> results.overall_att # doctest: +SKIP + """ + + def __init__( + self, + degree: int = 3, + num_knots: int = 0, + dvals: Optional[np.ndarray] = None, + control_group: str = "never_treated", + anticipation: int = 0, + base_period: str = "varying", + alpha: float = 0.05, + n_bootstrap: int = 0, + bootstrap_weights: str = "rademacher", + seed: Optional[int] = None, + rank_deficient_action: str = "warn", + ): + self.degree = degree + self.num_knots = num_knots + self.dvals = np.asarray(dvals, dtype=float) if dvals is not None else None + self.control_group = control_group + self.anticipation = anticipation + self.base_period = base_period + self.alpha = alpha + self.n_bootstrap = n_bootstrap + self.bootstrap_weights = bootstrap_weights + self.seed = seed + self.rank_deficient_action = rank_deficient_action + + def get_params(self) -> Dict[str, Any]: + """Return estimator parameters as a dictionary.""" + return { + "degree": self.degree, + "num_knots": self.num_knots, + "dvals": self.dvals, + "control_group": self.control_group, + "anticipation": self.anticipation, + "base_period": self.base_period, + "alpha": self.alpha, + "n_bootstrap": self.n_bootstrap, + "bootstrap_weights": self.bootstrap_weights, + "seed": self.seed, + "rank_deficient_action": self.rank_deficient_action, + } + + def set_params(self, **params) -> "ContinuousDiD": + """Set estimator parameters and return self.""" + for key, value in params.items(): + if not hasattr(self, key): + raise ValueError(f"Invalid parameter: {key}") + setattr(self, key, value) + return self + + # ------------------------------------------------------------------ + # Main fit + # ------------------------------------------------------------------ + + def fit( + self, + data: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + dose: str, + aggregate: Optional[str] = None, + ) -> ContinuousDiDResults: + """ + Fit the continuous DiD estimator. + + Parameters + ---------- + data : pd.DataFrame + Panel data. + outcome : str + Outcome column name. + unit : str + Unit identifier column. + time : str + Time period column. + first_treat : str + First treatment period column (0 or inf for never-treated). + dose : str + Continuous dose column. + aggregate : str, optional + ``"dose"`` for dose-response aggregation, ``"eventstudy"`` for + binarized event study. + + Returns + ------- + ContinuousDiDResults + """ + # 1. Validate & prepare + df = data.copy() + for col in [outcome, unit, time, first_treat, dose]: + if col not in df.columns: + raise ValueError(f"Column '{col}' not found in data.") + + # Verify dose is time-invariant + dose_nunique = df.groupby(unit)[dose].nunique() + if dose_nunique.max() > 1: + bad_units = dose_nunique[dose_nunique > 1].index.tolist() + raise ValueError( + f"Dose must be time-invariant. Units with varying dose: {bad_units[:5]}" + ) + + # Normalize first_treat: inf → 0 + df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0) + + # Drop units with positive first_treat but zero dose (R convention) + unit_info = df.groupby(unit).first()[[first_treat, dose]] + drop_units = unit_info[ + (unit_info[first_treat] > 0) & (unit_info[dose] == 0) + ].index + if len(drop_units) > 0: + warnings.warn( + f"Dropping {len(drop_units)} units with positive first_treat but zero dose.", + UserWarning, + stacklevel=2, + ) + df = df[~df[unit].isin(drop_units)] + + # Force dose=0 for never-treated units with nonzero dose + never_treated_mask = df[first_treat] == 0 + if (df.loc[never_treated_mask, dose] != 0).any(): + df.loc[never_treated_mask, dose] = 0.0 + + # Verify balanced panel + obs_per_unit = df.groupby(unit)[time].nunique() + if obs_per_unit.nunique() > 1: + raise ValueError( + "Unbalanced panel detected. ContinuousDiD requires a balanced panel." + ) + + # Identify groups and time periods + unit_cohort = df.groupby(unit)[first_treat].first() + treatment_groups = sorted([g for g in unit_cohort.unique() if g > 0]) + time_periods = sorted(df[time].unique()) + + if len(treatment_groups) == 0: + raise ValueError("No treated units found (all first_treat == 0).") + + n_control = int((unit_cohort == 0).sum()) + if self.control_group == "never_treated" and n_control == 0: + raise ValueError( + "No never-treated units found. Use control_group='not_yet_treated' " + "or add never-treated units." + ) + + # 2. Precompute structures + precomp = self._precompute_structures( + df, outcome, unit, time, first_treat, dose, time_periods + ) + + # Compute dvals (evaluation grid) + all_treated_doses = precomp["dose_vector"][precomp["dose_vector"] > 0] + if self.dvals is not None: + dvals = self.dvals + else: + dvals = default_dose_grid(all_treated_doses) + + # Build B-spline knots from all treated doses + knots, degree = build_bspline_basis( + all_treated_doses, degree=self.degree, num_knots=self.num_knots + ) + + # 3. Iterate over (g,t) cells + gt_results = {} + gt_bootstrap_info = {} + + for g in treatment_groups: + for t in time_periods: + result = self._compute_dose_response_gt( + precomp, g, t, knots, degree, dvals + ) + if result is not None: + gt_results[(g, t)] = result + gt_bootstrap_info[(g, t)] = result.get("_bootstrap_info", {}) + + if len(gt_results) == 0: + raise ValueError("No valid (g,t) cells computed.") + + # 4. Aggregate + post_gt = { + (g, t): r + for (g, t), r in gt_results.items() + if t >= g - self.anticipation + } + + # Compute cell weights: group-proportional (matching R's contdid convention). + # Each group g gets weight proportional to its number of treated units. + # Within each group, weight is divided equally among post-treatment cells. + group_n_treated = {} + group_n_post_cells = {} + for (g, t), r in post_gt.items(): + if g not in group_n_treated: + group_n_treated[g] = float(r["n_treated"]) + group_n_post_cells[g] = 0 + group_n_post_cells[g] += 1 + + total_treated = sum(group_n_treated.values()) + cell_weights = {} + if total_treated > 0: + for (g, t), r in post_gt.items(): + pg = group_n_treated[g] / total_treated + cell_weights[(g, t)] = pg / group_n_post_cells[g] + + # Dose-response aggregation + n_grid = len(dvals) + agg_att_d = np.zeros(n_grid) + agg_acrt_d = np.zeros(n_grid) + overall_att = 0.0 + overall_acrt = 0.0 + + for gt, w in cell_weights.items(): + r = post_gt[gt] + agg_att_d += w * r["att_d"] + agg_acrt_d += w * r["acrt_d"] + overall_att += w * r["att_glob"] + overall_acrt += w * r["acrt_glob"] + + # Event study aggregation (binarized) + event_study_effects = None + if aggregate == "eventstudy": + event_study_effects = self._aggregate_event_study( + gt_results, treatment_groups + ) + + # 5. Bootstrap + att_d_se = np.full(n_grid, np.nan) + att_d_ci_lower = np.full(n_grid, np.nan) + att_d_ci_upper = np.full(n_grid, np.nan) + acrt_d_se = np.full(n_grid, np.nan) + acrt_d_ci_lower = np.full(n_grid, np.nan) + acrt_d_ci_upper = np.full(n_grid, np.nan) + overall_att_se = np.nan + overall_att_t = np.nan + overall_att_p = np.nan + overall_att_ci = (np.nan, np.nan) + overall_acrt_se = np.nan + overall_acrt_t = np.nan + overall_acrt_p = np.nan + overall_acrt_ci = (np.nan, np.nan) + + if self.n_bootstrap > 0: + boot_result = self._run_bootstrap( + precomp, gt_results, gt_bootstrap_info, post_gt, cell_weights, + knots, degree, dvals, overall_att, overall_acrt, + agg_att_d, agg_acrt_d, + event_study_effects, + ) + att_d_se = boot_result["att_d_se"] + att_d_ci_lower = boot_result["att_d_ci_lower"] + att_d_ci_upper = boot_result["att_d_ci_upper"] + acrt_d_se = boot_result["acrt_d_se"] + acrt_d_ci_lower = boot_result["acrt_d_ci_lower"] + acrt_d_ci_upper = boot_result["acrt_d_ci_upper"] + overall_att_se = boot_result["overall_att_se"] + overall_att_t, overall_att_p, overall_att_ci = safe_inference( + overall_att, overall_att_se, self.alpha + ) + overall_acrt_se = boot_result["overall_acrt_se"] + overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference( + overall_acrt, overall_acrt_se, self.alpha + ) + if event_study_effects is not None: + for e, info in event_study_effects.items(): + if e in boot_result.get("es_se", {}): + info["se"] = boot_result["es_se"][e] + info["t_stat"], info["p_value"], info["conf_int"] = ( + safe_inference(info["effect"], info["se"], self.alpha) + ) + else: + # Analytical SEs via influence functions + analytic = self._compute_analytical_se( + precomp, gt_results, gt_bootstrap_info, post_gt, cell_weights, + knots, degree, dvals, agg_att_d, agg_acrt_d, + ) + att_d_se = analytic["att_d_se"] + acrt_d_se = analytic["acrt_d_se"] + overall_att_se = analytic["overall_att_se"] + overall_acrt_se = analytic["overall_acrt_se"] + + overall_att_t, overall_att_p, overall_att_ci = safe_inference( + overall_att, overall_att_se, self.alpha + ) + overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference( + overall_acrt, overall_acrt_se, self.alpha + ) + + # Per-grid-point inference for dose-response + for idx in range(n_grid): + _, _, ci = safe_inference( + agg_att_d[idx], att_d_se[idx], self.alpha + ) + att_d_ci_lower[idx] = ci[0] + att_d_ci_upper[idx] = ci[1] + + _, _, ci = safe_inference( + agg_acrt_d[idx], acrt_d_se[idx], self.alpha + ) + acrt_d_ci_lower[idx] = ci[0] + acrt_d_ci_upper[idx] = ci[1] + + # 6. Assemble results + dose_response_att = DoseResponseCurve( + dose_grid=dvals, + effects=agg_att_d, + se=att_d_se, + conf_int_lower=att_d_ci_lower, + conf_int_upper=att_d_ci_upper, + target="att", + ) + dose_response_acrt = DoseResponseCurve( + dose_grid=dvals, + effects=agg_acrt_d, + se=acrt_d_se, + conf_int_lower=acrt_d_ci_lower, + conf_int_upper=acrt_d_ci_upper, + target="acrt", + ) + + # Strip bootstrap internals from gt_results + clean_gt = {} + for gt, r in gt_results.items(): + clean_gt[gt] = { + k: v for k, v in r.items() if not k.startswith("_") + } + + return ContinuousDiDResults( + dose_response_att=dose_response_att, + dose_response_acrt=dose_response_acrt, + overall_att=overall_att, + overall_att_se=overall_att_se, + overall_att_t_stat=overall_att_t, + overall_att_p_value=overall_att_p, + overall_att_conf_int=overall_att_ci, + overall_acrt=overall_acrt, + overall_acrt_se=overall_acrt_se, + overall_acrt_t_stat=overall_acrt_t, + overall_acrt_p_value=overall_acrt_p, + overall_acrt_conf_int=overall_acrt_ci, + group_time_effects=clean_gt, + dose_grid=dvals, + groups=treatment_groups, + time_periods=time_periods, + n_obs=len(df), + n_treated_units=int((unit_cohort > 0).sum()), + n_control_units=n_control, + alpha=self.alpha, + control_group=self.control_group, + degree=self.degree, + num_knots=self.num_knots, + event_study_effects=event_study_effects, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _precompute_structures( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + dose: str, + time_periods: List[Any], + ) -> Dict[str, Any]: + """Pivot to wide format and build lookup structures.""" + all_units = sorted(df[unit].unique()) + unit_to_idx = {u: i for i, u in enumerate(all_units)} + n_units = len(all_units) + n_periods = len(time_periods) + period_to_col = {t: j for j, t in enumerate(time_periods)} + + # Outcome matrix: (n_units, n_periods) + outcome_matrix = np.full((n_units, n_periods), np.nan) + for _, row in df.iterrows(): + i = unit_to_idx[row[unit]] + j = period_to_col[row[time]] + outcome_matrix[i, j] = row[outcome] + + # Per-unit cohort and dose + unit_cohorts = np.zeros(n_units, dtype=float) + dose_vector = np.zeros(n_units, dtype=float) + unit_first = df.groupby(unit).first() + for u in all_units: + i = unit_to_idx[u] + unit_cohorts[i] = unit_first.loc[u, first_treat] + dose_vector[i] = unit_first.loc[u, dose] + + # Cohort masks + cohort_masks = {} + unique_cohorts = np.unique(unit_cohorts) + for c in unique_cohorts: + cohort_masks[c] = unit_cohorts == c + + never_treated_mask = unit_cohorts == 0 + + return { + "all_units": all_units, + "unit_to_idx": unit_to_idx, + "outcome_matrix": outcome_matrix, + "period_to_col": period_to_col, + "unit_cohorts": unit_cohorts, + "dose_vector": dose_vector, + "cohort_masks": cohort_masks, + "never_treated_mask": never_treated_mask, + "time_periods": time_periods, + "n_units": n_units, + } + + def _compute_dose_response_gt( + self, + precomp: Dict[str, Any], + g: Any, + t: Any, + knots: np.ndarray, + degree: int, + dvals: np.ndarray, + ) -> Optional[Dict[str, Any]]: + """Compute dose-response for a single (g,t) cell.""" + period_to_col = precomp["period_to_col"] + outcome_matrix = precomp["outcome_matrix"] + unit_cohorts = precomp["unit_cohorts"] + dose_vector = precomp["dose_vector"] + never_treated_mask = precomp["never_treated_mask"] + time_periods = precomp["time_periods"] + + # Base period selection + is_post = t >= g - self.anticipation + if self.base_period == "varying": + if is_post: + base_t = g - 1 - self.anticipation + else: + # Pre-treatment: use t-1 + t_idx = time_periods.index(t) + if t_idx == 0: + return None # No prior period + base_t = time_periods[t_idx - 1] + else: + # Universal base period + base_t = g - 1 - self.anticipation + + if base_t not in period_to_col or t not in period_to_col: + return None + + col_t = period_to_col[t] + col_base = period_to_col[base_t] + + # Treated units: first_treat == g and dose > 0 + treated_mask = (unit_cohorts == g) & (dose_vector > 0) + n_treated = int(np.sum(treated_mask)) + if n_treated == 0: + return None + + # Control units + if self.control_group == "never_treated": + control_mask = never_treated_mask + else: + # Not-yet-treated: never-treated + first_treat > t + control_mask = never_treated_mask | (unit_cohorts > t) + n_control = int(np.sum(control_mask)) + if n_control == 0: + warnings.warn( + f"No control units for (g={g}, t={t}). Skipping.", + UserWarning, + stacklevel=3, + ) + return None + + # Outcome changes + delta_y_treated = outcome_matrix[treated_mask, col_t] - outcome_matrix[treated_mask, col_base] + delta_y_control = outcome_matrix[control_mask, col_t] - outcome_matrix[control_mask, col_base] + + # Control counterfactual + mu_0 = float(np.mean(delta_y_control)) + + # Demean + delta_tilde_y = delta_y_treated - mu_0 + + # Treated doses + treated_doses = dose_vector[treated_mask] + + # B-spline OLS + Psi = bspline_design_matrix(treated_doses, knots, degree, include_intercept=True) + n_basis = Psi.shape[1] + + # Check for all-same dose + if np.all(treated_doses == treated_doses[0]): + warnings.warn( + f"All treated doses identical in (g={g}, t={t}). " + "ACRT(d) will be 0 everywhere.", + UserWarning, + stacklevel=3, + ) + + # Skip if not enough treated units for OLS (need n > K for residual df) + if n_treated <= n_basis: + warnings.warn( + f"Not enough treated units ({n_treated}) for {n_basis} basis functions " + f"in (g={g}, t={t}). Skipping cell.", + UserWarning, + stacklevel=3, + ) + return None + + # OLS regression + beta_hat, residuals, vcov = solve_ols( + Psi, delta_tilde_y, + return_vcov=True, + rank_deficient_action=self.rank_deficient_action, + ) + + # Evaluate ATT(d) and ACRT(d) at dvals + Psi_eval = bspline_design_matrix(dvals, knots, degree, include_intercept=True) + dPsi_eval = bspline_derivative_design_matrix(dvals, knots, degree, include_intercept=True) + + att_d = Psi_eval @ beta_hat + acrt_d = dPsi_eval @ beta_hat + + # Summary parameters + att_glob = float(np.mean(delta_y_treated) - mu_0) + + # ACRT^{glob}: plug-in average of ACRT(D_i) for treated + dPsi_treated = bspline_derivative_design_matrix( + treated_doses, knots, degree, include_intercept=True + ) + acrt_glob = float(np.mean(dPsi_treated @ beta_hat)) + + # Store bootstrap info for influence function computation + # bread = (Psi'Psi / n_treated)^{-1} + PtP = Psi.T @ Psi + try: + bread = np.linalg.inv(PtP / n_treated) + except np.linalg.LinAlgError: + bread = np.linalg.pinv(PtP / n_treated) + + # ee_treated: per-unit estimating equation vectors (K-vector per unit) + ee_treated = Psi * residuals[:, np.newaxis] # (n_treated, K) + + # ee_control: per-unit deviation from control mean + ee_control = delta_y_control - mu_0 # (n_control,) + + # psi_bar: mean basis vector for treated + psi_bar = np.mean(Psi, axis=0) # (K,) + + # Unit indices for bootstrap + treated_indices = np.where(treated_mask)[0] + control_indices = np.where(control_mask)[0] + + bootstrap_info = { + "bread": bread, + "ee_treated": ee_treated, + "ee_control": ee_control, + "psi_bar": psi_bar, + "beta_hat": beta_hat, + "treated_indices": treated_indices, + "control_indices": control_indices, + "n_treated": n_treated, + "n_control": n_control, + "Psi_eval": Psi_eval, + "dPsi_eval": dPsi_eval, + "dPsi_treated": dPsi_treated, + "delta_y_treated": delta_y_treated, + "delta_y_control": delta_y_control, + "mu_0": mu_0, + "att_glob": att_glob, + } + + return { + "att_d": att_d, + "acrt_d": acrt_d, + "att_glob": att_glob, + "acrt_glob": acrt_glob, + "beta_hat": beta_hat, + "n_treated": n_treated, + "n_control": n_control, + "_bootstrap_info": bootstrap_info, + } + + def _aggregate_event_study( + self, + gt_results: Dict[Tuple, Dict], + treatment_groups: List[Any], + ) -> Dict[int, Dict[str, Any]]: + """Aggregate binarized ATT_glob by relative period.""" + effects_by_e: Dict[int, List[Tuple[float, float]]] = {} + + for (g, t), r in gt_results.items(): + e = t - g + if e not in effects_by_e: + effects_by_e[e] = [] + effects_by_e[e].append((r["att_glob"], float(r["n_treated"]))) + + result = {} + for e, entries in sorted(effects_by_e.items()): + effects = np.array([x[0] for x in entries]) + weights = np.array([x[1] for x in entries]) + if np.sum(weights) > 0: + w = weights / np.sum(weights) + agg = float(np.sum(w * effects)) + else: + agg = np.nan + result[e] = { + "effect": agg, + "se": np.nan, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + } + return result + + def _compute_analytical_se( + self, + precomp: Dict[str, Any], + gt_results: Dict[Tuple, Dict], + gt_bootstrap_info: Dict[Tuple, Dict], + post_gt: Dict[Tuple, Dict], + cell_weights: Dict[Tuple, float], + knots: np.ndarray, + degree: int, + dvals: np.ndarray, + agg_att_d: np.ndarray, + agg_acrt_d: np.ndarray, + ) -> Dict[str, Any]: + """Compute analytical SEs using influence functions.""" + n_units = precomp["n_units"] + n_grid = len(dvals) + + # Build per-unit influence functions for aggregated parameters + # IF_i for overall ATT_glob (binarized) + if_att_glob = np.zeros(n_units) + if_acrt_glob = np.zeros(n_units) + if_att_d = np.zeros((n_units, n_grid)) + if_acrt_d = np.zeros((n_units, n_grid)) + + for gt, w in cell_weights.items(): + if w == 0: + continue + info = gt_bootstrap_info[gt] + if not info: + continue + treated_idx = info["treated_indices"] + control_idx = info["control_indices"] + n_t = info["n_treated"] + n_c = info["n_control"] + bread = info["bread"] + ee_treated = info["ee_treated"] + ee_control = info["ee_control"] + psi_bar = info["psi_bar"] + Psi_eval = info["Psi_eval"] + dPsi_eval = info["dPsi_eval"] + dPsi_treated = info["dPsi_treated"] + att_glob_gt = info["att_glob"] + mu_0 = info["mu_0"] + delta_y_treated = info["delta_y_treated"] + + n_total = n_t + n_c + p_1 = n_t / n_total + p_0 = n_c / n_total + + # IF for ATT_glob (binarized DiD) + for k, idx in enumerate(treated_idx): + if_att_glob[idx] += w * (delta_y_treated[k] - att_glob_gt - mu_0) / p_1 / n_total + for k, idx in enumerate(control_idx): + if_att_glob[idx] -= w * ee_control[k] / p_0 / n_total + + # IF for beta perturbation → ATT(d) and ACRT(d) + # beta perturbation from treated: bread @ (1/n_t) * sum w_i * ee_treated_i + # beta perturbation from control: -bread @ psi_bar * (1/n_c) * sum w_i * ee_control_i + # ATT_b(d) = Psi_eval @ beta_b => IF_i(d) contribution + + # Treated unit contributions to beta + for k, idx in enumerate(treated_idx): + beta_pert = bread @ ee_treated[k] / n_t + if_att_d[idx] += w * (Psi_eval @ beta_pert) + if_acrt_d[idx] += w * (dPsi_eval @ beta_pert) + + # Control unit contributions to beta (through mu_0) + for k, idx in enumerate(control_idx): + beta_pert = -bread @ psi_bar * ee_control[k] / n_c + if_att_d[idx] += w * (Psi_eval @ beta_pert) + if_acrt_d[idx] += w * (dPsi_eval @ beta_pert) + + # ACRT_glob IF: (1/n_t) sum_j dpsi(D_j)' @ beta_pert + dpsi_bar = np.mean(dPsi_treated, axis=0) + for k, idx in enumerate(treated_idx): + beta_pert = bread @ ee_treated[k] / n_t + if_acrt_glob[idx] += w * float(dpsi_bar @ beta_pert) + for k, idx in enumerate(control_idx): + beta_pert = -bread @ psi_bar * ee_control[k] / n_c + if_acrt_glob[idx] += w * float(dpsi_bar @ beta_pert) + + # SE = sqrt(mean(IF_i^2)) + overall_att_se = float(np.sqrt(np.mean(if_att_glob**2))) + overall_acrt_se = float(np.sqrt(np.mean(if_acrt_glob**2))) + + att_d_se = np.sqrt(np.mean(if_att_d**2, axis=0)) + acrt_d_se = np.sqrt(np.mean(if_acrt_d**2, axis=0)) + + return { + "overall_att_se": overall_att_se, + "overall_acrt_se": overall_acrt_se, + "att_d_se": att_d_se, + "acrt_d_se": acrt_d_se, + } + + def _run_bootstrap( + self, + precomp: Dict[str, Any], + gt_results: Dict[Tuple, Dict], + gt_bootstrap_info: Dict[Tuple, Dict], + post_gt: Dict[Tuple, Dict], + cell_weights: Dict[Tuple, float], + knots: np.ndarray, + degree: int, + dvals: np.ndarray, + original_att: float, + original_acrt: float, + original_att_d: np.ndarray, + original_acrt_d: np.ndarray, + event_study_effects: Optional[Dict[int, Dict]], + ) -> Dict[str, Any]: + """Run multiplier bootstrap inference.""" + if self.n_bootstrap < 50: + warnings.warn( + f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 " + "for reliable inference.", + UserWarning, + stacklevel=3, + ) + + rng = np.random.default_rng(self.seed) + n_units = precomp["n_units"] + n_grid = len(dvals) + + # Generate all weights upfront + all_weights = generate_bootstrap_weights_batch( + self.n_bootstrap, n_units, self.bootstrap_weights, rng + ) + + boot_att_glob = np.zeros(self.n_bootstrap) + boot_acrt_glob = np.zeros(self.n_bootstrap) + boot_att_d = np.zeros((self.n_bootstrap, n_grid)) + boot_acrt_d = np.zeros((self.n_bootstrap, n_grid)) + + # Event study bootstrap — compute weights per event-time bin + es_keys = sorted(event_study_effects.keys()) if event_study_effects else [] + boot_es = {e: np.zeros(self.n_bootstrap) for e in es_keys} + # Per-(g,t) weight within event-time bin + es_cell_weights: Dict[Tuple, float] = {} + if event_study_effects is not None: + # Build event-time bin weights from n_treated + from collections import defaultdict + es_bin_total: Dict[int, float] = defaultdict(float) + for gt, r in gt_results.items(): + g_val, t_val = gt + e = t_val - g_val + es_bin_total[e] += float(r["n_treated"]) + for gt, r in gt_results.items(): + g_val, t_val = gt + e = t_val - g_val + if es_bin_total[e] > 0: + es_cell_weights[gt] = float(r["n_treated"]) / es_bin_total[e] + + # Helper to bootstrap a single (g,t) cell + def _bootstrap_gt_cell(gt, info): + """Returns att_glob_b array (B,) for this cell.""" + treated_idx = info["treated_indices"] + control_idx = info["control_indices"] + n_t = info["n_treated"] + n_c = info["n_control"] + bread = info["bread"] + ee_treated = info["ee_treated"] + ee_control = info["ee_control"] + psi_bar = info["psi_bar"] + beta_hat = info["beta_hat"] + Psi_eval = info["Psi_eval"] + dPsi_eval = info["dPsi_eval"] + dPsi_treated = info["dPsi_treated"] + delta_y_treated = info["delta_y_treated"] + mu_0 = info["mu_0"] + att_glob_gt = info["att_glob"] + + w_treated = all_weights[:, treated_idx] + w_control = all_weights[:, control_idx] + + with np.errstate(divide='ignore', invalid='ignore', over='ignore'): + treated_sum = w_treated @ ee_treated / n_t + control_sum = (w_control @ ee_control) / n_c + psi_bar_outer = psi_bar[np.newaxis, :] + + delta_beta = (treated_sum - control_sum[:, np.newaxis] * psi_bar_outer) @ bread.T + beta_b = beta_hat[np.newaxis, :] + delta_beta + + att_d_b = beta_b @ Psi_eval.T + acrt_d_b = beta_b @ dPsi_eval.T + + mu_0_pert = (w_control @ ee_control) / n_c + mean_dy_treated_pert = (w_treated @ (delta_y_treated - att_glob_gt - mu_0)) / n_t + att_glob_b = att_glob_gt + mean_dy_treated_pert - mu_0_pert + + dpsi_mean = np.mean(dPsi_treated, axis=0) + acrt_glob_b = delta_beta @ dpsi_mean + + return att_d_b, acrt_d_b, att_glob_b, acrt_glob_b, info.get("acrt_glob", 0.0) + + # Iterate over post-treatment cells for dose-response/overall aggregation + for gt, w in cell_weights.items(): + if w == 0: + continue + info = gt_bootstrap_info[gt] + if not info: + continue + + att_d_b, acrt_d_b, att_glob_b, acrt_glob_b, acrt_glob_pt = _bootstrap_gt_cell(gt, info) + + boot_att_d += w * att_d_b + boot_acrt_d += w * acrt_d_b + boot_att_glob += w * att_glob_b + boot_acrt_glob += w * (acrt_glob_pt + acrt_glob_b) + + # Event study bootstrap — iterate over ALL (g,t) cells + if event_study_effects is not None: + for gt, r in gt_results.items(): + info = gt_bootstrap_info[gt] + if not info: + continue + g_val, t_val = gt + e = t_val - g_val + if e not in boot_es: + continue + es_w = es_cell_weights.get(gt, 0.0) + if es_w == 0: + continue + _, _, att_glob_b, _, _ = _bootstrap_gt_cell(gt, info) + boot_es[e] += es_w * att_glob_b + + # Compute statistics + result: Dict[str, Any] = {} + + # Per-grid-point + att_d_se = np.full(n_grid, np.nan) + att_d_ci_lower = np.full(n_grid, np.nan) + att_d_ci_upper = np.full(n_grid, np.nan) + acrt_d_se = np.full(n_grid, np.nan) + acrt_d_ci_lower = np.full(n_grid, np.nan) + acrt_d_ci_upper = np.full(n_grid, np.nan) + + for idx in range(n_grid): + se, ci, _ = compute_effect_bootstrap_stats( + original_att_d[idx], boot_att_d[:, idx], + alpha=self.alpha, context=f"ATT(d) at grid point {idx}", + ) + att_d_se[idx] = se + att_d_ci_lower[idx] = ci[0] + att_d_ci_upper[idx] = ci[1] + + se, ci, _ = compute_effect_bootstrap_stats( + original_acrt_d[idx], boot_acrt_d[:, idx], + alpha=self.alpha, context=f"ACRT(d) at grid point {idx}", + ) + acrt_d_se[idx] = se + acrt_d_ci_lower[idx] = ci[0] + acrt_d_ci_upper[idx] = ci[1] + + result["att_d_se"] = att_d_se + result["att_d_ci_lower"] = att_d_ci_lower + result["att_d_ci_upper"] = att_d_ci_upper + result["acrt_d_se"] = acrt_d_se + result["acrt_d_ci_lower"] = acrt_d_ci_lower + result["acrt_d_ci_upper"] = acrt_d_ci_upper + + # Overall + se, _, _ = compute_effect_bootstrap_stats( + original_att, boot_att_glob, alpha=self.alpha, context="overall ATT_glob", + ) + result["overall_att_se"] = se + + se, _, _ = compute_effect_bootstrap_stats( + original_acrt, boot_acrt_glob, alpha=self.alpha, context="overall ACRT_glob", + ) + result["overall_acrt_se"] = se + + # Event study SEs + if event_study_effects is not None: + es_se = {} + for e in es_keys: + se_e, _, _ = compute_effect_bootstrap_stats( + event_study_effects[e]["effect"], boot_es[e], + alpha=self.alpha, context=f"event study e={e}", + ) + es_se[e] = se_e + result["es_se"] = es_se + + return result diff --git a/diff_diff/continuous_did_bspline.py b/diff_diff/continuous_did_bspline.py new file mode 100644 index 0000000..c5a13b2 --- /dev/null +++ b/diff_diff/continuous_did_bspline.py @@ -0,0 +1,188 @@ +""" +B-spline utilities for continuous Difference-in-Differences estimation. + +Provides basis construction, evaluation, and derivative computation for +the dose-response curve estimation in ContinuousDiD. +""" + +import numpy as np +from scipy.interpolate import BSpline + +__all__ = [ + "build_bspline_basis", + "bspline_design_matrix", + "bspline_derivative_design_matrix", + "default_dose_grid", +] + + +def build_bspline_basis(dose, degree=3, num_knots=0): + """ + Construct B-spline knot vector from positive dose values. + + Interior knots are placed at quantiles of the dose distribution, + matching R's ``choose_knots_quantile`` convention. + + Parameters + ---------- + dose : array-like + Positive dose values from treated units. + degree : int, default=3 + Degree of the B-spline (3 = cubic). + num_knots : int, default=0 + Number of interior knots. + + Returns + ------- + knots : np.ndarray + Full knot vector with boundary clamping. + degree : int + The B-spline degree (echoed back for convenience). + """ + dose = np.asarray(dose, dtype=float) + d_L = float(np.min(dose)) + d_U = float(np.max(dose)) + + if num_knots > 0: + # Interior knots at evenly-spaced quantiles of dose distribution + probs = np.linspace(0, 1, num_knots + 2)[1:-1] + interior_knots = np.quantile(dose, probs) + else: + interior_knots = np.array([]) + + # Full knot vector: clamped at boundaries + knots = np.concatenate([ + np.repeat(d_L, degree + 1), + interior_knots, + np.repeat(d_U, degree + 1), + ]) + + return knots, degree + + +def bspline_design_matrix(x, knots, degree, include_intercept=True): + """ + Evaluate B-spline basis functions at points ``x``. + + To match R's ``splines2::bSpline(intercept=FALSE)`` plus an explicit + intercept column: drop the first B-spline column and prepend a + column of ones. + + Parameters + ---------- + x : array-like + Evaluation points, shape ``(n,)``. + knots : np.ndarray + Full knot vector (from :func:`build_bspline_basis`). + degree : int + B-spline degree. + include_intercept : bool, default=True + If True, drop first B-spline column and prepend intercept column. + + Returns + ------- + np.ndarray + Design matrix, shape ``(n, n_cols)``. + """ + x = np.asarray(x, dtype=float) + + # scipy requires evaluation within [knots[degree], knots[-(degree+1)]] + # Clamp to boundary knots to avoid extrapolation issues + t_min = knots[degree] + t_max = knots[-(degree + 1)] + x_clamped = np.clip(x, t_min, t_max) + + # Sparse design matrix from scipy, convert to dense + B = BSpline.design_matrix(x_clamped, knots, degree).toarray() + + if include_intercept: + # Drop first B-spline column, prepend intercept + B = np.column_stack([np.ones(len(x)), B[:, 1:]]) + + return B + + +def bspline_derivative_design_matrix(x, knots, degree, include_intercept=True): + """ + Evaluate first derivatives of B-spline basis functions at points ``x``. + + Parameters + ---------- + x : array-like + Evaluation points, shape ``(n,)``. + knots : np.ndarray + Full knot vector. + degree : int + B-spline degree. + include_intercept : bool, default=True + If True, drop derivative of first B-spline (replaced by intercept + whose derivative is 0) and prepend a zeros column. + + Returns + ------- + np.ndarray + Derivative design matrix, shape ``(n, n_cols)``. + """ + x = np.asarray(x, dtype=float) + + # Number of basis functions + n_basis = len(knots) - degree - 1 + + # Clamp evaluation points to boundary + t_min = knots[degree] + t_max = knots[-(degree + 1)] + x_clamped = np.clip(x, t_min, t_max) + + # Build derivative for each basis function + dB = np.zeros((len(x), n_basis)) + + # Check if knot vector is degenerate (all identical, e.g. single dose) + if knots[0] == knots[-1]: + # All knots identical: derivatives are all zero + pass + else: + for j in range(n_basis): + c = np.zeros(n_basis) + c[j] = 1.0 + try: + spline_j = BSpline(knots, c, degree) + deriv_j = spline_j.derivative() + dB[:, j] = deriv_j(x_clamped) + except ValueError: + # Degenerate knot vector: derivative is zero + pass + + if include_intercept: + # Drop first column (intercept derivative = 0), prepend zeros + dB = np.column_stack([np.zeros(len(x)), dB[:, 1:]]) + + return dB + + +def default_dose_grid(dose, lower_quantile=0.10, upper_quantile=0.99): + """ + Compute a quantile-based evaluation grid from positive dose values. + + Matches R's default: ``quantile(dose[dose > 0], probs=seq(0.10, 0.99, 0.01))``, + producing 90 evaluation points. + + Parameters + ---------- + dose : array-like + Dose values (only positive values are used). + lower_quantile : float, default=0.10 + Lower quantile bound. + upper_quantile : float, default=0.99 + Upper quantile bound. + + Returns + ------- + np.ndarray + Dose evaluation grid. + """ + dose = np.asarray(dose, dtype=float) + positive_dose = dose[dose > 0] + if len(positive_dose) == 0: + return np.array([]) + probs = np.arange(lower_quantile, upper_quantile + 0.005, 0.01) + return np.quantile(positive_dose, probs) diff --git a/diff_diff/continuous_did_results.py b/diff_diff/continuous_did_results.py new file mode 100644 index 0000000..f3aadd0 --- /dev/null +++ b/diff_diff/continuous_did_results.py @@ -0,0 +1,326 @@ +""" +Result container classes for Continuous Difference-in-Differences estimator. + +Provides dataclass containers for dose-response curves, group-time effects, +and aggregated estimation results. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.results import _get_significance_stars + +__all__ = ["ContinuousDiDResults", "DoseResponseCurve"] + + +@dataclass +class DoseResponseCurve: + """ + Dose-response curve from continuous DiD estimation. + + Attributes + ---------- + dose_grid : np.ndarray + Evaluation points, shape ``(n_grid,)``. + effects : np.ndarray + ATT(d) or ACRT(d) values, shape ``(n_grid,)``. + se : np.ndarray + Standard errors, shape ``(n_grid,)``. + conf_int_lower : np.ndarray + Lower CI bounds, shape ``(n_grid,)``. + conf_int_upper : np.ndarray + Upper CI bounds, shape ``(n_grid,)``. + target : str + ``"att"`` or ``"acrt"``. + """ + + dose_grid: np.ndarray + effects: np.ndarray + se: np.ndarray + conf_int_lower: np.ndarray + conf_int_upper: np.ndarray + target: str + + def to_dataframe(self) -> pd.DataFrame: + """Convert to DataFrame with dose, effect, se, CI, t_stat, p_value.""" + t_stat = np.where( + (np.isfinite(self.se) & (self.se > 0)), + self.effects / self.se, + np.nan, + ) + from scipy import stats + + p_value = np.where( + np.isfinite(t_stat), 2 * (1 - stats.norm.cdf(np.abs(t_stat))), np.nan + ) + return pd.DataFrame( + { + "dose": self.dose_grid, + "effect": self.effects, + "se": self.se, + "conf_int_lower": self.conf_int_lower, + "conf_int_upper": self.conf_int_upper, + "t_stat": t_stat, + "p_value": p_value, + } + ) + + +@dataclass +class ContinuousDiDResults: + """ + Results from Continuous Difference-in-Differences estimation. + + Implements Callaway, Goodman-Bacon & Sant'Anna (2024). + + Attributes + ---------- + dose_response_att : DoseResponseCurve + ATT(d) dose-response curve. + dose_response_acrt : DoseResponseCurve + ACRT(d) dose-response curve. + overall_att : float + Binarized overall ATT^{glob}. + overall_acrt : float + Plug-in overall ACRT^{glob}. + group_time_effects : dict + Per (g,t) cell results. + """ + + dose_response_att: DoseResponseCurve + dose_response_acrt: DoseResponseCurve + overall_att: float + overall_att_se: float + overall_att_t_stat: float + overall_att_p_value: float + overall_att_conf_int: Tuple[float, float] + overall_acrt: float + overall_acrt_se: float + overall_acrt_t_stat: float + overall_acrt_p_value: float + overall_acrt_conf_int: Tuple[float, float] + group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] + dose_grid: np.ndarray + groups: List[Any] + time_periods: List[Any] + n_obs: int + n_treated_units: int + n_control_units: int + alpha: float = 0.05 + control_group: str = "never_treated" + degree: int = 3 + num_knots: int = 0 + event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None) + + def __repr__(self) -> str: + sig_att = _get_significance_stars(self.overall_att_p_value) + sig_acrt = _get_significance_stars(self.overall_acrt_p_value) + return ( + f"ContinuousDiDResults(" + f"ATT_glob={self.overall_att:.4f}{sig_att}, " + f"ACRT_glob={self.overall_acrt:.4f}{sig_acrt}, " + f"n_groups={len(self.groups)}, " + f"n_periods={len(self.time_periods)})" + ) + + def summary(self, alpha: Optional[float] = None) -> str: + """Generate formatted summary.""" + alpha = alpha or self.alpha + conf_level = int((1 - alpha) * 100) + w = 85 + + lines = [ + "=" * w, + "Continuous Difference-in-Differences Results".center(w), + "(Callaway, Goodman-Bacon & Sant'Anna 2024)".center(w), + "=" * w, + "", + f"{'Total observations:':<30} {self.n_obs:>10}", + f"{'Treated units:':<30} {self.n_treated_units:>10}", + f"{'Control units:':<30} {self.n_control_units:>10}", + f"{'Treatment cohorts:':<30} {len(self.groups):>10}", + f"{'Time periods:':<30} {len(self.time_periods):>10}", + f"{'Control group:':<30} {self.control_group:>10}", + f"{'B-spline degree:':<30} {self.degree:>10}", + f"{'Interior knots:':<30} {self.num_knots:>10}", + "", + ] + + # Overall summary parameters + lines.extend([ + "-" * w, + "Overall Summary Parameters".center(w), + "-" * w, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * w, + ]) + for label, est, se, t, p in [ + ( + "ATT_glob", + self.overall_att, + self.overall_att_se, + self.overall_att_t_stat, + self.overall_att_p_value, + ), + ( + "ACRT_glob", + self.overall_acrt, + self.overall_acrt_se, + self.overall_acrt_t_stat, + self.overall_acrt_p_value, + ), + ]: + t_str = f"{t:>10.3f}" if np.isfinite(t) else f"{'NaN':>10}" + p_str = f"{p:>10.4f}" if np.isfinite(p) else f"{'NaN':>10}" + sig = _get_significance_stars(p) + lines.append( + f"{label:<15} {est:>12.4f} {se:>12.4f} {t_str} {p_str} {sig:>6}" + ) + lines.extend([ + "-" * w, + "", + f"{conf_level}% CI for ATT_glob: " + f"[{self.overall_att_conf_int[0]:.4f}, {self.overall_att_conf_int[1]:.4f}]", + f"{conf_level}% CI for ACRT_glob: " + f"[{self.overall_acrt_conf_int[0]:.4f}, {self.overall_acrt_conf_int[1]:.4f}]", + "", + ]) + + # Dose-response curve summary (first/mid/last points) + if len(self.dose_grid) > 0: + lines.extend([ + "-" * w, + "Dose-Response Curve (selected points)".center(w), + "-" * w, + f"{'Dose':>10} {'ATT(d)':>12} {'SE':>10} " + f"{'ACRT(d)':>12} {'SE':>10}", + "-" * w, + ]) + n_grid = len(self.dose_grid) + indices = sorted(set([0, n_grid // 4, n_grid // 2, 3 * n_grid // 4, n_grid - 1])) + for idx in indices: + if idx < n_grid: + lines.append( + f"{self.dose_grid[idx]:>10.3f} " + f"{self.dose_response_att.effects[idx]:>12.4f} " + f"{self.dose_response_att.se[idx]:>10.4f} " + f"{self.dose_response_acrt.effects[idx]:>12.4f} " + f"{self.dose_response_acrt.se[idx]:>10.4f}" + ) + lines.extend(["-" * w, ""]) + + # Event study effects if available + if self.event_study_effects: + lines.extend([ + "-" * w, + "Event Study (Dynamic) Effects (Binarized ATT)".center(w), + "-" * w, + f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * w, + ]) + for rel_t in sorted(self.event_study_effects.keys()): + eff = self.event_study_effects[rel_t] + sig = _get_significance_stars(eff["p_value"]) + t_str = ( + f"{eff['t_stat']:>10.3f}" + if np.isfinite(eff["t_stat"]) + else f"{'NaN':>10}" + ) + p_str = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{t_str} {p_str} {sig:>6}" + ) + lines.extend(["-" * w, ""]) + + lines.extend([ + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * w, + ]) + return "\n".join(lines) + + def print_summary(self, alpha: Optional[float] = None) -> None: + """Print summary to stdout.""" + print(self.summary(alpha)) + + def to_dataframe(self, level: str = "dose_response") -> pd.DataFrame: + """ + Convert results to DataFrame. + + Parameters + ---------- + level : str, default="dose_response" + ``"dose_response"``, ``"group_time"``, or ``"event_study"``. + """ + if level == "dose_response": + att_df = self.dose_response_att.to_dataframe() + acrt_df = self.dose_response_acrt.to_dataframe() + return pd.DataFrame( + { + "dose": att_df["dose"], + "att": att_df["effect"], + "att_se": att_df["se"], + "att_ci_lower": att_df["conf_int_lower"], + "att_ci_upper": att_df["conf_int_upper"], + "acrt": acrt_df["effect"], + "acrt_se": acrt_df["se"], + "acrt_ci_lower": acrt_df["conf_int_lower"], + "acrt_ci_upper": acrt_df["conf_int_upper"], + } + ) + elif level == "group_time": + rows = [] + for (g, t), data in sorted(self.group_time_effects.items()): + rows.append( + { + "group": g, + "time": t, + "att_glob": data.get("att_glob", np.nan), + "acrt_glob": data.get("acrt_glob", np.nan), + "n_treated": data.get("n_treated", 0), + "n_control": data.get("n_control", 0), + } + ) + return pd.DataFrame(rows) + elif level == "event_study": + if self.event_study_effects is None: + raise ValueError( + "Event study effects not computed. Use aggregate='eventstudy'." + ) + rows = [] + for rel_t, data in sorted(self.event_study_effects.items()): + rows.append( + { + "relative_period": rel_t, + "att_glob": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + } + ) + return pd.DataFrame(rows) + else: + raise ValueError( + f"Unknown level: {level}. Use 'dose_response', 'group_time', or 'event_study'." + ) + + @property + def is_significant(self) -> bool: + """Check if overall ATT is significant.""" + return bool(self.overall_att_p_value < self.alpha) + + @property + def significance_stars(self) -> str: + """Significance stars for overall ATT.""" + return _get_significance_stars(self.overall_att_p_value) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 3f80cfe..7ff1eb8 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -18,6 +18,7 @@ # Re-export data generation functions from prep_dgp for backward compatibility from diff_diff.prep_dgp import ( + generate_continuous_did_data, generate_did_data, generate_staggered_data, generate_factor_data, diff --git a/diff_diff/prep_dgp.py b/diff_diff/prep_dgp.py index d8707aa..ed00d9b 100644 --- a/diff_diff/prep_dgp.py +++ b/diff_diff/prep_dgp.py @@ -6,7 +6,7 @@ factor model data, triple difference, and event study designs. """ -from typing import List, Optional +from typing import Dict, List, Optional import numpy as np import pandas as pd @@ -775,3 +775,158 @@ def generate_event_study_data( }) return pd.DataFrame(records) + + +def generate_continuous_did_data( + n_units: int = 500, + n_periods: int = 4, + cohort_periods: Optional[List[int]] = None, + never_treated_frac: float = 0.3, + dose_distribution: str = "lognormal", + dose_params: Optional[Dict] = None, + att_function: str = "linear", + att_slope: float = 2.0, + att_intercept: float = 1.0, + unit_fe_sd: float = 2.0, + time_trend: float = 0.5, + noise_sd: float = 1.0, + seed: Optional[int] = None, +) -> pd.DataFrame: + """ + Generate synthetic data for continuous DiD analysis with known dose-response. + + Creates a balanced panel with continuous treatment doses and known ATT(d) + function, satisfying strong parallel trends by construction. + + Parameters + ---------- + n_units : int, default=500 + Number of units in the panel. + n_periods : int, default=4 + Number of time periods (1-indexed). + cohort_periods : list of int, optional + Treatment cohort periods. Default: ``[2]`` (single cohort). + never_treated_frac : float, default=0.3 + Fraction of units that are never-treated. + dose_distribution : str, default="lognormal" + Distribution for dose: ``"lognormal"``, ``"uniform"``, ``"exponential"``. + dose_params : dict, optional + Distribution-specific parameters. Defaults: + lognormal: ``{"mean": 0.5, "sigma": 0.5}`` + uniform: ``{"low": 0.5, "high": 5.0}`` + exponential: ``{"scale": 2.0}`` + att_function : str, default="linear" + Functional form of ATT(d): ``"linear"``, ``"quadratic"``, ``"log"``. + att_slope : float, default=2.0 + Slope parameter for ATT function. + att_intercept : float, default=1.0 + Intercept parameter for ATT function. + unit_fe_sd : float, default=2.0 + Standard deviation of unit fixed effects. + time_trend : float, default=0.5 + Linear time trend coefficient. + noise_sd : float, default=1.0 + Standard deviation of idiosyncratic noise. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + pd.DataFrame + Panel data with columns: ``unit``, ``period``, ``outcome``, + ``first_treat``, ``dose``, ``true_att``. + """ + rng = np.random.default_rng(seed) + + if cohort_periods is None: + cohort_periods = [2] + + # Assign units to cohorts + n_never = int(n_units * never_treated_frac) + n_treated_total = n_units - n_never + n_per_cohort = n_treated_total // len(cohort_periods) + + cohort_assignments = np.zeros(n_units, dtype=int) + idx = 0 + for i, g in enumerate(cohort_periods): + n_this = n_per_cohort if i < len(cohort_periods) - 1 else n_treated_total - idx + cohort_assignments[n_never + idx: n_never + idx + n_this] = g + idx += n_this + + # Generate doses + default_params = { + "lognormal": {"mean": 0.5, "sigma": 0.5}, + "uniform": {"low": 0.5, "high": 5.0}, + "exponential": {"scale": 2.0}, + } + params = dose_params or default_params.get(dose_distribution, {}) + + dose_per_unit = np.zeros(n_units) + treated_mask = cohort_assignments > 0 + n_treated_actual = int(np.sum(treated_mask)) + + if dose_distribution == "lognormal": + dose_per_unit[treated_mask] = rng.lognormal( + mean=params.get("mean", 0.5), + sigma=params.get("sigma", 0.5), + size=n_treated_actual, + ) + elif dose_distribution == "uniform": + dose_per_unit[treated_mask] = rng.uniform( + low=params.get("low", 0.5), + high=params.get("high", 5.0), + size=n_treated_actual, + ) + elif dose_distribution == "exponential": + dose_per_unit[treated_mask] = rng.exponential( + scale=params.get("scale", 2.0), + size=n_treated_actual, + ) + else: + raise ValueError( + f"dose_distribution must be 'lognormal', 'uniform', or 'exponential', " + f"got '{dose_distribution}'" + ) + + # ATT function + def _att_func(d): + if att_function == "linear": + return att_intercept + att_slope * d + elif att_function == "quadratic": + return att_intercept + att_slope * d**2 + elif att_function == "log": + return att_intercept + att_slope * np.log1p(d) + else: + raise ValueError( + f"att_function must be 'linear', 'quadratic', or 'log', " + f"got '{att_function}'" + ) + + # Unit fixed effects + unit_fe = rng.normal(0, unit_fe_sd, size=n_units) + + # Build panel + periods = np.arange(1, n_periods + 1) + records = [] + for i in range(n_units): + g_i = cohort_assignments[i] + d_i = dose_per_unit[i] + for t in periods: + # Potential outcome without treatment + y0 = unit_fe[i] + time_trend * t + rng.normal(0, noise_sd) + # Treatment effect + if g_i > 0 and t >= g_i: + att_d = _att_func(d_i) + else: + att_d = 0.0 + + records.append({ + "unit": i, + "period": int(t), + "outcome": y0 + att_d, + "first_treat": int(g_i) if g_i > 0 else 0, + "dose": d_i, + "true_att": att_d, + }) + + return pd.DataFrame(records) diff --git a/diff_diff/staggered_bootstrap.py b/diff_diff/staggered_bootstrap.py index 4bcf235..0d09a17 100644 --- a/diff_diff/staggered_bootstrap.py +++ b/diff_diff/staggered_bootstrap.py @@ -1,173 +1,34 @@ """ Bootstrap inference for Callaway-Sant'Anna estimator. -This module provides bootstrap weight generation functions, the bootstrap -results container, and the mixin class with bootstrap inference methods. +This module provides the bootstrap results container and the mixin class +with bootstrap inference methods. Weight generation and statistical helpers +are in :mod:`diff_diff.bootstrap_utils`. """ import warnings from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import numpy as np -# Import Rust backend if available (from _backend to avoid circular imports) -from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights +from diff_diff.bootstrap_utils import ( + compute_bootstrap_pvalue as _compute_bootstrap_pvalue_func, +) +from diff_diff.bootstrap_utils import ( + compute_effect_bootstrap_stats as _compute_effect_bootstrap_stats_func, +) +from diff_diff.bootstrap_utils import ( + compute_percentile_ci as _compute_percentile_ci_func, +) +from diff_diff.bootstrap_utils import ( + generate_bootstrap_weights_batch as _generate_bootstrap_weights_batch, +) if TYPE_CHECKING: pass -# ============================================================================= -# Bootstrap Weight Generators -# ============================================================================= - - -def _generate_bootstrap_weights( - n_units: int, - weight_type: str, - rng: np.random.Generator, -) -> np.ndarray: - """ - Generate bootstrap weights for multiplier bootstrap. - - Parameters - ---------- - n_units : int - Number of units (clusters) to generate weights for. - weight_type : str - Type of weights: "rademacher", "mammen", or "webb". - rng : np.random.Generator - Random number generator. - - Returns - ------- - np.ndarray - Array of bootstrap weights with shape (n_units,). - """ - if weight_type == "rademacher": - # Rademacher: +1 or -1 with equal probability - return rng.choice([-1.0, 1.0], size=n_units) - - elif weight_type == "mammen": - # Mammen's two-point distribution - # E[v] = 0, E[v^2] = 1, E[v^3] = 1 - sqrt5 = np.sqrt(5) - val1 = -(sqrt5 - 1) / 2 # ≈ -0.618 - val2 = (sqrt5 + 1) / 2 # ≈ 1.618 (golden ratio) - p1 = (sqrt5 + 1) / (2 * sqrt5) # ≈ 0.724 - return rng.choice([val1, val2], size=n_units, p=[p1, 1 - p1]) - - elif weight_type == "webb": - # Webb's 6-point distribution (recommended for few clusters) - # Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each) - # This matches R's did package: E[w]=0, Var(w)=1.0 - values = np.array([ - -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2), - np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2) - ]) - return rng.choice(values, size=n_units) # Equal probs (1/6 each) - - else: - raise ValueError( - f"weight_type must be 'rademacher', 'mammen', or 'webb', " - f"got '{weight_type}'" - ) - - -def _generate_bootstrap_weights_batch( - n_bootstrap: int, - n_units: int, - weight_type: str, - rng: np.random.Generator, -) -> np.ndarray: - """ - Generate all bootstrap weights at once (vectorized). - - Parameters - ---------- - n_bootstrap : int - Number of bootstrap iterations. - n_units : int - Number of units (clusters) to generate weights for. - weight_type : str - Type of weights: "rademacher", "mammen", or "webb". - rng : np.random.Generator - Random number generator. - - Returns - ------- - np.ndarray - Array of bootstrap weights with shape (n_bootstrap, n_units). - """ - # Use Rust backend if available (parallel + fast RNG) - if HAS_RUST_BACKEND and _rust_bootstrap_weights is not None: - # Get seed from the NumPy RNG for reproducibility - seed = rng.integers(0, 2**63 - 1) - return _rust_bootstrap_weights(n_bootstrap, n_units, weight_type, seed) - - # Fallback to NumPy implementation - return _generate_bootstrap_weights_batch_numpy(n_bootstrap, n_units, weight_type, rng) - - -def _generate_bootstrap_weights_batch_numpy( - n_bootstrap: int, - n_units: int, - weight_type: str, - rng: np.random.Generator, -) -> np.ndarray: - """ - NumPy fallback implementation of _generate_bootstrap_weights_batch. - - Generates multiplier bootstrap weights for wild cluster bootstrap. - All weight distributions satisfy E[w] = 0, E[w^2] = 1. - - Parameters - ---------- - n_bootstrap : int - Number of bootstrap iterations. - n_units : int - Number of units (clusters) to generate weights for. - weight_type : str - Type of weights: "rademacher" (+-1), "mammen" (2-point), - or "webb" (6-point). - rng : np.random.Generator - Random number generator for reproducibility. - - Returns - ------- - np.ndarray - Array of bootstrap weights with shape (n_bootstrap, n_units). - """ - if weight_type == "rademacher": - # Rademacher: +1 or -1 with equal probability - return rng.choice([-1.0, 1.0], size=(n_bootstrap, n_units)) - - elif weight_type == "mammen": - # Mammen's two-point distribution - sqrt5 = np.sqrt(5) - val1 = -(sqrt5 - 1) / 2 - val2 = (sqrt5 + 1) / 2 - p1 = (sqrt5 + 1) / (2 * sqrt5) - return rng.choice([val1, val2], size=(n_bootstrap, n_units), p=[p1, 1 - p1]) - - elif weight_type == "webb": - # Webb's 6-point distribution - # Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each) - # This matches R's did package: E[w]=0, Var(w)=1.0 - values = np.array([ - -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2), - np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2) - ]) - return rng.choice(values, size=(n_bootstrap, n_units)) # Equal probs (1/6 each) - - else: - raise ValueError( - f"weight_type must be 'rademacher', 'mammen', or 'webb', " - f"got '{weight_type}'" - ) - - # ============================================================================= # Bootstrap Results Container # ============================================================================= @@ -633,9 +494,7 @@ def _compute_percentile_ci( alpha: float, ) -> Tuple[float, float]: """Compute percentile confidence interval from bootstrap distribution.""" - lower = float(np.percentile(boot_dist, alpha / 2 * 100)) - upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100)) - return (lower, upper) + return _compute_percentile_ci_func(boot_dist, alpha) def _compute_bootstrap_pvalue( self, @@ -646,41 +505,9 @@ def _compute_bootstrap_pvalue( """ Compute two-sided bootstrap p-value. - Uses the percentile method: p-value is the proportion of bootstrap - estimates on the opposite side of zero from the original estimate, - doubled for two-sided test. - - Parameters - ---------- - original_effect : float - Original point estimate. - boot_dist : np.ndarray - Bootstrap distribution of the effect. - n_valid : int, optional - Number of valid bootstrap samples. If None, uses self.n_bootstrap. - Use this when boot_dist has already been filtered for non-finite values - to ensure the p-value floor is based on the actual valid sample count. - - Returns - ------- - float - Two-sided bootstrap p-value. + Delegates to :func:`bootstrap_utils.compute_bootstrap_pvalue`. """ - if original_effect >= 0: - # Proportion of bootstrap estimates <= 0 - p_one_sided = np.mean(boot_dist <= 0) - else: - # Proportion of bootstrap estimates >= 0 - p_one_sided = np.mean(boot_dist >= 0) - - # Two-sided p-value - p_value = min(2 * p_one_sided, 1.0) - - # Ensure minimum p-value using n_valid if provided, otherwise n_bootstrap - n_for_floor = n_valid if n_valid is not None else self.n_bootstrap - p_value = max(p_value, 1 / (n_for_floor + 1)) - - return float(p_value) + return _compute_bootstrap_pvalue_func(original_effect, boot_dist, n_valid=n_valid) def _compute_effect_bootstrap_stats( self, @@ -691,63 +518,8 @@ def _compute_effect_bootstrap_stats( """ Compute bootstrap statistics for a single effect. - Non-finite bootstrap samples are dropped and a warning is issued if any - are present. If too few valid samples remain (<50%), returns NaN for all - statistics to signal invalid inference. - - Parameters - ---------- - original_effect : float - Original point estimate. - boot_dist : np.ndarray - Bootstrap distribution of the effect. - context : str, optional - Description for warning messages, by default "bootstrap distribution". - - Returns - ------- - se : float - Bootstrap standard error. - ci : Tuple[float, float] - Percentile confidence interval. - p_value : float - Bootstrap p-value. + Delegates to :func:`bootstrap_utils.compute_effect_bootstrap_stats`. """ - # Filter out non-finite values - finite_mask = np.isfinite(boot_dist) - n_valid = np.sum(finite_mask) - n_total = len(boot_dist) - - if n_valid < n_total: - import warnings - n_nonfinite = n_total - n_valid - warnings.warn( - f"Dropping {n_nonfinite}/{n_total} non-finite bootstrap samples in {context}. " - "This may occur with very small samples or extreme weights. " - "Bootstrap estimates based on remaining valid samples.", - RuntimeWarning, - stacklevel=3 - ) - - # Check if we have enough valid samples - if n_valid < n_total * 0.5: - import warnings - warnings.warn( - f"Too few valid bootstrap samples ({n_valid}/{n_total}) in {context}. " - "Returning NaN for SE/CI/p-value to signal invalid inference.", - RuntimeWarning, - stacklevel=3 - ) - return np.nan, (np.nan, np.nan), np.nan - - # Use only valid samples - valid_dist = boot_dist[finite_mask] - n_valid_bootstrap = len(valid_dist) - - se = float(np.std(valid_dist, ddof=1)) - ci = self._compute_percentile_ci(valid_dist, self.alpha) - - # Compute p-value using shared method with correct floor based on valid sample count - p_value = self._compute_bootstrap_pvalue(original_effect, valid_dist, n_valid=n_valid_bootstrap) - - return se, ci, p_value + return _compute_effect_bootstrap_stats_func( + original_effect, boot_dist, alpha=self.alpha, context=context + ) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index f13440e..06abfe0 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -10,6 +10,7 @@ This document provides the academic foundations and key implementation requireme - [TwoWayFixedEffects](#twowayfixedeffects) 2. [Modern Staggered Estimators](#modern-staggered-estimators) - [CallawaySantAnna](#callawaysantanna) + - [ContinuousDiD](#continuousdid) - [SunAbraham](#sunabraham) - [ImputationDiD](#imputationdid) - [TwoStageDiD](#twostagedid) @@ -392,6 +393,55 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: --- +## ContinuousDiD + +**Primary Source:** Callaway, Goodman-Bacon & Sant'Anna (2024), "Difference-in-Differences with a Continuous Treatment," NBER Working Paper 32117. + +**R Reference:** `contdid` v0.1.0 (CRAN). + +### Identification + +Under **Strong Parallel Trends** (SPT): for all doses d in D_+, +`E[Y_t(0) - Y_{t-1}(0) | D = d] = E[Y_t(0) - Y_{t-1}(0) | D = 0]`. + +This is stronger than standard PT because it conditions on specific dose values. + +### Key Equations + +**Target parameters:** +- `ATT(d) = E[Y_t(d) - Y_t(0) | D > 0]` — dose-response curve +- `ACRT(d) = dATT(d)/dd` — average causal response (marginal effect) +- `ATT^{glob} = E[Delta Y | D > 0] - E[Delta Y | D = 0]` — binarized ATT +- `ACRT^{glob} = E[ACRT(D_i) | D > 0]` — plug-in average marginal effect + +**Estimation via B-spline OLS:** +1. Compute `Delta_tilde_Y = (Y_t - Y_{t-1})_treated - mean((Y_t - Y_{t-1})_control)` +2. Build B-spline basis `Psi(D_i)` from treated doses +3. OLS: `beta = (Psi'Psi)^{-1} Psi' Delta_tilde_Y` +4. `ATT(d) = Psi(d)' beta`, `ACRT(d) = dPsi(d)/dd' beta` + +### Edge Cases + +- **No untreated group**: Remark 3.1 (lowest-dose-as-control) not implemented; requires P(D=0) > 0. +- **Discrete treatment**: Detect integer-valued dose and warn; saturated regression deferred. +- **All-same dose**: B-spline basis collapses; ACRT(d) = 0 everywhere. +- **Rank deficiency**: When n_treated <= n_basis, cell is skipped. +- **Balanced panel required**: Matches R `contdid` v0.1.0. + +### Implementation Checklist + +- [x] B-spline basis construction matching R's `splines2::bSpline` +- [x] Multi-period (g,t) cell iteration with base period selection +- [x] Dose-response and event-study aggregation with n_treated weights +- [x] Multiplier bootstrap for inference +- [x] Analytical SEs via influence functions +- [x] Equation verification tests (linear, quadratic, multi-period) +- [ ] Covariate support (deferred, matching R v0.1.0) +- [ ] Discrete treatment saturated regression +- [ ] Lowest-dose-as-control (Remark 3.1) + +--- + ## SunAbraham **Primary source:** [Sun, L., & Abraham, S. (2021). Estimating dynamic treatment effects in event studies with heterogeneous treatment effects. *Journal of Econometrics*, 225(2), 175-199.](https://doi.org/10.1016/j.jeconom.2020.09.006) @@ -1487,6 +1537,7 @@ should be a deliberate user choice. | SunAbraham | fixest | `sunab()` | | ImputationDiD | didimputation | `did_imputation()` | | TwoStageDiD | did2s | `did2s()` | +| ContinuousDiD | contdid | `cont_did()` | | SyntheticDiD | synthdid | `synthdid_estimate()` | | TripleDifference | triplediff | `ddd()` | | StackedDiD | stacked-did-weights | `create_sub_exp()` + `compute_weights()` | diff --git a/docs/methodology/continuous-did.md b/docs/methodology/continuous-did.md new file mode 100644 index 0000000..1c26cb5 --- /dev/null +++ b/docs/methodology/continuous-did.md @@ -0,0 +1,431 @@ +# Continuous Difference-in-Differences: Methodology Reference + +**Paper:** Callaway, Goodman-Bacon & Sant'Anna (2024/2025), "Difference-in-Differences +with a Continuous Treatment," NBER Working Paper 32117 (arXiv: 2107.02637v8). + +**R reference implementation:** `contdid` v0.1.0 (CRAN) + +--- + +## 1. Problem Statement + +In many DiD applications, treatment has a **dose** or intensity rather than binary on/off. +Examples: pollution exposure varying by distance, different minimum wage levels, varying +tax rates, different subsidy shares. + +Binary DiD cannot: +1. Handle settings where all units receive *some* treatment (no clean untreated group) +2. Estimate dose-response relationships +3. Identify effects of marginal changes in treatment intensity + +### What goes wrong with TWFE + +The standard TWFE regression for continuous DiD: + +``` +Y_{i,t} = theta_t + eta_i + beta^{twfe} * D_i * Post_t + v_{i,t} (Eq. 1.1) +``` + +beta^{twfe} suffers from multiple simultaneous problems: +- **Negative weights on levels**: Weights on ATT(l|l) integrate to *zero*, not one. + Below-average dose units get negative weight. +- **Selection bias under standard PT**: A contamination term persists even under + standard parallel trends. +- **Non-representative weighting under strong PT**: Even without selection bias, + weights don't match the dose density, making beta^{twfe} sensitive to the untreated + group share. +- **Scale dependence**: Rescaling dose (0-1 to 0-100) changes beta^{twfe} proportionally, + while natural target parameters are invariant. + +--- + +## 2. Setup and Notation + +### Two-period baseline case + +- Two periods: t=1 (pre), t=2 (post) +- In t=1, no unit is treated (all have dose 0) +- In t=2, some units receive dose D_i in D_+ subset (0, inf), rest remain at D_i = 0 + +### Key variables + +| Symbol | Definition | +|--------|-----------| +| Y_{i,t} | Outcome for unit i at time t | +| D_i | Dose (treatment intensity). D_i in D = {0} union D_+ | +| D_+ = (d_L, d_U) | Support of continuous dose among treated | +| Y_{i,t}(d) | Potential outcome under dose d | +| Delta Y | Y_{t=2} - Y_{t=1} | +| f_{D\|D>0}(d) | Dose density conditional on being treated | + +### Multi-period staggered notation + +| Symbol | Definition | +|--------|-----------| +| G_i | Timing group: period when unit i first treated | +| G = inf | Never-treated units | +| W_{i,t} = D_i * 1{t >= G_i} | Treatment exposure at time t | +| Y_{i,t}(g,d) | Potential outcome if first treated in period g with dose d | + +### What is a "dose"? What is a "group"? + +- **Dose**: The amount/intensity of treatment. Can be continuous or multi-valued discrete. + Crucially, dose is **time-invariant** for each unit (the amount doesn't change once assigned). +- **Group**: In two-period case, all units with the same dose d. In multi-period case, + characterized by both timing (G_i) and dose (D_i). + +--- + +## 3. Target Parameters / Estimands + +### 3.1 Level treatment effects + +**Local ATT** (effect of dose d on units who received dose d'): +``` +ATT(d|d') = E[Y_{t=2}(d) - Y_{t=2}(0) | D = d'] +``` + +When d' = d (the "own-dose" case): +``` +ATT(d|d) = E[Y_{t=2}(d) - Y_{t=2}(0) | D = d] +``` + +**Global ATT** (effect of dose d across all treated): +``` +ATT(d) = E[Y_{t=2}(d) - Y_{t=2}(0) | D > 0] +``` + +Key distinction: ATT(d|d) != ATT(d) when there is selection into dose groups on +treatment effects. Under standard PT, only ATT(d|d) is identified. Under strong PT, +ATT(d|d) = ATT(d). + +### 3.2 Average causal response (ACRT) + +For **continuous** treatment: +``` +ACRT(d|d') = d/dl ATT(l|d') |_{l=d} +ACRT(d) = d/dd ATT(d) +``` + +For **discrete** (multi-valued) treatment: +``` +ACRT(d_j|d_k) = [ATT(d_j|d_k) - ATT(d_{j-1}|d_k)] / (d_j - d_{j-1}) +``` + +Level effects = height of dose-response curve (switching from 0 to d). +Causal responses = slope of dose-response curve (marginal increase in d). +These coincide for binary treatment; they diverge for continuous treatment. + +### 3.3 Summary parameters + +Four natural scalar summaries: +``` +ATT^{loc} = E[ATT(D|D) | D > 0] -- local levels, weighted by dose density +ATT^{glob} = E[ATT(D) | D > 0] -- global levels, weighted by dose density +ACRT^{loc} = E[ACRT(D|D) | D > 0] -- local slopes, weighted by dose density +ACRT^{glob} = E[ACRT(D) | D > 0] -- global slopes, weighted by dose density +``` + +`loc` = each dose group's own curve. `glob` = single curve for all treated. + +### 3.4 Multi-period parameters + +Group-time-dose specific: +``` +ATT(g,t,d|g,d) = E[Y_t(g,d) - Y_t(0) | G=g, D=d] +``` + +Aggregated across groups/periods by dose: +``` +ATT^{dose}(d|d) = weighted average of ATT(g,t,d|g,d) across (g,t) cells +``` + +Event-study versions: +``` +ATT^{es}_{loc}(e) = E[ATT^{dose,es}(D|D,e) | G+e in [2,T], G <= T] +``` + +where e = event time (periods since treatment). + +--- + +## 4. Identifying Assumptions + +### Assumption PT (Standard / Weak Parallel Trends) + +For all d in D_+: +``` +E[Y_{t=2}(0) - Y_{t=1}(0) | D = d] = E[Y_{t=2}(0) - Y_{t=1}(0) | D = 0] +``` + +Untreated potential outcome paths are the same across all dose groups and the +untreated group. Direct analog of binary DiD parallel trends. + +**Identifies**: ATT(d|d), ATT^{loc}. Does NOT identify ATT(d), ACRT, or any +cross-dose comparison. + +### Assumption SPT (Strong Parallel Trends) + +For all d in D: +``` +E[Y_{t=2}(d) - Y_{t=1}(0) | D > 0] = E[Y_{t=2}(d) - Y_{t=1}(0) | D = d] +``` + +No selection into dose groups on the basis of treatment effects. Implies +ATT(d|d) = ATT(d) for all d. + +**Additionally identifies**: ATT(d), ACRT(d), ACRT^{glob}, and cross-dose +comparisons have causal interpretation. + +### Other assumptions + +- **No anticipation**: Y_{i,t=1} = Y_{i,t=1}(0) for all units +- **Overlap**: P(D=0) > 0 (untreated units exist); f_{D|D>0}(d) > 0 on D_+ +- **Multi-period PT-MP**: E[Delta Y_t(0) | G=g, D=d] = E[Delta Y_t(0) | G=inf, D=0] + +### Comparison to binary DiD + +When D in {0,1}: PT = SPT, ATT(1|1) = ATT(1) = ATT^{loc} = ATT^{glob}, +ACRT = ATT. Everything collapses to standard Callaway & Sant'Anna (2021). + +--- + +## 5. Estimation Procedures + +### 5.1 Discrete treatment: saturated regression + +When dose takes values d_1, ..., d_J (Eq. 4.1): +``` +Delta Y_i = beta_0 + sum_{j=1}^{J} 1{D_i = d_j} * beta_j + epsilon_i +``` + +- beta_j estimates ATT(d_j) +- (beta_j - beta_{j-1}) / (d_j - d_{j-1}) estimates ACRT(d_j) +- Standard OLS inference applies + +### 5.2 Continuous treatment: parametric (B-spline sieve) + +**This is the default in the R package and the recommended starting point.** + +For each (g,t) 2x2 cell: + +1. Compute untreated counterfactual: E_n[Delta Y | D=0] +2. Demean treated outcomes: Delta tilde{Y}_i = Delta Y_i - E_n[Delta Y | D=0] for D_i > 0 +3. Construct B-spline basis psi^K(D) of degree `degree` with `num_knots` interior knots +4. OLS regression among treated (Eq. 4.4): + ``` + Delta tilde{Y}_i = psi^K(D_i)' beta_K + epsilon_i + ``` +5. Evaluate at dose grid (Eq. 4.5): + ``` + ATT(d) = psi^K(d)' * hat{beta}_K + ACRT(d) = (d/dd psi^K(d))' * hat{beta}_K + ``` + +**R package defaults**: degree=3 (cubic), num_knots=0 (global polynomial), +dose grid = quantiles P10 to P99 in 1% steps (90 points). + +### 5.3 Continuous treatment: nonparametric CCK + +Adapts Chen, Christensen & Kankanala (2025) sieve estimator with data-driven +dimension selection. Same regression framework as 5.2 but: + +- Sieve dimension K is selected automatically via Lepski-type method (Algorithm 1) +- Provides honest, sup-norm rate-adaptive uniform confidence bands +- **Restricted to two-period settings** (no staggered adoption) + +Algorithm for sieve dimension selection: +1. Define candidate set K = {2^k + 3 : k in N_+} +2. Compute K_max based on sample size +3. For each K in candidate set, test stability of estimates across K values +4. Select smallest K where estimates are stable (within bootstrap critical value) + +### 5.4 Summary parameter estimation + +**ATT^{glob} (binarized DiD)**: Under SPT (Eq. 4.6): +``` +ATT^{glob} = E[Delta Y | D > 0] - E[Delta Y | D = 0] +``` +Simple difference in means between any-treated and untreated. + +**ACRT^{glob} (plug-in)**: +``` +ACRT^{glob} = (1/n_{D>0}) * sum_{i: D_i > 0} ACRT(D_i) +``` +Average the estimated ACRT curve over treated units' doses. + +### 5.5 Multi-period estimation + +For staggered adoption: apply two-period estimation to each (g,t) cell separately, +then aggregate. The R package handles this via the `ptetools` framework. + +For event-study ATT^{es}_{loc}(e): can binarize treatment and use standard +Callaway & Sant'Anna (2021) machinery. + +### 5.6 No untreated group (Remark 3.1) + +When P(D=0) = 0 (all units receive some treatment), use the lowest dose group d_L +as comparison. Under PT, this recovers ATT(d|d) - ATT(d_L|d_L). Under SPT, +recovers ATT(d) - ATT(d_L). + +--- + +## 6. Inference + +### Parametric / discrete case + +Standard OLS inference applies. Can cluster as needed. + +### Nonparametric CCK case + +**Multiplier (Gaussian) bootstrap** — NOT standard nonparametric bootstrap: +1. Draw omega_i iid N(0,1) for i=1,...,n +2. Compute weighted sums of influence functions using these weights +3. Repeat B=1000 times +4. No re-estimation needed per bootstrap draw (computationally efficient) + +**Pointwise confidence intervals**: +``` +ATT(d) +/- z_{0.975} * sigma_K(d) / sqrt(n) +``` + +**Uniform confidence bands (UCBs)**: +1. Compute bootstrap distribution of sup-t statistic across all dose values +2. Critical value c_alpha from (1-alpha) quantile +3. Band: ATT(d) +/- (c_alpha + A * gamma) * sigma(d) / sqrt(n) +4. These are honest and rate-adaptive + +### Summary parameter inference + +ACRT^{glob} plug-in estimator is sqrt(n)-consistent and asymptotically normal. +Standard errors via delta method or bootstrap. + +--- + +## 7. TWFE Decomposition (Theorem 3.4) + +Four decompositions of beta^{twfe}, each revealing a different pathology: + +| Decomposition | Weights positive? | Sum to 1? | Selection bias (under PT)? | +|:---|:---:|:---:|:---:| +| (a) Causal response | Yes | Yes | Yes | +| (b) Levels | No | No (sum to 0) | N/A | +| (c) Scaled levels | No | Yes | N/A | +| (d) Scaled 2x2 | Yes | Yes | Yes | + +Even under SPT (best case), decomposition (a) uses TWFE-specific weights that +don't match the dose density, making beta^{twfe} an unappealing summary. + +--- + +## 8. Key Theorems + +| Theorem | Statement (plain English) | +|---------|--------------------------| +| 3.1 | Under PT: ATT(d\|d) = E[Delta Y \| D=d] - E[Delta Y \| D=0]. The local ATT for each dose group is identified by the standard DiD comparison. | +| 3.2 | Under PT: cross-dose comparisons mix causal effects with selection bias. The derivative of E[Delta Y\|D=d] does NOT identify ACRT without stronger assumptions. | +| 3.3 | Under SPT: ATT(d) = E[Delta Y\|D=d] - E[Delta Y\|D=0], and ACRT(d) = d/dd E[Delta Y\|D=d]. Cross-dose comparisons are causal. | +| 3.4 | TWFE decomposition: beta^{twfe} admits four representations, all problematic. | +| Cor 3.1 | ATT^{glob} = binarized DiD. ACRT^{glob} = weighted average of dose-specific slopes. | +| C.1 | Multi-period: ATT(g,t,d\|g,d) = E[Y_t - Y_{g-1}\|G=g,D=d] - E[Y_t - Y_{g-1}\|W_t=0]. | + +--- + +## 9. R Package Implementation Details + +### API surface + +Main function: `cont_did()` returns `pte_dose_results` or `dose_obj`. + +Key parameters: +``` +cont_did(yname, dname, gname, tname, idname, data, + target_parameter = "level"|"slope", + aggregation = "dose"|"eventstudy"|"none", + treatment_type = "continuous"|"discrete", + dose_est_method = "parametric"|"cck", + dvals = NULL, + degree = 3, num_knots = 0, + control_group = "notyettreated"|"nevertreated"|"eventuallytreated", + anticipation = 0, + bstrap = TRUE, boot_type = "multiplier", biters = 1000, + cband = FALSE, alp = 0.05, + base_period = "varying", + ...) +``` + +### Core algorithm per (g,t) cell + +1. Extract 2x2 subset: target group (g) + control group, pre-period + post-period +2. Construct B-spline basis from treated units' doses using `splines2::bSpline()` +3. OLS: regress Delta Y on B-spline basis +4. Evaluate fitted spline at `dvals` -> ATT(d) vector +5. Evaluate derivative of spline at `dvals` -> ACRT(d) vector +6. Return estimates + influence functions for bootstrap + +### Aggregation + +- **`"dose"`**: Average across (g,t) cells at each dose point -> dose-response curve +- **`"eventstudy"`**: Average across dose at each event time e -> dynamic effects +- **`"none"`**: Return disaggregated (g,t,d) results + +### Data conventions + +- **Dose is time-invariant**: Set to actual value in ALL periods (pre and post) +- **Never-treated**: G=0, dose forced to 0 internally +- **Balanced panel required** in v0.1.0 +- Units with treatment timing but zero dose are dropped + +### Default dose grid + +``` +dvals = quantile(dose[dose > 0], probs = seq(0.10, 0.99, 0.01)) +``` +90 evaluation points, P10 to P99. Provides implicit tail trimming. + +### Knot placement + +Quantile-based by default: `choose_knots_quantile(dose[dose > 0], num_knots)`. +With `num_knots=0`, no interior knots (global polynomial of given degree). + +### Dependencies mapping (R -> Python) + +| R Package | Purpose | Python Equivalent | +|-----------|---------|-------------------| +| `splines2` | B-spline basis + derivatives | `scipy.interpolate.BSpline` + custom derivative | +| `sandwich` | Robust variance | Already in diff-diff `linalg.py` | +| `ptetools` | Group-time iteration, aggregation, bootstrap | Reimplement (mirrors existing CS framework) | +| `MASS::ginv` | Pseudo-inverse | `numpy.linalg.pinv` | +| `npiv` | CCK nonparametric | Reimplement for CCK method | + +### Current limitations (v0.1.0) + +- Covariates not supported (xformula = ~1 only) +- Discrete treatment not yet implemented +- Unbalanced panels not supported +- CCK restricted to 2-period case +- Repeated cross-sections not supported + +--- + +## 10. Implementation Priorities for diff-diff + +### Phase 1 (Core) +1. Parametric B-spline estimation for two-period case +2. ATT(d) and ACRT(d) dose-response curves +3. Summary parameters: ATT^{glob}, ACRT^{glob} +4. Bootstrap inference (multiplier) + +### Phase 2 (Staggered) +5. Multi-period extension via (g,t) cell iteration +6. Dose aggregation and event-study aggregation +7. Control group options (not-yet-treated, never-treated) + +### Phase 3 (Advanced) +8. CCK nonparametric estimation +9. Uniform confidence bands +10. Covariates support (DR/IPW/OR) + +### Defer +- Discrete treatment (saturated regression — simpler, add later) +- TWFE decomposition diagnostics diff --git a/tests/conftest.py b/tests/conftest.py index 596fa1f..4d4ff41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,10 @@ import os import subprocess +# Force non-interactive matplotlib backend before any test imports it. +# Prevents plt.show() from blocking the test suite on a GUI window. +os.environ.setdefault("MPLBACKEND", "Agg") + import pytest diff --git a/tests/test_continuous_did.py b/tests/test_continuous_did.py new file mode 100644 index 0000000..40cfc28 --- /dev/null +++ b/tests/test_continuous_did.py @@ -0,0 +1,506 @@ +""" +Unit and integration tests for ContinuousDiD estimator. +""" + +import numpy as np +import pandas as pd +import pytest + +from diff_diff.continuous_did import ContinuousDiD +from diff_diff.continuous_did_bspline import ( + bspline_derivative_design_matrix, + bspline_design_matrix, + build_bspline_basis, + default_dose_grid, +) +from diff_diff.continuous_did_results import ContinuousDiDResults +from diff_diff.prep_dgp import generate_continuous_did_data + +# ============================================================================= +# B-Spline Basis Tests +# ============================================================================= + + +class TestBSplineBasis: + """Test B-spline utility functions.""" + + def test_knot_construction_no_interior(self): + dose = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + knots, deg = build_bspline_basis(dose, degree=3, num_knots=0) + assert deg == 3 + # Boundary knots repeated degree+1 times + assert knots[0] == 1.0 + assert knots[-1] == 5.0 + assert len(knots) == 3 + 1 + 3 + 1 # (degree+1)*2 + + def test_knot_construction_with_interior(self): + dose = np.linspace(1, 10, 100) + knots, deg = build_bspline_basis(dose, degree=3, num_knots=2) + # Interior knots at 1/3 and 2/3 quantiles + n_expected = 2 * (3 + 1) + 2 # boundary + interior + assert len(knots) == n_expected + + def test_design_matrix_shape(self): + dose = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + knots, deg = build_bspline_basis(dose, degree=3, num_knots=0) + B = bspline_design_matrix(dose, knots, deg, include_intercept=True) + n_basis = len(knots) - deg - 1 # Total basis functions + assert B.shape == (5, n_basis) # Same columns (intercept replaces first) + + def test_design_matrix_intercept_column(self): + dose = np.linspace(1, 5, 20) + knots, deg = build_bspline_basis(dose, degree=3, num_knots=0) + B = bspline_design_matrix(dose, knots, deg, include_intercept=True) + # First column should be all ones + np.testing.assert_array_equal(B[:, 0], np.ones(20)) + + def test_design_matrix_no_intercept(self): + dose = np.linspace(1, 5, 20) + knots, deg = build_bspline_basis(dose, degree=3, num_knots=0) + B_no = bspline_design_matrix(dose, knots, deg, include_intercept=False) + n_basis = len(knots) - deg - 1 + assert B_no.shape == (20, n_basis) + # First column should NOT be all ones + assert not np.allclose(B_no[:, 0], 1.0) + + def test_derivative_numerical_check(self): + """Verify B-spline derivatives match finite differences.""" + dose = np.linspace(1, 5, 50) + knots, deg = build_bspline_basis(dose, degree=3, num_knots=1) + + # Evaluate at interior points (avoid boundaries) + x = np.linspace(1.5, 4.5, 30) + dB = bspline_derivative_design_matrix(x, knots, deg, include_intercept=True) + + # Finite difference check + h = 1e-6 + x_plus = x + h + x_minus = x - h + B_plus = bspline_design_matrix(x_plus, knots, deg, include_intercept=True) + B_minus = bspline_design_matrix(x_minus, knots, deg, include_intercept=True) + fd = (B_plus - B_minus) / (2 * h) + + # Intercept derivative should be 0 + np.testing.assert_allclose(dB[:, 0], 0.0, atol=1e-10) + # Other columns should match finite differences + np.testing.assert_allclose(dB[:, 1:], fd[:, 1:], atol=1e-4) + + def test_partition_of_unity(self): + """B-spline basis without intercept should sum to ~1 at interior points.""" + dose = np.linspace(1, 5, 50) + knots, deg = build_bspline_basis(dose, degree=3, num_knots=2) + x = np.linspace(1.1, 4.9, 30) + B = bspline_design_matrix(x, knots, deg, include_intercept=False) + row_sums = B.sum(axis=1) + np.testing.assert_allclose(row_sums, 1.0, atol=1e-10) + + def test_linear_basis(self): + """Degree 1 with 0 knots: 2 basis functions (intercept + linear).""" + dose = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + knots, deg = build_bspline_basis(dose, degree=1, num_knots=0) + B = bspline_design_matrix(dose, knots, deg, include_intercept=True) + assert B.shape[1] == 2 # intercept + 1 basis fn + + +class TestDoseGrid: + """Test dose grid computation.""" + + def test_default_grid_size(self): + dose = np.random.default_rng(42).lognormal(0.5, 0.5, size=100) + grid = default_dose_grid(dose) + assert len(grid) == 90 # quantiles 0.10 to 0.99 + + def test_default_grid_sorted(self): + dose = np.random.default_rng(42).lognormal(0.5, 0.5, size=100) + grid = default_dose_grid(dose) + assert np.all(np.diff(grid) >= 0) + + def test_custom_grid_passthrough(self): + custom = np.array([1.0, 2.0, 3.0]) + est = ContinuousDiD(dvals=custom) + np.testing.assert_array_equal(est.dvals, custom) + + def test_empty_dose(self): + grid = default_dose_grid(np.array([0.0, 0.0])) + assert len(grid) == 0 + + +# ============================================================================= +# ContinuousDiD Estimator Tests +# ============================================================================= + + +class TestContinuousDiDInit: + """Test constructor, get_params, set_params.""" + + def test_default_params(self): + est = ContinuousDiD() + params = est.get_params() + assert params["degree"] == 3 + assert params["num_knots"] == 0 + assert params["control_group"] == "never_treated" + assert params["alpha"] == 0.05 + assert params["n_bootstrap"] == 0 + + def test_set_params(self): + est = ContinuousDiD() + est.set_params(degree=1, num_knots=2) + assert est.degree == 1 + assert est.num_knots == 2 + + def test_set_invalid_param(self): + est = ContinuousDiD() + with pytest.raises(ValueError, match="Invalid parameter"): + est.set_params(nonexistent_param=42) + + +class TestContinuousDiDDataValidation: + """Test data validation in fit().""" + + def test_missing_column(self): + data = pd.DataFrame({"unit": [1], "period": [1], "outcome": [1.0]}) + est = ContinuousDiD() + with pytest.raises(ValueError, match="Column.*not found"): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + def test_non_time_invariant_dose(self): + data = pd.DataFrame({ + "unit": [1, 1, 2, 2], + "period": [1, 2, 1, 2], + "outcome": [1.0, 2.0, 1.0, 2.0], + "first_treat": [2, 2, 0, 0], + "dose": [1.0, 2.0, 0.0, 0.0], # Dose changes over time! + }) + est = ContinuousDiD() + with pytest.raises(ValueError, match="time-invariant"): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + def test_drop_zero_dose_treated(self): + """Units with positive first_treat but zero dose should be dropped.""" + # Need enough treated units for OLS: degree=1 → 2 basis fns → need >2 treated + rows = [] + uid = 0 + # 1 treated unit with zero dose (should be dropped) + rows += [{"unit": uid, "period": 1, "outcome": 1.0, "first_treat": 2, "dose": 0.0}, + {"unit": uid, "period": 2, "outcome": 3.0, "first_treat": 2, "dose": 0.0}] + uid += 1 + # 4 treated units with positive dose (should remain) + for d in [1.0, 2.0, 3.0, 4.0]: + rows += [{"unit": uid, "period": 1, "outcome": 0.0, "first_treat": 2, "dose": d}, + {"unit": uid, "period": 2, "outcome": 2 * d, "first_treat": 2, "dose": d}] + uid += 1 + # 3 control units + for _ in range(3): + rows += [{"unit": uid, "period": 1, "outcome": 0.0, "first_treat": 0, "dose": 0.0}, + {"unit": uid, "period": 2, "outcome": 0.0, "first_treat": 0, "dose": 0.0}] + uid += 1 + + data = pd.DataFrame(rows) + est = ContinuousDiD(degree=1, num_knots=0) + with pytest.warns(UserWarning, match="Dropping.*units"): + results = est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + # Unit 0 dropped (zero dose but treated), 4 treated remain + assert results.n_treated_units == 4 + + def test_unbalanced_panel_error(self): + data = pd.DataFrame({ + "unit": [1, 1, 2], + "period": [1, 2, 1], + "outcome": [1.0, 2.0, 1.0], + "first_treat": [2, 2, 0], + "dose": [1.0, 1.0, 0.0], + }) + est = ContinuousDiD() + with pytest.raises(ValueError, match="[Uu]nbalanced"): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + def test_no_never_treated_error(self): + data = pd.DataFrame({ + "unit": [1, 1, 2, 2], + "period": [1, 2, 1, 2], + "outcome": [1.0, 3.0, 1.0, 4.0], + "first_treat": [2, 2, 2, 2], + "dose": [1.0, 1.0, 2.0, 2.0], + }) + est = ContinuousDiD(control_group="never_treated") + with pytest.raises(ValueError, match="[Nn]ever-treated"): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + +class TestContinuousDiDFit: + """Test basic fit returns correct types and shapes.""" + + @pytest.fixture + def basic_data(self): + return generate_continuous_did_data( + n_units=100, n_periods=3, seed=42, noise_sd=0.5 + ) + + def test_fit_returns_results(self, basic_data): + est = ContinuousDiD() + results = est.fit( + basic_data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert isinstance(results, ContinuousDiDResults) + + def test_dose_response_shapes(self, basic_data): + est = ContinuousDiD() + results = est.fit( + basic_data, "outcome", "unit", "period", "first_treat", "dose" + ) + n_grid = len(results.dose_grid) + assert results.dose_response_att.effects.shape == (n_grid,) + assert results.dose_response_acrt.effects.shape == (n_grid,) + assert results.dose_response_att.se.shape == (n_grid,) + assert results.dose_response_acrt.se.shape == (n_grid,) + + def test_overall_parameters(self, basic_data): + est = ContinuousDiD() + results = est.fit( + basic_data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert np.isfinite(results.overall_att) + assert np.isfinite(results.overall_acrt) + + def test_group_time_effects_populated(self, basic_data): + est = ContinuousDiD() + results = est.fit( + basic_data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert len(results.group_time_effects) > 0 + + def test_not_yet_treated_control(self): + data = generate_continuous_did_data( + n_units=100, n_periods=4, cohort_periods=[2, 3], seed=42, + ) + est = ContinuousDiD(control_group="not_yet_treated") + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert isinstance(results, ContinuousDiDResults) + + +class TestContinuousDiDResults: + """Test results object methods.""" + + @pytest.fixture + def results(self): + data = generate_continuous_did_data( + n_units=100, n_periods=3, seed=42, noise_sd=0.1 + ) + est = ContinuousDiD(n_bootstrap=49, seed=42) + return est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + + def test_summary(self, results): + s = results.summary() + assert "ATT_glob" in s + assert "ACRT_glob" in s + assert "Continuous" in s + + def test_print_summary(self, results, capsys): + results.print_summary() + captured = capsys.readouterr() + assert "ATT_glob" in captured.out + + def test_to_dataframe_dose_response(self, results): + df = results.to_dataframe(level="dose_response") + assert "dose" in df.columns + assert "att" in df.columns + assert "acrt" in df.columns + assert len(df) == len(results.dose_grid) + + def test_to_dataframe_group_time(self, results): + df = results.to_dataframe(level="group_time") + assert "group" in df.columns + assert "time" in df.columns + assert "att_glob" in df.columns + + def test_to_dataframe_event_study_error(self, results): + """Should error if event study not computed.""" + with pytest.raises(ValueError, match="[Ee]vent study"): + results.to_dataframe(level="event_study") + + def test_to_dataframe_invalid_level(self, results): + with pytest.raises(ValueError, match="Unknown level"): + results.to_dataframe(level="invalid") + + def test_is_significant(self, results): + assert isinstance(results.is_significant, bool) + + def test_significance_stars(self, results): + stars = results.significance_stars + assert stars in ("", ".", "*", "**", "***") + + def test_repr(self, results): + r = repr(results) + assert "ContinuousDiDResults" in r + + +class TestDoseAggregation: + """Test dose-response aggregation across (g,t) cells.""" + + def test_multi_period_aggregation(self): + data = generate_continuous_did_data( + n_units=200, n_periods=5, cohort_periods=[2, 4], + seed=42, noise_sd=0.1, + ) + est = ContinuousDiD(degree=1, num_knots=0) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="dose", + ) + # With linear DGP (ATT(d) = 1 + 2d) and degree=1, should recover well + # ACRT should be close to 2.0 + assert abs(results.overall_acrt - 2.0) < 0.3 + + def test_single_cohort_aggregation(self): + data = generate_continuous_did_data( + n_units=100, n_periods=3, seed=42, noise_sd=0.1, + ) + est = ContinuousDiD(degree=1, num_knots=0) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="dose", + ) + assert len(results.groups) == 1 + assert np.isfinite(results.overall_att) + + +class TestEventStudyAggregation: + """Test event-study aggregation path.""" + + def test_event_study_computed(self): + data = generate_continuous_did_data( + n_units=200, n_periods=5, cohort_periods=[2, 4], + seed=42, noise_sd=0.5, + ) + est = ContinuousDiD(n_bootstrap=49, seed=42) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="eventstudy", + ) + assert results.event_study_effects is not None + # Should have pre and post relative periods + rel_periods = sorted(results.event_study_effects.keys()) + assert min(rel_periods) < 0 # Pre-treatment + assert max(rel_periods) >= 0 # Post-treatment + + def test_event_study_to_dataframe(self): + data = generate_continuous_did_data( + n_units=200, n_periods=4, cohort_periods=[2, 3], + seed=42, noise_sd=0.5, + ) + est = ContinuousDiD() + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="eventstudy", + ) + df = results.to_dataframe(level="event_study") + assert "relative_period" in df.columns + assert "att_glob" in df.columns + + +class TestBootstrap: + """Test bootstrap inference.""" + + def test_bootstrap_ses_positive(self, ci_params): + n_boot = ci_params.bootstrap(99) + data = generate_continuous_did_data( + n_units=100, n_periods=3, seed=42, noise_sd=0.5, + ) + est = ContinuousDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert results.overall_att_se > 0 + assert results.overall_acrt_se > 0 + # Dose-response SEs should be positive + assert np.all(results.dose_response_att.se > 0) + + def test_bootstrap_ci_contains_estimate(self, ci_params): + n_boot = ci_params.bootstrap(99) + data = generate_continuous_did_data( + n_units=100, n_periods=3, seed=42, noise_sd=0.5, + ) + est = ContinuousDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + lo, hi = results.overall_att_conf_int + assert lo <= results.overall_att <= hi + + def test_bootstrap_p_values_valid(self, ci_params): + n_boot = ci_params.bootstrap(99) + data = generate_continuous_did_data( + n_units=100, n_periods=3, seed=42, noise_sd=0.5, + ) + est = ContinuousDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert 0 <= results.overall_att_p_value <= 1 + assert 0 <= results.overall_acrt_p_value <= 1 + + +class TestAnalyticalSE: + """Test analytical standard errors (n_bootstrap=0).""" + + def test_analytical_se_positive(self): + data = generate_continuous_did_data( + n_units=100, n_periods=3, seed=42, noise_sd=0.5, + ) + est = ContinuousDiD(n_bootstrap=0) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert results.overall_att_se > 0 + assert results.overall_acrt_se > 0 + + def test_analytical_ci(self): + data = generate_continuous_did_data( + n_units=100, n_periods=3, seed=42, noise_sd=0.5, + ) + est = ContinuousDiD(n_bootstrap=0) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + lo, hi = results.overall_att_conf_int + assert lo < results.overall_att < hi + + +class TestEdgeCases: + """Test edge cases.""" + + def test_few_treated_units(self): + """Estimator should handle very few treated units.""" + data = generate_continuous_did_data( + n_units=30, n_periods=3, seed=42, + never_treated_frac=0.8, # Only ~6 treated + ) + est = ContinuousDiD(degree=1, num_knots=0) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert isinstance(results, ContinuousDiDResults) + + def test_inf_first_treat_normalization(self): + """first_treat=inf should be treated as never-treated.""" + data = generate_continuous_did_data(n_units=50, n_periods=3, seed=42) + data.loc[data["first_treat"] == 0, "first_treat"] = np.inf + est = ContinuousDiD() + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert results.n_control_units > 0 + + def test_custom_dvals(self): + data = generate_continuous_did_data(n_units=100, n_periods=3, seed=42) + custom_grid = np.array([1.0, 2.0, 3.0]) + est = ContinuousDiD(dvals=custom_grid) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + np.testing.assert_array_equal(results.dose_grid, custom_grid) + assert len(results.dose_response_att.effects) == 3 diff --git a/tests/test_methodology_continuous_did.py b/tests/test_methodology_continuous_did.py new file mode 100644 index 0000000..596ae09 --- /dev/null +++ b/tests/test_methodology_continuous_did.py @@ -0,0 +1,654 @@ +""" +Equation verification and R benchmark tests for ContinuousDiD. + +Phase 1: Hand-calculable cases verifying the estimator recovers known truths. +Phase 2: R `contdid` benchmarks (skipped if R not installed). +""" + +import json +import subprocess +import tempfile + +import numpy as np +import pandas as pd +import pytest + +from diff_diff.continuous_did import ContinuousDiD +from diff_diff.prep_dgp import generate_continuous_did_data + +# ============================================================================= +# Phase 1: Hand-calculable equation verification +# ============================================================================= + + +class TestLinearDoseResponse: + """Two-period case with linear dose-response ATT(d) = 2d.""" + + @pytest.fixture + def linear_data(self): + """6 treated, 4 control. True ATT(d) = 2d. No noise.""" + treated_doses = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + n_control = 4 + + rows = [] + # Control units: Delta Y = 0 (no treatment) + for i in range(n_control): + rows.append({"unit": i, "period": 1, "outcome": 0.0, "first_treat": 0, "dose": 0.0}) + rows.append({"unit": i, "period": 2, "outcome": 0.0, "first_treat": 0, "dose": 0.0}) + + # Treated units: Delta Y = ATT(d) = 2*d + for j, d in enumerate(treated_doses): + uid = n_control + j + rows.append({"unit": uid, "period": 1, "outcome": 0.0, "first_treat": 2, "dose": d}) + rows.append({"unit": uid, "period": 2, "outcome": 2 * d, "first_treat": 2, "dose": d}) + + return pd.DataFrame(rows) + + def test_linear_att_recovery(self, linear_data): + """With degree=1 and linear truth, ATT(d) should be exactly 2d.""" + est = ContinuousDiD(degree=1, num_knots=0, dvals=np.array([1.0, 3.0, 5.0])) + results = est.fit( + linear_data, "outcome", "unit", "period", "first_treat", "dose" + ) + expected = np.array([2.0, 6.0, 10.0]) + np.testing.assert_allclose(results.dose_response_att.effects, expected, atol=1e-10) + + def test_linear_acrt(self, linear_data): + """ACRT(d) should be constant = 2 for linear truth.""" + est = ContinuousDiD(degree=1, num_knots=0, dvals=np.array([1.5, 3.0, 4.5])) + results = est.fit( + linear_data, "outcome", "unit", "period", "first_treat", "dose" + ) + # Derivative of 2d is 2 + np.testing.assert_allclose(results.dose_response_acrt.effects, 2.0, atol=1e-6) + + def test_att_glob_binarized(self, linear_data): + """ATT_glob = mean(Delta_Y | treated) - mean(Delta_Y | control).""" + treated_doses = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + mean_delta_treated = np.mean(2 * treated_doses) # = 7.0 + mean_delta_control = 0.0 + expected_att_glob = mean_delta_treated - mean_delta_control + + est = ContinuousDiD(degree=1, num_knots=0) + results = est.fit( + linear_data, "outcome", "unit", "period", "first_treat", "dose" + ) + np.testing.assert_allclose(results.overall_att, expected_att_glob, atol=1e-10) + + def test_acrt_glob_plugin(self, linear_data): + """ACRT_glob = mean(ACRT(D_i)) over treated = 2.""" + est = ContinuousDiD(degree=1, num_knots=0) + results = est.fit( + linear_data, "outcome", "unit", "period", "first_treat", "dose" + ) + np.testing.assert_allclose(results.overall_acrt, 2.0, atol=1e-6) + + +class TestQuadraticWithCubicBasis: + """ATT(d) = d^2. Cubic B-spline can represent quadratic exactly.""" + + @pytest.fixture + def quadratic_data(self): + doses = np.linspace(1, 5, 20) + n_control = 10 + + rows = [] + for i in range(n_control): + rows.append({"unit": i, "period": 1, "outcome": 0.0, "first_treat": 0, "dose": 0.0}) + rows.append({"unit": i, "period": 2, "outcome": 0.0, "first_treat": 0, "dose": 0.0}) + + for j, d in enumerate(doses): + uid = n_control + j + rows.append({"unit": uid, "period": 1, "outcome": 0.0, "first_treat": 2, "dose": d}) + rows.append({"unit": uid, "period": 2, "outcome": d**2, "first_treat": 2, "dose": d}) + + return pd.DataFrame(rows) + + def test_quadratic_recovery(self, quadratic_data): + """Cubic basis should recover d^2 exactly.""" + eval_grid = np.array([1.5, 2.5, 3.5, 4.5]) + est = ContinuousDiD(degree=3, num_knots=0, dvals=eval_grid) + results = est.fit( + quadratic_data, "outcome", "unit", "period", "first_treat", "dose" + ) + expected = eval_grid**2 + np.testing.assert_allclose( + results.dose_response_att.effects, expected, atol=1e-6 + ) + + +class TestMultiPeriodAggregation: + """4 periods, 2 cohorts. Verify (g,t) cells and aggregation weights.""" + + @pytest.fixture + def staggered_data(self): + return generate_continuous_did_data( + n_units=200, + n_periods=4, + cohort_periods=[2, 3], + seed=42, + noise_sd=0.0, # No noise for exact verification + att_function="linear", + att_slope=2.0, + att_intercept=1.0, + ) + + def test_multiple_groups(self, staggered_data): + est = ContinuousDiD(degree=1, num_knots=0) + results = est.fit( + staggered_data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert len(results.groups) == 2 + assert 2 in results.groups + assert 3 in results.groups + + def test_gt_cell_count(self, staggered_data): + est = ContinuousDiD(degree=1, num_knots=0) + results = est.fit( + staggered_data, "outcome", "unit", "period", "first_treat", "dose" + ) + # Group 2: periods 1(pre-via-varying),2,3,4; Group 3: periods 2(pre),3,4 + # Exact count depends on base period logic + assert len(results.group_time_effects) >= 4 + + +class TestEdgeCasesMethodology: + """Edge cases: all-same dose, single treated unit, boundary doses.""" + + def test_all_same_dose(self): + """When all treated have same dose, OLS can only recover mean effect.""" + n_control = 10 + n_treated = 5 + dose_val = 3.0 + rows = [] + for i in range(n_control): + rows.append({"unit": i, "period": 1, "outcome": 0.0, "first_treat": 0, "dose": 0.0}) + rows.append({"unit": i, "period": 2, "outcome": 0.0, "first_treat": 0, "dose": 0.0}) + for j in range(n_treated): + uid = n_control + j + rows.append({"unit": uid, "period": 1, "outcome": 0.0, "first_treat": 2, "dose": dose_val}) + rows.append({"unit": uid, "period": 2, "outcome": 5.0, "first_treat": 2, "dose": dose_val}) + + data = pd.DataFrame(rows) + est = ContinuousDiD(degree=1, num_knots=0, rank_deficient_action="silent") + with pytest.warns(UserWarning, match="[Ii]dentical"): + results = est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + # ATT_glob should be 5.0 + np.testing.assert_allclose(results.overall_att, 5.0, atol=1e-10) + + def test_single_treated_unit(self): + """Single treated unit: not enough for OLS → no valid cells → ValueError.""" + rows = [] + for i in range(5): + rows.append({"unit": i, "period": 1, "outcome": 0.0, "first_treat": 0, "dose": 0.0}) + rows.append({"unit": i, "period": 2, "outcome": 0.0, "first_treat": 0, "dose": 0.0}) + rows.append({"unit": 5, "period": 1, "outcome": 0.0, "first_treat": 2, "dose": 2.0}) + rows.append({"unit": 5, "period": 2, "outcome": 4.0, "first_treat": 2, "dose": 2.0}) + + data = pd.DataFrame(rows) + est = ContinuousDiD(degree=1, num_knots=0, rank_deficient_action="silent") + with pytest.raises(ValueError, match="No valid"): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + +# ============================================================================= +# Phase 2: R `contdid` benchmarks +# ============================================================================= + + +def _check_r_contdid(): + """Check if R and contdid package are available.""" + try: + result = subprocess.run( + ["Rscript", "-e", "library(contdid); cat('OK')"], + capture_output=True, text=True, timeout=10, + ) + return result.stdout.strip() == "OK" + except (FileNotFoundError, subprocess.TimeoutExpired): + return False + + +_HAS_R_CONTDID = _check_r_contdid() + +require_contdid = pytest.mark.skipif( + not _HAS_R_CONTDID, + reason="R or contdid package not installed", +) + + +def _run_r_contdid(csv_path, degree=3, num_knots=0, control_group="nevertreated", + aggregation="dose", staggered=False): + """Run R's cont_did() and return results for comparison. + + For 2-period data (staggered=False): recomputes ATT(d)/ACRT(d) with consistent + boundary knots, fixing R's contdid v0.1.0 quirk of using range(dvals) instead + of range(dose) for the evaluation basis. + + For multi-period data (staggered=True): compares only overall ATT/ACRT, which + are not affected by the boundary knot issue. + """ + cg = "nevertreated" if control_group == "never_treated" else "notyettreated" + + if staggered: + # For staggered data, compare overall ATT/ACRT only + r_code = f""" + library(contdid) + library(jsonlite) + + data <- read.csv("{csv_path}") + res_level <- cont_did( + yname = "outcome", tname = "period", idname = "unit", + gname = "first_treat", dname = "dose", data = data, + target_parameter = "level", aggregation = "{aggregation}", + treatment_type = "continuous", control_group = "{cg}", + degree = {degree}, num_knots = {num_knots}, + bstrap = FALSE, print_details = FALSE + ) + res_slope <- cont_did( + yname = "outcome", tname = "period", idname = "unit", + gname = "first_treat", dname = "dose", data = data, + target_parameter = "slope", aggregation = "{aggregation}", + treatment_type = "continuous", control_group = "{cg}", + degree = {degree}, num_knots = {num_knots}, + bstrap = FALSE, print_details = FALSE + ) + out <- list( + overall_att = res_level$overall_att, + overall_att_se = res_level$overall_att_se, + overall_acrt = res_slope$overall_acrt, + overall_acrt_se = res_slope$overall_acrt_se, + dvals = as.numeric(res_level$dose) + ) + cat(toJSON(out, auto_unbox = TRUE, digits = 10)) + """ + else: + # For 2-period data, recompute dose-response with consistent knots + r_code = f""" + library(contdid) + library(jsonlite) + library(splines2) + + data <- read.csv("{csv_path}") + res <- cont_did( + yname = "outcome", tname = "period", idname = "unit", + gname = "first_treat", dname = "dose", data = data, + target_parameter = "level", aggregation = "{aggregation}", + treatment_type = "continuous", control_group = "{cg}", + degree = {degree}, num_knots = {num_knots}, + bstrap = FALSE, print_details = FALSE + ) + res_slope <- cont_did( + yname = "outcome", tname = "period", idname = "unit", + gname = "first_treat", dname = "dose", data = data, + target_parameter = "slope", aggregation = "{aggregation}", + treatment_type = "continuous", control_group = "{cg}", + degree = {degree}, num_knots = {num_knots}, + bstrap = FALSE, print_details = FALSE + ) + + dvals <- as.numeric(res$dose) + first_period <- min(data[["period"]]) + fp_data <- data[data[["period"]] == first_period,] + treated_doses <- fp_data[["dose"]][fp_data[["first_treat"]] > 0 & fp_data[["dose"]] > 0] + bknots <- range(treated_doses) + interior_knots <- as.numeric(res$pte_params$knots) + + # Rebuild OLS with consistent boundary knots + bs_train <- bSpline(treated_doses, degree = {degree}, + knots = interior_knots, Boundary.knots = bknots, + intercept = FALSE) + post_period <- sort(unique(data[["period"]]))[2] + pre_data <- data[data[["period"]] == first_period,] + post_data <- data[data[["period"]] == post_period,] + pre_data <- pre_data[order(pre_data[["unit"]]),] + post_data <- post_data[order(post_data[["unit"]]),] + dy <- post_data[["outcome"]] - pre_data[["outcome"]] + dy_treated <- dy[pre_data[["first_treat"]] > 0 & pre_data[["dose"]] > 0] + dy_control <- dy[pre_data[["first_treat"]] == 0] + mu_0 <- mean(dy_control) + + bs_df <- as.data.frame(bs_train) + colnames(bs_df) <- paste0("V", seq_len(ncol(bs_df))) + bs_df$dy <- dy_treated + reg <- lm(dy ~ ., data = bs_df) + beta <- coef(reg) + + bs_grid <- bSpline(dvals, degree = {degree}, knots = interior_knots, + Boundary.knots = bknots, intercept = FALSE) + bs_grid_df <- as.data.frame(bs_grid) + colnames(bs_grid_df) <- paste0("V", seq_len(ncol(bs_grid_df))) + att_d <- predict(reg, newdata = bs_grid_df) - mu_0 + + dbs_grid <- dbs(dvals, degree = {degree}, knots = interior_knots, + Boundary.knots = bknots) + acrt_d <- as.numeric(dbs_grid %*% beta[-1]) + + out <- list( + overall_att = res$overall_att, + overall_att_se = res$overall_att_se, + overall_acrt = res_slope$overall_acrt, + overall_acrt_se = res_slope$overall_acrt_se, + att_d = as.numeric(att_d), + acrt_d = acrt_d, + dvals = dvals, + beta = as.numeric(beta) + ) + cat(toJSON(out, auto_unbox = TRUE, digits = 10)) + """ + result = subprocess.run( + ["Rscript", "-e", r_code], + capture_output=True, text=True, timeout=120, + ) + if result.returncode != 0: + pytest.skip(f"R contdid failed: {result.stderr[:500]}") + return json.loads(result.stdout) + + +@require_contdid +class TestRBenchmark: + """R `contdid` v0.1.0 benchmark comparisons.""" + + def _compare_with_r(self, data, degree=3, num_knots=0, + control_group="never_treated", aggregation="dose", + staggered=False, att_tol=0.01, acrt_tol=0.02): + """Helper: run both Python and R, compare.""" + with tempfile.NamedTemporaryFile(suffix=".csv", mode="w", delete=False) as f: + data.to_csv(f, index=False) + csv_path = f.name + + r_out = _run_r_contdid( + csv_path, degree=degree, num_knots=num_knots, + control_group=control_group, aggregation=aggregation, + staggered=staggered, + ) + + # Map R aggregation names to Python aggregate parameter + py_aggregate = None + if aggregation == "dose": + py_aggregate = "dose" + elif aggregation == "eventstudy": + py_aggregate = "eventstudy" + + # Python estimation using R's dvals for exact grid match + dvals = np.array(r_out["dvals"]) + est = ContinuousDiD( + degree=degree, num_knots=num_knots, dvals=dvals, + control_group=control_group, + ) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate=py_aggregate, + ) + + # Compare overall ATT + r_overall_att = r_out["overall_att"] + py_overall_att = results.overall_att + overall_att_diff = abs(py_overall_att - r_overall_att) / (abs(r_overall_att) + 1e-10) + assert overall_att_diff < att_tol, ( + f"Overall ATT diff: {overall_att_diff:.4f} " + f"(R={r_overall_att:.6f}, Py={py_overall_att:.6f})" + ) + + # Compare ATT(d) and ACRT(d) only for non-staggered cases + # (staggered cases have the R boundary knot quirk in aggregated curves) + if not staggered: + r_att_d = np.array(r_out["att_d"]) + py_att_d = results.dose_response_att.effects + rel_diff_att = np.abs(py_att_d - r_att_d) / (np.abs(r_att_d) + 1e-10) + max_att_diff = np.max(rel_diff_att) + assert max_att_diff < att_tol, ( + f"ATT(d) max relative diff: {max_att_diff:.4f}\n" + f" R: {r_att_d[:5]}...\n" + f" Py: {py_att_d[:5]}..." + ) + + r_acrt_d = np.array(r_out["acrt_d"]) + py_acrt_d = results.dose_response_acrt.effects + rel_diff_acrt = np.abs(py_acrt_d - r_acrt_d) / (np.abs(r_acrt_d) + 1e-10) + max_acrt_diff = np.max(rel_diff_acrt) + assert max_acrt_diff < acrt_tol, ( + f"ACRT(d) max relative diff: {max_acrt_diff:.4f}\n" + f" R: {r_acrt_d[:5]}...\n" + f" Py: {py_acrt_d[:5]}..." + ) + + return results, r_out + + def test_benchmark_1_basic_cubic(self): + """2 periods, 1 cohort, degree=3, no knots, never_treated.""" + data = generate_continuous_did_data( + n_units=300, n_periods=2, cohort_periods=[2], + seed=100, noise_sd=0.5, + ) + self._compare_with_r(data, degree=3, num_knots=0) + + def test_benchmark_2_linear(self): + """2 periods, 1 cohort, degree=1 (linear), never_treated.""" + data = generate_continuous_did_data( + n_units=300, n_periods=2, cohort_periods=[2], + seed=101, noise_sd=0.5, + ) + self._compare_with_r(data, degree=1, num_knots=0) + + def test_benchmark_3_interior_knots(self): + """2 periods, 1 cohort, degree=3, 2 interior knots.""" + data = generate_continuous_did_data( + n_units=300, n_periods=2, cohort_periods=[2], + seed=102, noise_sd=0.5, + ) + self._compare_with_r(data, degree=3, num_knots=2) + + def test_benchmark_4_staggered_dose(self): + """4 periods, 3 cohorts, degree=3, dose aggregation. + + Uses R's simulate_contdid_data() to generate data compatible with + contdid's internal aggregation. Compares overall_att and overall_acrt + via pte_default (with consistent control_group). + """ + r_code = """ + library(contdid) + library(ptetools) + library(jsonlite) + + set.seed(42) + df <- simulate_contdid_data( + n = 200, num_time_periods = 4, num_groups = 4, + dose_linear_effect = 2, dose_quadratic_effect = 0.5 + ) + + # Overall ACRT via cont_did (dose aggregation) + res_slope <- cont_did( + yname = "Y", tname = "time_period", idname = "id", + gname = "G", dname = "D", data = df, + target_parameter = "slope", aggregation = "dose", + treatment_type = "continuous", control_group = "nevertreated", + degree = 3, num_knots = 0, bstrap = FALSE, print_details = FALSE + ) + + # Overall ATT via pte_default (with matching control_group) + att_res <- suppressWarnings(pte_default( + yname = "Y", gname = "G", tname = "time_period", + idname = "id", data = df, d_outcome = TRUE, + anticipation = 0, base_period = "varying", + control_group = "nevertreated", + biters = 100, alp = 0.05 + )) + + write.csv(df, "/tmp/r_bench4.csv", row.names = FALSE) + out <- list( + overall_att = att_res$overall_att$overall.att, + overall_acrt = res_slope$overall_acrt, + dvals = as.numeric(res_slope$dose) + ) + cat(toJSON(out, auto_unbox = TRUE, digits = 10)) + """ + result = subprocess.run( + ["Rscript", "-e", r_code], + capture_output=True, text=True, timeout=120, + ) + if result.returncode != 0: + pytest.skip(f"R contdid failed: {result.stderr[:500]}") + r_out = json.loads(result.stdout) + + data = pd.read_csv("/tmp/r_bench4.csv") + data = data.rename(columns={ + "id": "unit", "time_period": "period", + "Y": "outcome", "G": "first_treat", "D": "dose", + }) + dvals = np.array(r_out["dvals"]) + est = ContinuousDiD( + degree=3, num_knots=0, dvals=dvals, + control_group="never_treated", + ) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="dose", + ) + + # Overall ATT + att_diff = abs(results.overall_att - r_out["overall_att"]) / (abs(r_out["overall_att"]) + 1e-10) + assert att_diff < 0.01, ( + f"Overall ATT diff: {att_diff:.4f} " + f"(R={r_out['overall_att']:.6f}, Py={results.overall_att:.6f})" + ) + + # Overall ACRT + acrt_diff = abs(results.overall_acrt - r_out["overall_acrt"]) / (abs(r_out["overall_acrt"]) + 1e-10) + assert acrt_diff < 0.01, ( + f"Overall ACRT diff: {acrt_diff:.4f} " + f"(R={r_out['overall_acrt']:.6f}, Py={results.overall_acrt:.6f})" + ) + + def test_benchmark_5_not_yet_treated(self): + """4 periods, 3 cohorts, not-yet-treated control.""" + r_code = """ + library(contdid) + library(ptetools) + library(jsonlite) + + set.seed(123) + df <- simulate_contdid_data( + n = 200, num_time_periods = 4, num_groups = 4, + dose_linear_effect = 1.5, dose_quadratic_effect = 0 + ) + + res_slope <- cont_did( + yname = "Y", tname = "time_period", idname = "id", + gname = "G", dname = "D", data = df, + target_parameter = "slope", aggregation = "dose", + treatment_type = "continuous", control_group = "notyettreated", + degree = 3, num_knots = 0, bstrap = FALSE, print_details = FALSE + ) + + att_res <- suppressWarnings(pte_default( + yname = "Y", gname = "G", tname = "time_period", + idname = "id", data = df, d_outcome = TRUE, + anticipation = 0, base_period = "varying", + control_group = "notyettreated", + biters = 100, alp = 0.05 + )) + + write.csv(df, "/tmp/r_bench5.csv", row.names = FALSE) + out <- list( + overall_att = att_res$overall_att$overall.att, + overall_acrt = res_slope$overall_acrt, + dvals = as.numeric(res_slope$dose) + ) + cat(toJSON(out, auto_unbox = TRUE, digits = 10)) + """ + result = subprocess.run( + ["Rscript", "-e", r_code], + capture_output=True, text=True, timeout=120, + ) + if result.returncode != 0: + pytest.skip(f"R contdid failed: {result.stderr[:500]}") + r_out = json.loads(result.stdout) + + data = pd.read_csv("/tmp/r_bench5.csv") + data = data.rename(columns={ + "id": "unit", "time_period": "period", + "Y": "outcome", "G": "first_treat", "D": "dose", + }) + dvals = np.array(r_out["dvals"]) + est = ContinuousDiD( + degree=3, num_knots=0, dvals=dvals, + control_group="not_yet_treated", + ) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="dose", + ) + + att_diff = abs(results.overall_att - r_out["overall_att"]) / (abs(r_out["overall_att"]) + 1e-10) + assert att_diff < 0.01, ( + f"Overall ATT diff: {att_diff:.4f} " + f"(R={r_out['overall_att']:.6f}, Py={results.overall_att:.6f})" + ) + + acrt_diff = abs(results.overall_acrt - r_out["overall_acrt"]) / (abs(r_out["overall_acrt"]) + 1e-10) + assert acrt_diff < 0.01, ( + f"Overall ACRT diff: {acrt_diff:.4f} " + f"(R={r_out['overall_acrt']:.6f}, Py={results.overall_acrt:.6f})" + ) + + def test_benchmark_6_event_study(self): + """4 periods, 3 cohorts, event study aggregation (binarized ATT). + + R's event study uses ptetools::did_attgt (standard binary DiD) for + per-cell estimation, then aggregates by relative period. We compare + overall ATT (binarized) via pte_default with matching control_group. + """ + r_code = """ + library(contdid) + library(ptetools) + library(jsonlite) + + set.seed(99) + df <- simulate_contdid_data( + n = 200, num_time_periods = 4, num_groups = 4, + dose_linear_effect = 2, dose_quadratic_effect = 0 + ) + + # Overall ATT via pte_default (matching control_group) + att_res <- suppressWarnings(pte_default( + yname = "Y", gname = "G", tname = "time_period", + idname = "id", data = df, d_outcome = TRUE, + anticipation = 0, base_period = "varying", + control_group = "nevertreated", + biters = 100, alp = 0.05 + )) + + write.csv(df, "/tmp/r_bench6.csv", row.names = FALSE) + out <- list( + overall_att = att_res$overall_att$overall.att + ) + cat(toJSON(out, auto_unbox = TRUE, digits = 10)) + """ + result = subprocess.run( + ["Rscript", "-e", r_code], + capture_output=True, text=True, timeout=120, + ) + if result.returncode != 0: + pytest.skip(f"R contdid failed: {result.stderr[:500]}") + r_out = json.loads(result.stdout) + + data = pd.read_csv("/tmp/r_bench6.csv") + data = data.rename(columns={ + "id": "unit", "time_period": "period", + "Y": "outcome", "G": "first_treat", "D": "dose", + }) + est = ContinuousDiD( + degree=3, num_knots=0, + control_group="never_treated", + ) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="eventstudy", + ) + + # Compare overall ATT (binarized) + att_diff = abs(results.overall_att - r_out["overall_att"]) / (abs(r_out["overall_att"]) + 1e-10) + assert att_diff < 0.01, ( + f"Overall ATT diff: {att_diff:.4f} " + f"(R={r_out['overall_att']:.6f}, Py={results.overall_att:.6f})" + ) From b9e97f057b84657c0ba9ae6fbdfb87967f1a3eec Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 21 Feb 2026 17:18:06 -0500 Subject: [PATCH 2/9] Fix PR #177 review issues: control group bug, safe_inference, dose validation - Fix not_yet_treated control group to exclude cohort g from its own control set (matches staggered.py behavior) - Replace inline t_stat/p_value computation in DoseResponseCurve.to_dataframe() with safe_inference() loop per project convention - Add validation rejecting negative doses among treated units - Fix test_inf_first_treat_normalization CI failure (cast to float before inf) - Add test for not_yet_treated control group correctness and negative dose Co-Authored-By: Claude Opus 4.6 --- diff_diff/continuous_did.py | 11 +++- diff_diff/continuous_did_results.py | 16 +++--- tests/test_continuous_did.py | 78 +++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 10 deletions(-) diff --git a/diff_diff/continuous_did.py b/diff_diff/continuous_did.py index f8f9851..7ad00a7 100644 --- a/diff_diff/continuous_did.py +++ b/diff_diff/continuous_did.py @@ -197,6 +197,15 @@ def fit( ) df = df[~df[unit].isin(drop_units)] + # Validate no negative doses among treated units + treated_doses = df.loc[df[first_treat] > 0, dose] + if (treated_doses < 0).any(): + n_neg = int((treated_doses < 0).sum()) + raise ValueError( + f"Found {n_neg} treated unit(s) with negative dose. " + f"Dose must be strictly positive for treated units (D > 0)." + ) + # Force dose=0 for never-treated units with nonzero dose never_treated_mask = df[first_treat] == 0 if (df.loc[never_treated_mask, dose] != 0).any(): @@ -538,7 +547,7 @@ def _compute_dose_response_gt( control_mask = never_treated_mask else: # Not-yet-treated: never-treated + first_treat > t - control_mask = never_treated_mask | (unit_cohorts > t) + control_mask = never_treated_mask | ((unit_cohorts > t) & (unit_cohorts != g)) n_control = int(np.sum(control_mask)) if n_control == 0: warnings.warn( diff --git a/diff_diff/continuous_did_results.py b/diff_diff/continuous_did_results.py index f3aadd0..318d7ff 100644 --- a/diff_diff/continuous_did_results.py +++ b/diff_diff/continuous_did_results.py @@ -46,16 +46,14 @@ class DoseResponseCurve: def to_dataframe(self) -> pd.DataFrame: """Convert to DataFrame with dose, effect, se, CI, t_stat, p_value.""" - t_stat = np.where( - (np.isfinite(self.se) & (self.se > 0)), - self.effects / self.se, - np.nan, - ) - from scipy import stats + from diff_diff.utils import safe_inference - p_value = np.where( - np.isfinite(t_stat), 2 * (1 - stats.norm.cdf(np.abs(t_stat))), np.nan - ) + t_stat = np.full(len(self.effects), np.nan) + p_value = np.full(len(self.effects), np.nan) + for i in range(len(self.effects)): + t_i, p_i, _ = safe_inference(self.effects[i], self.se[i]) + t_stat[i] = t_i + p_value[i] = p_i return pd.DataFrame( { "dose": self.dose_grid, diff --git a/tests/test_continuous_did.py b/tests/test_continuous_did.py index 40cfc28..2cd5d75 100644 --- a/tests/test_continuous_did.py +++ b/tests/test_continuous_did.py @@ -488,6 +488,7 @@ def test_few_treated_units(self): def test_inf_first_treat_normalization(self): """first_treat=inf should be treated as never-treated.""" data = generate_continuous_did_data(n_units=50, n_periods=3, seed=42) + data["first_treat"] = data["first_treat"].astype(float) data.loc[data["first_treat"] == 0, "first_treat"] = np.inf est = ContinuousDiD() results = est.fit( @@ -504,3 +505,80 @@ def test_custom_dvals(self): ) np.testing.assert_array_equal(results.dose_grid, custom_grid) assert len(results.dose_response_att.effects) == 3 + + def test_negative_dose_raises(self): + """Negative doses among treated units should raise ValueError.""" + data = generate_continuous_did_data(n_units=50, n_periods=3, seed=42) + # Set one treated unit's dose to negative + treated_units = data.loc[data["first_treat"] > 0, "unit"].unique() + data.loc[data["unit"] == treated_units[0], "dose"] = -1.0 + est = ContinuousDiD() + with pytest.raises(ValueError, match="negative dose"): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + def test_not_yet_treated_excludes_own_cohort(self): + """not_yet_treated control group must not include the treated cohort itself. + + Construct a panel where contamination from including cohort g=2 in its own + control set would produce a biased pre-treatment effect. With the fix, + the pre-treatment ATT(g=2,t=1) should be near zero. + """ + rng = np.random.RandomState(99) + n_per_group = 20 + periods = [1, 2, 3, 4] + + rows = [] + # Group 1: never-treated (first_treat=0, dose=0) + for i in range(n_per_group): + uid = i + for t in periods: + rows.append({ + "unit": uid, "period": t, "first_treat": 0, "dose": 0.0, + "outcome": rng.normal(0, 0.5), + }) + # Group 2: treated at period 2 (g=2), moderate dose + for i in range(n_per_group): + uid = n_per_group + i + dose_i = rng.uniform(1, 3) + for t in periods: + y = rng.normal(0, 0.5) + if t >= 2: + y += 5.0 * dose_i # strong treatment effect + rows.append({ + "unit": uid, "period": t, "first_treat": 2, "dose": dose_i, + "outcome": y, + }) + # Group 3: treated at period 3 (g=3), high dose + for i in range(n_per_group): + uid = 2 * n_per_group + i + dose_i = rng.uniform(1, 3) + for t in periods: + y = rng.normal(0, 0.5) + if t >= 3: + y += 5.0 * dose_i + rows.append({ + "unit": uid, "period": t, "first_treat": 3, "dose": dose_i, + "outcome": y, + }) + + data = pd.DataFrame(rows) + est = ContinuousDiD( + control_group="not_yet_treated", degree=1, num_knots=0, n_bootstrap=0, + ) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + ) + + # Pre-treatment cells for g=2 should be near zero (t=1 is pre-treatment) + # If cohort g=2 were included in its own control set, the pre-treatment + # difference would be contaminated by the cohort's own outcomes + pre_treatment_effects = { + (g, t): v for (g, t), v in results.group_time_effects.items() + if t < g + } + for (g, t), cell in pre_treatment_effects.items(): + att_glob = cell.get("att_glob", 0) + assert abs(att_glob) < 2.0, ( + f"Pre-treatment ATT(g={g},t={t}) = {att_glob:.4f} is too large; " + f"cohort may be contaminating its own control group" + ) From 8c3980b413a2aa20ecf5747d17bcb3ee091e4bcd Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 21 Feb 2026 18:17:42 -0500 Subject: [PATCH 3/9] Fix analytical SE scaling, add empty post_gt guard, and validate params - Fix analytical SE: use sqrt(sum(IF^2)) instead of sqrt(mean(IF^2)) to match CallawaySantAnna's influence function convention (P0) - Add discrete dose detection warning for integer-valued treatments (P1) - Guard empty post-treatment cells: warn and return NaN instead of 0.0 (P1) - Validate control_group and base_period params in __init__ and set_params (P2) - Add 7 new tests: SE parity, discrete dose, anticipation event study, empty post_gt, and parameter validation Co-Authored-By: Claude Opus 4.6 --- diff_diff/continuous_did.py | 239 +++++++++++++++++++++-------------- tests/test_continuous_did.py | 106 ++++++++++++++++ 2 files changed, 250 insertions(+), 95 deletions(-) diff --git a/diff_diff/continuous_did.py b/diff_diff/continuous_did.py index 7ad00a7..809f53b 100644 --- a/diff_diff/continuous_did.py +++ b/diff_diff/continuous_did.py @@ -78,6 +78,9 @@ class ContinuousDiD: >>> results.overall_att # doctest: +SKIP """ + _VALID_CONTROL_GROUPS = {"never_treated", "not_yet_treated"} + _VALID_BASE_PERIODS = {"varying", "universal"} + def __init__( self, degree: int = 3, @@ -103,6 +106,20 @@ def __init__( self.bootstrap_weights = bootstrap_weights self.seed = seed self.rank_deficient_action = rank_deficient_action + self._validate_constrained_params() + + def _validate_constrained_params(self) -> None: + """Validate control_group and base_period values.""" + if self.control_group not in self._VALID_CONTROL_GROUPS: + raise ValueError( + f"Invalid control_group: '{self.control_group}'. " + f"Must be one of {self._VALID_CONTROL_GROUPS}." + ) + if self.base_period not in self._VALID_BASE_PERIODS: + raise ValueError( + f"Invalid base_period: '{self.base_period}'. " + f"Must be one of {self._VALID_BASE_PERIODS}." + ) def get_params(self) -> Dict[str, Any]: """Return estimator parameters as a dictionary.""" @@ -126,6 +143,7 @@ def set_params(self, **params) -> "ContinuousDiD": if not hasattr(self, key): raise ValueError(f"Invalid parameter: {key}") setattr(self, key, value) + self._validate_constrained_params() return self # ------------------------------------------------------------------ @@ -206,6 +224,21 @@ def fit( f"Dose must be strictly positive for treated units (D > 0)." ) + # Detect discrete (integer-valued) dose among treated units + unit_doses = df.loc[df[first_treat] > 0].groupby(unit)[dose].first() + unique_pos_doses = unit_doses[unit_doses > 0].unique() + is_integer = len(unique_pos_doses) > 0 and np.allclose( + unique_pos_doses, np.round(unique_pos_doses) + ) + if is_integer: + warnings.warn( + f"Dose appears discrete ({len(unique_pos_doses)} unique integer values). " + "B-spline smoothing may be inappropriate for discrete treatments. " + "Consider a saturated regression approach (not yet implemented).", + UserWarning, + stacklevel=2, + ) + # Force dose=0 for never-treated units with nonzero dose never_treated_mask = df[first_treat] == 0 if (df.loc[never_treated_mask, dose] != 0).any(): @@ -273,46 +306,10 @@ def fit( if t >= g - self.anticipation } - # Compute cell weights: group-proportional (matching R's contdid convention). - # Each group g gets weight proportional to its number of treated units. - # Within each group, weight is divided equally among post-treatment cells. - group_n_treated = {} - group_n_post_cells = {} - for (g, t), r in post_gt.items(): - if g not in group_n_treated: - group_n_treated[g] = float(r["n_treated"]) - group_n_post_cells[g] = 0 - group_n_post_cells[g] += 1 - - total_treated = sum(group_n_treated.values()) - cell_weights = {} - if total_treated > 0: - for (g, t), r in post_gt.items(): - pg = group_n_treated[g] / total_treated - cell_weights[(g, t)] = pg / group_n_post_cells[g] - # Dose-response aggregation n_grid = len(dvals) - agg_att_d = np.zeros(n_grid) - agg_acrt_d = np.zeros(n_grid) - overall_att = 0.0 - overall_acrt = 0.0 - - for gt, w in cell_weights.items(): - r = post_gt[gt] - agg_att_d += w * r["att_d"] - agg_acrt_d += w * r["acrt_d"] - overall_att += w * r["att_glob"] - overall_acrt += w * r["acrt_glob"] - - # Event study aggregation (binarized) - event_study_effects = None - if aggregate == "eventstudy": - event_study_effects = self._aggregate_event_study( - gt_results, treatment_groups - ) - # 5. Bootstrap + # NaN-initialized SE/CI fields (used when post_gt is empty or as defaults) att_d_se = np.full(n_grid, np.nan) att_d_ci_lower = np.full(n_grid, np.nan) att_d_ci_upper = np.full(n_grid, np.nan) @@ -328,65 +325,116 @@ def fit( overall_acrt_p = np.nan overall_acrt_ci = (np.nan, np.nan) - if self.n_bootstrap > 0: - boot_result = self._run_bootstrap( - precomp, gt_results, gt_bootstrap_info, post_gt, cell_weights, - knots, degree, dvals, overall_att, overall_acrt, - agg_att_d, agg_acrt_d, - event_study_effects, - ) - att_d_se = boot_result["att_d_se"] - att_d_ci_lower = boot_result["att_d_ci_lower"] - att_d_ci_upper = boot_result["att_d_ci_upper"] - acrt_d_se = boot_result["acrt_d_se"] - acrt_d_ci_lower = boot_result["acrt_d_ci_lower"] - acrt_d_ci_upper = boot_result["acrt_d_ci_upper"] - overall_att_se = boot_result["overall_att_se"] - overall_att_t, overall_att_p, overall_att_ci = safe_inference( - overall_att, overall_att_se, self.alpha - ) - overall_acrt_se = boot_result["overall_acrt_se"] - overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference( - overall_acrt, overall_acrt_se, self.alpha - ) - if event_study_effects is not None: - for e, info in event_study_effects.items(): - if e in boot_result.get("es_se", {}): - info["se"] = boot_result["es_se"][e] - info["t_stat"], info["p_value"], info["conf_int"] = ( - safe_inference(info["effect"], info["se"], self.alpha) - ) - else: - # Analytical SEs via influence functions - analytic = self._compute_analytical_se( - precomp, gt_results, gt_bootstrap_info, post_gt, cell_weights, - knots, degree, dvals, agg_att_d, agg_acrt_d, + # Event study aggregation (binarized) — runs on ALL (g,t) cells + event_study_effects = None + if aggregate == "eventstudy": + event_study_effects = self._aggregate_event_study( + gt_results, treatment_groups ) - att_d_se = analytic["att_d_se"] - acrt_d_se = analytic["acrt_d_se"] - overall_att_se = analytic["overall_att_se"] - overall_acrt_se = analytic["overall_acrt_se"] - overall_att_t, overall_att_p, overall_att_ci = safe_inference( - overall_att, overall_att_se, self.alpha - ) - overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference( - overall_acrt, overall_acrt_se, self.alpha + if len(post_gt) == 0: + warnings.warn( + "No post-treatment (g,t) cells available for aggregation. " + "This can occur when all treatments start after the last observed " + "period or all cells were skipped due to insufficient data.", + UserWarning, + stacklevel=2, ) - - # Per-grid-point inference for dose-response - for idx in range(n_grid): - _, _, ci = safe_inference( - agg_att_d[idx], att_d_se[idx], self.alpha + overall_att = np.nan + overall_acrt = np.nan + agg_att_d = np.full(n_grid, np.nan) + agg_acrt_d = np.full(n_grid, np.nan) + else: + # Compute cell weights: group-proportional (matching R's contdid convention). + # Each group g gets weight proportional to its number of treated units. + # Within each group, weight is divided equally among post-treatment cells. + group_n_treated = {} + group_n_post_cells = {} + for (g, t), r in post_gt.items(): + if g not in group_n_treated: + group_n_treated[g] = float(r["n_treated"]) + group_n_post_cells[g] = 0 + group_n_post_cells[g] += 1 + + total_treated = sum(group_n_treated.values()) + cell_weights = {} + if total_treated > 0: + for (g, t), r in post_gt.items(): + pg = group_n_treated[g] / total_treated + cell_weights[(g, t)] = pg / group_n_post_cells[g] + + agg_att_d = np.zeros(n_grid) + agg_acrt_d = np.zeros(n_grid) + overall_att = 0.0 + overall_acrt = 0.0 + + for gt, w in cell_weights.items(): + r = post_gt[gt] + agg_att_d += w * r["att_d"] + agg_acrt_d += w * r["acrt_d"] + overall_att += w * r["att_glob"] + overall_acrt += w * r["acrt_glob"] + + # 5. Bootstrap / Analytical SE + if self.n_bootstrap > 0: + boot_result = self._run_bootstrap( + precomp, gt_results, gt_bootstrap_info, post_gt, cell_weights, + knots, degree, dvals, overall_att, overall_acrt, + agg_att_d, agg_acrt_d, + event_study_effects, + ) + att_d_se = boot_result["att_d_se"] + att_d_ci_lower = boot_result["att_d_ci_lower"] + att_d_ci_upper = boot_result["att_d_ci_upper"] + acrt_d_se = boot_result["acrt_d_se"] + acrt_d_ci_lower = boot_result["acrt_d_ci_lower"] + acrt_d_ci_upper = boot_result["acrt_d_ci_upper"] + overall_att_se = boot_result["overall_att_se"] + overall_att_t, overall_att_p, overall_att_ci = safe_inference( + overall_att, overall_att_se, self.alpha + ) + overall_acrt_se = boot_result["overall_acrt_se"] + overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference( + overall_acrt, overall_acrt_se, self.alpha ) - att_d_ci_lower[idx] = ci[0] - att_d_ci_upper[idx] = ci[1] + if event_study_effects is not None: + for e, info in event_study_effects.items(): + if e in boot_result.get("es_se", {}): + info["se"] = boot_result["es_se"][e] + info["t_stat"], info["p_value"], info["conf_int"] = ( + safe_inference(info["effect"], info["se"], self.alpha) + ) + else: + # Analytical SEs via influence functions + analytic = self._compute_analytical_se( + precomp, gt_results, gt_bootstrap_info, post_gt, cell_weights, + knots, degree, dvals, agg_att_d, agg_acrt_d, + ) + att_d_se = analytic["att_d_se"] + acrt_d_se = analytic["acrt_d_se"] + overall_att_se = analytic["overall_att_se"] + overall_acrt_se = analytic["overall_acrt_se"] - _, _, ci = safe_inference( - agg_acrt_d[idx], acrt_d_se[idx], self.alpha + overall_att_t, overall_att_p, overall_att_ci = safe_inference( + overall_att, overall_att_se, self.alpha + ) + overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference( + overall_acrt, overall_acrt_se, self.alpha ) - acrt_d_ci_lower[idx] = ci[0] - acrt_d_ci_upper[idx] = ci[1] + + # Per-grid-point inference for dose-response + for idx in range(n_grid): + _, _, ci = safe_inference( + agg_att_d[idx], att_d_se[idx], self.alpha + ) + att_d_ci_lower[idx] = ci[0] + att_d_ci_upper[idx] = ci[1] + + _, _, ci = safe_inference( + agg_acrt_d[idx], acrt_d_se[idx], self.alpha + ) + acrt_d_ci_lower[idx] = ci[0] + acrt_d_ci_upper[idx] = ci[1] # 6. Assemble results dose_response_att = DoseResponseCurve( @@ -780,12 +828,13 @@ def _compute_analytical_se( beta_pert = -bread @ psi_bar * ee_control[k] / n_c if_acrt_glob[idx] += w * float(dpsi_bar @ beta_pert) - # SE = sqrt(mean(IF_i^2)) - overall_att_se = float(np.sqrt(np.mean(if_att_glob**2))) - overall_acrt_se = float(np.sqrt(np.mean(if_acrt_glob**2))) + # SE = sqrt(sum(IF_i^2)), matching CallawaySantAnna's convention + # (per-unit IFs already contain 1/n_t, 1/n_c scaling) + overall_att_se = float(np.sqrt(np.sum(if_att_glob**2))) + overall_acrt_se = float(np.sqrt(np.sum(if_acrt_glob**2))) - att_d_se = np.sqrt(np.mean(if_att_d**2, axis=0)) - acrt_d_se = np.sqrt(np.mean(if_acrt_d**2, axis=0)) + att_d_se = np.sqrt(np.sum(if_att_d**2, axis=0)) + acrt_d_se = np.sqrt(np.sum(if_acrt_d**2, axis=0)) return { "overall_att_se": overall_att_se, diff --git a/tests/test_continuous_did.py b/tests/test_continuous_did.py index 2cd5d75..9da41ba 100644 --- a/tests/test_continuous_did.py +++ b/tests/test_continuous_did.py @@ -582,3 +582,109 @@ def test_not_yet_treated_excludes_own_cohort(self): f"Pre-treatment ATT(g={g},t={t}) = {att_glob:.4f} is too large; " f"cohort may be contaminating its own control group" ) + + +class TestAnalyticalSEParity: + """Test analytical SE vs bootstrap SE agreement.""" + + def test_analytical_se_matches_bootstrap(self, ci_params): + """Analytical SEs should be within ~50% of bootstrap SEs.""" + n_boot = ci_params.bootstrap(999, min_n=199) + data = generate_continuous_did_data( + n_units=200, n_periods=3, seed=42, noise_sd=1.0, + ) + est_boot = ContinuousDiD(n_bootstrap=n_boot, seed=42) + results_boot = est_boot.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + est_analytic = ContinuousDiD(n_bootstrap=0) + results_analytic = est_analytic.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + threshold = 0.50 if n_boot < 100 else 0.30 + ratio = results_analytic.overall_att_se / results_boot.overall_att_se + assert (1 - threshold) < ratio < (1 + threshold) / (1 - threshold), ( + f"Analytical/bootstrap SE ratio = {ratio:.3f}, " + f"expected within [{1 - threshold:.2f}, {(1 + threshold) / (1 - threshold):.2f}]" + ) + + +class TestDiscreteDoseWarning: + """Test discrete dose detection warning.""" + + def test_discrete_dose_warning(self): + """Integer-valued doses should trigger a discrete dose warning.""" + data = generate_continuous_did_data( + n_units=100, n_periods=3, seed=42, + ) + data["dose"] = data["dose"].round().astype(float) + data.loc[data["first_treat"] == 0, "dose"] = 0.0 + est = ContinuousDiD() + with pytest.warns(UserWarning, match="[Dd]iscrete"): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + +class TestAnticipationEventStudy: + """Test event study with anticipation > 0.""" + + def test_anticipation_event_study(self): + """Event study with anticipation > 0 should include anticipation periods.""" + data = generate_continuous_did_data( + n_units=100, n_periods=5, cohort_periods=[3], seed=42, + ) + est = ContinuousDiD(anticipation=1, n_bootstrap=0) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="eventstudy", + ) + assert results.event_study_effects is not None + # With anticipation=1 and g=3, post-treatment starts at t=2 (g - anticipation). + # Relative times e = t - g, so t=2 → e=-1 (the anticipation period). + rel_times = sorted(results.event_study_effects.keys()) + assert -1 in rel_times, ( + f"Anticipation period e=-1 missing from event study; got {rel_times}" + ) + assert np.isfinite(results.event_study_effects[-1]["effect"]) + + +class TestEmptyPostTreatment: + """Test guard for empty post-treatment cells.""" + + def test_no_post_treatment_cells_warns(self): + """When no post-treatment cells exist, should warn and return NaN.""" + data = generate_continuous_did_data( + n_units=50, n_periods=3, cohort_periods=[5], seed=42, + ) + est = ContinuousDiD() + with pytest.warns(UserWarning, match="[Nn]o post-treatment"): + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert np.isnan(results.overall_att) + assert np.isnan(results.overall_acrt) + + +class TestParameterValidation: + """Test parameter validation for constrained values.""" + + def test_invalid_control_group_raises(self): + """Invalid control_group should raise ValueError.""" + with pytest.raises(ValueError, match="control_group"): + ContinuousDiD(control_group="invalid") + + def test_invalid_base_period_raises(self): + """Invalid base_period should raise ValueError.""" + with pytest.raises(ValueError, match="base_period"): + ContinuousDiD(base_period="invalid") + + def test_set_params_invalid_control_group_raises(self): + """set_params with invalid control_group should raise ValueError.""" + est = ContinuousDiD() + with pytest.raises(ValueError, match="control_group"): + est.set_params(control_group="NEVER_TREATED") + + def test_set_params_invalid_base_period_raises(self): + """set_params with invalid base_period should raise ValueError.""" + est = ContinuousDiD() + with pytest.raises(ValueError, match="base_period"): + est.set_params(base_period="VARYING") From 40d22d4530fb45aed6315b92a27d7a33f7d27e93 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 21 Feb 2026 18:50:53 -0500 Subject: [PATCH 4/9] Fix bootstrap percentile inference, add P(D=0) warning, and analytical event-study SEs - Use percentile CI/p-value from bootstrap (not normal approx) for overall ATT/ACRT and event-study effects, matching CallawaySantAnna convention - Add P(D=0)>0 warning when control_group='not_yet_treated' has no never-treated units (Remark 3.1 in Callaway et al.) - Compute IF-based analytical SEs for event-study bins when n_bootstrap=0 (previously yielded NaN) - Add tests for all three fixes Co-Authored-By: Claude Opus 4.6 --- diff_diff/continuous_did.py | 104 +++++++++++++++++++++++++++++++---- tests/test_continuous_did.py | 77 ++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 10 deletions(-) diff --git a/diff_diff/continuous_did.py b/diff_diff/continuous_did.py index 809f53b..b3c2600 100644 --- a/diff_diff/continuous_did.py +++ b/diff_diff/continuous_did.py @@ -266,6 +266,16 @@ def fit( "or add never-treated units." ) + if self.control_group == "not_yet_treated" and n_control == 0: + warnings.warn( + "No never-treated (D=0) units found. With control_group='not_yet_treated', " + "not-yet-treated units (D>0) serve as controls, but dose-response curve " + "identification requires P(D=0) > 0 (see Remark 3.1 in Callaway et al.). " + "Estimates may be unreliable.", + UserWarning, + stacklevel=2, + ) + # 2. Precompute structures precomp = self._precompute_structures( df, outcome, unit, time, first_treat, dose, time_periods @@ -390,20 +400,26 @@ def fit( acrt_d_ci_lower = boot_result["acrt_d_ci_lower"] acrt_d_ci_upper = boot_result["acrt_d_ci_upper"] overall_att_se = boot_result["overall_att_se"] - overall_att_t, overall_att_p, overall_att_ci = safe_inference( + overall_att_t = safe_inference( overall_att, overall_att_se, self.alpha - ) + )[0] + overall_att_p = boot_result["overall_att_p"] + overall_att_ci = boot_result["overall_att_ci"] overall_acrt_se = boot_result["overall_acrt_se"] - overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference( + overall_acrt_t = safe_inference( overall_acrt, overall_acrt_se, self.alpha - ) + )[0] + overall_acrt_p = boot_result["overall_acrt_p"] + overall_acrt_ci = boot_result["overall_acrt_ci"] if event_study_effects is not None: for e, info in event_study_effects.items(): if e in boot_result.get("es_se", {}): info["se"] = boot_result["es_se"][e] - info["t_stat"], info["p_value"], info["conf_int"] = ( - safe_inference(info["effect"], info["se"], self.alpha) - ) + info["t_stat"] = safe_inference( + info["effect"], info["se"], self.alpha + )[0] + info["p_value"] = boot_result["es_p"][e] + info["conf_int"] = boot_result["es_ci"][e] else: # Analytical SEs via influence functions analytic = self._compute_analytical_se( @@ -436,6 +452,64 @@ def fit( acrt_d_ci_lower[idx] = ci[0] acrt_d_ci_upper[idx] = ci[1] + # Event study analytical SEs + if event_study_effects is not None: + n_units = precomp["n_units"] + for e_val, info_e in event_study_effects.items(): + # Collect (g,t) cells for this event-time bin + e_gts = [gt for gt in gt_results if gt[1] - gt[0] == e_val] + if not e_gts: + continue + # n_treated-proportional weights within this bin + ns = np.array( + [gt_results[gt]["n_treated"] for gt in e_gts], + dtype=float, + ) + total_n = ns.sum() + if total_n == 0: + continue + ws = ns / total_n + + # Build per-unit IF for this event-time bin + if_es = np.zeros(n_units) + for idx_cell, gt in enumerate(e_gts): + b_info = gt_bootstrap_info.get(gt, {}) + if not b_info: + continue + w = ws[idx_cell] + treated_idx = b_info["treated_indices"] + control_idx = b_info["control_indices"] + n_t = b_info["n_treated"] + n_c = b_info["n_control"] + n_total_gt = n_t + n_c + p_1 = n_t / n_total_gt + p_0 = n_c / n_total_gt + att_glob_gt = b_info["att_glob"] + mu_0 = b_info["mu_0"] + delta_y_treated = b_info["delta_y_treated"] + ee_control = b_info["ee_control"] + + for k, uid in enumerate(treated_idx): + if_es[uid] += ( + w + * (delta_y_treated[k] - att_glob_gt - mu_0) + / p_1 + / n_total_gt + ) + for k, uid in enumerate(control_idx): + if_es[uid] -= ( + w * ee_control[k] / p_0 / n_total_gt + ) + + es_se = float(np.sqrt(np.sum(if_es**2))) + t_stat, p_val, ci_es = safe_inference( + info_e["effect"], es_se, self.alpha + ) + info_e["se"] = es_se + info_e["t_stat"] = t_stat + info_e["p_value"] = p_val + info_e["conf_int"] = ci_es + # 6. Assemble results dose_response_att = DoseResponseCurve( dose_grid=dvals, @@ -1010,25 +1084,35 @@ def _bootstrap_gt_cell(gt, info): result["acrt_d_ci_upper"] = acrt_d_ci_upper # Overall - se, _, _ = compute_effect_bootstrap_stats( + se, ci, p = compute_effect_bootstrap_stats( original_att, boot_att_glob, alpha=self.alpha, context="overall ATT_glob", ) result["overall_att_se"] = se + result["overall_att_ci"] = ci + result["overall_att_p"] = p - se, _, _ = compute_effect_bootstrap_stats( + se, ci, p = compute_effect_bootstrap_stats( original_acrt, boot_acrt_glob, alpha=self.alpha, context="overall ACRT_glob", ) result["overall_acrt_se"] = se + result["overall_acrt_ci"] = ci + result["overall_acrt_p"] = p # Event study SEs if event_study_effects is not None: es_se = {} + es_ci = {} + es_p = {} for e in es_keys: - se_e, _, _ = compute_effect_bootstrap_stats( + se_e, ci_e, p_e = compute_effect_bootstrap_stats( event_study_effects[e]["effect"], boot_es[e], alpha=self.alpha, context=f"event study e={e}", ) es_se[e] = se_e + es_ci[e] = ci_e + es_p[e] = p_e result["es_se"] = es_se + result["es_ci"] = es_ci + result["es_p"] = es_p return result diff --git a/tests/test_continuous_did.py b/tests/test_continuous_did.py index 9da41ba..e5b4b6a 100644 --- a/tests/test_continuous_did.py +++ b/tests/test_continuous_did.py @@ -688,3 +688,80 @@ def test_set_params_invalid_base_period_raises(self): est = ContinuousDiD() with pytest.raises(ValueError, match="base_period"): est.set_params(base_period="VARYING") + + +class TestBootstrapPercentileInference: + """Test that bootstrap uses percentile CI/p-value, not normal approximation.""" + + def test_bootstrap_percentile_ci(self, ci_params): + """Bootstrap CIs should use percentile method (generally asymmetric).""" + n_boot = ci_params.bootstrap(499, min_n=199) + data = generate_continuous_did_data( + n_units=200, n_periods=3, seed=42, noise_sd=0.5, + ) + est = ContinuousDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + lo, hi = results.overall_att_conf_int + estimate = results.overall_att + # CI should contain estimate + assert lo <= estimate <= hi + # p-value should be finite and in [0, 1] + assert 0 <= results.overall_att_p_value <= 1 + # Percentile CIs are generally asymmetric around the estimate. + # With enough bootstrap reps, the upper and lower distances differ. + upper_dist = hi - estimate + lower_dist = estimate - lo + # Just verify both distances are positive (CI is non-degenerate) + assert upper_dist > 0 + assert lower_dist > 0 + + +class TestNotYetTreatedNoDZeroWarning: + """Test P(D=0)>0 warning for not_yet_treated with no never-treated units.""" + + def test_no_never_treated_warns(self): + """not_yet_treated with zero never-treated units should warn.""" + data = generate_continuous_did_data( + n_units=100, + n_periods=4, + cohort_periods=[2, 3], + never_treated_frac=0.0, + seed=42, + ) + est = ContinuousDiD(control_group="not_yet_treated", degree=1, num_knots=0) + with pytest.warns(UserWarning, match="No never-treated.*D=0"): + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + # Estimation should still complete (warn-and-continue) + assert isinstance(results, ContinuousDiDResults) + assert np.isfinite(results.overall_att) + + +class TestEventStudyAnalyticalSE: + """Test analytical SEs for event study aggregation (n_bootstrap=0).""" + + def test_event_study_analytical_se_finite(self): + """Event study with n_bootstrap=0 should produce finite SE/t/p for all bins.""" + data = generate_continuous_did_data( + n_units=200, n_periods=5, cohort_periods=[2, 4], + seed=42, noise_sd=0.5, + ) + est = ContinuousDiD(n_bootstrap=0) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="eventstudy", + ) + assert results.event_study_effects is not None + for e, info in results.event_study_effects.items(): + assert np.isfinite(info["se"]), f"SE is NaN for e={e}" + assert info["se"] > 0, f"SE is non-positive for e={e}" + assert np.isfinite(info["t_stat"]), f"t_stat is NaN for e={e}" + assert np.isfinite(info["p_value"]), f"p_value is NaN for e={e}" + assert 0 <= info["p_value"] <= 1, f"p_value out of range for e={e}" + lo, hi = info["conf_int"] + assert np.isfinite(lo) and np.isfinite(hi), ( + f"conf_int contains NaN for e={e}" + ) From 1e01fd45b73f49bd7b04939b293c7a6f86b4ae72 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 21 Feb 2026 19:20:48 -0500 Subject: [PATCH 5/9] Address round-4 review: harden validation and fix tempfile usage - Upgrade P(D=0)=0 warning to ValueError for not_yet_treated (P1) - Strengthen balanced-panel check to verify identical time sets (P1) - Add aggregate parameter validation at fit() entry (P3) - Replace hardcoded /tmp paths with tempfile in R benchmarks (P3) Co-Authored-By: Claude Opus 4.6 --- diff_diff/continuous_did.py | 27 +- tests/test_continuous_did.py | 42 ++- tests/test_methodology_continuous_did.py | 391 ++++++++++++----------- 3 files changed, 256 insertions(+), 204 deletions(-) diff --git a/diff_diff/continuous_did.py b/diff_diff/continuous_did.py index b3c2600..ec403e8 100644 --- a/diff_diff/continuous_did.py +++ b/diff_diff/continuous_did.py @@ -186,6 +186,13 @@ def fit( ContinuousDiDResults """ # 1. Validate & prepare + _VALID_AGGREGATES = (None, "dose", "eventstudy") + if aggregate not in _VALID_AGGREGATES: + raise ValueError( + f"Invalid aggregate: '{aggregate}'. " + f"Must be one of {_VALID_AGGREGATES}." + ) + df = data.copy() for col in [outcome, unit, time, first_treat, dose]: if col not in df.columns: @@ -245,10 +252,14 @@ def fit( df.loc[never_treated_mask, dose] = 0.0 # Verify balanced panel - obs_per_unit = df.groupby(unit)[time].nunique() - if obs_per_unit.nunique() > 1: + all_periods = set(df[time].unique()) + unit_periods = df.groupby(unit)[time].apply(set) + is_unbalanced = unit_periods.apply(lambda s: s != all_periods) + if is_unbalanced.any(): + n_bad = int(is_unbalanced.sum()) raise ValueError( - "Unbalanced panel detected. ContinuousDiD requires a balanced panel." + "Unbalanced panel detected. ContinuousDiD requires a balanced panel. " + f"{n_bad} unit(s) have missing periods." ) # Identify groups and time periods @@ -267,13 +278,11 @@ def fit( ) if self.control_group == "not_yet_treated" and n_control == 0: - warnings.warn( + raise ValueError( "No never-treated (D=0) units found. With control_group='not_yet_treated', " - "not-yet-treated units (D>0) serve as controls, but dose-response curve " - "identification requires P(D=0) > 0 (see Remark 3.1 in Callaway et al.). " - "Estimates may be unreliable.", - UserWarning, - stacklevel=2, + "dose-response curve identification requires P(D=0) > 0 " + "(Remark 3.1 in Callaway et al. is not yet implemented). " + "Add never-treated units or use a dataset with D=0 observations." ) # 2. Precompute structures diff --git a/tests/test_continuous_did.py b/tests/test_continuous_did.py index e5b4b6a..f5bad0b 100644 --- a/tests/test_continuous_did.py +++ b/tests/test_continuous_did.py @@ -214,6 +214,33 @@ def test_unbalanced_panel_error(self): with pytest.raises(ValueError, match="[Uu]nbalanced"): est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + def test_unbalanced_panel_same_count_different_periods(self): + """Units with same period count but different periods should be caught.""" + data = pd.DataFrame({ + "unit": [1, 1, 1, 2, 2, 2], + "period": [1, 2, 3, 1, 2, 4], # Same count (3) but unit 2 has {1,2,4} vs {1,2,3} + "outcome": [1.0, 2.0, 3.0, 1.0, 2.0, 3.0], + "first_treat": [2, 2, 2, 0, 0, 0], + "dose": [1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + }) + est = ContinuousDiD() + with pytest.raises(ValueError, match="[Uu]nbalanced"): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + def test_invalid_aggregate_raises(self): + """Invalid aggregate value should raise ValueError.""" + data = pd.DataFrame({ + "unit": [1, 1, 2, 2], + "period": [1, 2, 1, 2], + "outcome": [1.0, 2.0, 1.0, 2.0], + "first_treat": [2, 2, 0, 0], + "dose": [1.0, 1.0, 0.0, 0.0], + }) + est = ContinuousDiD() + with pytest.raises(ValueError, match="Invalid aggregate"): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="event_study") + def test_no_never_treated_error(self): data = pd.DataFrame({ "unit": [1, 1, 2, 2], @@ -718,11 +745,11 @@ def test_bootstrap_percentile_ci(self, ci_params): assert lower_dist > 0 -class TestNotYetTreatedNoDZeroWarning: - """Test P(D=0)>0 warning for not_yet_treated with no never-treated units.""" +class TestNotYetTreatedNoDZeroError: + """Test P(D=0)>0 error for not_yet_treated with no never-treated units.""" - def test_no_never_treated_warns(self): - """not_yet_treated with zero never-treated units should warn.""" + def test_no_never_treated_raises(self): + """not_yet_treated with zero never-treated units should raise ValueError.""" data = generate_continuous_did_data( n_units=100, n_periods=4, @@ -731,13 +758,10 @@ def test_no_never_treated_warns(self): seed=42, ) est = ContinuousDiD(control_group="not_yet_treated", degree=1, num_knots=0) - with pytest.warns(UserWarning, match="No never-treated.*D=0"): - results = est.fit( + with pytest.raises(ValueError, match="D=0"): + est.fit( data, "outcome", "unit", "period", "first_treat", "dose" ) - # Estimation should still complete (warn-and-continue) - assert isinstance(results, ContinuousDiDResults) - assert np.isfinite(results.overall_att) class TestEventStudyAnalyticalSE: diff --git a/tests/test_methodology_continuous_did.py b/tests/test_methodology_continuous_did.py index 596ae09..2b9fab7 100644 --- a/tests/test_methodology_continuous_did.py +++ b/tests/test_methodology_continuous_did.py @@ -6,6 +6,7 @@ """ import json +import os import subprocess import tempfile @@ -445,151 +446,163 @@ def test_benchmark_4_staggered_dose(self): contdid's internal aggregation. Compares overall_att and overall_acrt via pte_default (with consistent control_group). """ - r_code = """ - library(contdid) - library(ptetools) - library(jsonlite) - - set.seed(42) - df <- simulate_contdid_data( - n = 200, num_time_periods = 4, num_groups = 4, - dose_linear_effect = 2, dose_quadratic_effect = 0.5 - ) - - # Overall ACRT via cont_did (dose aggregation) - res_slope <- cont_did( - yname = "Y", tname = "time_period", idname = "id", - gname = "G", dname = "D", data = df, - target_parameter = "slope", aggregation = "dose", - treatment_type = "continuous", control_group = "nevertreated", - degree = 3, num_knots = 0, bstrap = FALSE, print_details = FALSE - ) + tmp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False) + tmp_path = tmp.name + tmp.close() + try: + r_code = f""" + library(contdid) + library(ptetools) + library(jsonlite) + + set.seed(42) + df <- simulate_contdid_data( + n = 200, num_time_periods = 4, num_groups = 4, + dose_linear_effect = 2, dose_quadratic_effect = 0.5 + ) - # Overall ATT via pte_default (with matching control_group) - att_res <- suppressWarnings(pte_default( - yname = "Y", gname = "G", tname = "time_period", - idname = "id", data = df, d_outcome = TRUE, - anticipation = 0, base_period = "varying", - control_group = "nevertreated", - biters = 100, alp = 0.05 - )) + # Overall ACRT via cont_did (dose aggregation) + res_slope <- cont_did( + yname = "Y", tname = "time_period", idname = "id", + gname = "G", dname = "D", data = df, + target_parameter = "slope", aggregation = "dose", + treatment_type = "continuous", control_group = "nevertreated", + degree = 3, num_knots = 0, bstrap = FALSE, print_details = FALSE + ) - write.csv(df, "/tmp/r_bench4.csv", row.names = FALSE) - out <- list( - overall_att = att_res$overall_att$overall.att, - overall_acrt = res_slope$overall_acrt, - dvals = as.numeric(res_slope$dose) - ) - cat(toJSON(out, auto_unbox = TRUE, digits = 10)) - """ - result = subprocess.run( - ["Rscript", "-e", r_code], - capture_output=True, text=True, timeout=120, - ) - if result.returncode != 0: - pytest.skip(f"R contdid failed: {result.stderr[:500]}") - r_out = json.loads(result.stdout) - - data = pd.read_csv("/tmp/r_bench4.csv") - data = data.rename(columns={ - "id": "unit", "time_period": "period", - "Y": "outcome", "G": "first_treat", "D": "dose", - }) - dvals = np.array(r_out["dvals"]) - est = ContinuousDiD( - degree=3, num_knots=0, dvals=dvals, - control_group="never_treated", - ) - results = est.fit( - data, "outcome", "unit", "period", "first_treat", "dose", - aggregate="dose", - ) + # Overall ATT via pte_default (with matching control_group) + att_res <- suppressWarnings(pte_default( + yname = "Y", gname = "G", tname = "time_period", + idname = "id", data = df, d_outcome = TRUE, + anticipation = 0, base_period = "varying", + control_group = "nevertreated", + biters = 100, alp = 0.05 + )) + + write.csv(df, "{tmp_path}", row.names = FALSE) + out <- list( + overall_att = att_res$overall_att$overall.att, + overall_acrt = res_slope$overall_acrt, + dvals = as.numeric(res_slope$dose) + ) + cat(toJSON(out, auto_unbox = TRUE, digits = 10)) + """ + result = subprocess.run( + ["Rscript", "-e", r_code], + capture_output=True, text=True, timeout=120, + ) + if result.returncode != 0: + pytest.skip(f"R contdid failed: {result.stderr[:500]}") + r_out = json.loads(result.stdout) + + data = pd.read_csv(tmp_path) + data = data.rename(columns={ + "id": "unit", "time_period": "period", + "Y": "outcome", "G": "first_treat", "D": "dose", + }) + dvals = np.array(r_out["dvals"]) + est = ContinuousDiD( + degree=3, num_knots=0, dvals=dvals, + control_group="never_treated", + ) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="dose", + ) - # Overall ATT - att_diff = abs(results.overall_att - r_out["overall_att"]) / (abs(r_out["overall_att"]) + 1e-10) - assert att_diff < 0.01, ( - f"Overall ATT diff: {att_diff:.4f} " - f"(R={r_out['overall_att']:.6f}, Py={results.overall_att:.6f})" - ) + # Overall ATT + att_diff = abs(results.overall_att - r_out["overall_att"]) / (abs(r_out["overall_att"]) + 1e-10) + assert att_diff < 0.01, ( + f"Overall ATT diff: {att_diff:.4f} " + f"(R={r_out['overall_att']:.6f}, Py={results.overall_att:.6f})" + ) - # Overall ACRT - acrt_diff = abs(results.overall_acrt - r_out["overall_acrt"]) / (abs(r_out["overall_acrt"]) + 1e-10) - assert acrt_diff < 0.01, ( - f"Overall ACRT diff: {acrt_diff:.4f} " - f"(R={r_out['overall_acrt']:.6f}, Py={results.overall_acrt:.6f})" - ) + # Overall ACRT + acrt_diff = abs(results.overall_acrt - r_out["overall_acrt"]) / (abs(r_out["overall_acrt"]) + 1e-10) + assert acrt_diff < 0.01, ( + f"Overall ACRT diff: {acrt_diff:.4f} " + f"(R={r_out['overall_acrt']:.6f}, Py={results.overall_acrt:.6f})" + ) + finally: + os.unlink(tmp_path) def test_benchmark_5_not_yet_treated(self): """4 periods, 3 cohorts, not-yet-treated control.""" - r_code = """ - library(contdid) - library(ptetools) - library(jsonlite) - - set.seed(123) - df <- simulate_contdid_data( - n = 200, num_time_periods = 4, num_groups = 4, - dose_linear_effect = 1.5, dose_quadratic_effect = 0 - ) - - res_slope <- cont_did( - yname = "Y", tname = "time_period", idname = "id", - gname = "G", dname = "D", data = df, - target_parameter = "slope", aggregation = "dose", - treatment_type = "continuous", control_group = "notyettreated", - degree = 3, num_knots = 0, bstrap = FALSE, print_details = FALSE - ) + tmp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False) + tmp_path = tmp.name + tmp.close() + try: + r_code = f""" + library(contdid) + library(ptetools) + library(jsonlite) + + set.seed(123) + df <- simulate_contdid_data( + n = 200, num_time_periods = 4, num_groups = 4, + dose_linear_effect = 1.5, dose_quadratic_effect = 0 + ) - att_res <- suppressWarnings(pte_default( - yname = "Y", gname = "G", tname = "time_period", - idname = "id", data = df, d_outcome = TRUE, - anticipation = 0, base_period = "varying", - control_group = "notyettreated", - biters = 100, alp = 0.05 - )) + res_slope <- cont_did( + yname = "Y", tname = "time_period", idname = "id", + gname = "G", dname = "D", data = df, + target_parameter = "slope", aggregation = "dose", + treatment_type = "continuous", control_group = "notyettreated", + degree = 3, num_knots = 0, bstrap = FALSE, print_details = FALSE + ) - write.csv(df, "/tmp/r_bench5.csv", row.names = FALSE) - out <- list( - overall_att = att_res$overall_att$overall.att, - overall_acrt = res_slope$overall_acrt, - dvals = as.numeric(res_slope$dose) - ) - cat(toJSON(out, auto_unbox = TRUE, digits = 10)) - """ - result = subprocess.run( - ["Rscript", "-e", r_code], - capture_output=True, text=True, timeout=120, - ) - if result.returncode != 0: - pytest.skip(f"R contdid failed: {result.stderr[:500]}") - r_out = json.loads(result.stdout) - - data = pd.read_csv("/tmp/r_bench5.csv") - data = data.rename(columns={ - "id": "unit", "time_period": "period", - "Y": "outcome", "G": "first_treat", "D": "dose", - }) - dvals = np.array(r_out["dvals"]) - est = ContinuousDiD( - degree=3, num_knots=0, dvals=dvals, - control_group="not_yet_treated", - ) - results = est.fit( - data, "outcome", "unit", "period", "first_treat", "dose", - aggregate="dose", - ) + att_res <- suppressWarnings(pte_default( + yname = "Y", gname = "G", tname = "time_period", + idname = "id", data = df, d_outcome = TRUE, + anticipation = 0, base_period = "varying", + control_group = "notyettreated", + biters = 100, alp = 0.05 + )) + + write.csv(df, "{tmp_path}", row.names = FALSE) + out <- list( + overall_att = att_res$overall_att$overall.att, + overall_acrt = res_slope$overall_acrt, + dvals = as.numeric(res_slope$dose) + ) + cat(toJSON(out, auto_unbox = TRUE, digits = 10)) + """ + result = subprocess.run( + ["Rscript", "-e", r_code], + capture_output=True, text=True, timeout=120, + ) + if result.returncode != 0: + pytest.skip(f"R contdid failed: {result.stderr[:500]}") + r_out = json.loads(result.stdout) + + data = pd.read_csv(tmp_path) + data = data.rename(columns={ + "id": "unit", "time_period": "period", + "Y": "outcome", "G": "first_treat", "D": "dose", + }) + dvals = np.array(r_out["dvals"]) + est = ContinuousDiD( + degree=3, num_knots=0, dvals=dvals, + control_group="not_yet_treated", + ) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="dose", + ) - att_diff = abs(results.overall_att - r_out["overall_att"]) / (abs(r_out["overall_att"]) + 1e-10) - assert att_diff < 0.01, ( - f"Overall ATT diff: {att_diff:.4f} " - f"(R={r_out['overall_att']:.6f}, Py={results.overall_att:.6f})" - ) + att_diff = abs(results.overall_att - r_out["overall_att"]) / (abs(r_out["overall_att"]) + 1e-10) + assert att_diff < 0.01, ( + f"Overall ATT diff: {att_diff:.4f} " + f"(R={r_out['overall_att']:.6f}, Py={results.overall_att:.6f})" + ) - acrt_diff = abs(results.overall_acrt - r_out["overall_acrt"]) / (abs(r_out["overall_acrt"]) + 1e-10) - assert acrt_diff < 0.01, ( - f"Overall ACRT diff: {acrt_diff:.4f} " - f"(R={r_out['overall_acrt']:.6f}, Py={results.overall_acrt:.6f})" - ) + acrt_diff = abs(results.overall_acrt - r_out["overall_acrt"]) / (abs(r_out["overall_acrt"]) + 1e-10) + assert acrt_diff < 0.01, ( + f"Overall ACRT diff: {acrt_diff:.4f} " + f"(R={r_out['overall_acrt']:.6f}, Py={results.overall_acrt:.6f})" + ) + finally: + os.unlink(tmp_path) def test_benchmark_6_event_study(self): """4 periods, 3 cohorts, event study aggregation (binarized ATT). @@ -598,57 +611,63 @@ def test_benchmark_6_event_study(self): per-cell estimation, then aggregates by relative period. We compare overall ATT (binarized) via pte_default with matching control_group. """ - r_code = """ - library(contdid) - library(ptetools) - library(jsonlite) - - set.seed(99) - df <- simulate_contdid_data( - n = 200, num_time_periods = 4, num_groups = 4, - dose_linear_effect = 2, dose_quadratic_effect = 0 - ) - - # Overall ATT via pte_default (matching control_group) - att_res <- suppressWarnings(pte_default( - yname = "Y", gname = "G", tname = "time_period", - idname = "id", data = df, d_outcome = TRUE, - anticipation = 0, base_period = "varying", - control_group = "nevertreated", - biters = 100, alp = 0.05 - )) + tmp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False) + tmp_path = tmp.name + tmp.close() + try: + r_code = f""" + library(contdid) + library(ptetools) + library(jsonlite) + + set.seed(99) + df <- simulate_contdid_data( + n = 200, num_time_periods = 4, num_groups = 4, + dose_linear_effect = 2, dose_quadratic_effect = 0 + ) - write.csv(df, "/tmp/r_bench6.csv", row.names = FALSE) - out <- list( - overall_att = att_res$overall_att$overall.att - ) - cat(toJSON(out, auto_unbox = TRUE, digits = 10)) - """ - result = subprocess.run( - ["Rscript", "-e", r_code], - capture_output=True, text=True, timeout=120, - ) - if result.returncode != 0: - pytest.skip(f"R contdid failed: {result.stderr[:500]}") - r_out = json.loads(result.stdout) - - data = pd.read_csv("/tmp/r_bench6.csv") - data = data.rename(columns={ - "id": "unit", "time_period": "period", - "Y": "outcome", "G": "first_treat", "D": "dose", - }) - est = ContinuousDiD( - degree=3, num_knots=0, - control_group="never_treated", - ) - results = est.fit( - data, "outcome", "unit", "period", "first_treat", "dose", - aggregate="eventstudy", - ) + # Overall ATT via pte_default (matching control_group) + att_res <- suppressWarnings(pte_default( + yname = "Y", gname = "G", tname = "time_period", + idname = "id", data = df, d_outcome = TRUE, + anticipation = 0, base_period = "varying", + control_group = "nevertreated", + biters = 100, alp = 0.05 + )) + + write.csv(df, "{tmp_path}", row.names = FALSE) + out <- list( + overall_att = att_res$overall_att$overall.att + ) + cat(toJSON(out, auto_unbox = TRUE, digits = 10)) + """ + result = subprocess.run( + ["Rscript", "-e", r_code], + capture_output=True, text=True, timeout=120, + ) + if result.returncode != 0: + pytest.skip(f"R contdid failed: {result.stderr[:500]}") + r_out = json.loads(result.stdout) + + data = pd.read_csv(tmp_path) + data = data.rename(columns={ + "id": "unit", "time_period": "period", + "Y": "outcome", "G": "first_treat", "D": "dose", + }) + est = ContinuousDiD( + degree=3, num_knots=0, + control_group="never_treated", + ) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose", + aggregate="eventstudy", + ) - # Compare overall ATT (binarized) - att_diff = abs(results.overall_att - r_out["overall_att"]) / (abs(r_out["overall_att"]) + 1e-10) - assert att_diff < 0.01, ( - f"Overall ATT diff: {att_diff:.4f} " - f"(R={r_out['overall_att']:.6f}, Py={results.overall_att:.6f})" - ) + # Compare overall ATT (binarized) + att_diff = abs(results.overall_att - r_out["overall_att"]) / (abs(r_out["overall_att"]) + 1e-10) + assert att_diff < 0.01, ( + f"Overall ATT diff: {att_diff:.4f} " + f"(R={r_out['overall_att']:.6f}, Py={results.overall_att:.6f})" + ) + finally: + os.unlink(tmp_path) From 0f2849bf7b0091802a6c1053c6cd2ed56564cbc1 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 21 Feb 2026 19:42:03 -0500 Subject: [PATCH 6/9] Fix bootstrap ACRT^{glob} centering bug and add regression test Store acrt_glob in _bootstrap_info so the bootstrap distribution is centered at the point estimate instead of 0. Add test that verifies bootstrap ACRT CI brackets the estimate rather than zero. Co-Authored-By: Claude Opus 4.6 --- diff_diff/continuous_did.py | 1 + tests/test_continuous_did.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/diff_diff/continuous_did.py b/diff_diff/continuous_did.py index ec403e8..af9935d 100644 --- a/diff_diff/continuous_did.py +++ b/diff_diff/continuous_did.py @@ -785,6 +785,7 @@ def _compute_dose_response_gt( "delta_y_control": delta_y_control, "mu_0": mu_0, "att_glob": att_glob, + "acrt_glob": acrt_glob, } return { diff --git a/tests/test_continuous_did.py b/tests/test_continuous_did.py index f5bad0b..74cc802 100644 --- a/tests/test_continuous_did.py +++ b/tests/test_continuous_did.py @@ -458,6 +458,30 @@ def test_bootstrap_ci_contains_estimate(self, ci_params): lo, hi = results.overall_att_conf_int assert lo <= results.overall_att <= hi + def test_bootstrap_acrt_ci_centered(self, ci_params): + """Bootstrap ACRT CI should bracket the point estimate, not zero.""" + n_boot = ci_params.bootstrap(99) + data = generate_continuous_did_data( + n_units=200, n_periods=3, seed=42, noise_sd=0.5, + att_function="linear", att_slope=2.0, att_intercept=1.0, + ) + est = ContinuousDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) + lo, hi = results.overall_acrt_conf_int + assert lo <= results.overall_acrt <= hi, ( + f"ACRT CI [{lo:.4f}, {hi:.4f}] does not bracket " + f"point estimate {results.overall_acrt:.4f}" + ) + # CI midpoint should be closer to estimate than to 0 + midpoint = (lo + hi) / 2 + assert abs(midpoint - results.overall_acrt) < abs(midpoint), ( + f"CI midpoint {midpoint:.4f} is closer to 0 than to " + f"estimate {results.overall_acrt:.4f} — bootstrap distribution " + f"may still be mis-centered" + ) + def test_bootstrap_p_values_valid(self, ci_params): n_boot = ci_params.bootstrap(99) data = generate_continuous_did_data( From f737ab6b99eb4ad6fbc2a001586bc59ec87c465a Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 21 Feb 2026 20:39:18 -0500 Subject: [PATCH 7/9] Address PR #177 review round 6: boundary knot docs, results provenance, drop unused vcov - Document B-spline boundary knot deviation in REGISTRY.md and continuous-did.md - Add base_period, anticipation, n_bootstrap, bootstrap_weights, seed, rank_deficient_action fields to ContinuousDiDResults with passthrough from fit() - Switch per-cell OLS to return_vcov=False to skip unused covariance computation - Add test_results_contain_init_params verifying param roundtrip Co-Authored-By: Claude Opus 4.6 --- diff_diff/continuous_did.py | 10 ++++++++-- diff_diff/continuous_did_results.py | 20 ++++++++++++++++++++ docs/methodology/REGISTRY.md | 3 ++- docs/methodology/continuous-did.md | 7 +++++++ tests/test_continuous_did.py | 19 +++++++++++++++++++ 5 files changed, 56 insertions(+), 3 deletions(-) diff --git a/diff_diff/continuous_did.py b/diff_diff/continuous_did.py index af9935d..2754756 100644 --- a/diff_diff/continuous_did.py +++ b/diff_diff/continuous_did.py @@ -568,6 +568,12 @@ def fit( control_group=self.control_group, degree=self.degree, num_knots=self.num_knots, + base_period=self.base_period, + anticipation=self.anticipation, + n_bootstrap=self.n_bootstrap, + bootstrap_weights=self.bootstrap_weights, + seed=self.seed, + rank_deficient_action=self.rank_deficient_action, event_study_effects=event_study_effects, ) @@ -725,9 +731,9 @@ def _compute_dose_response_gt( return None # OLS regression - beta_hat, residuals, vcov = solve_ols( + beta_hat, residuals, _ = solve_ols( Psi, delta_tilde_y, - return_vcov=True, + return_vcov=False, rank_deficient_action=self.rank_deficient_action, ) diff --git a/diff_diff/continuous_did_results.py b/diff_diff/continuous_did_results.py index 318d7ff..e467c4b 100644 --- a/diff_diff/continuous_did_results.py +++ b/diff_diff/continuous_did_results.py @@ -86,6 +86,18 @@ class ContinuousDiDResults: Plug-in overall ACRT^{glob}. group_time_effects : dict Per (g,t) cell results. + base_period : str + Base period strategy (``"varying"`` or ``"universal"``). + anticipation : int + Number of anticipation periods. + n_bootstrap : int + Number of bootstrap iterations used. + bootstrap_weights : str + Bootstrap weight type (``"rademacher"`` or ``"mammen"``). + seed : int or None + Random seed used for bootstrap. + rank_deficient_action : str + How rank deficiency is handled (``"warn"``, ``"error"``, ``"silent"``). """ dose_response_att: DoseResponseCurve @@ -111,6 +123,12 @@ class ContinuousDiDResults: control_group: str = "never_treated" degree: int = 3 num_knots: int = 0 + base_period: str = "varying" + anticipation: int = 0 + n_bootstrap: int = 0 + bootstrap_weights: str = "rademacher" + seed: Optional[int] = None + rank_deficient_action: str = "warn" event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None) def __repr__(self) -> str: @@ -144,6 +162,8 @@ def summary(self, alpha: Optional[float] = None) -> str: f"{'Control group:':<30} {self.control_group:>10}", f"{'B-spline degree:':<30} {self.degree:>10}", f"{'Interior knots:':<30} {self.num_knots:>10}", + f"{'Base period:':<30} {self.base_period:>10}", + f"{'Anticipation:':<30} {self.anticipation:>10}", "", ] diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 06abfe0..1fd9e7d 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -427,10 +427,11 @@ This is stronger than standard PT because it conditions on specific dose values. - **All-same dose**: B-spline basis collapses; ACRT(d) = 0 everywhere. - **Rank deficiency**: When n_treated <= n_basis, cell is skipped. - **Balanced panel required**: Matches R `contdid` v0.1.0. +- **Boundary knots**: Evaluation grid is clamped to training-dose boundary knots (`range(dose)`). R's `contdid` v0.1.0 has an inconsistency where `splines2::bSpline(dvals)` uses `range(dvals)` instead of `range(dose)`, which can produce extrapolation artifacts at dose grid extremes. Our approach avoids extrapolation and is methodologically sound. ### Implementation Checklist -- [x] B-spline basis construction matching R's `splines2::bSpline` +- [x] B-spline basis construction matching R's `splines2::bSpline` (boundary knots use training-dose range; see deviation note below) - [x] Multi-period (g,t) cell iteration with base period selection - [x] Dose-response and event-study aggregation with n_treated weights - [x] Multiplier bootstrap for inference diff --git a/docs/methodology/continuous-did.md b/docs/methodology/continuous-did.md index 1c26cb5..3ae882e 100644 --- a/docs/methodology/continuous-did.md +++ b/docs/methodology/continuous-did.md @@ -358,6 +358,13 @@ cont_did(yname, dname, gname, tname, idname, data, 1. Extract 2x2 subset: target group (g) + control group, pre-period + post-period 2. Construct B-spline basis from treated units' doses using `splines2::bSpline()` + + > **Boundary knot note**: The B-spline boundary knots are set from the + > training doses (`range(dose_treated)`). Evaluation at `dvals` is clamped + > to these boundaries. R's `contdid` v0.1.0 uses `range(dvals)` as boundary + > knots when evaluating, which can cause extrapolation artifacts. This is an + > intentional deviation. + 3. OLS: regress Delta Y on B-spline basis 4. Evaluate fitted spline at `dvals` -> ATT(d) vector 5. Evaluate derivative of spline at `dvals` -> ACRT(d) vector diff --git a/tests/test_continuous_did.py b/tests/test_continuous_did.py index 74cc802..9046105 100644 --- a/tests/test_continuous_did.py +++ b/tests/test_continuous_did.py @@ -296,6 +296,25 @@ def test_group_time_effects_populated(self, basic_data): ) assert len(results.group_time_effects) > 0 + def test_results_contain_init_params(self, basic_data): + est = ContinuousDiD( + base_period="universal", + anticipation=0, + n_bootstrap=49, + bootstrap_weights="mammen", + seed=123, + rank_deficient_action="error", + ) + results = est.fit( + basic_data, "outcome", "unit", "period", "first_treat", "dose" + ) + assert results.base_period == "universal" + assert results.anticipation == 0 + assert results.n_bootstrap == 49 + assert results.bootstrap_weights == "mammen" + assert results.seed == 123 + assert results.rank_deficient_action == "error" + def test_not_yet_treated_control(self): data = generate_continuous_did_data( n_units=100, n_periods=4, cohort_periods=[2, 3], seed=42, From ce8240e94719555f0e30ec9f54bcd86ea19daf4b Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 22 Feb 2026 06:38:33 -0500 Subject: [PATCH 8/9] Address PR #177 review round 7: clarify global knots and aggregation weights in docs, add "webb" to docstring Co-Authored-By: Claude Opus 4.6 --- diff_diff/continuous_did_results.py | 2 +- docs/methodology/REGISTRY.md | 6 +++--- docs/methodology/continuous-did.md | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/diff_diff/continuous_did_results.py b/diff_diff/continuous_did_results.py index e467c4b..fedbd3f 100644 --- a/diff_diff/continuous_did_results.py +++ b/diff_diff/continuous_did_results.py @@ -93,7 +93,7 @@ class ContinuousDiDResults: n_bootstrap : int Number of bootstrap iterations used. bootstrap_weights : str - Bootstrap weight type (``"rademacher"`` or ``"mammen"``). + Bootstrap weight type (``"rademacher"``, ``"mammen"``, or ``"webb"``). seed : int or None Random seed used for bootstrap. rank_deficient_action : str diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 1fd9e7d..ec4fe9e 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -427,13 +427,13 @@ This is stronger than standard PT because it conditions on specific dose values. - **All-same dose**: B-spline basis collapses; ACRT(d) = 0 everywhere. - **Rank deficiency**: When n_treated <= n_basis, cell is skipped. - **Balanced panel required**: Matches R `contdid` v0.1.0. -- **Boundary knots**: Evaluation grid is clamped to training-dose boundary knots (`range(dose)`). R's `contdid` v0.1.0 has an inconsistency where `splines2::bSpline(dvals)` uses `range(dvals)` instead of `range(dose)`, which can produce extrapolation artifacts at dose grid extremes. Our approach avoids extrapolation and is methodologically sound. +- **Boundary knots**: Knots are built once from all treated doses (global, not per-cell) to ensure a common basis across (g,t) cells for aggregation. Evaluation grid is clamped to training-dose boundary knots (`range(dose)`). R's `contdid` v0.1.0 has an inconsistency where `splines2::bSpline(dvals)` uses `range(dvals)` instead of `range(dose)`, which can produce extrapolation artifacts at dose grid extremes. Our approach avoids extrapolation and is methodologically sound. ### Implementation Checklist -- [x] B-spline basis construction matching R's `splines2::bSpline` (boundary knots use training-dose range; see deviation note below) +- [x] B-spline basis construction matching R's `splines2::bSpline` (global knots from all treated doses; boundary knots use training-dose range; see deviation note above) - [x] Multi-period (g,t) cell iteration with base period selection -- [x] Dose-response and event-study aggregation with n_treated weights +- [x] Dose-response and event-study aggregation with group-proportional weights (n_treated/n_total per group, divided among post-treatment cells; R `ptetools` convention) - [x] Multiplier bootstrap for inference - [x] Analytical SEs via influence functions - [x] Equation verification tests (linear, quadratic, multi-period) diff --git a/docs/methodology/continuous-did.md b/docs/methodology/continuous-did.md index 3ae882e..15b6307 100644 --- a/docs/methodology/continuous-did.md +++ b/docs/methodology/continuous-did.md @@ -394,6 +394,9 @@ dvals = quantile(dose[dose > 0], probs = seq(0.10, 0.99, 0.01)) Quantile-based by default: `choose_knots_quantile(dose[dose > 0], num_knots)`. With `num_knots=0`, no interior knots (global polynomial of given degree). +Knots are built **once globally** from all positive doses, not per (g,t) cell. +This ensures a common basis space across cells so that dose-response vectors +can be meaningfully aggregated. ### Dependencies mapping (R -> Python) From 98082803a6039e5328ddd61a1296dce0ad34e458 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 22 Feb 2026 07:04:12 -0500 Subject: [PATCH 9/9] Guard bootstrap NaN propagation: SE/CI/p-value all NaN when SE invalid MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When n_valid=1 (ddof=1 → NaN) or all-identical samples (SE=0), compute_effect_bootstrap_stats now returns NaN for all inference fields instead of mixing finite CI/p-value with NaN SE. Adds regression tests and updates REGISTRY.md edge-case documentation. Co-Authored-By: Claude Opus 4.6 --- diff_diff/bootstrap_utils.py | 11 ++++++ docs/methodology/REGISTRY.md | 4 +- tests/test_bootstrap_utils.py | 72 +++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 tests/test_bootstrap_utils.py diff --git a/diff_diff/bootstrap_utils.py b/diff_diff/bootstrap_utils.py index eb6837b..c4d44bf 100644 --- a/diff_diff/bootstrap_utils.py +++ b/diff_diff/bootstrap_utils.py @@ -258,6 +258,17 @@ def compute_effect_bootstrap_stats( valid_dist = boot_dist[finite_mask] se = float(np.std(valid_dist, ddof=1)) + + # Guard: if SE is not finite or zero, all inference fields must be NaN. + if not np.isfinite(se) or se <= 0: + warnings.warn( + f"Bootstrap SE is non-finite or zero (n_valid={n_valid}) in {context}. " + "Returning NaN for SE/CI/p-value.", + RuntimeWarning, + stacklevel=3, + ) + return np.nan, (np.nan, np.nan), np.nan + ci = compute_percentile_ci(valid_dist, alpha) p_value = compute_bootstrap_pvalue( original_effect, valid_dist, n_valid=len(valid_dist) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index ec4fe9e..519e417 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -347,7 +347,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: - Parameter: `rank_deficient_action` controls behavior: "warn" (default), "error", or "silent" - Non-finite inference values: - Analytic SE: Returns NaN to signal invalid inference (not biased via zeroing) - - Bootstrap: Drops non-finite samples, warns, and adjusts p-value floor accordingly + - Bootstrap: Drops non-finite samples, warns, and adjusts p-value floor accordingly. SE, CI, and p-value are all NaN if SE is non-finite or zero (e.g., n_valid=1 with ddof=1, or identical samples) - Threshold: Returns NaN if <50% of bootstrap samples are valid - Per-effect t_stat: Uses NaN (not 0.0) when SE is non-finite or zero (consistent with overall_t_stat) - **Note**: This is a defensive enhancement over reference implementations (R's `did::att_gt`, Stata's `csdid`) which may error or produce unhandled inf/nan in edge cases without informative warnings @@ -488,7 +488,7 @@ where weights ŵ_{g,e} = n_{g,e} / Σ_g n_{g,e} (sample share of cohort g at eve - NaN inference for undefined statistics: - t_stat: Uses NaN (not 0.0) when SE is non-finite or zero - Analytical inference: p_value and CI also NaN when t_stat is NaN (NaN propagates through `compute_p_value` and `compute_confidence_interval`) - - Bootstrap inference: p_value and CI computed from bootstrap distribution, may be valid even when SE/t_stat is NaN (only NaN if <50% of bootstrap samples are valid) + - Bootstrap inference: p_value and CI computed from bootstrap distribution. SE, CI, and p-value are all NaN if SE is non-finite or zero, or if <50% of bootstrap samples are valid - Applies to overall ATT, per-effect event study, and aggregated event study - **Note**: Defensive enhancement matching CallawaySantAnna behavior; R's `fixest::sunab()` may produce Inf/NaN without warning - Inference distribution: diff --git a/tests/test_bootstrap_utils.py b/tests/test_bootstrap_utils.py new file mode 100644 index 0000000..aaee64e --- /dev/null +++ b/tests/test_bootstrap_utils.py @@ -0,0 +1,72 @@ +"""Tests for bootstrap utility edge cases (NaN propagation).""" + +import numpy as np +import pytest + +from diff_diff.bootstrap_utils import compute_effect_bootstrap_stats + + +class TestBootstrapStatsNaNPropagation: + """Regression tests for compute_effect_bootstrap_stats NaN guard.""" + + def test_bootstrap_stats_single_valid_sample(self): + """Single valid sample: ddof=1 produces NaN SE -> all NaN.""" + boot_dist = np.array([1.5]) + with pytest.warns(RuntimeWarning, match="non-finite or zero"): + se, ci, p_value = compute_effect_bootstrap_stats( + original_effect=1.0, boot_dist=boot_dist + ) + assert np.isnan(se) + assert np.isnan(ci[0]) + assert np.isnan(ci[1]) + assert np.isnan(p_value) + + def test_bootstrap_stats_all_nonfinite(self): + """All non-finite samples: fails 50% validity check -> all NaN.""" + boot_dist = np.array([np.nan, np.nan, np.inf]) + with pytest.warns(RuntimeWarning): + se, ci, p_value = compute_effect_bootstrap_stats( + original_effect=1.0, boot_dist=boot_dist + ) + assert np.isnan(se) + assert np.isnan(ci[0]) + assert np.isnan(ci[1]) + assert np.isnan(p_value) + + def test_bootstrap_stats_identical_values(self): + """All identical values: se=0 -> all NaN.""" + boot_dist = np.array([2.0] * 100) + with pytest.warns(RuntimeWarning, match="non-finite or zero"): + se, ci, p_value = compute_effect_bootstrap_stats( + original_effect=2.0, boot_dist=boot_dist + ) + assert np.isnan(se) + assert np.isnan(ci[0]) + assert np.isnan(ci[1]) + assert np.isnan(p_value) + + def test_bootstrap_stats_mostly_valid_but_identical(self): + """67% valid (passes 50% check) but identical values: se=0 -> all NaN.""" + boot_dist = np.array([2.0, 2.0, np.nan]) + with pytest.warns(RuntimeWarning, match="non-finite or zero"): + se, ci, p_value = compute_effect_bootstrap_stats( + original_effect=2.0, boot_dist=boot_dist + ) + assert np.isnan(se) + assert np.isnan(ci[0]) + assert np.isnan(ci[1]) + assert np.isnan(p_value) + + def test_bootstrap_stats_normal_case(self): + """Normal case with varied values: all fields finite.""" + boot_dist = np.arange(100.0) + se, ci, p_value = compute_effect_bootstrap_stats( + original_effect=50.0, boot_dist=boot_dist + ) + assert np.isfinite(se) + assert se > 0 + assert np.isfinite(ci[0]) + assert np.isfinite(ci[1]) + assert ci[0] < ci[1] + assert np.isfinite(p_value) + assert 0 < p_value <= 1