OnlineCML Quickstart¶
This notebook demonstrates the core OnlineCML workflow: stream a causal dataset one observation at a time, estimate the average treatment effect (ATE), and inspect diagnostics.
What you will learn:
- How to use
LinearCausalStreamto generate synthetic data - How to fit
OnlineIPWin a streaming loop - How to read ATE estimates and confidence intervals
- How to use
OnlineSMDto check covariate balance
In [1]:
Copied!
from onlinecml.datasets import LinearCausalStream
from onlinecml.reweighting import OnlineIPW
from onlinecml.diagnostics import OnlineSMD
from onlinecml.datasets import LinearCausalStream
from onlinecml.reweighting import OnlineIPW
from onlinecml.diagnostics import OnlineSMD
1. Generate a synthetic stream¶
LinearCausalStream yields (x_dict, treatment, outcome, true_cate) tuples.
The true ATE is 2.0 by default.
In [2]:
Copied!
stream = LinearCausalStream(n=1000, true_ate=2.0, seed=42)
print(f"Stream length: {len(stream)}")
# Peek at the first observation
for x, w, y, tau in stream:
print(f"x={x}, treatment={w}, outcome={y:.2f}, true_cate={tau:.2f}")
break
stream = LinearCausalStream(n=1000, true_ate=2.0, seed=42)
print(f"Stream length: {len(stream)}")
# Peek at the first observation
for x, w, y, tau in stream:
print(f"x={x}, treatment={w}, outcome={y:.2f}, true_cate={tau:.2f}")
break
Stream length: 1000
x={'x0': -1.302179506862318, 'x1': 0.12784040316728537, 'x2': -0.3162425923435822, 'x3': -0.016801157504288795, 'x4': -0.85304392757358}, treatment=1, outcome=3.66, true_cate=2.00
2. Fit OnlineIPW in a streaming loop¶
In [3]:
Copied!
ipw = OnlineIPW()
smd = OnlineSMD(covariates=["x0", "x1", "x2"])
for x, w, y, tau in LinearCausalStream(n=1000, true_ate=2.0, seed=42):
ps = ipw.ps_model.predict_one(x)
smd.update(x, treatment=w, weight=ipw.ps_model.ipw_weight(x, w))
ipw.learn_one(x, w, y)
print(f"Estimated ATE : {ipw.predict_ate():.3f}")
print(f"True ATE : 2.000")
print(f"95% CI : {ipw.predict_ci()}")
print(f"n_seen : {ipw.n_seen}")
ipw = OnlineIPW()
smd = OnlineSMD(covariates=["x0", "x1", "x2"])
for x, w, y, tau in LinearCausalStream(n=1000, true_ate=2.0, seed=42):
ps = ipw.ps_model.predict_one(x)
smd.update(x, treatment=w, weight=ipw.ps_model.ipw_weight(x, w))
ipw.learn_one(x, w, y)
print(f"Estimated ATE : {ipw.predict_ate():.3f}")
print(f"True ATE : 2.000")
print(f"95% CI : {ipw.predict_ci()}")
print(f"n_seen : {ipw.n_seen}")
Estimated ATE : 2.405 True ATE : 2.000 95% CI : (np.float64(2.042000117883833), np.float64(2.7681087859491837)) n_seen : 1000
3. Covariate balance¶
In [4]:
Copied!
report = smd.report()
print("Covariate balance (raw SMD | weighted SMD):")
for cov, (raw, weighted) in report.items():
print(f" {cov}: raw={raw:+.3f} weighted={weighted:+.3f}")
print(f"\nBalance adequate (|SMD| < 0.1): {smd.is_balanced()}")
report = smd.report()
print("Covariate balance (raw SMD | weighted SMD):")
for cov, (raw, weighted) in report.items():
print(f" {cov}: raw={raw:+.3f} weighted={weighted:+.3f}")
print(f"\nBalance adequate (|SMD| < 0.1): {smd.is_balanced()}")
Covariate balance (raw SMD | weighted SMD): x0: raw=+0.093 weighted=+0.040 x1: raw=-0.247 weighted=-0.105 x2: raw=+0.125 weighted=+0.033 Balance adequate (|SMD| < 0.1): False
4. ATE convergence plot¶
In [5]:
Copied!
import matplotlib
import matplotlib.pyplot as plt
from onlinecml.diagnostics import ATETracker
from onlinecml.reweighting import OnlineIPW
tracker = ATETracker(log_every=10, warmup=50)
ipw2 = OnlineIPW()
for x, w, y, tau in LinearCausalStream(n=1000, true_ate=2.0, seed=42):
ipw2.learn_one(x, w, y)
# Compute IPW pseudo-outcome for tracker
ps = ipw2.ps_model.predict_one(x)
ps = max(0.01, min(0.99, ps))
psi = (w * y / ps) - ((1 - w) * y / (1 - ps))
tracker.update(psi)
ax = tracker.plot()
ax.axhline(2.0, color="red", linestyle="--", label="True ATE")
ax.legend()
plt.tight_layout()
plt.savefig("/tmp/ate_convergence.png", dpi=100)
print("Saved to /tmp/ate_convergence.png")
import matplotlib
import matplotlib.pyplot as plt
from onlinecml.diagnostics import ATETracker
from onlinecml.reweighting import OnlineIPW
tracker = ATETracker(log_every=10, warmup=50)
ipw2 = OnlineIPW()
for x, w, y, tau in LinearCausalStream(n=1000, true_ate=2.0, seed=42):
ipw2.learn_one(x, w, y)
# Compute IPW pseudo-outcome for tracker
ps = ipw2.ps_model.predict_one(x)
ps = max(0.01, min(0.99, ps))
psi = (w * y / ps) - ((1 - w) * y / (1 - ps))
tracker.update(psi)
ax = tracker.plot()
ax.axhline(2.0, color="red", linestyle="--", label="True ATE")
ax.legend()
plt.tight_layout()
plt.savefig("/tmp/ate_convergence.png", dpi=100)
print("Saved to /tmp/ate_convergence.png")
Saved to /tmp/ate_convergence.png