{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Staggered Difference-in-Differences\n",
"\n",
"This notebook demonstrates the `StaggeredDifferenceInDifferences` estimator in CausalPy, which handles **staggered adoption** settings where different units receive treatment at different times.\n",
"\n",
"## The Staggered Adoption Problem\n",
"\n",
"In many real-world settings, treatment is not applied to all units at the same time. Instead, different cohorts of units adopt treatment at different times. This creates a \"staggered adoption\" pattern.\n",
"\n",
"### Why Standard Two-Way Fixed Effects (TWFE) Can Fail\n",
"\n",
"The standard TWFE regression with a single treatment indicator:\n",
"\n",
"$$Y_{it} = \\alpha_i + \\lambda_t + \\tau D_{it} + \\varepsilon_{it}$$\n",
"\n",
"can produce biased estimates in staggered settings {cite:p}`goodman2021difference` because:\n",
"\n",
"1. Already-treated units serve as implicit controls for later-treated units\n",
"2. The estimate is a weighted average of treatment effects that can include negative weights\n",
"3. Dynamic treatment effects (effects that vary over time since treatment) are not properly accounted for\n",
"\n",
":::{note}\n",
"Notice that $\\tau$ in the TWFE formulation above is a **scalar**—a single number meant to summarize the treatment effect across all units and time periods. This is a fundamental limitation: it assumes the treatment effect is constant, regardless of when a unit was treated or how long they've been treated.\n",
"\n",
"Event-study and modern staggered DiD approaches address this by treating treatment effects as **dynamic**—allowing $\\tau$ to vary by event-time (time relative to treatment). This captures realistic patterns like effects that build up over time, decay, or differ across treatment cohorts.\n",
":::\n",
"\n",
"### The Imputation-Based Solution\n",
"\n",
"CausalPy's `StaggeredDifferenceInDifferences` uses an imputation-based approach inspired by {cite:t}`borusyak2024revisiting`:\n",
"\n",
"1. **Fit a model on untreated observations only** - using pre-treatment periods for eventually-treated units plus all periods for never-treated units\n",
"2. **Predict counterfactual outcomes** for all observations\n",
"3. **Compute treatment effects** as the difference between observed and predicted outcomes\n",
"4. **Aggregate effects** by event-time (time relative to treatment) for an event-study curve\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import causalpy as cp\n",
"from causalpy.data.simulate_data import generate_staggered_did_data\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=FutureWarning)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"%config InlineBackend.figure_format = 'retina'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate Synthetic Staggered Panel Data\n",
"\n",
"We'll create synthetic data with:\n",
"- 50 units observed over 20 time periods\n",
"- 3 treatment cohorts adopting at times 5, 10, and 15\n",
"- Some never-treated units\n",
"- A known dynamic treatment effect that we can verify\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
"
]
},
"metadata": {
"image/png": {
"height": 592,
"width": 1097
}
},
"output_type": "display_data"
}
],
"source": [
"# Create a heatmap of treatment status\n",
"fig, ax = plt.subplots(figsize=(12, 6))\n",
"\n",
"# Pivot to get unit x time matrix of treatment status\n",
"treatment_matrix = df.pivot(index=\"unit\", columns=\"time\", values=\"treated\")\n",
"\n",
"# Sort by treatment time for better visualization\n",
"unit_treatment_times = df.groupby(\"unit\")[\"treatment_time\"].first().sort_values()\n",
"treatment_matrix = treatment_matrix.loc[unit_treatment_times.index]\n",
"\n",
"im = ax.imshow(\n",
" treatment_matrix.values,\n",
" aspect=\"auto\",\n",
" cmap=\"Greys\",\n",
" interpolation=\"nearest\",\n",
" vmin=0,\n",
" vmax=1,\n",
")\n",
"ax.set_xlabel(\"Time Period\")\n",
"ax.set_ylabel(\"Unit (sorted by treatment time)\")\n",
"ax.set_title(\n",
" \"Staggered Treatment Adoption Pattern\\n(White = Untreated, Black = Treated)\"\n",
")\n",
"from matplotlib.ticker import MultipleLocator\n",
"\n",
"ax.yaxis.set_minor_locator(MultipleLocator(1))\n",
"plt.colorbar(im, ax=ax, label=\"Treatment Status\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fit the Staggered DiD Model\n",
"\n",
"We'll use a model with unit and time fixed effects, which is the baseline counterfactual model for imputation.\n",
"\n",
"The formula defines a model of **untreated potential outcomes**. Crucially, this model is fitted only on observations that are not yet treated: pre-treatment periods for units that eventually receive treatment, plus all periods for never-treated units. The fitted model is then used to predict what the treated units *would have* experienced in the absence of treatment (the counterfactual). Treatment effects are computed as the difference between observed outcomes and these counterfactual predictions.\n",
"\n",
"The formula `y ~ 1 + C(unit) + C(time)` specifies a two-way fixed effects model, but you're not limited to this specification. If you have additional covariates that help explain variation in the outcome—such as weather conditions, seasonality indicators, economic indicators, or any other time-varying controls—you can include them in the formula. For example, `y ~ 1 + C(unit) + C(time) + temperature + holiday` would add temperature and holiday effects to the model. Including relevant covariates can improve the precision of your treatment effect estimates and strengthen the plausibility of the parallel trends assumption.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [beta, y_hat_sigma]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "96ac2978b9a4472093d22aeb5eb7366e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.\n",
"The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details\n",
"The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details\n",
"Sampling: [beta, y_hat, y_hat_sigma]\n",
"Sampling: [y_hat]\n",
"Sampling: [y_hat]\n"
]
}
],
"source": [
"# Fit the staggered DiD model with PyMC\n",
"result = cp.StaggeredDifferenceInDifferences(\n",
" df,\n",
" formula=\"y ~ 1 + C(unit) + C(time)\",\n",
" unit_variable_name=\"unit\",\n",
" time_variable_name=\"time\",\n",
" treated_variable_name=\"treated\",\n",
" treatment_time_variable_name=\"treatment_time\",\n",
" model=cp.pymc_models.LinearRegression(\n",
" sample_kwargs={\n",
" \"progressbar\": True,\n",
" \"random_seed\": 42,\n",
" }\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Event-Study Plot\n",
"\n",
"The event-study plot shows estimates at each event-time (time relative to treatment). Key features:\n",
"\n",
"- **Pre-treatment placebo estimates** (event-time < 0, gray squares): These should be close to zero if the parallel trends assumption holds. They are computed as residuals (observed minus predicted) for eventually-treated units before they receive treatment. These are **not** treatment effects—they are fit diagnostics.\n",
"- **Post-treatment ATT estimates** (event-time ≥ 0, blue circles): These are the actual Average Treatment effect on the Treated (ATT) estimates showing the dynamic treatment effect over time since treatment.\n",
"- **Error bars**: 94% Highest Density Intervals (HDI) from the Bayesian posterior\n",
"- **Gray shaded region**: Pre-treatment period (placebo check zone)\n",
"\n",
":::{note}\n",
"Since we generated synthetic data with known treatment effects, we can overlay the true effects on the plot to validate the estimator's performance. In real applications, the true effects are unknown.\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"
"
]
},
"metadata": {
"image/png": {
"height": 611,
"width": 1011
}
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = result.plot()\n",
"\n",
"# Overlay true treatment effects (only possible because we simulated the data)\n",
"att_et = result.att_event_time_\n",
"post_treatment = att_et[att_et[\"event_time\"] >= 0]\n",
"true_vals = [TRUE_EFFECTS.get(e, 2.5) for e in post_treatment[\"event_time\"]]\n",
"ax[0].scatter(\n",
" post_treatment[\"event_time\"],\n",
" true_vals,\n",
" color=\"red\",\n",
" marker=\"x\",\n",
" s=100,\n",
" linewidths=2,\n",
" zorder=5,\n",
" label=\"True Effect\",\n",
")\n",
"ax[0].legend()\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## View Summary Statistics\n",
"\n",
"The `n_obs` column in the event-time ATT table shows the number of treated unit-time observations contributing to each event-time estimate. This varies across event-times because different cohorts have different lengths of post-treatment history. For example, units treated early in the panel contribute observations at all event-times, while units treated later only contribute to earlier event-times (e.g., event-time 0, 1, 2) before the panel ends.\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================Staggered Difference in Differences=======================\n",
"Formula: y ~ 1 + C(unit) + C(time)\n",
"Number of units: 50\n",
"Number of time periods: 20\n",
"Treatment cohorts: [np.float64(5.0), np.float64(10.0), np.float64(15.0)]\n",
"Never-treated units: 14\n",
"\n",
"Event-time estimates:\n",
" event_time type att att_lower att_upper n_obs\n",
" -15 placebo -0.089727 -0.182651 0.001678 12\n",
" -14 placebo -0.076510 -0.165531 0.015464 12\n",
" -13 placebo -0.117002 -0.208039 -0.026421 12\n",
" -12 placebo 0.030745 -0.059782 0.121330 12\n",
" -11 placebo 0.074539 -0.016959 0.164344 12\n",
" -10 placebo -0.076136 -0.143237 -0.012652 24\n",
" -9 placebo 0.081386 0.014587 0.148938 24\n",
" -8 placebo -0.013679 -0.080336 0.051946 24\n",
" -7 placebo -0.032442 -0.098791 0.033575 24\n",
" -6 placebo 0.004729 -0.060527 0.071197 24\n",
" -5 placebo 0.021254 -0.042810 0.082547 36\n",
" -4 placebo 0.018158 -0.044094 0.081336 36\n",
" -3 placebo 0.051156 -0.011460 0.112558 36\n",
" -2 placebo 0.024409 -0.037726 0.085558 36\n",
" -1 placebo -0.033889 -0.096769 0.030393 36\n",
" 0 ATT 1.074842 0.990362 1.159714 36\n",
" 1 ATT 1.477406 1.395102 1.559364 36\n",
" 2 ATT 1.936514 1.854290 2.016924 36\n",
" 3 ATT 2.491325 2.412439 2.572188 36\n",
" 4 ATT 2.494277 2.412555 2.577045 36\n",
" 5 ATT 2.439444 2.327826 2.554720 24\n",
" 6 ATT 2.508034 2.391311 2.625448 24\n",
" 7 ATT 2.495111 2.382095 2.607191 24\n",
" 8 ATT 2.554547 2.444958 2.666289 24\n",
" 9 ATT 2.504438 2.390342 2.620304 24\n",
" 10 ATT 2.699058 2.515626 2.879858 12\n",
" 11 ATT 2.379267 2.192838 2.562129 12\n",
" 12 ATT 2.441211 2.260847 2.617471 12\n",
" 13 ATT 2.580254 2.405220 2.759980 12\n",
" 14 ATT 2.625220 2.452779 2.803349 12\n",
"\n",
"Model coefficients:\n",
"Model coefficients:\n",
" Intercept 0.78, 94% HDI [0.63, 0.95]\n",
" C(unit)[T.1] -2.6, 94% HDI [-2.9, -2.3]\n",
" C(unit)[T.2] 0.92, 94% HDI [0.62, 1.2]\n",
" C(unit)[T.3] 1.4, 94% HDI [1.2, 1.6]\n",
" C(unit)[T.4] -4.4, 94% HDI [-4.6, -4.2]\n",
" C(unit)[T.5] -3.2, 94% HDI [-3.4, -3]\n",
" C(unit)[T.6] -0.37, 94% HDI [-0.58, -0.18]\n",
" C(unit)[T.7] -1.1, 94% HDI [-1.3, -0.94]\n",
" C(unit)[T.8] -0.43, 94% HDI [-0.73, -0.13]\n",
" C(unit)[T.9] -2.3, 94% HDI [-2.6, -2.1]\n",
" C(unit)[T.10] 1.4, 94% HDI [1.2, 1.5]\n",
" C(unit)[T.11] 1, 94% HDI [0.73, 1.3]\n",
" C(unit)[T.12] -0.37, 94% HDI [-0.58, -0.18]\n",
" C(unit)[T.13] 1.8, 94% HDI [1.6, 2]\n",
" C(unit)[T.14] 0.34, 94% HDI [0.12, 0.54]\n",
" C(unit)[T.15] -2.3, 94% HDI [-2.6, -2.1]\n",
" C(unit)[T.16] 0.21, 94% HDI [-0.088, 0.51]\n",
" C(unit)[T.17] -2.4, 94% HDI [-2.6, -2.2]\n",
" C(unit)[T.18] 1.2, 94% HDI [0.99, 1.5]\n",
" C(unit)[T.19] -0.7, 94% HDI [-0.89, -0.5]\n",
" C(unit)[T.20] -0.9, 94% HDI [-1.1, -0.7]\n",
" C(unit)[T.21] -1.7, 94% HDI [-2, -1.4]\n",
" C(unit)[T.22] 2, 94% HDI [1.8, 2.2]\n",
" C(unit)[T.23] -0.9, 94% HDI [-1.1, -0.72]\n",
" C(unit)[T.24] -1.3, 94% HDI [-1.6, -1.1]\n",
" C(unit)[T.25] -1.3, 94% HDI [-1.6, -0.95]\n",
" C(unit)[T.26] 0.51, 94% HDI [0.27, 0.74]\n",
" C(unit)[T.27] 0.18, 94% HDI [-0.021, 0.37]\n",
" C(unit)[T.28] 0.23, 94% HDI [0.015, 0.43]\n",
" C(unit)[T.29] 0.18, 94% HDI [-0.025, 0.39]\n",
" C(unit)[T.30] 3.7, 94% HDI [3.5, 4]\n",
" C(unit)[T.31] -1.3, 94% HDI [-1.6, -1.1]\n",
" C(unit)[T.32] -1.4, 94% HDI [-1.6, -1.2]\n",
" C(unit)[T.33] -2.1, 94% HDI [-2.3, -1.9]\n",
" C(unit)[T.34] 0.65, 94% HDI [0.35, 0.96]\n",
" C(unit)[T.35] 1.8, 94% HDI [1.5, 2.1]\n",
" C(unit)[T.36] -0.81, 94% HDI [-1, -0.61]\n",
" C(unit)[T.37] -2.1, 94% HDI [-2.3, -1.8]\n",
" C(unit)[T.38] -2.3, 94% HDI [-2.5, -2.1]\n",
" C(unit)[T.39] 0.68, 94% HDI [0.48, 0.87]\n",
" C(unit)[T.40] 0.91, 94% HDI [0.66, 1.1]\n",
" C(unit)[T.41] 0.7, 94% HDI [0.47, 0.93]\n",
" C(unit)[T.42] -2.1, 94% HDI [-2.3, -1.9]\n",
" C(unit)[T.43] -0.16, 94% HDI [-0.37, 0.045]\n",
" C(unit)[T.44] -0.2, 94% HDI [-0.41, 0.0086]\n",
" C(unit)[T.45] -0.0012, 94% HDI [-0.29, 0.28]\n",
" C(unit)[T.46] 1.1, 94% HDI [0.91, 1.4]\n",
" C(unit)[T.47] -0.23, 94% HDI [-0.52, 0.066]\n",
" C(unit)[T.48] 0.81, 94% HDI [0.59, 1]\n",
" C(unit)[T.49] -0.54, 94% HDI [-0.84, -0.25]\n",
" C(time)[T.1] 0.38, 94% HDI [0.26, 0.49]\n",
" C(time)[T.2] -1.7, 94% HDI [-1.8, -1.6]\n",
" C(time)[T.3] -0.56, 94% HDI [-0.68, -0.44]\n",
" C(time)[T.4] -0.77, 94% HDI [-0.89, -0.66]\n",
" C(time)[T.5] -0.96, 94% HDI [-1.1, -0.84]\n",
" C(time)[T.6] -0.51, 94% HDI [-0.64, -0.38]\n",
" C(time)[T.7] 1.3, 94% HDI [1.2, 1.4]\n",
" C(time)[T.8] -1, 94% HDI [-1.1, -0.89]\n",
" C(time)[T.9] 0.72, 94% HDI [0.59, 0.84]\n",
" C(time)[T.10] -1.9, 94% HDI [-2.1, -1.8]\n",
" C(time)[T.11] -0.55, 94% HDI [-0.7, -0.42]\n",
" C(time)[T.12] -0.028, 94% HDI [-0.17, 0.11]\n",
" C(time)[T.13] 0.38, 94% HDI [0.24, 0.53]\n",
" C(time)[T.14] 0.51, 94% HDI [0.37, 0.65]\n",
" C(time)[T.15] 0.45, 94% HDI [0.28, 0.63]\n",
" C(time)[T.16] -0.61, 94% HDI [-0.79, -0.42]\n",
" C(time)[T.17] -0.69, 94% HDI [-0.87, -0.5]\n",
" C(time)[T.18] 0.55, 94% HDI [0.37, 0.73]\n",
" C(time)[T.19] -0.55, 94% HDI [-0.73, -0.37]\n",
" y_hat_sigma 0.31, 94% HDI [0.29, 0.33]\n"
]
}
],
"source": [
"result.summary(round_to=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Effect Summary\n",
"\n",
"Get a prose summary of the causal effects. The summary includes the average post-treatment effect, and if pre-treatment placebo effects are available, it reports on whether the parallel trends assumption appears to be satisfied:\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Staggered DiD analysis: The average post-treatment effect across event-times was 2.31 (average 94% HDI [2.19, 2.44]). Pre-treatment placebo check: Average pre-treatment effect was -0.01, consistent with parallel trends assumption. Analysis includes 3 treatment cohort(s).\n"
]
}
],
"source": [
"effect_summary = result.effect_summary()\n",
"print(effect_summary.text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pre-Treatment Placebo Check\n",
"\n",
"The `att_event_time_` table includes pre-treatment event times (negative values). These represent the average **residuals** (observed - predicted) for eventually-treated units *before* they receive treatment.\n",
"\n",
":::{important}\n",
"Pre-treatment estimates are **not** ATT (Average Treatment effect on the Treated). They are **placebo/fit diagnostics** that validate the counterfactual model and parallel trends assumption. If these values are close to zero, it suggests the model fits well and parallel trends is plausible.\n",
":::\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pre-treatment (placebo) effects:\n",
" Mean: -0.009\n",
" Should be close to zero if parallel trends holds\n",
"\n",
"Post-treatment effects:\n",
" Mean: 2.313\n",
" True average effect: 1.750\n"
]
}
],
"source": [
"# Separate pre- and post-treatment effects\n",
"att_et = result.att_event_time_\n",
"pre_treatment = att_et[att_et[\"event_time\"] < 0]\n",
"post_treatment = att_et[att_et[\"event_time\"] >= 0]\n",
"\n",
"print(\"Pre-treatment (placebo) effects:\")\n",
"print(f\" Mean: {pre_treatment['att'].mean():.3f}\")\n",
"print(f\" Should be close to zero if parallel trends holds\")\n",
"print()\n",
"print(\"Post-treatment effects:\")\n",
"print(f\" Mean: {post_treatment['att'].mean():.3f}\")\n",
"print(f\" True average effect: {np.mean(list(TRUE_EFFECTS.values())):.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Examine the Event-Time ATT Table\n",
"\n",
"The `att_event_time_` attribute provides direct access to the underlying event-time ATT estimates as a pandas DataFrame. This is the same data visualized in the event-study plot, but in tabular form.\n",
"\n",
"**When to use this table:**\n",
"\n",
"- **Reporting precise estimates**: When you need exact point estimates and credible intervals for a paper or presentation, rather than reading approximate values from a plot.\n",
"- **Custom analysis**: When you want to perform additional calculations on the estimates, such as computing cumulative effects, testing specific hypotheses, or comparing effects at particular event-times.\n",
"- **Debugging and validation**: When checking that the model is behaving as expected, or comparing estimates across different model specifications.\n",
"- **Exporting results**: When you need to save estimates to a file or integrate them into a larger analysis pipeline.\n",
"\n",
"The table includes `event_time` (periods relative to treatment), `type` (placebo vs ATT), point estimates (`att`), credible/confidence intervals, and `n_obs` (sample size at each event-time).\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
event_time
\n",
"
att
\n",
"
att_lower
\n",
"
att_upper
\n",
"
n_obs
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
-15
\n",
"
-0.089727
\n",
"
-0.182651
\n",
"
0.001678
\n",
"
12
\n",
"
\n",
"
\n",
"
1
\n",
"
-14
\n",
"
-0.076510
\n",
"
-0.165531
\n",
"
0.015464
\n",
"
12
\n",
"
\n",
"
\n",
"
2
\n",
"
-13
\n",
"
-0.117002
\n",
"
-0.208039
\n",
"
-0.026421
\n",
"
12
\n",
"
\n",
"
\n",
"
3
\n",
"
-12
\n",
"
0.030745
\n",
"
-0.059782
\n",
"
0.121330
\n",
"
12
\n",
"
\n",
"
\n",
"
4
\n",
"
-11
\n",
"
0.074539
\n",
"
-0.016959
\n",
"
0.164344
\n",
"
12
\n",
"
\n",
"
\n",
"
5
\n",
"
-10
\n",
"
-0.076136
\n",
"
-0.143237
\n",
"
-0.012652
\n",
"
24
\n",
"
\n",
"
\n",
"
6
\n",
"
-9
\n",
"
0.081386
\n",
"
0.014587
\n",
"
0.148938
\n",
"
24
\n",
"
\n",
"
\n",
"
7
\n",
"
-8
\n",
"
-0.013679
\n",
"
-0.080336
\n",
"
0.051946
\n",
"
24
\n",
"
\n",
"
\n",
"
8
\n",
"
-7
\n",
"
-0.032442
\n",
"
-0.098791
\n",
"
0.033575
\n",
"
24
\n",
"
\n",
"
\n",
"
9
\n",
"
-6
\n",
"
0.004729
\n",
"
-0.060527
\n",
"
0.071197
\n",
"
24
\n",
"
\n",
"
\n",
"
10
\n",
"
-5
\n",
"
0.021254
\n",
"
-0.042810
\n",
"
0.082547
\n",
"
36
\n",
"
\n",
"
\n",
"
11
\n",
"
-4
\n",
"
0.018158
\n",
"
-0.044094
\n",
"
0.081336
\n",
"
36
\n",
"
\n",
"
\n",
"
12
\n",
"
-3
\n",
"
0.051156
\n",
"
-0.011460
\n",
"
0.112558
\n",
"
36
\n",
"
\n",
"
\n",
"
13
\n",
"
-2
\n",
"
0.024409
\n",
"
-0.037726
\n",
"
0.085558
\n",
"
36
\n",
"
\n",
"
\n",
"
14
\n",
"
-1
\n",
"
-0.033889
\n",
"
-0.096769
\n",
"
0.030393
\n",
"
36
\n",
"
\n",
"
\n",
"
15
\n",
"
0
\n",
"
1.074842
\n",
"
0.990362
\n",
"
1.159714
\n",
"
36
\n",
"
\n",
"
\n",
"
16
\n",
"
1
\n",
"
1.477406
\n",
"
1.395102
\n",
"
1.559364
\n",
"
36
\n",
"
\n",
"
\n",
"
17
\n",
"
2
\n",
"
1.936514
\n",
"
1.854290
\n",
"
2.016924
\n",
"
36
\n",
"
\n",
"
\n",
"
18
\n",
"
3
\n",
"
2.491325
\n",
"
2.412439
\n",
"
2.572188
\n",
"
36
\n",
"
\n",
"
\n",
"
19
\n",
"
4
\n",
"
2.494277
\n",
"
2.412555
\n",
"
2.577045
\n",
"
36
\n",
"
\n",
"
\n",
"
20
\n",
"
5
\n",
"
2.439444
\n",
"
2.327826
\n",
"
2.554720
\n",
"
24
\n",
"
\n",
"
\n",
"
21
\n",
"
6
\n",
"
2.508034
\n",
"
2.391311
\n",
"
2.625448
\n",
"
24
\n",
"
\n",
"
\n",
"
22
\n",
"
7
\n",
"
2.495111
\n",
"
2.382095
\n",
"
2.607191
\n",
"
24
\n",
"
\n",
"
\n",
"
23
\n",
"
8
\n",
"
2.554547
\n",
"
2.444958
\n",
"
2.666289
\n",
"
24
\n",
"
\n",
"
\n",
"
24
\n",
"
9
\n",
"
2.504438
\n",
"
2.390342
\n",
"
2.620304
\n",
"
24
\n",
"
\n",
"
\n",
"
25
\n",
"
10
\n",
"
2.699058
\n",
"
2.515626
\n",
"
2.879858
\n",
"
12
\n",
"
\n",
"
\n",
"
26
\n",
"
11
\n",
"
2.379267
\n",
"
2.192838
\n",
"
2.562129
\n",
"
12
\n",
"
\n",
"
\n",
"
27
\n",
"
12
\n",
"
2.441211
\n",
"
2.260847
\n",
"
2.617471
\n",
"
12
\n",
"
\n",
"
\n",
"
28
\n",
"
13
\n",
"
2.580254
\n",
"
2.405220
\n",
"
2.759980
\n",
"
12
\n",
"
\n",
"
\n",
"
29
\n",
"
14
\n",
"
2.625220
\n",
"
2.452779
\n",
"
2.803349
\n",
"
12
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" event_time att att_lower att_upper n_obs\n",
"0 -15 -0.089727 -0.182651 0.001678 12\n",
"1 -14 -0.076510 -0.165531 0.015464 12\n",
"2 -13 -0.117002 -0.208039 -0.026421 12\n",
"3 -12 0.030745 -0.059782 0.121330 12\n",
"4 -11 0.074539 -0.016959 0.164344 12\n",
"5 -10 -0.076136 -0.143237 -0.012652 24\n",
"6 -9 0.081386 0.014587 0.148938 24\n",
"7 -8 -0.013679 -0.080336 0.051946 24\n",
"8 -7 -0.032442 -0.098791 0.033575 24\n",
"9 -6 0.004729 -0.060527 0.071197 24\n",
"10 -5 0.021254 -0.042810 0.082547 36\n",
"11 -4 0.018158 -0.044094 0.081336 36\n",
"12 -3 0.051156 -0.011460 0.112558 36\n",
"13 -2 0.024409 -0.037726 0.085558 36\n",
"14 -1 -0.033889 -0.096769 0.030393 36\n",
"15 0 1.074842 0.990362 1.159714 36\n",
"16 1 1.477406 1.395102 1.559364 36\n",
"17 2 1.936514 1.854290 2.016924 36\n",
"18 3 2.491325 2.412439 2.572188 36\n",
"19 4 2.494277 2.412555 2.577045 36\n",
"20 5 2.439444 2.327826 2.554720 24\n",
"21 6 2.508034 2.391311 2.625448 24\n",
"22 7 2.495111 2.382095 2.607191 24\n",
"23 8 2.554547 2.444958 2.666289 24\n",
"24 9 2.504438 2.390342 2.620304 24\n",
"25 10 2.699058 2.515626 2.879858 12\n",
"26 11 2.379267 2.192838 2.562129 12\n",
"27 12 2.441211 2.260847 2.617471 12\n",
"28 13 2.580254 2.405220 2.759980 12\n",
"29 14 2.625220 2.452779 2.803349 12"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result.att_event_time_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Group-Time ATT Table\n",
"\n",
"The `att_group_time_` attribute provides the most granular level of treatment effect estimates: effects for each combination of treatment cohort (G) and calendar time (t). This is the \"building block\" data from which event-time effects are aggregated.\n",
"\n",
"**Understanding the difference from event-time effects:**\n",
"\n",
"- **Event-time ATT** (`att_event_time_`): Aggregates effects across cohorts at the same *relative* time since treatment (e.g., \"effect 2 periods after treatment\" averages over all cohorts).\n",
"- **Group-time ATT** (`att_group_time_`): Keeps cohort and calendar time separate, showing effects for specific cohort-time pairs (e.g., \"effect for cohort treated at t=5, observed at t=7\").\n",
"\n",
"**When to use this table:**\n",
"\n",
"- **Cohort heterogeneity analysis**: When you suspect treatment effects differ across cohorts (e.g., early adopters vs late adopters respond differently to treatment).\n",
"- **Calendar time effects**: When you want to check if treatment effects vary with calendar time, not just time since treatment (e.g., macroeconomic conditions may amplify or dampen effects).\n",
"- **Diagnostics**: When event-time effects look suspicious and you want to trace the issue back to specific cohort-time combinations.\n",
"- **Custom aggregation**: When you want to compute alternative summary measures (e.g., cohort-specific average effects, or effects weighted by cohort size).\n",
"\n",
"The table includes `cohort` (treatment time G), `time` (calendar time t), and treatment effect estimates with uncertainty intervals:\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
cohort
\n",
"
time
\n",
"
att
\n",
"
att_lower
\n",
"
att_upper
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
5.0
\n",
"
5
\n",
"
1.105831
\n",
"
0.981070
\n",
"
1.234096
\n",
"
\n",
"
\n",
"
1
\n",
"
5.0
\n",
"
6
\n",
"
1.313366
\n",
"
1.191865
\n",
"
1.438411
\n",
"
\n",
"
\n",
"
2
\n",
"
5.0
\n",
"
7
\n",
"
1.874254
\n",
"
1.738034
\n",
"
2.002803
\n",
"
\n",
"
\n",
"
3
\n",
"
5.0
\n",
"
8
\n",
"
2.324531
\n",
"
2.195446
\n",
"
2.449806
\n",
"
\n",
"
\n",
"
4
\n",
"
5.0
\n",
"
9
\n",
"
2.495319
\n",
"
2.365605
\n",
"
2.621871
\n",
"
\n",
"
\n",
"
5
\n",
"
5.0
\n",
"
10
\n",
"
2.330355
\n",
"
2.188604
\n",
"
2.468263
\n",
"
\n",
"
\n",
"
6
\n",
"
5.0
\n",
"
11
\n",
"
2.469498
\n",
"
2.326770
\n",
"
2.611914
\n",
"
\n",
"
\n",
"
7
\n",
"
5.0
\n",
"
12
\n",
"
2.452538
\n",
"
2.311010
\n",
"
2.599357
\n",
"
\n",
"
\n",
"
8
\n",
"
5.0
\n",
"
13
\n",
"
2.521460
\n",
"
2.381104
\n",
"
2.661465
\n",
"
\n",
"
\n",
"
9
\n",
"
5.0
\n",
"
14
\n",
"
2.492747
\n",
"
2.348853
\n",
"
2.636524
\n",
"
\n",
"
\n",
"
10
\n",
"
5.0
\n",
"
15
\n",
"
2.699058
\n",
"
2.515626
\n",
"
2.879858
\n",
"
\n",
"
\n",
"
11
\n",
"
5.0
\n",
"
16
\n",
"
2.379267
\n",
"
2.192838
\n",
"
2.562129
\n",
"
\n",
"
\n",
"
12
\n",
"
5.0
\n",
"
17
\n",
"
2.441211
\n",
"
2.260847
\n",
"
2.617471
\n",
"
\n",
"
\n",
"
13
\n",
"
5.0
\n",
"
18
\n",
"
2.580254
\n",
"
2.405220
\n",
"
2.759980
\n",
"
\n",
"
\n",
"
14
\n",
"
5.0
\n",
"
19
\n",
"
2.625220
\n",
"
2.452779
\n",
"
2.803349
\n",
"
\n",
"
\n",
"
15
\n",
"
10.0
\n",
"
10
\n",
"
0.995009
\n",
"
0.864157
\n",
"
1.123532
\n",
"
\n",
"
\n",
"
16
\n",
"
10.0
\n",
"
11
\n",
"
1.493974
\n",
"
1.367916
\n",
"
1.621203
\n",
"
\n",
"
\n",
"
17
\n",
"
10.0
\n",
"
12
\n",
"
1.948346
\n",
"
1.815420
\n",
"
2.081204
\n",
"
\n",
"
\n",
"
18
\n",
"
10.0
\n",
"
13
\n",
"
2.485107
\n",
"
2.350294
\n",
"
2.616990
\n",
"
\n",
"
\n",
"
19
\n",
"
10.0
\n",
"
14
\n",
"
2.378103
\n",
"
2.246150
\n",
"
2.507052
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" cohort time att att_lower att_upper\n",
"0 5.0 5 1.105831 0.981070 1.234096\n",
"1 5.0 6 1.313366 1.191865 1.438411\n",
"2 5.0 7 1.874254 1.738034 2.002803\n",
"3 5.0 8 2.324531 2.195446 2.449806\n",
"4 5.0 9 2.495319 2.365605 2.621871\n",
"5 5.0 10 2.330355 2.188604 2.468263\n",
"6 5.0 11 2.469498 2.326770 2.611914\n",
"7 5.0 12 2.452538 2.311010 2.599357\n",
"8 5.0 13 2.521460 2.381104 2.661465\n",
"9 5.0 14 2.492747 2.348853 2.636524\n",
"10 5.0 15 2.699058 2.515626 2.879858\n",
"11 5.0 16 2.379267 2.192838 2.562129\n",
"12 5.0 17 2.441211 2.260847 2.617471\n",
"13 5.0 18 2.580254 2.405220 2.759980\n",
"14 5.0 19 2.625220 2.452779 2.803349\n",
"15 10.0 10 0.995009 0.864157 1.123532\n",
"16 10.0 11 1.493974 1.367916 1.621203\n",
"17 10.0 12 1.948346 1.815420 2.081204\n",
"18 10.0 13 2.485107 2.350294 2.616990\n",
"19 10.0 14 2.378103 2.246150 2.507052"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result.att_group_time_.head(20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using scikit-learn Models\n",
"\n",
"For faster analysis, you can use scikit-learn models instead of PyMC. The error bars represent approximate 95% confidence intervals (±1.96 standard errors).\n",
"\n",
"```python\n",
"from sklearn.linear_model import LinearRegression\n",
"\n",
"result_ols = cp.StaggeredDifferenceInDifferences(\n",
" df,\n",
" formula=\"y ~ 1 + C(unit) + C(time)\",\n",
" unit_variable_name=\"unit\",\n",
" time_variable_name=\"time\",\n",
" treated_variable_name=\"treated\",\n",
" treatment_time_variable_name=\"treatment_time\",\n",
" model=LinearRegression(),\n",
")\n",
"\n",
"# Plot the event-study results\n",
"fig, ax = result_ols.plot()\n",
"plt.show()\n",
"```\n",
"\n",
"The OLS approach produces similar point estimates to the Bayesian model but runs much faster. However, the uncertainty quantification differs: OLS uses asymptotic standard errors while PyMC provides full posterior distributions.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Key Takeaways\n",
"\n",
"1. **Staggered adoption requires special handling** - standard TWFE can produce biased estimates\n",
"\n",
"2. **The imputation approach** fits a model on untreated observations and predicts counterfactuals\n",
"\n",
"3. **Event-study curves** show dynamic treatment effects and allow for parallel trends checks\n",
"\n",
"4. **Pre-treatment \"placebo\" estimates** (event-time < 0) are **not** treatment effects—they are fit diagnostics. Values near zero support the parallel trends assumption.\n",
"\n",
"5. **Post-treatment ATT estimates** (event-time ≥ 0) are the actual Average Treatment effect on the Treated\n",
"\n",
"6. **CausalPy supports both Bayesian and OLS** approaches for flexibility\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n",
"\n",
":::{bibliography}\n",
":filter: docname in docnames\n",
":::\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "CausalPy",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.14.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}