IPW vs AIPW: Bias, Variance, and Double Robustness¶
This notebook compares OnlineIPW (Inverse Probability Weighting) against OnlineAIPW (Augmented IPW, also known as the Doubly Robust estimator) on a linear causal stream.
Key takeaways:
- AIPW uses the outcome model to reduce variance (doubly robust)
- IPW relies entirely on the propensity model
- Both converge to the true ATE as
n → ∞, but AIPW typically converges faster OnlineOverlapWeightsprovides an alternative that is bounded and stable
In [1]:
Copied!
from onlinecml.datasets import LinearCausalStream
from onlinecml.reweighting import OnlineIPW, OnlineAIPW, OnlineOverlapWeights
from onlinecml.datasets import LinearCausalStream
from onlinecml.reweighting import OnlineIPW, OnlineAIPW, OnlineOverlapWeights
1. Single run comparison¶
In [2]:
Copied!
TRUE_ATE = 3.0
N = 2000
ipw = OnlineIPW()
aipw = OnlineAIPW()
ow = OnlineOverlapWeights()
for x, w, y, _ in LinearCausalStream(n=N, true_ate=TRUE_ATE, seed=0):
ipw.learn_one(x, w, y)
aipw.learn_one(x, w, y)
ow.learn_one(x, w, y)
print(f"True ATE : {TRUE_ATE:.3f}")
print(f"IPW : {ipw.predict_ate():.3f} CI={ipw.predict_ci()}")
print(f"AIPW : {aipw.predict_ate():.3f} CI={aipw.predict_ci()}")
print(f"OW : {ow.predict_ate():.3f} CI={ow.predict_ci()}")
TRUE_ATE = 3.0
N = 2000
ipw = OnlineIPW()
aipw = OnlineAIPW()
ow = OnlineOverlapWeights()
for x, w, y, _ in LinearCausalStream(n=N, true_ate=TRUE_ATE, seed=0):
ipw.learn_one(x, w, y)
aipw.learn_one(x, w, y)
ow.learn_one(x, w, y)
print(f"True ATE : {TRUE_ATE:.3f}")
print(f"IPW : {ipw.predict_ate():.3f} CI={ipw.predict_ci()}")
print(f"AIPW : {aipw.predict_ate():.3f} CI={aipw.predict_ci()}")
print(f"OW : {ow.predict_ate():.3f} CI={ow.predict_ci()}")
True ATE : 3.000 IPW : 3.021 CI=(np.float64(2.847136562939421), np.float64(3.1938877623874644)) AIPW : 3.014 CI=(np.float64(2.9185814066191345), np.float64(3.1092489291142984)) OW : 1.467 CI=(np.float64(1.379655672967889), np.float64(1.555221629485397))
2. Convergence comparison over time¶
In [3]:
Copied!
import matplotlib
import matplotlib.pyplot as plt
from onlinecml.reweighting import OnlineIPW, OnlineAIPW
LOG_EVERY = 50
ipw2 = OnlineIPW()
aipw2 = OnlineAIPW()
ipw_ates, aipw_ates, steps = [], [], []
for i, (x, w, y, _) in enumerate(LinearCausalStream(n=2000, true_ate=3.0, seed=1)):
ipw2.learn_one(x, w, y)
aipw2.learn_one(x, w, y)
if (i + 1) % LOG_EVERY == 0:
steps.append(i + 1)
ipw_ates.append(ipw2.predict_ate())
aipw_ates.append(aipw2.predict_ate())
fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(steps, ipw_ates, label="IPW", color="tab:blue")
ax.plot(steps, aipw_ates, label="AIPW", color="tab:orange")
ax.axhline(3.0, color="red", linestyle="--", label="True ATE = 3.0")
ax.set_xlabel("Observations seen")
ax.set_ylabel("Estimated ATE")
ax.set_title("IPW vs AIPW convergence")
ax.legend()
plt.tight_layout()
plt.savefig("/tmp/ipw_vs_aipw.png", dpi=100)
print("Saved to /tmp/ipw_vs_aipw.png")
import matplotlib
import matplotlib.pyplot as plt
from onlinecml.reweighting import OnlineIPW, OnlineAIPW
LOG_EVERY = 50
ipw2 = OnlineIPW()
aipw2 = OnlineAIPW()
ipw_ates, aipw_ates, steps = [], [], []
for i, (x, w, y, _) in enumerate(LinearCausalStream(n=2000, true_ate=3.0, seed=1)):
ipw2.learn_one(x, w, y)
aipw2.learn_one(x, w, y)
if (i + 1) % LOG_EVERY == 0:
steps.append(i + 1)
ipw_ates.append(ipw2.predict_ate())
aipw_ates.append(aipw2.predict_ate())
fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(steps, ipw_ates, label="IPW", color="tab:blue")
ax.plot(steps, aipw_ates, label="AIPW", color="tab:orange")
ax.axhline(3.0, color="red", linestyle="--", label="True ATE = 3.0")
ax.set_xlabel("Observations seen")
ax.set_ylabel("Estimated ATE")
ax.set_title("IPW vs AIPW convergence")
ax.legend()
plt.tight_layout()
plt.savefig("/tmp/ipw_vs_aipw.png", dpi=100)
print("Saved to /tmp/ipw_vs_aipw.png")
Saved to /tmp/ipw_vs_aipw.png
3. Multiple-seed variance comparison¶
Run both estimators on 20 different seeds and compare the spread of ATE estimates.
In [4]:
Copied!
import statistics
N_SEEDS = 20
ipw_results, aipw_results = [], []
for seed in range(N_SEEDS):
m_ipw = OnlineIPW()
m_aipw = OnlineAIPW()
for x, w, y, _ in LinearCausalStream(n=500, true_ate=3.0, seed=seed):
m_ipw.learn_one(x, w, y)
m_aipw.learn_one(x, w, y)
ipw_results.append(m_ipw.predict_ate())
aipw_results.append(m_aipw.predict_ate())
print(f"IPW mean={statistics.mean(ipw_results):.3f} std={statistics.stdev(ipw_results):.3f}")
print(f"AIPW mean={statistics.mean(aipw_results):.3f} std={statistics.stdev(aipw_results):.3f}")
import statistics
N_SEEDS = 20
ipw_results, aipw_results = [], []
for seed in range(N_SEEDS):
m_ipw = OnlineIPW()
m_aipw = OnlineAIPW()
for x, w, y, _ in LinearCausalStream(n=500, true_ate=3.0, seed=seed):
m_ipw.learn_one(x, w, y)
m_aipw.learn_one(x, w, y)
ipw_results.append(m_ipw.predict_ate())
aipw_results.append(m_aipw.predict_ate())
print(f"IPW mean={statistics.mean(ipw_results):.3f} std={statistics.stdev(ipw_results):.3f}")
print(f"AIPW mean={statistics.mean(aipw_results):.3f} std={statistics.stdev(aipw_results):.3f}")
IPW mean=3.794 std=0.463 AIPW mean=3.262 std=0.156