Diagnostics
onlinecml.diagnostics.ate_tracker.ATETracker
Tracks the running ATE estimate with online confidence intervals.
Maintains a running mean and variance of per-observation pseudo-outcomes and optionally records history for convergence plotting.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_every
|
int
|
Append a history entry every |
1
|
warmup
|
int
|
Number of initial pseudo-outcomes to skip when accumulating the ATE estimate. History is not recorded during warmup. Default 0. |
0
|
forgetting_factor
|
float
|
Controls how quickly old pseudo-outcomes are forgotten.
|
1.0
|
Notes
Unlike BaseOnlineEstimator, this is a standalone diagnostic tool
that users instantiate separately and feed pseudo-outcomes into. It is
not tied to any specific estimation method.
When forgetting_factor < 1.0, the internal Welford state is replaced
by an EWMA (EWMAStats). The CI formula remains mean ± z * sqrt(var/n)
where var and n come from the EWMA estimates.
Examples:
>>> tracker = ATETracker(log_every=10)
>>> for pseudo_outcome in [1.5, 2.3, 1.8, 2.1]:
... tracker.update(pseudo_outcome)
>>> abs(tracker.ate - 1.925) < 1e-10
True
Source code in onlinecml/diagnostics/ate_tracker.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | |
ate
property
Current ATE estimate (running mean of pseudo-outcomes).
Returns 0.0 before any observations are seen (or during warmup).
history
property
Recorded history as a list of (step, ate, ci_lower, ci_upper) tuples.
n
property
Number of pseudo-outcomes processed (excluding warmup).
ci(alpha=0.05)
Return a confidence interval for the current ATE estimate.
Uses a normal approximation via the central limit theorem.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alpha
|
float
|
Significance level. Default 0.05 gives a 95% CI. |
0.05
|
Returns:
| Name | Type | Description |
|---|---|---|
lower |
float
|
Lower bound of the confidence interval. |
upper |
float
|
Upper bound of the confidence interval. |
Notes
Returns (-inf, inf) before at least 2 observations are seen.
Source code in onlinecml/diagnostics/ate_tracker.py
convergence_width(alpha=0.05)
Return the current confidence interval width.
Useful as an early-stopping criterion: stop collecting data when the CI width falls below a target threshold.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alpha
|
float
|
Significance level for the CI. Default 0.05. |
0.05
|
Returns:
| Type | Description |
|---|---|
float
|
Width of the current CI ( |
Source code in onlinecml/diagnostics/ate_tracker.py
plot(ax=None)
Plot the ATE convergence curve with shaded confidence band.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ax
|
Axes or None
|
Axes to plot on. If None, creates a new figure. |
None
|
Returns:
| Type | Description |
|---|---|
Axes
|
The axes with the convergence plot. |
Source code in onlinecml/diagnostics/ate_tracker.py
reset()
update(pseudo_outcome)
Incorporate one pseudo-outcome into the running ATE estimate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pseudo_outcome
|
float
|
Per-observation pseudo-outcome (e.g. IPW score, DR score, or per-obs CATE estimate). The running mean of these values converges to the ATE under the relevant identification assumptions. |
required |
Source code in onlinecml/diagnostics/ate_tracker.py
onlinecml.diagnostics.smd.OnlineSMD
Bases: Base
Online covariate balance diagnostics via Standardized Mean Difference.
Tracks the raw and IPW-weighted SMD for a set of covariates, updated one observation at a time. Used to monitor whether treatment and control groups are comparable in covariate distributions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
covariates
|
list of str
|
Names of covariates to track. Must match keys in the feature dicts
passed to |
required |
Notes
SMD for a covariate is defined as:
.. math::
\text{SMD} = \frac{\bar{X}_T - \bar{X}_C}{\sqrt{(s_T^2 + s_C^2) / 2}}
where s^2 is the sample variance within each group. Returns 0.0
when either group has fewer than 2 observations.
Raw SMD uses unweighted RunningStats; weighted SMD uses
WeightedRunningStats with West's (1979) algorithm (population
variance). The weighted SMD is used by is_balanced.
This class does NOT inherit from BaseOnlineEstimator — it is a
standalone diagnostic tool.
Examples:
>>> smd = OnlineSMD(covariates=["age", "income"])
>>> smd.update({"age": 30, "income": 50000}, treatment=1, weight=1.2)
>>> smd.update({"age": 45, "income": 70000}, treatment=0, weight=0.8)
>>> report = smd.report()
>>> "age" in report
True
Source code in onlinecml/diagnostics/smd.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | |
is_balanced(thr=0.1)
Check whether all covariates are balanced after weighting.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
thr
|
float
|
Maximum absolute weighted SMD threshold. Default 0.1 (the conventional "well-balanced" threshold). |
0.1
|
Returns:
| Type | Description |
|---|---|
bool
|
True if all covariates have |
Source code in onlinecml/diagnostics/smd.py
report()
Return the current SMD for each tracked covariate.
Returns:
| Type | Description |
|---|---|
dict
|
Mapping from covariate name to |
Source code in onlinecml/diagnostics/smd.py
update(x, treatment, weight=1.0)
Update balance statistics with one observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
dict
|
Feature dictionary. Missing covariates default to 0.0. |
required |
treatment
|
int
|
Treatment indicator (0 = control, 1 = treated). |
required |
weight
|
float
|
Importance weight for this observation (e.g. IPW weight). Default 1.0 (unweighted). |
1.0
|
Source code in onlinecml/diagnostics/smd.py
onlinecml.diagnostics.live_love_plot.LiveLovePlot
Real-time Love Plot for monitoring covariate balance online.
Displays raw and weighted Standardized Mean Differences (SMD) for a
set of covariates. Updates the plot every update_every steps.
A vertical reference line at |SMD| = 0.1 marks the conventional
"well-balanced" threshold.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
covariates
|
list of str
|
Names of covariates to display (in order). |
required |
update_every
|
int
|
Redraw the plot every |
100
|
balance_threshold
|
float
|
Reference line position. Default 0.1. |
0.1
|
Notes
This class wraps OnlineSMD internally. Users can either pass
feature dicts directly to update or maintain an external
OnlineSMD instance and call render with its report.
The plot is only rendered when matplotlib is available. If
matplotlib is not installed, calls to update and render
are no-ops.
Examples:
>>> plot = LiveLovePlot(covariates=["age", "income"], update_every=50)
>>> plot.update({"age": 30, "income": 50000}, treatment=1, weight=1.2)
>>> ax = plot.render()
Source code in onlinecml/diagnostics/live_love_plot.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | |
render(ax=None)
Render the Love Plot from current SMD data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ax
|
Axes or None
|
Axes to render on. If None, creates or reuses internal axes. |
None
|
Returns:
| Type | Description |
|---|---|
Axes or None
|
The rendered axes, or None if matplotlib is unavailable. |
Source code in onlinecml/diagnostics/live_love_plot.py
save(path)
Save the current plot to a file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
File path (e.g. |
required |
Source code in onlinecml/diagnostics/live_love_plot.py
update(x, treatment, weight=1.0)
Update the covariate balance statistics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
dict
|
Feature dictionary for this observation. |
required |
treatment
|
int
|
Treatment indicator (0 or 1). |
required |
weight
|
float
|
Importance weight for this observation. Default 1.0. |
1.0
|
Source code in onlinecml/diagnostics/live_love_plot.py
onlinecml.diagnostics.overlap_checker.OverlapChecker
Monitors propensity score distributions for positivity violations.
Tracks the distribution of predicted propensity scores per treatment arm and raises warnings when extreme scores are detected. Reports the proportion of units in the common support region.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ps_min
|
float
|
Lower positivity threshold. PS values below this are flagged. Default 0.05. |
0.05
|
ps_max
|
float
|
Upper positivity threshold. PS values above this are flagged. Default 0.95. |
0.95
|
Notes
A unit is in common support if its propensity score satisfies
ps_min < p < ps_max. Units outside this region have unreliable
causal estimates.
The report() method returns a summary including:
- Mean PS per arm
- Proportion flagged per arm
- Overall common support rate
Examples:
>>> checker = OverlapChecker(ps_min=0.05, ps_max=0.95)
>>> checker.update(propensity=0.3, treatment=1)
>>> checker.update(propensity=0.02, treatment=0)
>>> checker.report()['n_flagged']
1
Source code in onlinecml/diagnostics/overlap_checker.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | |
is_overlap_adequate(max_flag_rate=0.05)
Return True if the positivity violation rate is acceptable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
max_flag_rate
|
float
|
Maximum tolerable fraction of flagged units. Default 0.05. |
0.05
|
Returns:
| Type | Description |
|---|---|
bool
|
True if fewer than |
Source code in onlinecml/diagnostics/overlap_checker.py
report()
Return a summary of the propensity score distribution.
Returns:
| Type | Description |
|---|---|
dict
|
Keys:
|
Source code in onlinecml/diagnostics/overlap_checker.py
reset()
update(propensity, treatment)
Record a propensity score observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
propensity
|
float
|
Predicted propensity score |
required |
treatment
|
int
|
Observed treatment indicator (0 or 1). Used to track per-arm PS distributions. |
required |
Source code in onlinecml/diagnostics/overlap_checker.py
onlinecml.diagnostics.concept_drift_monitor.ConceptDriftMonitor
Monitors the ATE estimate stream for structural breaks (concept drift).
Wraps River's ADWIN (Adaptive Windowing) drift detector to identify changes in the distribution of per-observation pseudo-outcomes, which signal a shift in the underlying treatment effect.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
delta
|
float
|
ADWIN confidence parameter. Smaller values reduce false alarm rate at the cost of slower detection. Default 0.002. |
0.002
|
Notes
ADWIN maintains an adaptive window over a data stream and raises a drift signal when the mean of the earlier and later sub-windows differ by more than the statistical threshold.
When drift is detected, drift_detected returns True and
n_drifts increments. The estimator being monitored should be
reset after drift is detected (the monitor does not do this
automatically — it only signals).
References
Bifet, A. and Gavalda, R. (2007). Learning from time-changing data with adaptive windowing. Proceedings of the 7th SIAM International Conference on Data Mining, 443-448.
Examples:
>>> monitor = ConceptDriftMonitor(delta=0.002)
>>> for pseudo_outcome in [1.0, 1.1, 0.9, 1.0, 5.0, 5.1, 4.9]:
... monitor.update(pseudo_outcome)
>>> monitor.n_drifts >= 0
True
Source code in onlinecml/diagnostics/concept_drift_monitor.py
drift_detected
property
True if drift was detected on the most recent update call.
n_drifts
property
Total number of drift events detected since initialization.
n_seen
property
Total number of observations processed.
reset()
update(pseudo_outcome)
Feed one pseudo-outcome to the drift detector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pseudo_outcome
|
float
|
Per-observation pseudo-outcome (e.g. IPW score or CATE estimate). Drift in this stream indicates a shift in the treatment effect. |
required |