diff --git a/CLAUDE.md b/CLAUDE.md index ad24d4f..70da7ec 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -144,6 +144,18 @@ pure Rust by default. - **`diff_diff/two_stage_bootstrap.py`** - Bootstrap inference: - `TwoStageDiDBootstrapMixin` - Mixin with GMM influence function bootstrap methods +- **`diff_diff/stacked_did.py`** - Stacked DiD estimator (Wing et al. 2024): + - `StackedDiD` - Stacked DiD with corrective Q-weights for compositional balance + - `stacked_did()` - Convenience function + - Builds sub-experiments per adoption cohort with clean controls + - IC1/IC2 trimming for compositional balance across event times + - Q-weights for aggregate, population, or sample share estimands (Table 1) + - WLS event study regression via sqrt(w) transformation + - Re-exports result class for backward compatibility + +- **`diff_diff/stacked_did_results.py`** - Result container classes: + - `StackedDiDResults` - Results with overall ATT, event study, group effects, stacked data access + - **`diff_diff/triple_diff.py`** - Triple Difference (DDD) estimator: - `TripleDifference` - Ortiz-Villavicencio & Sant'Anna (2025) estimator for DDD designs - `TripleDifferenceResults` - Results with ATT, SEs, cell means, diagnostics @@ -314,6 +326,7 @@ pure Rust by default. ├── TwoStageDiD ├── TripleDifference ├── TROP + ├── StackedDiD ├── SyntheticDiD └── BaconDecomposition ``` @@ -429,6 +442,7 @@ Tests mirror the source modules: - `tests/test_sun_abraham.py` - Tests for SunAbraham interaction-weighted estimator - `tests/test_imputation.py` - Tests for ImputationDiD (Borusyak et al. 2024) estimator - `tests/test_two_stage.py` - Tests for TwoStageDiD (Gardner 2022) estimator, including equivalence tests with ImputationDiD +- `tests/test_stacked_did.py` - Tests for Stacked DiD (Wing et al. 2024) estimator - `tests/test_triple_diff.py` - Tests for Triple Difference (DDD) estimator - `tests/test_trop.py` - Tests for Triply Robust Panel (TROP) estimator - `tests/test_bacon.py` - Tests for Goodman-Bacon decomposition @@ -445,6 +459,8 @@ Tests mirror the source modules: Session-scoped `ci_params` fixture in `conftest.py` scales bootstrap iterations and TROP grid sizes in pure Python mode — use `ci_params.bootstrap(n)` and `ci_params.grid(values)` in new tests with `n_bootstrap >= 20`. For SE convergence tests (analytical vs bootstrap comparison), use `ci_params.bootstrap(n, min_n=199)` with a conditional tolerance: `threshold = 0.40 if n_boot < 100 else 0.15`. The `min_n` parameter is capped at 49 in pure Python mode to keep CI fast, so convergence tests use wider tolerances when running with fewer bootstrap iterations. +**Slow test suites:** `tests/test_trop.py` is very time-consuming. Only run TROP tests when changes could affect the TROP estimator (e.g., `diff_diff/trop.py`, `diff_diff/trop_results.py`, `diff_diff/linalg.py`, `diff_diff/_backend.py`, or `rust/src/trop.rs`). For unrelated changes, exclude with `pytest --ignore=tests/test_trop.py`. + ### Test Writing Guidelines **For fallback/error handling paths:** diff --git a/METHODOLOGY_REVIEW.md b/METHODOLOGY_REVIEW.md index afe1027..1cf8669 100644 --- a/METHODOLOGY_REVIEW.md +++ b/METHODOLOGY_REVIEW.md @@ -27,6 +27,7 @@ Each estimator in diff-diff should be periodically reviewed to ensure: | SunAbraham | `sun_abraham.py` | `fixest::sunab()` | **Complete** | 2026-02-15 | | SyntheticDiD | `synthetic_did.py` | `synthdid::synthdid_estimate()` | **Complete** | 2026-02-10 | | TripleDifference | `triple_diff.py` | `triplediff::ddd()` | **Complete** | 2026-02-18 | +| StackedDiD | `stacked_did.py` | `stacked-did-weights` | **Complete** | 2026-02-19 | | TROP | `trop.py` | (forthcoming) | Not Started | - | | BaconDecomposition | `bacon.py` | `bacondecomp::bacon()` | Not Started | - | | HonestDiD | `honest_did.py` | `HonestDiD` package | Not Started | - | @@ -379,6 +380,102 @@ variables appear to the left of the `|` separator. --- +#### StackedDiD + +| Field | Value | +|-------|-------| +| Module | `stacked_did.py` | +| Primary Reference | Wing, Freedman & Hollingsworth (2024), NBER WP 32054 | +| R Reference | `stacked-did-weights` (`create_sub_exp()` + `compute_weights()`) | +| Status | **Complete** | +| Last Review | 2026-02-19 | + +**Verified Components:** +- [x] IC1 trimming: `a - kappa_pre >= T_min AND a + kappa_post <= T_max` (matches R reference) +- [x] IC2 trimming: Three clean control modes (not_yet_treated, strict, never_treated) +- [x] Sub-experiment construction: treated + clean controls within `[a - kappa_pre, a + kappa_post]` +- [x] Q-weights aggregate: treated Q=1, control `Q = (sub_treat_n/stack_treat_n) / (sub_control_n/stack_control_n)` per (event_time, sub_exp) — matches R `compute_weights()` +- [x] Q-weights population: `Q_a = (Pop_a^D / Pop^D) / (N_a^C / N^C)` (Table 1, Row 2) +- [x] Q-weights sample_share: `Q_a = ((N_a^D + N_a^C)/(N^D+N^C)) / (N_a^C / N^C)` (Table 1, Row 3) +- [x] WLS via sqrt(w) transformation (numerically equivalent to weighted regression) +- [x] Event study regression: `Y = α_0 + α_1·D_sa + Σ_{h≠-1}[λ_h·1(e=h) + δ_h·D_sa·1(e=h)] + U` (Eq. 3) +- [x] Reference period e=-1-anticipation normalized to zero (omitted from design matrix) +- [x] Delta-method SE for overall ATT: `SE = sqrt(ones' @ sub_vcv @ ones) / K` +- [x] Cluster-robust SEs at unit level (default) and unit×sub-experiment level +- [x] Anticipation parameter: reference period shifts to e=-1-anticipation, post-treatment includes anticipation periods +- [x] Rank deficiency handling (warn/error/silent via `solve_ols()`) +- [x] Never-treated encoding: both `first_treat=0` and `first_treat=inf` handled +- [x] R comparison: ATT matches within machine precision (diff < 2.1e-11) +- [x] R comparison: SE matches within machine precision (diff < 4.0e-10) +- [x] R comparison: Event study effects correlation = 1.000000, max diff < 4.5e-11 +- [x] safe_inference() used for all inference fields +- [x] All REGISTRY.md edge cases tested + +**Test Coverage:** +- 72 tests in `tests/test_stacked_did.py` across 11 test classes: + - `TestStackedDiDBasic` (8): fit, event study, group/all raises, simple aggregation, known constant effect, dynamic effects + - `TestTrimming` (5): IC1 window, IC2 no-controls, trimmed groups reported, all-trimmed raises, wider window + - `TestQWeights` (4): treated=1, aggregate formula, sample_share formula, positivity + - `TestCleanControl` (5): not_yet_treated, strict, never_treated, missing never-treated raises + - `TestClustering` (2): unit, unit_subexp + - `TestStackedData` (4): accessible, required columns, event time range + - `TestEdgeCases` (8): single cohort, anticipation, unbalanced panel, NaN inference, never-treated encodings + - `TestSklearnInterface` (4): get_params, set_params, unknown raises, convenience function + - `TestResultsMethods` (7): summary, to_dataframe, is_significant, significance_stars, repr + - `TestValidation` (8): missing columns, invalid params, population required, no treated units +- R benchmark tests via `benchmarks/run_benchmarks.py --estimator stacked` + +**R Comparison Results (200 units, 8 periods, kappa_pre=2, kappa_post=2):** +| Metric | Python | R | Diff | +|--------|--------|---|------| +| Overall ATT | 2.277699574579 | 2.2776995746 | 2.1e-11 | +| Overall SE | 0.062045687626 | 0.062045688027 | 4.0e-10 | +| ES e=-2 ATT | 0.044517975379 | 0.044517975379 | <1e-12 | +| ES e=0 ATT | 2.104181683763 | 2.104181683800 | <1e-11 | +| ES e=1 ATT | 2.209990715130 | 2.209990715100 | <1e-11 | +| ES e=2 ATT | 2.518926324845 | 2.518926324800 | <1e-11 | +| Stacked obs | 1600 | 1600 | exact | +| Sub-experiments | 3 | 3 | exact | + +**Corrections Made:** +1. **IC1 lower bound and time window aligned with R reference** (`stacked_did.py`, + `_trim_adoption_events()` and `_build_sub_experiment()`): The paper text specifies + time window `[a - kappa_pre - 1, a + kappa_post]` (including an extra pre-period), + but the R reference implementation by co-author Hollingsworth uses + `[a - kappa_pre, a + kappa_post]`. The extra period had no event-study dummy, + altering the baseline regression. Fixed to match R: removed `-1` from both + IC1 check (`a - kappa_pre >= T_min`) and time window start. Discrepancy documented + in `docs/methodology/papers/wing-2024-review.md` Gaps section. + +2. **Q-weight computation: event-time-specific for aggregate weighting** (`stacked_did.py`, + `_compute_q_weights()`): Changed aggregate Q-weights from unit counts per sub-experiment + to observation counts per (event_time, sub_exp), matching R reference `compute_weights()`. + For balanced panels, results are unchanged. For unbalanced panels, weights now adjust for + varying observation density. Population/sample_share retain unit-count formulas (paper notation). + +3. **Anticipation parameter: reference period and dummies** (`stacked_did.py`, `fit()`): + Reference period now shifts to `e = -1 - anticipation`. Event-time dummies cover the + full window `[-kappa_pre - anticipation, ..., kappa_post]`. Post-treatment effects include + anticipation periods. Consistent with ImputationDiD, TwoStageDiD, SunAbraham. + +4. **Group aggregation removed** (`stacked_did.py`): `aggregate="group"` and `aggregate="all"` + removed. The pooled stacked regression cannot produce cohort-specific effects without + cohort×event-time interactions. Use CallawaySantAnna or ImputationDiD for cohort-level estimates. + +5. **n_sub_experiments metadata** (`stacked_did.py`, `fit()`): Now tracks actual built + sub-experiments, not all events in omega_kappa. Warns if any sub-experiments are empty + after data filtering. + +**Outstanding Concerns:** +- Population/sample_share Q-weights use paper's unit-count formulas (no R reference to validate) +- Anticipation not validated against R (R reference doesn't test anticipation > 0) + +**Deviations from R's stacked-did-weights:** +1. **NaN for invalid inference**: Python returns NaN for t_stat/p_value/conf_int when + SE is non-finite or zero. R would propagate through `fixest::feols()` error handling. + +--- + ### Advanced Estimators #### SyntheticDiD diff --git a/README.md b/README.md index 571753d..c2e12fe 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1 - **Wild cluster bootstrap**: Valid inference with few clusters (<50) using Rademacher, Webb, or Mammen weights - **Panel data support**: Two-way fixed effects estimator for panel designs - **Multi-period analysis**: Event-study style DiD with period-specific treatment effects -- **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, and Two-Stage DiD (Gardner 2022) estimators for heterogeneous treatment timing +- **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), and Stacked DiD (Wing, Freedman & Hollingsworth 2024) estimators for heterogeneous treatment timing - **Triple Difference (DDD)**: Ortiz-Villavicencio & Sant'Anna (2025) estimators with proper covariate handling - **Synthetic DiD**: Combined DiD with synthetic control for improved robustness - **Triply Robust Panel (TROP)**: Factor-adjusted DiD with synthetic weights (Athey et al. 2025) @@ -974,6 +974,78 @@ TwoStageDiD( Both estimators are the efficient estimator under homogeneous treatment effects, producing shorter confidence intervals than Callaway-Sant'Anna or Sun-Abraham. +### Stacked DiD (Wing, Freedman & Hollingsworth 2024) + +Stacked DiD addresses TWFE bias in staggered adoption settings by constructing a "clean" comparison dataset for each treatment cohort and stacking them together. Each cohort's sub-experiment compares units treated at that cohort's timing against units that are not yet treated (or never treated) within a symmetric event-study window. This avoids the "bad comparisons" problem in TWFE while retaining a regression-based framework that practitioners familiar with event studies will find intuitive. + +```python +from diff_diff import StackedDiD, generate_staggered_data + +# Generate sample data +data = generate_staggered_data(n_units=200, n_periods=12, + cohort_periods=[4, 6, 8], seed=42) + +# Fit stacked DiD with event study +est = StackedDiD(kappa_pre=2, kappa_post=2) +results = est.fit(data, outcome='outcome', unit='unit', + time='period', first_treat='first_treat', + aggregate='event_study') +results.print_summary() + +# Access stacked data for custom analysis +stacked = results.stacked_data + +# Convenience function +from diff_diff import stacked_did +results = stacked_did(data, 'outcome', 'unit', 'period', 'first_treat', + kappa_pre=2, kappa_post=2, aggregate='event_study') +``` + +**Parameters:** + +```python +StackedDiD( + kappa_pre=1, # Pre-treatment event-study periods + kappa_post=1, # Post-treatment event-study periods + weighting='aggregate', # 'aggregate', 'population', or 'sample_share' + clean_control='not_yet_treated', # 'not_yet_treated', 'strict', or 'never_treated' + cluster='unit', # 'unit' or 'unit_subexp' + alpha=0.05, # Significance level + anticipation=0, # Anticipation periods + rank_deficient_action='warn', # 'warn', 'error', or 'silent' +) +``` + +> **Note:** Group aggregation (`aggregate='group'`) is not supported because the pooled +> stacked regression cannot produce cohort-specific effects. Use `CallawaySantAnna` or +> `ImputationDiD` for cohort-level estimates. + +**When to use Stacked DiD vs Callaway-Sant'Anna:** + +| Aspect | Stacked DiD | Callaway-Sant'Anna | +|--------|-------------|-------------------| +| Approach | Stack cohort sub-experiments, run pooled TWFE | 2x2 DiD aggregation | +| Symmetric windows | Enforced via kappa_pre / kappa_post | Not required | +| Control group | Not-yet-treated (default) or never-treated | Never-treated or not-yet-treated | +| Covariates | Passed to pooled regression | Doubly robust / IPW | +| Intuition | Familiar event-study regression | Nonparametric aggregation | + +**Convenience function:** + +```python +# One-liner estimation +results = stacked_did( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + kappa_pre=3, + kappa_post=3, + aggregate='event_study' +) +``` + ### Triple Difference (DDD) Triple Difference (DDD) is used when treatment requires satisfying two criteria: belonging to a treated **group** AND being in an eligible **partition**. The `TripleDifference` class implements the methodology from Ortiz-Villavicencio & Sant'Anna (2025), which correctly handles covariate adjustment (unlike naive implementations). @@ -2203,6 +2275,60 @@ TwoStageDiD( | `print_summary(alpha)` | Print summary to stdout | | `to_dataframe(level)` | Convert to DataFrame ('observation', 'event_study', 'group') | +### StackedDiD + +```python +StackedDiD( + kappa_pre=1, # Pre-treatment event-study periods + kappa_post=1, # Post-treatment event-study periods + weighting='aggregate', # 'aggregate', 'population', or 'sample_share' + clean_control='not_yet_treated', # 'not_yet_treated', 'strict', or 'never_treated' + cluster='unit', # 'unit' or 'unit_subexp' + alpha=0.05, # Significance level + anticipation=0, # Anticipation periods + rank_deficient_action='warn', # 'warn', 'error', or 'silent' +) +``` + +**fit() Parameters:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `data` | DataFrame | Panel data | +| `outcome` | str | Outcome variable column name | +| `unit` | str | Unit identifier column | +| `time` | str | Time period column | +| `first_treat` | str | First treatment period column (0 for never-treated) | +| `population` | str, optional | Population column (required if weighting='population') | +| `aggregate` | str | Aggregation: None, `"simple"`, or `"event_study"` | + +### StackedDiDResults + +**Attributes:** + +| Attribute | Description | +|-----------|-------------| +| `overall_att` | Overall average treatment effect on the treated | +| `overall_se` | Standard error | +| `overall_t_stat` | T-statistic | +| `overall_p_value` | P-value for H0: ATT = 0 | +| `overall_conf_int` | Confidence interval | +| `event_study_effects` | Dict of relative time -> effect dict (if `aggregate='event_study'`) | +| `stacked_data` | The stacked dataset used for estimation | +| `n_treated_obs` | Number of treated observations | +| `n_untreated_obs` | Number of untreated (clean control) observations | +| `n_cohorts` | Number of treatment cohorts | +| `kappa_pre` | Pre-treatment window used | +| `kappa_post` | Post-treatment window used | + +**Methods:** + +| Method | Description | +|--------|-------------| +| `summary(alpha)` | Get formatted summary string | +| `print_summary(alpha)` | Print summary to stdout | +| `to_dataframe(level)` | Convert to DataFrame ('event_study') | + ### TripleDifference ```python @@ -2689,6 +2815,8 @@ The `HonestDiD` module implements sensitivity analysis methods for relaxing the - **Goodman-Bacon, A. (2021).** "Difference-in-Differences with Variation in Treatment Timing." *Journal of Econometrics*, 225(2), 254-277. [https://doi.org/10.1016/j.jeconom.2021.03.014](https://doi.org/10.1016/j.jeconom.2021.03.014) +- **Wing, C., Freedman, S. M., & Hollingsworth, A. (2024).** "Stacked Difference-in-Differences." *NBER Working Paper* 32054. [https://www.nber.org/papers/w32054](https://www.nber.org/papers/w32054) + ### Power Analysis - **Bloom, H. S. (1995).** "Minimum Detectable Effects: A Simple Way to Report the Statistical Power of Experimental Designs." *Evaluation Review*, 19(5), 547-556. [https://doi.org/10.1177/0193841X9501900504](https://doi.org/10.1177/0193841X9501900504) diff --git a/ROADMAP.md b/ROADMAP.md index dfd8103..2227a7f 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -10,7 +10,7 @@ For past changes and release history, see [CHANGELOG.md](CHANGELOG.md). diff-diff v2.4.1 is a **production-ready** DiD library with feature parity with R's `did` + `HonestDiD` + `synthdid` ecosystem for core DiD analysis: -- **Core estimators**: Basic DiD, TWFE, MultiPeriod, Callaway-Sant'Anna, Sun-Abraham, Borusyak-Jaravel-Spiess Imputation, Synthetic DiD, Triple Difference (DDD), TROP, Two-Stage DiD (Gardner 2022) +- **Core estimators**: Basic DiD, TWFE, MultiPeriod, Callaway-Sant'Anna, Sun-Abraham, Borusyak-Jaravel-Spiess Imputation, Synthetic DiD, Triple Difference (DDD), TROP, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing et al. 2024) - **Valid inference**: Robust SEs, cluster SEs, wild bootstrap, multiplier bootstrap, placebo-based variance - **Assumption diagnostics**: Parallel trends tests, placebo tests, Goodman-Bacon decomposition - **Sensitivity analysis**: Honest DiD (Rambachan-Roth), Pre-trends power analysis (Roth 2022) @@ -24,16 +24,9 @@ diff-diff v2.4.1 is a **production-ready** DiD library with feature parity with High-value additions building on our existing foundation. -### Stacked Difference-in-Differences +### ~~Stacked Difference-in-Differences~~ (Implemented in v2.5) -An intuitive approach that explicitly constructs sub-experiments for each treatment cohort, avoiding forbidden comparisons. - -- Creates separate datasets per cohort with valid controls only -- Stacks sub-experiments and applies corrective sample weights -- Returns variance-weighted ATT with proper compositional balance -- Conceptually simpler alternative to aggregation-based methods - -**Reference**: [Wing, Freedman & Hollingsworth (2024)](https://www.nber.org/papers/w32054). *NBER Working Paper 32054*. Stata: `STACKDID`. +Implemented as `StackedDiD`. See `diff_diff/stacked_did.py`. ### Staggered Triple Difference (DDD) diff --git a/benchmarks/R/benchmark_stacked_did.R b/benchmarks/R/benchmark_stacked_did.R new file mode 100644 index 0000000..d423d80 --- /dev/null +++ b/benchmarks/R/benchmark_stacked_did.R @@ -0,0 +1,350 @@ +#!/usr/bin/env Rscript +# Benchmark: Stacked Difference-in-Differences (Wing, Freedman & Hollingsworth 2024) +# +# Uses the reference implementation functions from: +# https://github.com/hollina/stacked-did-weights +# embedded directly (not a CRAN package). +# Regression via fixest::feols. +# +# Compares against diff_diff.StackedDiD. +# +# Usage: +# Rscript benchmark_stacked_did.R --data path/to/data.csv --output path/to/results.json +# Rscript benchmark_stacked_did.R --data path/to/data.csv --output path/to/results.json --kappa-pre 2 --kappa-post 2 + +library(fixest) +library(jsonlite) +library(data.table) + +# ============================================================================= +# Reference implementation functions (from hollina/stacked-did-weights) +# ============================================================================= + +create_sub_exp <- function(dataset, timeID, groupID, adoptionTime, + focalAdoptionTime, kappa_pre, kappa_post) { + # Copy dataset + dt_temp <- copy(dataset) + + # Determine earliest and latest time in the data + minTime <- dt_temp[, min(get(timeID))] + maxTime <- dt_temp[, max(get(timeID))] + + # Include only treated groups and clean controls (not-yet-treated) + dt_temp <- dt_temp[ + get(adoptionTime) == focalAdoptionTime | + get(adoptionTime) > focalAdoptionTime + kappa_post | + is.na(get(adoptionTime)) + ] + + # Limit to time periods inside the event window + dt_temp <- dt_temp[ + get(timeID) %in% (focalAdoptionTime - kappa_pre):(focalAdoptionTime + kappa_post) + ] + + # Make treatment group dummy + dt_temp[, treat := 0] + dt_temp[get(adoptionTime) == focalAdoptionTime, treat := 1] + + # Make a post variable + dt_temp[, post := fifelse(get(timeID) >= focalAdoptionTime, 1, 0)] + + # Make event time variable + dt_temp[, event_time := get(timeID) - focalAdoptionTime] + + # Check feasibility (IC1) + dt_temp[, feasible := fifelse( + focalAdoptionTime - kappa_pre >= minTime & + focalAdoptionTime + kappa_post <= maxTime, 1, 0 + )] + + # Make a sub experiment ID + dt_temp[, sub_exp := focalAdoptionTime] + + return(dt_temp) +} + +compute_weights <- function(dataset, treatedVar, eventTimeVar, subexpVar) { + # Create a copy + stack_dt_temp <- copy(dataset) + + # Step 1: Compute stack-time counts for treated and control + stack_dt_temp[, `:=`( + stack_n = .N, + stack_treat_n = sum(get(treatedVar)), + stack_control_n = sum(1 - get(treatedVar)) + ), by = get(eventTimeVar)] + + # Step 2: Compute sub_exp-level counts + stack_dt_temp[, `:=`( + sub_n = .N, + sub_treat_n = sum(get(treatedVar)), + sub_control_n = sum(1 - get(treatedVar)) + ), by = list(get(subexpVar), get(eventTimeVar))] + + # Step 3: Compute sub-experiment share of totals + stack_dt_temp[, sub_share := sub_n / stack_n] + stack_dt_temp[, `:=`( + sub_treat_share = sub_treat_n / stack_treat_n, + sub_control_share = sub_control_n / stack_control_n + )] + + # Step 4: Compute weights (aggregate weighting) + stack_dt_temp[get(treatedVar) == 1, stack_weight := 1] + stack_dt_temp[get(treatedVar) == 0, stack_weight := sub_treat_share / sub_control_share] + + return(stack_dt_temp) +} + +# ============================================================================= +# Command line argument parsing +# ============================================================================= + +args <- commandArgs(trailingOnly = TRUE) + +parse_args <- function(args) { + result <- list( + data = NULL, + output = NULL, + kappa_pre = 2, + kappa_post = 2 + ) + + i <- 1 + while (i <= length(args)) { + if (args[i] == "--data") { + result$data <- args[i + 1] + i <- i + 2 + } else if (args[i] == "--output") { + result$output <- args[i + 1] + i <- i + 2 + } else if (args[i] == "--kappa-pre") { + result$kappa_pre <- as.integer(args[i + 1]) + i <- i + 2 + } else if (args[i] == "--kappa-post") { + result$kappa_post <- as.integer(args[i + 1]) + i <- i + 2 + } else { + i <- i + 1 + } + } + + if (is.null(result$data) || is.null(result$output)) { + stop("Usage: Rscript benchmark_stacked_did.R --data --output [--kappa-pre N] [--kappa-post N]") + } + + return(result) +} + +config <- parse_args(args) + +# ============================================================================= +# Load data +# ============================================================================= + +message(sprintf("Loading data from: %s", config$data)) +data <- fread(config$data) + +# Ensure proper column types +data[, unit := as.integer(unit)] +data[, time := as.integer(time)] +data[, first_treat := as.integer(first_treat)] + +# Convert never-treated (first_treat=0) to NA (R convention) +data[first_treat == 0, first_treat := NA_integer_] + +n_units <- length(unique(data$unit)) +n_periods <- length(unique(data$time)) +message(sprintf("Data: %d units, %d periods, %d obs", n_units, n_periods, nrow(data))) + +# Get unique adoption events (excluding never-treated) +events <- sort(unique(data[!is.na(first_treat), first_treat])) +message(sprintf("Adoption events: %s", paste(events, collapse = ", "))) +message(sprintf("Never-treated units: %d", sum(is.na(data[, .SD[1], by = unit]$first_treat)))) + +kappa_pre <- config$kappa_pre +kappa_post <- config$kappa_post +message(sprintf("kappa_pre=%d, kappa_post=%d", kappa_pre, kappa_post)) + +# ============================================================================= +# Build stacked dataset +# ============================================================================= + +message("Building stacked dataset...") +start_time <- Sys.time() + +sub_experiments <- list() + +for (j in events) { + sub_name <- paste0("sub_", j) + sub_experiments[[sub_name]] <- create_sub_exp( + dataset = data, + timeID = "time", + groupID = "unit", + adoptionTime = "first_treat", + focalAdoptionTime = j, + kappa_pre = kappa_pre, + kappa_post = kappa_post + ) +} + +# Vertically concatenate +stackfull <- rbindlist(sub_experiments) + +# Remove infeasible sub-experiments (IC1) +stacked_data <- stackfull[feasible == 1] + +if (nrow(stacked_data) == 0) { + stop("All sub-experiments are infeasible. Try smaller kappa values.") +} + +n_feasible <- length(unique(stacked_data$sub_exp)) +message(sprintf("Feasible sub-experiments: %d / %d", n_feasible, length(events))) +message(sprintf("Stacked dataset: %d obs", nrow(stacked_data))) + +# ============================================================================= +# Compute Q-weights (aggregate weighting) +# ============================================================================= + +message("Computing Q-weights...") +stacked_data2 <- compute_weights( + dataset = stacked_data, + treatedVar = "treat", + eventTimeVar = "event_time", + subexpVar = "sub_exp" +) + +stack_time <- as.numeric(difftime(Sys.time(), start_time, units = "secs")) +message(sprintf("Stacking + weights completed in %.3f seconds", stack_time)) + +# ============================================================================= +# Run WLS regression (event study) +# ============================================================================= + +message("Running weighted event study regression...") +reg_start_time <- Sys.time() + +# Weighted event study with fixed effects (matches Equation 3 in paper) +# i(event_time, treat, ref = -1) creates treat × event_time interactions, omitting -1 +weight_stack <- feols( + outcome ~ i(event_time, treat, ref = -1) | treat + event_time, + data = stacked_data2, + cluster = ~unit, + weights = stacked_data2$stack_weight +) + +reg_time <- as.numeric(difftime(Sys.time(), reg_start_time, units = "secs")) +total_time <- stack_time + reg_time +message(sprintf("Regression completed in %.3f seconds", reg_time)) + +# ============================================================================= +# Extract results +# ============================================================================= + +# Extract event study coefficients +coef_table <- coeftable(weight_stack) +coef_names <- rownames(coef_table) + +# Parse event study effects +event_study <- data.frame( + event_time = integer(0), + att = numeric(0), + se = numeric(0) +) + +for (i in seq_len(nrow(coef_table))) { + name <- coef_names[i] + # Parse "event_time::X:treat" pattern + if (grepl("event_time::", name)) { + et <- as.integer(gsub("event_time::|:treat", "", name)) + event_study <- rbind(event_study, data.frame( + event_time = et, + att = coef_table[i, "Estimate"], + se = coef_table[i, "Std. Error"] + )) + } +} + +event_study <- event_study[order(event_study$event_time), ] + +message("Event study effects:") +for (i in seq_len(nrow(event_study))) { + message(sprintf(" h=%d: ATT=%.6f (SE=%.6f)", + event_study$event_time[i], + event_study$att[i], + event_study$se[i])) +} + +# Compute overall ATT (average of post-treatment effects) +post_effects <- event_study[event_study$event_time >= 0, ] +if (nrow(post_effects) > 0) { + # Use hypotheses() for delta-method SE + K <- nrow(post_effects) + hyp_terms <- paste0("`event_time::", post_effects$event_time, ":treat`") + hyp_formula <- paste0("(", paste(hyp_terms, collapse = " + "), ") / ", K, " = 0") + + hyp_result <- tryCatch({ + marginaleffects::hypotheses(weight_stack, hyp_formula) + }, error = function(e) { + # Fallback: simple average without delta-method SE + message(sprintf(" hypotheses() not available, using simple average: %s", e$message)) + NULL + }) + + if (!is.null(hyp_result)) { + overall_att <- hyp_result$estimate + overall_se <- hyp_result$std.error + } else { + overall_att <- mean(post_effects$att) + # Approximate SE: sqrt(mean(se^2) / K) + overall_se <- sqrt(sum(post_effects$se^2)) / K + } +} else { + overall_att <- NA_real_ + overall_se <- NA_real_ +} + +message(sprintf("Overall ATT: %.6f (SE: %.6f)", overall_att, overall_se)) +message(sprintf("Total time: %.3f seconds", total_time)) + +# ============================================================================= +# Format and write output +# ============================================================================= + +results <- list( + estimator = "stacked-did-weights (Wing et al. 2024)", + + # Overall ATT + overall_att = overall_att, + overall_se = overall_se, + + # Event study + event_study = event_study, + + # Timing + timing = list( + stacking_seconds = stack_time, + regression_seconds = reg_time, + total_seconds = total_time + ), + + # Metadata + metadata = list( + r_version = R.version.string, + fixest_version = as.character(packageVersion("fixest")), + n_units = n_units, + n_periods = n_periods, + n_obs = nrow(data), + n_stacked_obs = nrow(stacked_data2), + n_sub_experiments = n_feasible, + kappa_pre = kappa_pre, + kappa_post = kappa_post, + weighting = "aggregate", + clean_control = "not_yet_treated" + ) +) + +message(sprintf("Writing results to: %s", config$output)) +dir.create(dirname(config$output), recursive = TRUE, showWarnings = FALSE) +write_json(results, config$output, auto_unbox = TRUE, pretty = TRUE, digits = 10) + +message(sprintf("Completed in %.3f seconds", total_time)) diff --git a/benchmarks/python/benchmark_stacked_did.py b/benchmarks/python/benchmark_stacked_did.py new file mode 100644 index 0000000..51de422 --- /dev/null +++ b/benchmarks/python/benchmark_stacked_did.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Benchmark: Stacked DiD Estimator (diff-diff StackedDiD). + +Compares against R's stacked-did-weights reference implementation +(Wing, Freedman & Hollingsworth 2024). + +Usage: + python benchmark_stacked_did.py --data path/to/data.csv --output path/to/results.json + python benchmark_stacked_did.py --data path/to/data.csv --output path/to/results.json --kappa-pre 2 --kappa-post 2 +""" + +import argparse +import json +import os +import sys +from pathlib import Path + + +# IMPORTANT: Parse --backend and set environment variable BEFORE importing diff_diff +def _get_backend_from_args(): + """Parse --backend argument without importing diff_diff.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--backend", default="auto", choices=["auto", "python", "rust"]) + args, _ = parser.parse_known_args() + return args.backend + + +_requested_backend = _get_backend_from_args() +if _requested_backend in ("python", "rust"): + os.environ["DIFF_DIFF_BACKEND"] = _requested_backend + +# NOW import diff_diff and other dependencies +import pandas as pd + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from diff_diff import StackedDiD, HAS_RUST_BACKEND +from benchmarks.python.utils import Timer + + +def parse_args(): + parser = argparse.ArgumentParser(description="Benchmark Stacked DiD estimator") + parser.add_argument("--data", required=True, help="Path to input CSV data") + parser.add_argument("--output", required=True, help="Path to output JSON results") + parser.add_argument( + "--backend", + default="auto", + choices=["auto", "python", "rust"], + help="Backend to use: auto (default), python (pure Python), rust (Rust backend)", + ) + parser.add_argument( + "--kappa-pre", + type=int, + default=2, + help="Number of pre-treatment event-study periods (default: 2)", + ) + parser.add_argument( + "--kappa-post", + type=int, + default=2, + help="Number of post-treatment event-study periods (default: 2)", + ) + return parser.parse_args() + + +def get_actual_backend() -> str: + """Return the actual backend being used based on HAS_RUST_BACKEND.""" + return "rust" if HAS_RUST_BACKEND else "python" + + +def main(): + args = parse_args() + + # Get actual backend + actual_backend = get_actual_backend() + print(f"Using backend: {actual_backend}") + + # Load data + print(f"Loading data from: {args.data}") + df = pd.read_csv(args.data) + + kappa_pre = args.kappa_pre + kappa_post = args.kappa_post + print(f"kappa_pre={kappa_pre}, kappa_post={kappa_post}") + + # Run benchmark + print("Running StackedDiD estimation...") + est = StackedDiD( + kappa_pre=kappa_pre, + kappa_post=kappa_post, + weighting="aggregate", + clean_control="not_yet_treated", + cluster="unit", + ) + + with Timer() as estimation_timer: + results = est.fit( + df, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + estimation_time = estimation_timer.elapsed + total_time = estimation_time + + # Store data info + n_units = len(df["unit"].unique()) + n_periods = len(df["time"].unique()) + n_obs = len(df) + + # Format event study effects + es_effects = [] + if results.event_study_effects: + for rel_t, effect_data in sorted(results.event_study_effects.items()): + # Skip reference period marker (n_obs == 0) + if effect_data.get("n_obs", 1) == 0: + continue + es_effects.append( + { + "event_time": int(rel_t), + "att": float(effect_data["effect"]), + "se": float(effect_data["se"]), + } + ) + + # Build output + output = { + "estimator": "diff_diff.StackedDiD", + "backend": actual_backend, + # Overall ATT + "overall_att": float(results.overall_att), + "overall_se": float(results.overall_se), + # Event study + "event_study": es_effects, + # Timing + "timing": { + "estimation_seconds": estimation_time, + "total_seconds": total_time, + }, + # Metadata + "metadata": { + "n_units": n_units, + "n_periods": n_periods, + "n_obs": n_obs, + "n_stacked_obs": results.n_stacked_obs, + "n_sub_experiments": results.n_sub_experiments, + "kappa_pre": kappa_pre, + "kappa_post": kappa_post, + "weighting": "aggregate", + "clean_control": "not_yet_treated", + }, + } + + # Write output + print(f"Writing results to: {args.output}") + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2) + + print(f"Overall ATT: {results.overall_att:.6f} (SE: {results.overall_se:.6f})") + if results.event_study_effects: + for h, eff in sorted(results.event_study_effects.items()): + if eff.get("n_obs", 1) > 0: + print(f" h={h}: ATT={eff['effect']:.6f} (SE={eff['se']:.6f})") + print(f"Completed in {total_time:.3f} seconds") + return output + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_benchmarks.py b/benchmarks/run_benchmarks.py index f5a05dd..bee27a0 100644 --- a/benchmarks/run_benchmarks.py +++ b/benchmarks/run_benchmarks.py @@ -134,8 +134,10 @@ def run_r_benchmark( cmd = [ "Rscript", str(r_script), - "--data", str(data_path), - "--output", str(output_path), + "--data", + str(data_path), + "--output", + str(output_path), ] if extra_args: cmd.extend(extra_args) @@ -193,9 +195,12 @@ def run_python_benchmark( cmd = [ sys.executable, str(py_script), - "--data", str(data_path), - "--output", str(output_path), - "--backend", backend, + "--data", + str(data_path), + "--output", + str(output_path), + "--backend", + backend, ] if extra_args: cmd.extend(extra_args) @@ -254,7 +259,9 @@ def generate_synthetic_datasets( # Staggered data for Callaway-Sant'Anna stag_cfg = config["staggered"] n_obs = stag_cfg["n_units"] * stag_cfg["n_periods"] - print(f" - staggered_{scale} ({stag_cfg['n_units']} units, {stag_cfg['n_periods']} periods, {n_obs:,} obs)") + print( + f" - staggered_{scale} ({stag_cfg['n_units']} units, {stag_cfg['n_periods']} periods, {n_obs:,} obs)" + ) staggered_data = generate_staggered_data( n_units=stag_cfg["n_units"], n_periods=stag_cfg["n_periods"], @@ -269,7 +276,9 @@ def generate_synthetic_datasets( # Basic 2x2 DiD data basic_cfg = config["basic"] n_obs = basic_cfg["n_units"] * basic_cfg["n_periods"] - print(f" - basic_{scale} ({basic_cfg['n_units']} units, {basic_cfg['n_periods']} periods, {n_obs:,} obs)") + print( + f" - basic_{scale} ({basic_cfg['n_units']} units, {basic_cfg['n_periods']} periods, {n_obs:,} obs)" + ) basic_data = generate_basic_did_data( n_units=basic_cfg["n_units"], n_periods=basic_cfg["n_periods"], @@ -302,7 +311,9 @@ def generate_synthetic_datasets( mp_cfg = config["multiperiod"] n_periods = mp_cfg["n_pre"] + mp_cfg["n_post"] n_obs = mp_cfg["n_units"] * n_periods - print(f" - multiperiod_{scale} ({mp_cfg['n_units']} units, {n_periods} periods, {n_obs:,} obs)") + print( + f" - multiperiod_{scale} ({mp_cfg['n_units']} units, {n_periods} periods, {n_obs:,} obs)" + ) multiperiod_data = generate_multiperiod_data( n_units=mp_cfg["n_units"], n_pre=mp_cfg["n_pre"], @@ -348,7 +359,9 @@ def run_callaway_benchmark( for backend in backends: # Map backend name to label (python -> pure, rust -> rust) backend_label = f"python_{'pure' if backend == 'python' else backend}" - print(f"\nRunning Python (diff_diff.CallawaySantAnna, backend={backend}) - {n_replications} replications...") + print( + f"\nRunning Python (diff_diff.CallawaySantAnna, backend={backend}) - {n_replications} replications..." + ) py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" py_output.parent.mkdir(parents=True, exist_ok=True) @@ -357,7 +370,9 @@ def run_callaway_benchmark( for rep in range(n_replications): try: py_result = run_python_benchmark( - "benchmark_callaway.py", data_path, py_output, + "benchmark_callaway.py", + data_path, + py_output, timeout=timeouts["python"], backend=backend, ) @@ -373,7 +388,9 @@ def run_callaway_benchmark( timing_stats = compute_timing_stats(py_timings) py_result["timing"] = timing_stats results[backend_label] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # For backward compatibility, also store as "python" (use rust if available) if results.get("python_rust"): @@ -390,8 +407,7 @@ def run_callaway_benchmark( for rep in range(n_replications): try: r_result = run_r_benchmark( - "benchmark_did.R", data_path, r_output, - timeout=timeouts["r"] + "benchmark_did.R", data_path, r_output, timeout=timeouts["r"] ) r_timings.append(r_result["timing"]["total_seconds"]) if rep == 0: @@ -405,13 +421,18 @@ def run_callaway_benchmark( timing_stats = compute_timing_stats(r_timings) r_result["timing"] = timing_stats results["r"] = r_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # Compare results if results["python"] and results["r"]: print("\nComparison (Python vs R):") comparison = compare_estimates( - results["python"], results["r"], "CallawaySantAnna", scale=scale, + results["python"], + results["r"], + "CallawaySantAnna", + scale=scale, python_pure_results=results.get("python_pure"), python_rust_results=results.get("python_rust"), ) @@ -426,8 +447,12 @@ def run_callaway_benchmark( print(f" {'-'*54}") r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None - pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None - rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + pure_mean = ( + results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + ) + rust_mean = ( + results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + ) if r_mean: print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") @@ -472,7 +497,9 @@ def run_synthdid_benchmark( for backend in backends: # Map backend name to label (python -> pure, rust -> rust) backend_label = f"python_{'pure' if backend == 'python' else backend}" - print(f"\nRunning Python (diff_diff.SyntheticDiD, backend={backend}) - {n_replications} replications...") + print( + f"\nRunning Python (diff_diff.SyntheticDiD, backend={backend}) - {n_replications} replications..." + ) py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" py_output.parent.mkdir(parents=True, exist_ok=True) @@ -500,7 +527,9 @@ def run_synthdid_benchmark( timing_stats = compute_timing_stats(py_timings) py_result["timing"] = timing_stats results[backend_label] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # For backward compatibility, also store as "python" (use rust if available) if results.get("python_rust"): @@ -517,8 +546,7 @@ def run_synthdid_benchmark( for rep in range(n_replications): try: r_result = run_r_benchmark( - "benchmark_synthdid.R", data_path, r_output, - timeout=timeouts["r"] + "benchmark_synthdid.R", data_path, r_output, timeout=timeouts["r"] ) r_timings.append(r_result["timing"]["total_seconds"]) if rep == 0: @@ -532,13 +560,18 @@ def run_synthdid_benchmark( timing_stats = compute_timing_stats(r_timings) r_result["timing"] = timing_stats results["r"] = r_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # Compare results if results["python"] and results["r"]: print("\nComparison (Python vs R):") comparison = compare_estimates( - results["python"], results["r"], "SyntheticDiD", scale=scale, + results["python"], + results["r"], + "SyntheticDiD", + scale=scale, python_pure_results=results.get("python_pure"), python_rust_results=results.get("python_rust"), ) @@ -553,8 +586,12 @@ def run_synthdid_benchmark( print(f" {'-'*54}") r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None - pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None - rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + pure_mean = ( + results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + ) + rust_mean = ( + results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + ) if r_mean: print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") @@ -599,7 +636,9 @@ def run_basic_did_benchmark( for backend in backends: # Map backend name to label (python -> pure, rust -> rust) backend_label = f"python_{'pure' if backend == 'python' else backend}" - print(f"\nRunning Python (diff_diff.DifferenceInDifferences, backend={backend}) - {n_replications} replications...") + print( + f"\nRunning Python (diff_diff.DifferenceInDifferences, backend={backend}) - {n_replications} replications..." + ) py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" py_output.parent.mkdir(parents=True, exist_ok=True) @@ -608,7 +647,9 @@ def run_basic_did_benchmark( for rep in range(n_replications): try: py_result = run_python_benchmark( - "benchmark_basic.py", data_path, py_output, + "benchmark_basic.py", + data_path, + py_output, extra_args=["--type", "twfe"], timeout=timeouts["python"], backend=backend, @@ -625,7 +666,9 @@ def run_basic_did_benchmark( timing_stats = compute_timing_stats(py_timings) py_result["timing"] = timing_stats results[backend_label] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # For backward compatibility, also store as "python" (use rust if available) if results.get("python_rust"): @@ -642,9 +685,11 @@ def run_basic_did_benchmark( for rep in range(n_replications): try: r_result = run_r_benchmark( - "benchmark_fixest.R", data_path, r_output, + "benchmark_fixest.R", + data_path, + r_output, extra_args=["--type", "twfe"], - timeout=timeouts["r"] + timeout=timeouts["r"], ) r_timings.append(r_result["timing"]["total_seconds"]) if rep == 0: @@ -658,13 +703,18 @@ def run_basic_did_benchmark( timing_stats = compute_timing_stats(r_timings) r_result["timing"] = timing_stats results["r"] = r_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # Compare results if results["python"] and results["r"]: print("\nComparison (Python vs R):") comparison = compare_estimates( - results["python"], results["r"], "BasicDiD/TWFE", scale=scale, + results["python"], + results["r"], + "BasicDiD/TWFE", + scale=scale, python_pure_results=results.get("python_pure"), python_rust_results=results.get("python_rust"), ) @@ -679,8 +729,12 @@ def run_basic_did_benchmark( print(f" {'-'*54}") r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None - pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None - rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + pure_mean = ( + results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + ) + rust_mean = ( + results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + ) if r_mean: print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") @@ -724,7 +778,9 @@ def run_twfe_benchmark( # Run Python benchmark for each backend for backend in backends: backend_label = f"python_{'pure' if backend == 'python' else backend}" - print(f"\nRunning Python (diff_diff.TwoWayFixedEffects, backend={backend}) - {n_replications} replications...") + print( + f"\nRunning Python (diff_diff.TwoWayFixedEffects, backend={backend}) - {n_replications} replications..." + ) py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" py_output.parent.mkdir(parents=True, exist_ok=True) @@ -733,7 +789,9 @@ def run_twfe_benchmark( for rep in range(n_replications): try: py_result = run_python_benchmark( - "benchmark_twfe.py", data_path, py_output, + "benchmark_twfe.py", + data_path, + py_output, timeout=timeouts["python"], backend=backend, ) @@ -749,7 +807,9 @@ def run_twfe_benchmark( timing_stats = compute_timing_stats(py_timings) py_result["timing"] = timing_stats results[backend_label] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # For backward compatibility, also store as "python" (use rust if available) if results.get("python_rust"): @@ -766,8 +826,7 @@ def run_twfe_benchmark( for rep in range(n_replications): try: r_result = run_r_benchmark( - "benchmark_twfe.R", data_path, r_output, - timeout=timeouts["r"] + "benchmark_twfe.R", data_path, r_output, timeout=timeouts["r"] ) r_timings.append(r_result["timing"]["total_seconds"]) if rep == 0: @@ -781,13 +840,18 @@ def run_twfe_benchmark( timing_stats = compute_timing_stats(r_timings) r_result["timing"] = timing_stats results["r"] = r_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # Compare results if results["python"] and results["r"]: print("\nComparison (Python vs R):") comparison = compare_estimates( - results["python"], results["r"], "TWFE", scale=scale, + results["python"], + results["r"], + "TWFE", + scale=scale, se_rtol=0.01, python_pure_results=results.get("python_pure"), python_rust_results=results.get("python_rust"), @@ -803,8 +867,12 @@ def run_twfe_benchmark( print(f" {'-'*54}") r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None - pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None - rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + pure_mean = ( + results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + ) + rust_mean = ( + results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + ) if r_mean: print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") @@ -851,7 +919,9 @@ def run_multiperiod_benchmark( # Run Python benchmark for each backend for backend in backends: backend_label = f"python_{'pure' if backend == 'python' else backend}" - print(f"\nRunning Python (diff_diff.MultiPeriodDiD, backend={backend}) - {n_replications} replications...") + print( + f"\nRunning Python (diff_diff.MultiPeriodDiD, backend={backend}) - {n_replications} replications..." + ) py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" py_output.parent.mkdir(parents=True, exist_ok=True) @@ -860,7 +930,9 @@ def run_multiperiod_benchmark( for rep in range(n_replications): try: py_result = run_python_benchmark( - "benchmark_multiperiod.py", data_path, py_output, + "benchmark_multiperiod.py", + data_path, + py_output, extra_args=extra_args, timeout=timeouts["python"], backend=backend, @@ -877,7 +949,9 @@ def run_multiperiod_benchmark( timing_stats = compute_timing_stats(py_timings) py_result["timing"] = timing_stats results[backend_label] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # For backward compatibility, also store as "python" (use rust if available) if results.get("python_rust"): @@ -894,9 +968,11 @@ def run_multiperiod_benchmark( for rep in range(n_replications): try: r_result = run_r_benchmark( - "benchmark_multiperiod.R", data_path, r_output, + "benchmark_multiperiod.R", + data_path, + r_output, extra_args=extra_args, - timeout=timeouts["r"] + timeout=timeouts["r"], ) r_timings.append(r_result["timing"]["total_seconds"]) if rep == 0: @@ -910,13 +986,18 @@ def run_multiperiod_benchmark( timing_stats = compute_timing_stats(r_timings) r_result["timing"] = timing_stats results["r"] = r_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # Compare results if results["python"] and results["r"]: print("\nComparison (Python vs R):") comparison = compare_estimates( - results["python"], results["r"], "MultiPeriodDiD", scale=scale, + results["python"], + results["r"], + "MultiPeriodDiD", + scale=scale, se_rtol=0.01, python_pure_results=results.get("python_pure"), python_rust_results=results.get("python_rust"), @@ -941,8 +1022,12 @@ def run_multiperiod_benchmark( print(f" {'-'*54}") r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None - pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None - rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + pure_mean = ( + results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + ) + rust_mean = ( + results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + ) if r_mean: print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") @@ -987,7 +1072,9 @@ def run_imputation_benchmark( for backend in backends: # Map backend name to label (python -> pure, rust -> rust) backend_label = f"python_{'pure' if backend == 'python' else backend}" - print(f"\nRunning Python (diff_diff.ImputationDiD, backend={backend}) - {n_replications} replications...") + print( + f"\nRunning Python (diff_diff.ImputationDiD, backend={backend}) - {n_replications} replications..." + ) py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" py_output.parent.mkdir(parents=True, exist_ok=True) @@ -996,7 +1083,9 @@ def run_imputation_benchmark( for rep in range(n_replications): try: py_result = run_python_benchmark( - "benchmark_imputation.py", data_path, py_output, + "benchmark_imputation.py", + data_path, + py_output, timeout=timeouts["python"], backend=backend, ) @@ -1012,7 +1101,9 @@ def run_imputation_benchmark( timing_stats = compute_timing_stats(py_timings) py_result["timing"] = timing_stats results[backend_label] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # For backward compatibility, also store as "python" (use rust if available) if results.get("python_rust"): @@ -1029,8 +1120,7 @@ def run_imputation_benchmark( for rep in range(n_replications): try: r_result = run_r_benchmark( - "benchmark_didimputation.R", data_path, r_output, - timeout=timeouts["r"] + "benchmark_didimputation.R", data_path, r_output, timeout=timeouts["r"] ) r_timings.append(r_result["timing"]["total_seconds"]) if rep == 0: @@ -1044,13 +1134,18 @@ def run_imputation_benchmark( timing_stats = compute_timing_stats(r_timings) r_result["timing"] = timing_stats results["r"] = r_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # Compare results if results.get("python") and results.get("r"): print("\nComparison (Python vs R):") comparison = compare_estimates( - results["python"], results["r"], "ImputationDiD", scale=scale, + results["python"], + results["r"], + "ImputationDiD", + scale=scale, python_pure_results=results.get("python_pure"), python_rust_results=results.get("python_rust"), ) @@ -1074,8 +1169,12 @@ def run_imputation_benchmark( print(f" {'-'*54}") r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None - pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None - rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + pure_mean = ( + results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + ) + rust_mean = ( + results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + ) if r_mean: print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") @@ -1119,7 +1218,9 @@ def run_sunab_benchmark( # Run Python benchmark for each backend for backend in backends: backend_label = f"python_{'pure' if backend == 'python' else backend}" - print(f"\nRunning Python (diff_diff.SunAbraham, backend={backend}) - {n_replications} replications...") + print( + f"\nRunning Python (diff_diff.SunAbraham, backend={backend}) - {n_replications} replications..." + ) py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" py_output.parent.mkdir(parents=True, exist_ok=True) @@ -1128,7 +1229,9 @@ def run_sunab_benchmark( for rep in range(n_replications): try: py_result = run_python_benchmark( - "benchmark_sun_abraham.py", data_path, py_output, + "benchmark_sun_abraham.py", + data_path, + py_output, timeout=timeouts["python"], backend=backend, ) @@ -1144,7 +1247,9 @@ def run_sunab_benchmark( timing_stats = compute_timing_stats(py_timings) py_result["timing"] = timing_stats results[backend_label] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # For backward compatibility, also store as "python" (use rust if available) if results.get("python_rust"): @@ -1161,8 +1266,7 @@ def run_sunab_benchmark( for rep in range(n_replications): try: r_result = run_r_benchmark( - "benchmark_sunab.R", data_path, r_output, - timeout=timeouts["r"] + "benchmark_sunab.R", data_path, r_output, timeout=timeouts["r"] ) r_timings.append(r_result["timing"]["total_seconds"]) if rep == 0: @@ -1176,13 +1280,18 @@ def run_sunab_benchmark( timing_stats = compute_timing_stats(r_timings) r_result["timing"] = timing_stats results["r"] = r_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) # Compare results if results.get("python") and results.get("r"): print("\nComparison (Python vs R):") comparison = compare_estimates( - results["python"], results["r"], "SunAbraham", scale=scale, + results["python"], + results["r"], + "SunAbraham", + scale=scale, se_rtol=0.01, python_pure_results=results.get("python_pure"), python_rust_results=results.get("python_rust"), @@ -1207,8 +1316,12 @@ def run_sunab_benchmark( print(f" {'-'*54}") r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None - pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None - rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + pure_mean = ( + results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + ) + rust_mean = ( + results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + ) if r_mean: print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") @@ -1223,10 +1336,154 @@ def run_sunab_benchmark( return results -def main(): - parser = argparse.ArgumentParser( - description="Run diff-diff benchmarks against R packages" +def run_stacked_did_benchmark( + data_path: Path, + name: str = "stacked", + scale: str = "small", + n_replications: int = 1, + backends: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Run Stacked DiD benchmarks (Python and R) with replications.""" + print(f"\n{'='*60}") + print(f"STACKED DID BENCHMARK ({scale})") + print(f"{'='*60}") + + if backends is None: + backends = ["python", "rust"] + + timeouts = TIMEOUT_CONFIGS.get(scale, TIMEOUT_CONFIGS["small"]) + results = { + "name": name, + "scale": scale, + "n_replications": n_replications, + "python_pure": None, + "python_rust": None, + "r": None, + "comparison": None, + } + + # Run Python benchmark for each backend + for backend in backends: + backend_label = f"python_{'pure' if backend == 'python' else backend}" + print( + f"\nRunning Python (diff_diff.StackedDiD, backend={backend}) - {n_replications} replications..." + ) + py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" + py_output.parent.mkdir(parents=True, exist_ok=True) + + py_timings = [] + py_result = None + for rep in range(n_replications): + try: + py_result = run_python_benchmark( + "benchmark_stacked_did.py", + data_path, + py_output, + timeout=timeouts["python"], + backend=backend, + ) + py_timings.append(py_result["timing"]["total_seconds"]) + if rep == 0: + print(f" ATT: {py_result['overall_att']:.4f}") + print(f" SE: {py_result['overall_se']:.4f}") + print(f" Rep {rep+1}/{n_replications}: {py_timings[-1]:.3f}s") + except Exception as e: + print(f" Rep {rep+1} failed: {e}") + + if py_result and py_timings: + timing_stats = compute_timing_stats(py_timings) + py_result["timing"] = timing_stats + results[backend_label] = py_result + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) + + # For backward compatibility, also store as "python" (use rust if available) + if results.get("python_rust"): + results["python"] = results["python_rust"] + elif results.get("python_pure"): + results["python"] = results["python_pure"] + + # R benchmark with replications + print(f"\nRunning R (stacked-did-weights + fixest) - {n_replications} replications...") + r_output = RESULTS_DIR / "accuracy" / f"r_{name}_{scale}.json" + + r_timings = [] + r_result = None + for rep in range(n_replications): + try: + r_result = run_r_benchmark( + "benchmark_stacked_did.R", data_path, r_output, timeout=timeouts["r"] + ) + r_timings.append(r_result["timing"]["total_seconds"]) + if rep == 0: + print(f" ATT: {r_result['overall_att']:.4f}") + print(f" SE: {r_result['overall_se']:.4f}") + print(f" Rep {rep+1}/{n_replications}: {r_timings[-1]:.3f}s") + except Exception as e: + print(f" Rep {rep+1} failed: {e}") + + if r_result and r_timings: + timing_stats = compute_timing_stats(r_timings) + r_result["timing"] = timing_stats + results["r"] = r_result + print( + f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s" + ) + + # Compare results + if results.get("python") and results.get("r"): + print("\nComparison (Python vs R):") + comparison = compare_estimates( + results["python"], + results["r"], + "StackedDiD", + scale=scale, + python_pure_results=results.get("python_pure"), + python_rust_results=results.get("python_rust"), + ) + results["comparison"] = comparison + print(f" ATT diff: {comparison.att_diff:.2e}") + print(f" SE rel diff: {comparison.se_rel_diff:.1%}") + print(f" Status: {'PASS' if comparison.passed else 'FAIL'}") + + # Event study comparison + py_effects = results["python"].get("event_study", []) + r_effects = results["r"].get("event_study", []) + if py_effects and r_effects: + corr, max_diff, all_close = compare_event_study(py_effects, r_effects) + print(f" Event study correlation: {corr:.6f}") + print(f" Event study max diff: {max_diff:.2e}") + print(f" Event study all close: {all_close}") + + # Print timing comparison table + print("\nTiming Comparison:") + print(f" {'Backend':<15} {'Time (s)':<12} {'vs R':<12} {'vs Pure Python':<15}") + print(f" {'-'*54}") + + r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None + pure_mean = ( + results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None ) + rust_mean = ( + results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + ) + + if r_mean: + print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") + if pure_mean: + r_speedup = f"{r_mean/pure_mean:.2f}x" if r_mean else "-" + print(f" {'Python (pure)':<15} {pure_mean:<12.3f} {r_speedup:<12} {'1.00x':<15}") + if rust_mean: + r_speedup = f"{r_mean/rust_mean:.2f}x" if r_mean else "-" + pure_speedup = f"{pure_mean/rust_mean:.2f}x" if pure_mean else "-" + print(f" {'Python (rust)':<15} {rust_mean:<12.3f} {r_speedup:<12} {pure_speedup:<15}") + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Run diff-diff benchmarks against R packages") parser.add_argument( "--all", action="store_true", @@ -1234,7 +1491,16 @@ def main(): ) parser.add_argument( "--estimator", - choices=["callaway", "synthdid", "basic", "twfe", "multiperiod", "imputation", "sunab"], + choices=[ + "callaway", + "synthdid", + "basic", + "twfe", + "multiperiod", + "imputation", + "sunab", + "stacked", + ], help="Run specific estimator benchmark", ) parser.add_argument( @@ -1367,6 +1633,17 @@ def main(): ) all_results.append(results) + if args.all or args.estimator == "stacked": + # Stacked DiD uses the same staggered data as Callaway-Sant'Anna + stag_key = f"staggered_{scale}" + if stag_key in datasets: + results = run_stacked_did_benchmark( + datasets[stag_key], + scale=scale, + n_replications=args.replications, + ) + all_results.append(results) + # Generate summary report if all_results: print(f"\n{'='*60}") @@ -1375,9 +1652,7 @@ def main(): comparisons = [r["comparison"] for r in all_results if r.get("comparison")] if comparisons: - report = generate_comparison_report( - comparisons, RESULTS_DIR / "comparison_report.txt" - ) + report = generate_comparison_report(comparisons, RESULTS_DIR / "comparison_report.txt") print(report) else: print("No comparisons available.") diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index 1935362..29fdd25 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -107,6 +107,11 @@ TwoStageDiDResults, two_stage_did, ) +from diff_diff.stacked_did import ( + StackedDiD, + StackedDiDResults, + stacked_did, +) from diff_diff.sun_abraham import ( SABootstrapResults, SunAbraham, @@ -161,6 +166,7 @@ "TwoStageDiD", "TripleDifference", "TROP", + "StackedDiD", # Bacon Decomposition "BaconDecomposition", "BaconDecompositionResults", @@ -187,6 +193,8 @@ "triple_difference", "TROPResults", "trop", + "StackedDiDResults", + "stacked_did", # Visualization "plot_event_study", "plot_group_effects", diff --git a/diff_diff/stacked_did.py b/diff_diff/stacked_did.py new file mode 100644 index 0000000..9f133fd --- /dev/null +++ b/diff_diff/stacked_did.py @@ -0,0 +1,871 @@ +""" +Wing, Freedman & Hollingsworth (2024) Stacked Difference-in-Differences Estimator. + +Implements the stacked DiD estimator from Wing, Freedman & Hollingsworth (2024), +NBER Working Paper 32054. The key contribution: naive stacked DiD regressions are +biased because they implicitly weight treatment and control group trends differently +across sub-experiments. The authors derive corrective Q-weights that make a weighted +stacked regression identify the "trimmed aggregate ATT" — a well-defined convex +combination of group-time ATTs with stable composition across event time. + +The implementation follows the R reference code at +https://github.com/hollina/stacked-did-weights. + +References +---------- +Wing, C., Freedman, S. M., & Hollingsworth, A. (2024). Stacked + Difference-in-Differences. NBER Working Paper 32054. +""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.linalg import solve_ols +from diff_diff.stacked_did_results import StackedDiDResults # noqa: F401 (re-export) +from diff_diff.utils import safe_inference + +__all__ = [ + "StackedDiD", + "StackedDiDResults", + "stacked_did", +] + + +class StackedDiD: + """ + Stacked Difference-in-Differences estimator. + + Implements Wing, Freedman & Hollingsworth (2024). Builds a stacked + dataset of sub-experiments (one per adoption cohort), applies + corrective Q-weights to address implicit weighting bias in naive + stacked regressions, and runs a weighted event-study regression. + + Parameters + ---------- + kappa_pre : int, default=1 + Number of pre-treatment event-time periods in the event window. + The event window spans [-kappa_pre, ..., kappa_post]. + kappa_post : int, default=1 + Number of post-treatment event-time periods. + weighting : str, default="aggregate" + Target estimand weighting scheme per Table 1 of the paper: + - "aggregate": Equal weight per adoption event (trimmed aggregate ATT) + - "population": Weight by population size of treated cohort + - "sample_share": Weight by sample share of each sub-experiment + clean_control : str, default="not_yet_treated" + How to define clean controls per Appendix A of the paper: + - "not_yet_treated": Units with A_s > a + kappa_post + - "strict": Units with A_s > a + kappa_post + kappa_pre + - "never_treated": Only units with A_s = infinity + cluster : str, default="unit" + Clustering level for standard errors: + - "unit": Cluster on original unit identifier + - "unit_subexp": Cluster on (unit, sub_experiment) pairs + alpha : float, default=0.05 + Significance level for confidence intervals. + anticipation : int, default=0 + Number of anticipation periods. When anticipation > 0: + - Reference period shifts from e=-1 to e=-1-anticipation + - Post-treatment includes anticipation periods (e >= -anticipation) + - Event window expands by anticipation pre-periods + Consistent with ImputationDiD, TwoStageDiD, SunAbraham. + rank_deficient_action : str, default="warn" + Action when design matrix is rank-deficient: + - "warn": Issue warning and drop linearly dependent columns + - "error": Raise ValueError + - "silent": Drop columns silently + + Attributes + ---------- + results_ : StackedDiDResults + Estimation results after calling fit(). + is_fitted_ : bool + Whether the model has been fitted. + + Examples + -------- + Basic usage: + + >>> from diff_diff import StackedDiD, generate_staggered_data + >>> data = generate_staggered_data(n_units=200, seed=42) + >>> est = StackedDiD(kappa_pre=2, kappa_post=2) + >>> results = est.fit(data, outcome='outcome', unit='unit', + ... time='period', first_treat='first_treat') + >>> results.print_summary() + + With event study: + + >>> results = est.fit(data, outcome='outcome', unit='unit', + ... time='period', first_treat='first_treat', + ... aggregate='event_study') + >>> from diff_diff import plot_event_study + >>> plot_event_study(results) + + Notes + ----- + The stacked estimator addresses TWFE bias by: + 1. Creating one sub-experiment per adoption cohort with clean controls + 2. Applying Q-weights to reweight the stacked regression + 3. Running a single event-study WLS regression on the weighted stack + + References + ---------- + Wing, C., Freedman, S. M., & Hollingsworth, A. (2024). Stacked + Difference-in-Differences. NBER Working Paper 32054. + """ + + def __init__( + self, + kappa_pre: int = 1, + kappa_post: int = 1, + weighting: str = "aggregate", + clean_control: str = "not_yet_treated", + cluster: str = "unit", + alpha: float = 0.05, + anticipation: int = 0, + rank_deficient_action: str = "warn", + ): + if weighting not in ("aggregate", "population", "sample_share"): + raise ValueError( + f"weighting must be 'aggregate', 'population', or 'sample_share', " + f"got '{weighting}'" + ) + if clean_control not in ("not_yet_treated", "strict", "never_treated"): + raise ValueError( + f"clean_control must be 'not_yet_treated', 'strict', or " + f"'never_treated', got '{clean_control}'" + ) + if cluster not in ("unit", "unit_subexp"): + raise ValueError(f"cluster must be 'unit' or 'unit_subexp', got '{cluster}'") + if rank_deficient_action not in ("warn", "error", "silent"): + raise ValueError( + f"rank_deficient_action must be 'warn', 'error', or 'silent', " + f"got '{rank_deficient_action}'" + ) + + self.kappa_pre = kappa_pre + self.kappa_post = kappa_post + self.weighting = weighting + self.clean_control = clean_control + self.cluster = cluster + self.alpha = alpha + self.anticipation = anticipation + self.rank_deficient_action = rank_deficient_action + + self.is_fitted_ = False + self.results_: Optional[StackedDiDResults] = None + + def fit( + self, + data: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + aggregate: Optional[str] = None, + population: Optional[str] = None, + ) -> StackedDiDResults: + """ + Fit the stacked DiD estimator. + + Parameters + ---------- + data : pd.DataFrame + Panel data with unit and time identifiers. + outcome : str + Name of outcome variable column. + unit : str + Name of unit identifier column. + time : str + Name of time period column. + first_treat : str + Name of column indicating when unit was first treated. + Use 0 or np.inf for never-treated units. + aggregate : str, optional + Aggregation mode: None/"simple" (overall ATT only) or + "event_study". Group aggregation is not supported because + the pooled stacked regression cannot produce cohort-specific + effects. Use CallawaySantAnna or ImputationDiD for + cohort-level estimates. + population : str, optional + Column name for population weights. Required only when + weighting="population". + + Returns + ------- + StackedDiDResults + Object containing all estimation results. + + Raises + ------ + ValueError + If required columns are missing or data validation fails. + """ + # ---- Validate inputs ---- + if aggregate in ("group", "all"): + raise ValueError( + f"aggregate='{aggregate}' is not supported by StackedDiD. " + "The pooled stacked regression cannot produce cohort-specific " + "effects. Use CallawaySantAnna or ImputationDiD for " + "cohort-level estimates." + ) + if aggregate not in (None, "simple", "event_study"): + raise ValueError( + f"aggregate must be None, 'simple', or 'event_study', " f"got '{aggregate}'" + ) + + required_cols = [outcome, unit, time, first_treat] + if population is not None: + required_cols.append(population) + missing = [c for c in required_cols if c not in data.columns] + if missing: + raise ValueError(f"Missing columns: {missing}") + + if self.weighting == "population" and population is None: + raise ValueError("population column must be specified when weighting='population'") + + df = data.copy() + df[time] = pd.to_numeric(df[time]) + df[first_treat] = pd.to_numeric(df[first_treat]) + + # ---- Data setup ---- + # Handle never-treated encoding: both 0 and inf -> inf + df[first_treat] = df[first_treat].replace(0, np.inf) + + # Build unit_info: one row per unit + unit_info = ( + df.groupby(unit) + .agg({first_treat: "first"}) + .reset_index() + .rename(columns={first_treat: "_first_treat"}) + ) + + T_min = int(df[time].min()) + T_max = int(df[time].max()) + time_periods = sorted(df[time].unique()) + + # Extract unique adoption events (finite first_treat values) + omega_A = sorted([a for a in unit_info["_first_treat"].unique() if np.isfinite(a)]) + + if len(omega_A) == 0: + raise ValueError( + "No treated units found. Check 'first_treat' column " + "(use 0 or np.inf for never-treated units)." + ) + + # ---- Trim adoption events (IC1 + IC2) ---- + omega_kappa, trimmed = self._trim_adoption_events(omega_A, T_min, T_max, unit_info) + + # ---- Build stacked dataset ---- + sub_experiments = [] + skipped_events = [] + for a in omega_kappa: + sub_exp = self._build_sub_experiment(df, unit_info, a, unit, time, first_treat, outcome) + if sub_exp is not None and len(sub_exp) > 0: + sub_experiments.append(sub_exp) + else: + skipped_events.append(a) + + if skipped_events: + warnings.warn( + f"Sub-experiments for events {skipped_events} were empty " f"after filtering.", + UserWarning, + stacklevel=2, + ) + + if len(sub_experiments) == 0: + raise ValueError( + "All sub-experiments are empty after filtering. " + "Check your data or reduce kappa values." + ) + + stacked_df = pd.concat(sub_experiments, ignore_index=True) + + # ---- Compute Q-weights ---- + stacked_df = self._compute_q_weights(stacked_df, unit, population) + + # ---- Count units ---- + treated_units = stacked_df.loc[stacked_df["_D_sa"] == 1, unit].unique() + control_units = stacked_df.loc[stacked_df["_D_sa"] == 0, unit].unique() + n_treated_units = len(treated_units) + n_control_units = len(control_units) + + # ---- Build design matrix and run WLS ---- + # Always run event study regression (Equation 3 in paper) + # Reference period: e = -1 - anticipation (shifts when anticipation > 0) + ref_period = -1 - self.anticipation + event_times = sorted( + [ + h + for h in range(-self.kappa_pre - self.anticipation, self.kappa_post + 1) + if h != ref_period + ] + ) + + n = len(stacked_df) + n_event_dummies = len(event_times) + + # Track column indices for VCV extraction + # [0] intercept, [1] D_sa, [2..K+1] event-time dummies, + # [K+2..2K+1] D_sa * event-time interactions + interaction_indices: Dict[int, int] = {} + + # Build design matrix + X = np.zeros((n, 2 + 2 * n_event_dummies)) + X[:, 0] = 1.0 # intercept + X[:, 1] = stacked_df["_D_sa"].values # treatment indicator + + et_vals = stacked_df["_event_time"].values + d_vals = stacked_df["_D_sa"].values + + for j, h in enumerate(event_times): + col_lambda = 2 + j # event-time dummy + col_delta = 2 + n_event_dummies + j # interaction + mask = et_vals == h + X[mask, col_lambda] = 1.0 + X[mask, col_delta] = d_vals[mask] + interaction_indices[h] = col_delta + + # WLS via sqrt(w) transformation + Q_weights = stacked_df["_Q_weight"].values + sqrt_w = np.sqrt(Q_weights) + Y = stacked_df[outcome].values + Y_t = Y * sqrt_w + X_t = X * sqrt_w[:, np.newaxis] + + # Cluster IDs + if self.cluster == "unit": + cluster_ids = stacked_df[unit].values + else: # unit_subexp + cluster_ids = ( + stacked_df[unit].astype(str) + "_" + stacked_df["_sub_exp"].astype(str) + ).values + + # Run OLS on transformed data (= WLS) + coef, residuals, vcov = solve_ols( + X_t, + Y_t, + cluster_ids=cluster_ids, + return_vcov=True, + rank_deficient_action=self.rank_deficient_action, + ) + + # ---- Extract event study effects ---- + event_study_effects: Optional[Dict[int, Dict[str, Any]]] = None + if aggregate == "event_study": + event_study_effects = {} + # Reference period (e = -1 - anticipation) + event_study_effects[ref_period] = { + "effect": 0.0, + "se": 0.0, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_obs": 0, + } + for h in event_times: + idx = interaction_indices[h] + effect = float(coef[idx]) + se = float(np.sqrt(max(vcov[idx, idx], 0.0))) + t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha) + n_obs_h = int(np.sum((et_vals == h) & (d_vals == 1))) + event_study_effects[h] = { + "effect": effect, + "se": se, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + "n_obs": n_obs_h, + } + + # ---- Compute overall ATT ---- + # Average of post-treatment delta_h coefficients with delta-method SE + # Post-treatment includes anticipation periods (h >= -anticipation) + post_event_times = [ + h for h in event_times if h >= -self.anticipation and h in interaction_indices + ] + post_indices = [interaction_indices[h] for h in post_event_times] + K = len(post_indices) + + if K > 0: + overall_att = sum(float(coef[i]) for i in post_indices) / K + # Delta method: gradient = 1/K for each post-period coefficient + sub_vcv = vcov[np.ix_(post_indices, post_indices)] + ones = np.ones(K) + overall_se = float(np.sqrt(max(ones @ sub_vcv @ ones, 0.0))) / K + else: + overall_att = np.nan + overall_se = np.nan + + overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha) + + # ---- Construct results ---- + self.results_ = StackedDiDResults( + overall_att=overall_att, + overall_se=overall_se, + overall_t_stat=overall_t, + overall_p_value=overall_p, + overall_conf_int=overall_ci, + event_study_effects=event_study_effects, + group_effects=None, + stacked_data=stacked_df, + groups=list(omega_kappa), + trimmed_groups=list(trimmed), + time_periods=time_periods, + n_obs=len(data), + n_stacked_obs=n, + n_sub_experiments=len(sub_experiments), + n_treated_units=n_treated_units, + n_control_units=n_control_units, + kappa_pre=self.kappa_pre, + kappa_post=self.kappa_post, + weighting=self.weighting, + clean_control=self.clean_control, + alpha=self.alpha, + ) + + self.is_fitted_ = True + return self.results_ + + # ========================================================================= + # Trimming (IC1 + IC2) + # ========================================================================= + + def _trim_adoption_events( + self, + adoption_events: List[Any], + T_min: int, + T_max: int, + unit_info: pd.DataFrame, + ) -> Tuple[List[Any], List[Any]]: + """ + Trim adoption events based on IC1 (window) and IC2 (controls). + + IC1: a - kappa_pre >= T_min AND a + kappa_post <= T_max + (matches R reference: focalAdoptionTime - kappa_pre >= minTime + AND focalAdoptionTime + kappa_post <= maxTime) + With anticipation: a - kappa_pre - anticipation >= T_min + + IC2: Clean controls exist for this adoption event. + + Parameters + ---------- + adoption_events : list + Unique finite adoption event times. + T_min, T_max : int + Min and max time periods in the data. + unit_info : pd.DataFrame + One row per unit with _first_treat column. + + Returns + ------- + omega_kappa : list + Included adoption events. + trimmed : list + Excluded adoption events. + """ + omega_kappa = [] + trimmed = [] + + for a in adoption_events: + a_int = int(a) + + # IC1: Event window fits in data + # a - kappa_pre >= T_min AND a + kappa_post <= T_max + # (matches R reference: focalAdoptionTime - kappa_pre >= minTime) + # With anticipation: shift window start earlier + lower_ok = (a_int - self.kappa_pre - self.anticipation) >= T_min + upper_ok = (a_int + self.kappa_post) <= T_max + ic1 = lower_ok and upper_ok + + # IC2: Clean controls exist + ic2 = self._check_clean_controls_exist(a_int, unit_info) + + if ic1 and ic2: + omega_kappa.append(a) + else: + trimmed.append(a) + + if trimmed: + warnings.warn( + f"Trimmed {len(trimmed)} adoption event(s) that don't satisfy " + f"inclusion criteria: {trimmed}. " + f"IC1 requires event window [{-self.kappa_pre}, {self.kappa_post}] " + f"to fit within data range [{T_min}, {T_max}]. " + f"IC2 requires clean controls to exist.", + UserWarning, + stacklevel=3, + ) + + if len(omega_kappa) == 0: + raise ValueError( + f"All {len(adoption_events)} adoption events were trimmed. " + f"No valid sub-experiments can be constructed. " + f"Consider reducing kappa_pre (currently {self.kappa_pre}) " + f"or kappa_post (currently {self.kappa_post}), or check that " + f"clean control units exist." + ) + + return omega_kappa, trimmed + + def _check_clean_controls_exist(self, a: int, unit_info: pd.DataFrame) -> bool: + """Check IC2: whether clean control units exist for adoption event a.""" + ft = unit_info["_first_treat"].values + if self.clean_control == "not_yet_treated": + return bool(np.any(ft > a + self.kappa_post)) + elif self.clean_control == "strict": + return bool(np.any(ft > a + self.kappa_post + self.kappa_pre)) + else: # never_treated + return bool(np.any(np.isinf(ft))) + + # ========================================================================= + # Sub-experiment construction + # ========================================================================= + + def _build_sub_experiment( + self, + df: pd.DataFrame, + unit_info: pd.DataFrame, + a: Any, + unit: str, + time: str, + first_treat: str, + outcome: str, + ) -> Optional[pd.DataFrame]: + """ + Build a single sub-experiment for adoption event a. + + Parameters + ---------- + df : pd.DataFrame + Full panel data. + unit_info : pd.DataFrame + One row per unit with _first_treat. + a : int/float + Adoption event time. + unit, time, first_treat, outcome : str + Column names. + + Returns + ------- + pd.DataFrame or None + Sub-experiment data with _sub_exp, _event_time, _D_sa columns. + """ + a_int = int(a) + ft = unit_info["_first_treat"].values + unit_ids = unit_info[unit].values + + # Treated units: A_s = a + treated_mask = ft == a + treated_units = set(unit_ids[treated_mask]) + + # Clean control units + if self.clean_control == "not_yet_treated": + control_mask = ft > a_int + self.kappa_post + elif self.clean_control == "strict": + control_mask = ft > a_int + self.kappa_post + self.kappa_pre + else: # never_treated + control_mask = np.isinf(ft) + control_units = set(unit_ids[control_mask]) + + if len(treated_units) == 0 or len(control_units) == 0: + return None + + # Time window: [a - kappa_pre - anticipation, a + kappa_post] + # Reference period a-1 (event time e=-1) is included when kappa_pre >= 1 + # Matches R reference: (focalAdoptionTime - kappa_pre):(focalAdoptionTime + kappa_post) + t_start = a_int - self.kappa_pre - self.anticipation + t_end = a_int + self.kappa_post + + all_units = treated_units | control_units + + # Filter data + mask = df[unit].isin(all_units) & (df[time] >= t_start) & (df[time] <= t_end) + sub_df = df.loc[mask].copy() + + if len(sub_df) == 0: + return None + + # Add sub-experiment columns + sub_df["_sub_exp"] = a + sub_df["_event_time"] = sub_df[time] - a_int + sub_df["_D_sa"] = sub_df[unit].isin(treated_units).astype(int) + + return sub_df + + # ========================================================================= + # Q-weight computation + # ========================================================================= + + def _compute_q_weights( + self, + stacked_df: pd.DataFrame, + unit_col: str, + population_col: Optional[str], + ) -> pd.DataFrame: + """ + Compute Q-weights per Table 1 of Wing et al. (2024). + + Treated observations always get Q = 1. + Control observations get Q based on the weighting scheme. + + For aggregate weighting, Q-weights are computed using observation + counts per (event_time, sub_exp), matching the R reference + ``compute_weights()``. For balanced panels this is equivalent to + unit counts per sub-experiment. For unbalanced panels the weights + adjust for varying observation density per event time. + + Population and sample_share weighting use unit counts per + sub-experiment, following the paper's notation (N_a^D, N_a^C). + + Parameters + ---------- + stacked_df : pd.DataFrame + Stacked dataset with _sub_exp, _event_time, and _D_sa columns. + unit_col : str + Unit column name. + population_col : str, optional + Population column name (for weighting="population"). + + Returns + ------- + pd.DataFrame + stacked_df with _Q_weight column added. + """ + if self.weighting == "aggregate": + return self._compute_q_weights_aggregate(stacked_df) + + # --- Population and sample_share: unit-count-based formulas --- + + # Count distinct units per sub-experiment + sub_exp_stats = ( + stacked_df.groupby(["_sub_exp", "_D_sa"])[unit_col].nunique().unstack(fill_value=0) + ) + + # N_a^D and N_a^C per sub-experiment + N_D = sub_exp_stats.get(1, pd.Series(dtype=float)).to_dict() + N_C = sub_exp_stats.get(0, pd.Series(dtype=float)).to_dict() + + # Totals + N_Omega_C = sum(N_C.values()) + + if self.weighting == "population": + # Pop_a^D: sum of population values for treated units per sub-exp + treated_pop = ( + stacked_df[stacked_df["_D_sa"] == 1] + .drop_duplicates(subset=[unit_col, "_sub_exp"]) + .groupby("_sub_exp")[population_col] + .sum() + .to_dict() + ) + Pop_D_total = sum(treated_pop.values()) + + q_control: Dict[Any, float] = {} + for a in N_D: + n_c = N_C.get(a, 0) + if n_c == 0 or N_Omega_C == 0: + q_control[a] = 1.0 + continue + control_share = n_c / N_Omega_C + pop_d = treated_pop.get(a, 0) + pop_share = pop_d / Pop_D_total if Pop_D_total > 0 else 0.0 + q_control[a] = pop_share / control_share if control_share > 0 else 1.0 + + else: # sample_share + N_Omega_D = sum(N_D.values()) + N_total = {a: N_D.get(a, 0) + N_C.get(a, 0) for a in N_D} + N_grand = N_Omega_D + N_Omega_C + + q_control = {} + for a in N_D: + n_c = N_C.get(a, 0) + if n_c == 0 or N_Omega_C == 0: + q_control[a] = 1.0 + continue + control_share = n_c / N_Omega_C + n_total_a = N_total.get(a, 0) + sample_share = n_total_a / N_grand if N_grand > 0 else 0.0 + q_control[a] = sample_share / control_share if control_share > 0 else 1.0 + + # Assign weights: treated=1, control=q_control[sub_exp] + sub_exp_vals = stacked_df["_sub_exp"].values + d_vals = stacked_df["_D_sa"].values + weights = np.ones(len(stacked_df)) + for i in range(len(stacked_df)): + if d_vals[i] == 0: + weights[i] = q_control.get(sub_exp_vals[i], 1.0) + + stacked_df["_Q_weight"] = weights + return stacked_df + + def _compute_q_weights_aggregate(self, stacked_df: pd.DataFrame) -> pd.DataFrame: + """ + Compute aggregate Q-weights using observation counts per (event_time, sub_exp). + + Matches the R reference ``compute_weights()`` which computes shares at the + (event_time, sub_exp) level, not the sub_exp level. For balanced panels the + two approaches are equivalent. For unbalanced panels this adjusts for varying + observation density per event time. + + R reference pattern:: + + stack_treat_n = count(D==1) BY event_time + stack_control_n = count(D==0) BY event_time + sub_treat_n = count(D==1) BY (sub_exp, event_time) + sub_control_n = count(D==0) BY (sub_exp, event_time) + sub_treat_share = sub_treat_n / stack_treat_n + sub_control_share = sub_control_n / stack_control_n + Q = sub_treat_share / sub_control_share (for controls) + Q = 1 (for treated) + """ + # Step 1: Stack-level totals by (event_time, D_sa) + stack_counts = stacked_df.groupby(["_event_time", "_D_sa"]).size().unstack(fill_value=0) + stack_treat_n = stack_counts.get(1, pd.Series(0, index=stack_counts.index)) + stack_control_n = stack_counts.get(0, pd.Series(0, index=stack_counts.index)) + + # Step 2: Sub-experiment-level counts by (event_time, sub_exp, D_sa) + sub_counts = ( + stacked_df.groupby(["_event_time", "_sub_exp", "_D_sa"]).size().unstack(fill_value=0) + ) + sub_treat_n = sub_counts.get(1, pd.Series(0, index=sub_counts.index)) + sub_control_n = sub_counts.get(0, pd.Series(0, index=sub_counts.index)) + + # Step 3: Compute shares and Q per (event_time, sub_exp) + # Q = (sub_treat_n / stack_treat_n) / (sub_control_n / stack_control_n) + q_lookup: Dict[Tuple[Any, Any], float] = {} + for et, sub_exp in sub_counts.index: + s_treat = sub_treat_n.get((et, sub_exp), 0) + s_control = sub_control_n.get((et, sub_exp), 0) + st_treat = stack_treat_n.get(et, 0) + st_control = stack_control_n.get(et, 0) + + if s_control == 0 or st_treat == 0 or st_control == 0: + q_lookup[(et, sub_exp)] = 1.0 + else: + treat_share = s_treat / st_treat + control_share = s_control / st_control + q_lookup[(et, sub_exp)] = treat_share / control_share if control_share > 0 else 1.0 + + # Step 4: Assign weights via vectorized merge + et_vals = stacked_df["_event_time"].values + sub_exp_vals = stacked_df["_sub_exp"].values + d_vals = stacked_df["_D_sa"].values + weights = np.ones(len(stacked_df)) + + for i in range(len(stacked_df)): + if d_vals[i] == 0: + weights[i] = q_lookup.get((et_vals[i], sub_exp_vals[i]), 1.0) + + stacked_df["_Q_weight"] = weights + return stacked_df + + # ========================================================================= + # sklearn-compatible interface + # ========================================================================= + + def get_params(self) -> Dict[str, Any]: + """Get estimator parameters (sklearn-compatible).""" + return { + "kappa_pre": self.kappa_pre, + "kappa_post": self.kappa_post, + "weighting": self.weighting, + "clean_control": self.clean_control, + "cluster": self.cluster, + "alpha": self.alpha, + "anticipation": self.anticipation, + "rank_deficient_action": self.rank_deficient_action, + } + + def set_params(self, **params: Any) -> "StackedDiD": + """Set estimator parameters (sklearn-compatible).""" + for key, value in params.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + raise ValueError(f"Unknown parameter: {key}") + return self + + def summary(self) -> str: + """Get summary of estimation results.""" + if not self.is_fitted_: + raise RuntimeError("Model must be fitted before calling summary()") + assert self.results_ is not None + return self.results_.summary() + + def print_summary(self) -> None: + """Print summary to stdout.""" + print(self.summary()) + + +# ============================================================================= +# Convenience function +# ============================================================================= + + +def stacked_did( + data: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + kappa_pre: int = 1, + kappa_post: int = 1, + aggregate: Optional[str] = None, + population: Optional[str] = None, + **kwargs: Any, +) -> StackedDiDResults: + """ + Convenience function for stacked DiD estimation. + + This is a shortcut for creating a StackedDiD estimator and calling fit(). + + Parameters + ---------- + data : pd.DataFrame + Panel data. + outcome : str + Outcome variable column name. + unit : str + Unit identifier column name. + time : str + Time period column name. + first_treat : str + Column indicating first treatment period (0 or inf for never-treated). + kappa_pre : int, default=1 + Pre-treatment event-time periods. + kappa_post : int, default=1 + Post-treatment event-time periods. + aggregate : str, optional + Aggregation mode: None, "simple", or "event_study". + population : str, optional + Population column for weighting="population". + **kwargs + Additional keyword arguments passed to StackedDiD constructor. + + Returns + ------- + StackedDiDResults + Estimation results. + + Examples + -------- + >>> from diff_diff import stacked_did, generate_staggered_data + >>> data = generate_staggered_data(seed=42) + >>> results = stacked_did(data, 'outcome', 'unit', 'period', + ... 'first_treat', kappa_pre=2, kappa_post=2, + ... aggregate='event_study') + >>> results.print_summary() + """ + est = StackedDiD(kappa_pre=kappa_pre, kappa_post=kappa_post, **kwargs) + return est.fit( + data, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + aggregate=aggregate, + population=population, + ) diff --git a/diff_diff/stacked_did_results.py b/diff_diff/stacked_did_results.py new file mode 100644 index 0000000..99b141a --- /dev/null +++ b/diff_diff/stacked_did_results.py @@ -0,0 +1,318 @@ +""" +Result containers for the Stacked DiD estimator. + +This module contains StackedDiDResults dataclass for Wing, Freedman & +Hollingsworth (2024) stacked difference-in-differences estimation. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.results import _get_significance_stars + +__all__ = [ + "StackedDiDResults", +] + + +@dataclass +class StackedDiDResults: + """ + Results from Stacked DiD estimation (Wing, Freedman & Hollingsworth 2024). + + Attributes + ---------- + overall_att : float + Overall average treatment effect on the treated (average of + post-treatment event-study coefficients). + overall_se : float + Standard error of overall ATT (delta method on VCV). + overall_t_stat : float + T-statistic for overall ATT. + overall_p_value : float + P-value for overall ATT. + overall_conf_int : tuple + Confidence interval for overall ATT. + event_study_effects : dict, optional + Dictionary mapping event time h to effect dict with keys: + 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'. + group_effects : dict, optional + Dictionary mapping cohort g to effect dict. + stacked_data : pd.DataFrame + Full stacked dataset with _sub_exp, _event_time, _D_sa, + _Q_weight columns. Accessible for custom analysis. + groups : list + Adoption events in the trimmed set (Omega_kappa). + trimmed_groups : list + Adoption events excluded by IC1/IC2. + time_periods : list + All time periods in the original data. + n_obs : int + Number of observations in the original data. + n_stacked_obs : int + Number of observations in the stacked dataset. + n_sub_experiments : int + Number of sub-experiments in the stack. + n_treated_units : int + Distinct treated units across trimmed set. + n_control_units : int + Distinct control units across trimmed set. + kappa_pre : int + Pre-treatment event-time window size. + kappa_post : int + Post-treatment event-time window size. + weighting : str + Weighting scheme used. + clean_control : str + Clean control definition used. + alpha : float + Significance level used. + """ + + overall_att: float + overall_se: float + overall_t_stat: float + overall_p_value: float + overall_conf_int: Tuple[float, float] + event_study_effects: Optional[Dict[int, Dict[str, Any]]] + group_effects: Optional[Dict[Any, Dict[str, Any]]] + stacked_data: pd.DataFrame = field(repr=False) + groups: List[Any] = field(default_factory=list) + trimmed_groups: List[Any] = field(default_factory=list) + time_periods: List[Any] = field(default_factory=list) + n_obs: int = 0 + n_stacked_obs: int = 0 + n_sub_experiments: int = 0 + n_treated_units: int = 0 + n_control_units: int = 0 + kappa_pre: int = 1 + kappa_post: int = 1 + weighting: str = "aggregate" + clean_control: str = "not_yet_treated" + alpha: float = 0.05 + + def __repr__(self) -> str: + """Concise string representation.""" + sig = _get_significance_stars(self.overall_p_value) + return ( + f"StackedDiDResults(ATT={self.overall_att:.4f}{sig}, " + f"SE={self.overall_se:.4f}, " + f"n_sub_exp={self.n_sub_experiments}, " + f"n_stacked_obs={self.n_stacked_obs})" + ) + + def summary(self, alpha: Optional[float] = None) -> str: + """ + Generate formatted summary of estimation results. + + Parameters + ---------- + alpha : float, optional + Significance level. Defaults to alpha used in estimation. + + Returns + ------- + str + Formatted summary. + """ + alpha = alpha or self.alpha + conf_level = int((1 - alpha) * 100) + + lines = [ + "=" * 85, + "Stacked DiD Estimator Results (Wing, Freedman & Hollingsworth 2024)".center(85), + "=" * 85, + "", + f"{'Original observations:':<30} {self.n_obs:>10}", + f"{'Stacked observations:':<30} {self.n_stacked_obs:>10}", + f"{'Sub-experiments:':<30} {self.n_sub_experiments:>10}", + f"{'Treated units:':<30} {self.n_treated_units:>10}", + f"{'Control units:':<30} {self.n_control_units:>10}", + f"{'Treatment cohorts:':<30} {len(self.groups):>10}", + f"{'Trimmed cohorts:':<30} {len(self.trimmed_groups):>10}", + f"{'Event window:':<30} {'[' + str(-self.kappa_pre) + ', ' + str(self.kappa_post) + ']':>10}", + f"{'Weighting:':<30} {self.weighting:>10}", + f"{'Clean control:':<30} {self.clean_control:>10}", + "", + ] + + # Overall ATT + lines.extend( + [ + "-" * 85, + "Overall Average Treatment Effect on the Treated".center(85), + "-" * 85, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + t_str = ( + f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}" + ) + p_str = ( + f"{self.overall_p_value:>10.4f}" + if np.isfinite(self.overall_p_value) + else f"{'NaN':>10}" + ) + sig = _get_significance_stars(self.overall_p_value) + + lines.extend( + [ + f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " + f"{t_str} {p_str} {sig:>6}", + "-" * 85, + "", + f"{conf_level}% Confidence Interval: " + f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", + "", + ] + ) + + # Event study effects + if self.event_study_effects: + lines.extend( + [ + "-" * 85, + "Event Study (Dynamic) Effects".center(85), + "-" * 85, + f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for h in sorted(self.event_study_effects.keys()): + eff = self.event_study_effects[h] + if eff.get("n_obs", 1) == 0: + # Reference period marker + lines.append( + f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}" + ) + elif np.isnan(eff["effect"]): + lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + e_sig = _get_significance_stars(eff["p_value"]) + e_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + e_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{e_t} {e_p} {e_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + # Group effects + if self.group_effects: + lines.extend( + [ + "-" * 85, + "Group (Cohort) Effects".center(85), + "-" * 85, + f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for g in sorted(self.group_effects.keys()): + eff = self.group_effects[g] + if np.isnan(eff["effect"]): + lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + g_sig = _get_significance_stars(eff["p_value"]) + g_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + g_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{g_t} {g_p} {g_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + lines.extend( + [ + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 85, + ] + ) + + return "\n".join(lines) + + def print_summary(self, alpha: Optional[float] = None) -> None: + """Print summary to stdout.""" + print(self.summary(alpha)) + + def to_dataframe(self, level: str = "event_study") -> pd.DataFrame: + """ + Convert results to DataFrame. + + Parameters + ---------- + level : str, default="event_study" + Level of aggregation: + - "event_study": Event study effects by relative time + - "group": Group (cohort) effects + + Returns + ------- + pd.DataFrame + Results as DataFrame. + """ + if level == "event_study": + if self.event_study_effects is None: + raise ValueError( + "Event study effects not computed. " "Use aggregate='event_study'." + ) + rows = [] + for h, data in sorted(self.event_study_effects.items()): + rows.append( + { + "relative_period": h, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "n_obs": data.get("n_obs", np.nan), + } + ) + return pd.DataFrame(rows) + + elif level == "group": + raise ValueError( + "Group aggregation is not supported by StackedDiD. " + "The pooled stacked regression cannot produce cohort-specific " + "effects. Use CallawaySantAnna or ImputationDiD for " + "cohort-level estimates." + ) + + else: + raise ValueError(f"Unknown level: {level}. Use 'event_study' or 'group'.") + + @property + def is_significant(self) -> bool: + """Check if overall ATT is significant.""" + return bool(self.overall_p_value < self.alpha) + + @property + def significance_stars(self) -> str: + """Significance stars for overall ATT.""" + return _get_significance_stars(self.overall_p_value) diff --git a/diff_diff/visualization.py b/diff_diff/visualization.py index 71fb243..f5c3eab 100644 --- a/diff_diff/visualization.py +++ b/diff_diff/visualization.py @@ -19,6 +19,7 @@ from diff_diff.imputation import ImputationDiDResults from diff_diff.sun_abraham import SunAbrahamResults from diff_diff.two_stage import TwoStageDiDResults + from diff_diff.stacked_did import StackedDiDResults # Type alias for results that can be plotted PlottableResults = Union[ @@ -27,6 +28,7 @@ "SunAbrahamResults", "ImputationDiDResults", "TwoStageDiDResults", + "StackedDiDResults", pd.DataFrame, ] diff --git a/docs/api/index.rst b/docs/api/index.rst index c87a464..0f6f403 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -19,6 +19,7 @@ Core estimator classes for DiD analysis: diff_diff.CallawaySantAnna diff_diff.SunAbraham diff_diff.ImputationDiD + diff_diff.StackedDiD diff_diff.TripleDifference diff_diff.TROP @@ -43,6 +44,7 @@ Result containers returned by estimators: diff_diff.ImputationDiDResults diff_diff.ImputationBootstrapResults diff_diff.TripleDifferenceResults + diff_diff.StackedDiDResults diff_diff.trop.TROPResults Visualization @@ -185,6 +187,7 @@ Detailed documentation by module: estimators staggered imputation + stacked_did triple_diff trop results diff --git a/docs/api/stacked_did.rst b/docs/api/stacked_did.rst new file mode 100644 index 0000000..4ece0d8 --- /dev/null +++ b/docs/api/stacked_did.rst @@ -0,0 +1,129 @@ +Stacked Difference-in-Differences +================================== + +Stacked DiD estimator for staggered adoption designs with corrective Q-weights. + +This module implements the methodology from Wing, Freedman & Hollingsworth (2024), +which addresses bias in naive stacked DiD regressions by: + +1. **Constructing sub-experiments**: One per adoption cohort with clean controls +2. **Applying corrective Q-weights**: Ensures proper weighting of treatment and + control group trends across sub-experiments +3. **Running weighted event-study regression**: WLS with Q-weights identifies + the "trimmed aggregate ATT" + +**When to use Stacked DiD:** + +- Staggered adoption design with multiple treatment cohorts +- Want an intuitive sub-experiment-based approach (vs. aggregation methods) +- Desire compositional balance: treatment group composition fixed across event times +- Need direct access to the stacked dataset for custom analysis + +**Reference:** Wing, C., Freedman, S. M., & Hollingsworth, A. (2024). Stacked +Difference-in-Differences. *NBER Working Paper* 32054. +``_ + +.. module:: diff_diff.stacked_did + +StackedDiD +---------- + +Main estimator class for Stacked Difference-in-Differences. + +.. autoclass:: diff_diff.StackedDiD + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + + .. rubric:: Methods + + .. autosummary:: + + ~StackedDiD.fit + ~StackedDiD.get_params + ~StackedDiD.set_params + +StackedDiDResults +----------------- + +Results container for Stacked DiD estimation. + +.. autoclass:: diff_diff.stacked_did.StackedDiDResults + :members: + :undoc-members: + :show-inheritance: + + .. rubric:: Methods + + .. autosummary:: + + ~StackedDiDResults.summary + ~StackedDiDResults.print_summary + ~StackedDiDResults.to_dataframe + +Convenience Function +-------------------- + +.. autofunction:: diff_diff.stacked_did + +Example Usage +------------- + +Basic usage:: + + from diff_diff import StackedDiD, generate_staggered_data + + data = generate_staggered_data(n_units=200, n_periods=12, + cohort_periods=[4, 6, 8], seed=42) + + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit(data, outcome='outcome', unit='unit', + time='period', first_treat='first_treat', + aggregate='event_study') + results.print_summary() + +Accessing the stacked dataset:: + + # The stacked data is available for custom analysis + stacked = results.stacked_data + print(stacked[['unit', 'period', '_sub_exp', '_event_time', '_D_sa', '_Q_weight']].head()) + +Different weighting schemes:: + + # Population-weighted ATT (requires population column) + est = StackedDiD(kappa_pre=2, kappa_post=2, weighting='population') + results = est.fit(data, outcome='outcome', unit='unit', + time='period', first_treat='first_treat', + population='pop_size') + + # Sample-share weighted ATT + est = StackedDiD(kappa_pre=2, kappa_post=2, weighting='sample_share') + results = est.fit(data, outcome='outcome', unit='unit', + time='period', first_treat='first_treat') + +Comparison with Other Staggered Estimators +------------------------------------------ + +.. list-table:: + :header-rows: 1 + :widths: 20 40 40 + + * - Feature + - Stacked DiD + - Callaway-Sant'Anna + * - Approach + - Pooled WLS on stacked sub-experiments + - Separate group-time regressions + * - Compositional balance + - Enforced by IC1/IC2 trimming + - Via balanced event study aggregation + * - Target parameter + - Trimmed aggregate ATT + - Weighted average of ATT(g,t) + * - Custom analysis + - Full stacked dataset accessible + - Group-time effects accessible + * - Covariates + - Not yet supported + - Supported (OR, IPW, DR) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 82774db..f13440e 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -13,6 +13,7 @@ This document provides the academic foundations and key implementation requireme - [SunAbraham](#sunabraham) - [ImputationDiD](#imputationdid) - [TwoStageDiD](#twostagedid) + - [StackedDiD](#stackeddid) 3. [Advanced Estimators](#advanced-estimators) - [SyntheticDiD](#syntheticdid) - [TripleDifference](#tripledifference) @@ -639,6 +640,98 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus --- +## StackedDiD + +**Primary source:** Wing, C., Freedman, S. M., & Hollingsworth, A. (2024). Stacked Difference-in-Differences. NBER Working Paper 32054. http://www.nber.org/papers/w32054 + +**Key implementation requirements:** + +*Assumption checks / warnings:* +- Assumption 1 (No Anticipation): ATT(a, a+e) = 0 for all e < 0 +- Assumption 2 (Common Trends): E[Y_{s,a+e}(0) - Y_{s,a-1}(0) | A_s = a] = E[Y_{s,a+e}(0) - Y_{s,a-1}(0) | A_s > a + e] +- Clean controls must exist for each sub-experiment (IC2) +- Event window must fit within observed data range (IC1) + +*Target parameter (Equation 2):* + + theta_kappa^e = sum_{a in Omega_kappa} ATT(a, a+e) * (N_a^D / N_Omega_kappa^D) + +where: +- `theta_kappa^e` = trimmed aggregate ATT at event time e +- `Omega_kappa` = trimmed set of adoption events satisfying IC1 and IC2 +- `N_a^D` = number of treated units in sub-experiment a +- `N_Omega_kappa^D` = total treated units across all sub-experiments in trimmed set + +*Estimator equation (Equation 3 — weighted saturated event study, recommended):* + + Y_sae = alpha_0 + alpha_1 * D_sa + sum_{h != -1} [lambda_h * 1(e=h) + delta_h * D_sa * 1(e=h)] + U_sae + +Estimated via WLS with Q-weights. The delta_h coefficients identify theta_kappa^e. + +*Q-weights (Section 5.3, Table 1):* + + Q_sa = 1 if D_sa = 1 (treated) + Q_sa = (N_a^D / N^D) / (N_a^C / N^C) if D_sa = 0 (control, aggregate weighting) + Q_sa = (Pop_a^D / Pop^D) / (N_a^C / N^C) if D_sa = 0 (control, population weighting) + Q_sa = ((N_a + N_a^C)/(N^D+N^C)) / (N_a^C/N^C) if D_sa = 0 (control, sample share weighting) + +*Standard errors (Section 5.4):* +- Default: Cluster-robust standard errors at the group (unit) level +- Alternative: Cluster at group x sub-experiment level +- Both approaches yield approximately correct coverage when clusters > 100 (Table 2) +- No special bootstrap procedure specified; standard cluster-robust SEs recommended +- For post-period average: delta method or `lincom`/`marginaleffects` + +*Edge cases:* +- All events trimmed: `len(Omega_kappa) == 0` -> ValueError suggesting reduced kappa +- No clean controls for event a: IC2 check fails -> Trim event, warn user +- Single cohort in trimmed set: Valid — Q-weights simplify +- Duplicate observations: Same (unit, time) appears in multiple sub-experiments -> handled by clustering at unit level +- Constant treatment share across sub-exps: Unweighted FE recovers correct estimand (special case, Section 5.5) +- Anticipation > 0: Reference period shifts to e = -1 - anticipation. Post-treatment includes anticipation periods (e >= -anticipation). Window expands by anticipation pre-periods. +- Group aggregation: Not supported — pooled stacked regression cannot produce cohort-specific effects. Use CallawaySantAnna or ImputationDiD. + +*Algorithm (Section 5):* +1. Choose kappa_pre, kappa_post event window +2. Apply IC1 (window fits in data) and IC2 (clean controls exist) to get Omega_kappa +3. For each a in Omega_kappa: build sub-experiment with treated (A_s = a), clean controls (A_s > a + kappa_post), time window [a - kappa_pre, a + kappa_post] (with anticipation: [a - kappa_pre - anticipation, a + kappa_post]) +4. Stack all sub-experiments vertically +5. Compute Q-weights: aggregate weighting uses observation counts per (event_time, sub_exp), matching R reference. Population/sample_share use unit counts per sub_exp (paper notation). +6. Run WLS regression of Equation 3 with Q-weights +7. Extract delta_h coefficients as event-study ATTs +8. Compute cluster-robust SEs at unit level + +*IC1 (Adoption Event Window, Section 3):* + + IC1_a = 1[a - kappa_pre >= T_min AND a + kappa_post <= T_max] + +Note: Matches R reference implementation (`focalAdoptionTime - kappa_pre >= minTime`). +The reference period a-1 is included in the window [a-kappa_pre, a+kappa_post] when kappa_pre >= 1. +The paper text states a stricter bound (T_min + 1) but the R code by the co-author uses T_min. + +*IC2 (Clean Controls Exist, Section 3):* + + IC2_a = 1[exists s with A_s > a + kappa_post] (not_yet_treated) + IC2_a = 1[exists s with A_s > a + kappa_post + kappa_pre] (strict) + IC2_a = 1[exists s with A_s = infinity] (never_treated) + +**Reference implementation(s):** +- R: https://github.com/hollina/stacked-did-weights (`create_sub_exp()`, `compute_weights()`) +- No Stata or Python package; Stata estimation via standard `reghdfe` with Q-weight column + +**Requirements checklist:** +- [x] Sub-experiment construction with treated + clean controls + time window +- [x] IC1 and IC2 trimming with warnings +- [x] Q-weight computation for all three weighting schemes (Table 1) +- [x] WLS via sqrt(w) transformation +- [x] Event study regression (Equation 3) with reference period e=-1 +- [x] Cluster-robust SEs at unit or unit x sub-exp level +- [x] Overall ATT as average of post-treatment delta_h with delta-method SE +- [x] Anticipation parameter support +- [x] Never-treated encoding (0 and inf) + +--- + # Advanced Estimators ## SyntheticDiD @@ -1374,6 +1467,7 @@ should be a deliberate user choice. | TwoStageDiD | GMM sandwich (Newey & McFadden 1994) | Multiplier bootstrap on GMM influence function | | SyntheticDiD | Placebo variance (Alg 4) | Unit-level bootstrap (fixed weights) | | TripleDifference | Influence function (all methods) | SE = std(IF) / sqrt(n) | +| StackedDiD | Cluster-robust (unit) | Cluster at unit × sub-experiment | | TROP | Block bootstrap | — | | BaconDecomposition | N/A (exact decomposition) | Individual 2×2 SEs | | HonestDiD | Inherited from event study | FLCI, C-LF | @@ -1395,6 +1489,7 @@ should be a deliberate user choice. | TwoStageDiD | did2s | `did2s()` | | SyntheticDiD | synthdid | `synthdid_estimate()` | | TripleDifference | triplediff | `ddd()` | +| StackedDiD | stacked-did-weights | `create_sub_exp()` + `compute_weights()` | | TROP | - | (forthcoming) | | BaconDecomposition | bacondecomp | `bacon()` | | HonestDiD | HonestDiD | `createSensitivityResults()` | diff --git a/tests/test_stacked_did.py b/tests/test_stacked_did.py new file mode 100644 index 0000000..34a3a15 --- /dev/null +++ b/tests/test_stacked_did.py @@ -0,0 +1,927 @@ +""" +Tests for Stacked DiD estimator (Wing, Freedman & Hollingsworth 2024). +""" + +import warnings + +import numpy as np +import pandas as pd +import pytest + +from diff_diff import StackedDiD, StackedDiDResults, stacked_did +from diff_diff.prep_dgp import generate_staggered_data + +# ============================================================================= +# Test Data Fixtures +# ============================================================================= + + +@pytest.fixture +def staggered_data(): + """Standard staggered adoption data for testing.""" + return generate_staggered_data( + n_units=200, + n_periods=12, + cohort_periods=[4, 6, 8], + never_treated_frac=0.3, + treatment_effect=5.0, + dynamic_effects=True, + seed=42, + ) + + +@pytest.fixture +def constant_effect_data(): + """Staggered data with constant treatment effect (no dynamics).""" + return generate_staggered_data( + n_units=200, + n_periods=12, + cohort_periods=[4, 6, 8], + never_treated_frac=0.3, + treatment_effect=5.0, + dynamic_effects=False, + seed=42, + ) + + +@pytest.fixture +def no_never_treated_data(): + """Staggered data without never-treated units.""" + return generate_staggered_data( + n_units=200, + n_periods=12, + cohort_periods=[4, 6, 8], + never_treated_frac=0.0, + treatment_effect=5.0, + dynamic_effects=True, + seed=42, + ) + + +# ============================================================================= +# TestStackedDiDBasic +# ============================================================================= + + +class TestStackedDiDBasic: + """Basic functionality tests.""" + + def test_basic_fit(self, staggered_data): + """Default parameters produce valid results.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert isinstance(results, StackedDiDResults) + assert np.isfinite(results.overall_att) + assert np.isfinite(results.overall_se) + assert results.overall_se > 0 + assert results.n_stacked_obs > 0 + assert results.n_sub_experiments > 0 + + def test_event_study(self, staggered_data): + """Event study aggregation populates event_study_effects.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="event_study", + ) + assert results.event_study_effects is not None + assert -1 in results.event_study_effects # reference period + # Reference period effect should be zero + ref = results.event_study_effects[-1] + assert ref["effect"] == 0.0 + assert ref["n_obs"] == 0 + + # Post-treatment periods should have effects + for h in range(0, 3): + if h in results.event_study_effects: + assert results.event_study_effects[h]["n_obs"] > 0 + + def test_group_aggregate_raises(self, staggered_data): + """aggregate='group' raises ValueError.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + with pytest.raises(ValueError, match="group.*not supported"): + est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="group", + ) + + def test_all_aggregate_raises(self, staggered_data): + """aggregate='all' raises ValueError.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + with pytest.raises(ValueError, match="all.*not supported"): + est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="all", + ) + + def test_simple_att(self, staggered_data): + """aggregate='simple' produces overall ATT only.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="simple", + ) + assert np.isfinite(results.overall_att) + assert results.event_study_effects is None + assert results.group_effects is None + + def test_known_constant_effect(self, constant_effect_data): + """With constant treatment effect, estimated ATT should be close.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + constant_effect_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + # Treatment effect is 5.0; allow generous tolerance + assert ( + abs(results.overall_att - 5.0) < 1.5 + ), f"Estimated ATT {results.overall_att:.2f} too far from true effect 5.0" + + def test_dynamic_effects(self, staggered_data): + """With dynamic effects, post-treatment coefficients should increase.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="event_study", + ) + assert results.event_study_effects is not None + # Post-treatment effects should generally increase + post_effects = [ + results.event_study_effects[h]["effect"] + for h in sorted(results.event_study_effects.keys()) + if h >= 0 and results.event_study_effects[h]["n_obs"] > 0 + ] + if len(post_effects) >= 2: + # Last post should be larger than first post (dynamic growth) + assert post_effects[-1] > post_effects[0] + + +# ============================================================================= +# TestTrimming +# ============================================================================= + + +class TestTrimming: + """Tests for IC1/IC2 trimming logic.""" + + def test_ic1_window_trimming(self, staggered_data): + """Events outside the observation window are trimmed.""" + # With very large kappa, early/late events should be trimmed + est = StackedDiD(kappa_pre=4, kappa_post=4) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + # With kappa_pre=4, kappa_post=4 on 12 periods, some events should trim + if len(results.trimmed_groups) > 0: + assert any("Trimmed" in str(wi.message) for wi in w) + + def test_ic2_no_controls_trimming(self, no_never_treated_data): + """Events without clean controls are trimmed with never_treated mode.""" + est = StackedDiD(kappa_pre=1, kappa_post=1, clean_control="never_treated") + # No never-treated units exist → all events should be trimmed + with pytest.raises(ValueError, match="All.*adoption events were trimmed"): + est.fit( + no_never_treated_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + + def test_trimmed_groups_reported(self, staggered_data): + """Trimmed groups are reported in results.""" + est = StackedDiD(kappa_pre=5, kappa_post=5) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + try: + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + # If some groups survive, check trimmed_groups + assert isinstance(results.trimmed_groups, list) + except ValueError: + # All trimmed — expected for large kappa + pass + + def test_all_trimmed_raises(self, staggered_data): + """ValueError when all events are eliminated by trimming.""" + est = StackedDiD(kappa_pre=10, kappa_post=10) + with pytest.raises(ValueError, match="All.*adoption events were trimmed"): + est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + + def test_wider_window_more_trimming(self, staggered_data): + """Larger kappa values should trim more (or equal) events.""" + est1 = StackedDiD(kappa_pre=1, kappa_post=1) + results1 = est1.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + + est2 = StackedDiD(kappa_pre=2, kappa_post=2) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + results2 = est2.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + + assert len(results2.trimmed_groups) >= len(results1.trimmed_groups) + + +# ============================================================================= +# TestQWeights +# ============================================================================= + + +class TestQWeights: + """Tests for Q-weight computation.""" + + def test_treated_weight_is_one(self, staggered_data): + """All treated observations should have Q=1.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + treated_weights = results.stacked_data.loc[results.stacked_data["_D_sa"] == 1, "_Q_weight"] + assert np.allclose(treated_weights, 1.0) + + def test_aggregate_weighting_formula(self, staggered_data): + """Q-weights match R's observation-count formula at (event_time, sub_exp) level.""" + est = StackedDiD(kappa_pre=2, kappa_post=2, weighting="aggregate") + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + sd = results.stacked_data + + # Compute expected Q per R formula at (event_time, sub_exp) level + for et in sd["_event_time"].unique(): + et_data = sd[sd["_event_time"] == et] + stack_treat_n = (et_data["_D_sa"] == 1).sum() + stack_control_n = (et_data["_D_sa"] == 0).sum() + for sub_exp in results.groups: + sub_et = et_data[et_data["_sub_exp"] == sub_exp] + sub_treat_n = (sub_et["_D_sa"] == 1).sum() + sub_control_n = (sub_et["_D_sa"] == 0).sum() + if sub_control_n > 0 and stack_treat_n > 0 and stack_control_n > 0: + expected_q = (sub_treat_n / stack_treat_n) / (sub_control_n / stack_control_n) + actual_q = sub_et.loc[sub_et["_D_sa"] == 0, "_Q_weight"].iloc[0] + assert ( + abs(actual_q - expected_q) < 1e-10 + ), f"Sub-exp {sub_exp}, et={et}: expected Q={expected_q:.6f}, got {actual_q:.6f}" + + def test_sample_share_weighting(self, staggered_data): + """Verify sample_share Q formula.""" + est = StackedDiD(kappa_pre=2, kappa_post=2, weighting="sample_share") + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + sd = results.stacked_data + + # All weights should be positive and finite + assert np.all(sd["_Q_weight"] > 0) + assert np.all(np.isfinite(sd["_Q_weight"])) + + def test_weights_positive(self, staggered_data): + """All Q-weights should be positive.""" + for w in ["aggregate", "sample_share"]: + est = StackedDiD(kappa_pre=2, kappa_post=2, weighting=w) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert np.all(results.stacked_data["_Q_weight"] > 0) + + +# ============================================================================= +# TestCleanControl +# ============================================================================= + + +class TestCleanControl: + """Tests for clean control group definitions.""" + + def test_not_yet_treated_default(self, staggered_data): + """Default includes not-yet-treated and never-treated as controls.""" + est = StackedDiD(kappa_pre=1, kappa_post=1, clean_control="not_yet_treated") + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert results.n_control_units > 0 + + def test_strict_excludes_more(self, staggered_data): + """Strict mode should have fewer (or equal) controls than not_yet_treated.""" + est_nyt = StackedDiD(kappa_pre=2, kappa_post=2, clean_control="not_yet_treated") + results_nyt = est_nyt.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + + est_strict = StackedDiD(kappa_pre=2, kappa_post=2, clean_control="strict") + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + try: + results_strict = est_strict.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + # Strict should have fewer or equal stacked obs + assert results_strict.n_stacked_obs <= results_nyt.n_stacked_obs + except ValueError: + # Strict may trim all events — that's valid behavior + pass + + def test_never_treated_only(self, staggered_data): + """never_treated mode only uses never-treated as controls.""" + est = StackedDiD(kappa_pre=2, kappa_post=2, clean_control="never_treated") + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + sd = results.stacked_data + # All control units should have first_treat = inf + control_ft = sd.loc[sd["_D_sa"] == 0, "first_treat"].unique() + assert all(np.isinf(ft) for ft in control_ft) + + def test_never_treated_no_nevertreated_raises(self, no_never_treated_data): + """Error when no never-treated units exist with never_treated mode.""" + est = StackedDiD(kappa_pre=1, kappa_post=1, clean_control="never_treated") + with pytest.raises(ValueError, match="All.*adoption events were trimmed"): + est.fit( + no_never_treated_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + + +# ============================================================================= +# TestClustering +# ============================================================================= + + +class TestClustering: + """Tests for clustering standard errors.""" + + def test_unit_clustering(self, staggered_data): + """Default unit clustering produces finite SEs.""" + est = StackedDiD(kappa_pre=2, kappa_post=2, cluster="unit") + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert np.isfinite(results.overall_se) + assert results.overall_se > 0 + + def test_unit_subexp_clustering(self, staggered_data): + """unit_subexp clustering produces finite SEs.""" + est = StackedDiD(kappa_pre=2, kappa_post=2, cluster="unit_subexp") + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert np.isfinite(results.overall_se) + assert results.overall_se > 0 + + +# ============================================================================= +# TestStackedData +# ============================================================================= + + +class TestStackedData: + """Tests for the stacked dataset.""" + + def test_stacked_data_accessible(self, staggered_data): + """results.stacked_data is a DataFrame.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert isinstance(results.stacked_data, pd.DataFrame) + + def test_required_columns(self, staggered_data): + """Stacked data has _sub_exp, _event_time, _D_sa, _Q_weight.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + required = {"_sub_exp", "_event_time", "_D_sa", "_Q_weight"} + assert required.issubset(results.stacked_data.columns) + + def test_event_time_range(self, staggered_data): + """Event times span [-kappa_pre, ..., kappa_post].""" + kp, kq = 2, 2 + est = StackedDiD(kappa_pre=kp, kappa_post=kq) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + et = results.stacked_data["_event_time"] + # Event times should include the reference period -1 + assert et.min() <= -kp + assert et.max() >= kq + + +# ============================================================================= +# TestEdgeCases +# ============================================================================= + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_single_cohort(self): + """Works with only one adoption event.""" + data = generate_staggered_data( + n_units=100, + n_periods=10, + cohort_periods=[5], + never_treated_frac=0.5, + treatment_effect=3.0, + dynamic_effects=False, + seed=99, + ) + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert results.n_sub_experiments == 1 + assert np.isfinite(results.overall_att) + + def test_anticipation_reference_period(self): + """anticipation=1 shifts reference period to e=-2.""" + data = generate_staggered_data( + n_units=200, + n_periods=12, + cohort_periods=[5, 7], + never_treated_frac=0.3, + treatment_effect=5.0, + seed=42, + ) + est = StackedDiD(kappa_pre=2, kappa_post=2, anticipation=1) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="event_study", + ) + + # Reference period is -2 (not -1) + assert -2 in results.event_study_effects + assert results.event_study_effects[-2]["effect"] == 0.0 + assert results.event_study_effects[-2]["n_obs"] == 0 # sentinel + + # -1 is NOT the reference; it should have a non-zero estimated effect + assert -1 in results.event_study_effects + assert results.event_study_effects[-1]["n_obs"] > 0 + + # Extra pre-period -3 should have a dummy + assert -3 in results.event_study_effects + assert results.event_study_effects[-3]["n_obs"] > 0 + + # Post-treatment includes anticipation period (-1) + # Overall ATT averages h in {-1, 0, 1, 2} + assert np.isfinite(results.overall_att) + + def test_unbalanced_panel(self): + """Works with missing observations within the window.""" + data = generate_staggered_data( + n_units=200, + n_periods=12, + cohort_periods=[4, 6, 8], + never_treated_frac=0.3, + treatment_effect=5.0, + seed=42, + ) + # Remove some random rows to create unbalanced panel + rng = np.random.default_rng(42) + drop_idx = rng.choice(len(data), size=50, replace=False) + data = data.drop(data.index[drop_idx]).reset_index(drop=True) + + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + sd = results.stacked_data + assert np.isfinite(results.overall_att) + assert np.all(sd["_Q_weight"] > 0) + assert np.all(np.isfinite(sd["_Q_weight"])) + + def test_nan_inference(self): + """Degenerate case with NaN inference fields.""" + # Create small data where estimation might degenerate. + # Need n > k to avoid division by zero in cluster-robust VCV: + # Design matrix has 4 columns (intercept, D_sa, lambda_0, delta_0), + # so we need > 4 observations (3 units × 2 periods = 6). + data = pd.DataFrame( + { + "unit": [1, 1, 2, 2, 3, 3], + "period": [1, 2, 1, 2, 1, 2], + "outcome": [1.0, 2.0, 1.0, 2.0, 1.5, 2.5], + "first_treat": [2, 2, 0, 0, 0, 0], + } + ) + est = StackedDiD(kappa_pre=1, kappa_post=0) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + # Should produce finite results or NaN (not crash) + assert isinstance(results, StackedDiDResults) + + def test_never_treated_encoding_zero(self): + """first_treat=0 treated same as first_treat=inf (never-treated).""" + data = generate_staggered_data( + n_units=100, + n_periods=10, + cohort_periods=[5], + never_treated_frac=0.5, + treatment_effect=5.0, + seed=42, + ) + # The generator uses 0 for never-treated + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert results.n_control_units > 0 + + def test_never_treated_encoding_inf(self): + """first_treat=inf works for never-treated units.""" + data = generate_staggered_data( + n_units=100, + n_periods=10, + cohort_periods=[5], + never_treated_frac=0.5, + treatment_effect=5.0, + seed=42, + ) + # Replace 0 with inf for never-treated + data["first_treat"] = data["first_treat"].replace(0, np.inf) + + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert results.n_control_units > 0 + + +# ============================================================================= +# TestSklearnInterface +# ============================================================================= + + +class TestSklearnInterface: + """Tests for sklearn-compatible API.""" + + def test_get_params(self): + """All init params present in get_params.""" + est = StackedDiD( + kappa_pre=3, + kappa_post=2, + weighting="population", + clean_control="strict", + cluster="unit_subexp", + alpha=0.10, + anticipation=1, + rank_deficient_action="error", + ) + params = est.get_params() + assert params["kappa_pre"] == 3 + assert params["kappa_post"] == 2 + assert params["weighting"] == "population" + assert params["clean_control"] == "strict" + assert params["cluster"] == "unit_subexp" + assert params["alpha"] == 0.10 + assert params["anticipation"] == 1 + assert params["rank_deficient_action"] == "error" + + def test_set_params(self): + """set_params modifies attributes correctly.""" + est = StackedDiD() + est.set_params(kappa_pre=5, weighting="sample_share") + assert est.kappa_pre == 5 + assert est.weighting == "sample_share" + + def test_set_params_unknown_raises(self): + """set_params raises on unknown parameter.""" + est = StackedDiD() + with pytest.raises(ValueError, match="Unknown parameter"): + est.set_params(nonexistent_param=42) + + def test_convenience_function(self, staggered_data): + """stacked_did() convenience function works.""" + results = stacked_did( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + kappa_pre=2, + kappa_post=2, + ) + assert isinstance(results, StackedDiDResults) + assert np.isfinite(results.overall_att) + + +# ============================================================================= +# TestResultsMethods +# ============================================================================= + + +class TestResultsMethods: + """Tests for StackedDiDResults methods.""" + + def test_summary(self, staggered_data): + """summary() returns formatted string.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="event_study", + ) + summary = results.summary() + assert "Stacked DiD" in summary + assert "ATT" in summary + + def test_to_dataframe_event_study(self, staggered_data): + """to_dataframe(level='event_study') returns DataFrame.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="event_study", + ) + df = results.to_dataframe(level="event_study") + assert isinstance(df, pd.DataFrame) + assert "relative_period" in df.columns + assert "effect" in df.columns + + def test_to_dataframe_group_raises(self, staggered_data): + """to_dataframe(level='group') raises ValueError.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + with pytest.raises(ValueError, match="Group aggregation is not supported"): + results.to_dataframe(level="group") + + def test_to_dataframe_no_event_study_raises(self, staggered_data): + """to_dataframe raises when event_study not computed.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + with pytest.raises(ValueError, match="Event study effects not computed"): + results.to_dataframe(level="event_study") + + def test_is_significant(self, staggered_data): + """is_significant property works.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert isinstance(results.is_significant, bool) + + def test_significance_stars(self, staggered_data): + """significance_stars property works.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + assert isinstance(results.significance_stars, str) + + def test_repr(self, staggered_data): + """__repr__ returns formatted string.""" + est = StackedDiD(kappa_pre=2, kappa_post=2) + results = est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + r = repr(results) + assert "StackedDiDResults" in r + assert "ATT=" in r + + +# ============================================================================= +# TestValidation +# ============================================================================= + + +class TestValidation: + """Tests for input validation.""" + + def test_missing_columns(self, staggered_data): + """Raises on missing required columns.""" + est = StackedDiD() + with pytest.raises(ValueError, match="Missing columns"): + est.fit( + staggered_data, + outcome="nonexistent", + unit="unit", + time="period", + first_treat="first_treat", + ) + + def test_invalid_weighting(self): + """Raises on invalid weighting parameter.""" + with pytest.raises(ValueError, match="weighting"): + StackedDiD(weighting="invalid") + + def test_invalid_clean_control(self): + """Raises on invalid clean_control parameter.""" + with pytest.raises(ValueError, match="clean_control"): + StackedDiD(clean_control="invalid") + + def test_invalid_cluster(self): + """Raises on invalid cluster parameter.""" + with pytest.raises(ValueError, match="cluster"): + StackedDiD(cluster="invalid") + + def test_invalid_aggregate(self, staggered_data): + """Raises on invalid aggregate parameter.""" + est = StackedDiD() + with pytest.raises(ValueError, match="aggregate"): + est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="invalid", + ) + + def test_population_required_for_population_weighting(self, staggered_data): + """Raises when population col not specified with weighting='population'.""" + est = StackedDiD(weighting="population") + with pytest.raises(ValueError, match="population"): + est.fit( + staggered_data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + + def test_no_treated_units(self): + """Raises when no treated units exist.""" + data = pd.DataFrame( + { + "unit": [1, 1, 2, 2], + "period": [1, 2, 1, 2], + "outcome": [1.0, 2.0, 1.0, 2.0], + "first_treat": [0, 0, 0, 0], + } + ) + est = StackedDiD() + with pytest.raises(ValueError, match="No treated units"): + est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + )