Base
Core abstract classes and statistical utilities.
BaseOnlineEstimator
onlinecml.base.base_estimator.BaseOnlineEstimator
Bases: Base
Abstract base class for all online causal estimators.
Every estimator in OnlineCML inherits from this class. It provides the standard interface for online causal inference: processing one observation at a time, estimating the Average Treatment Effect (ATE), and reporting confidence intervals.
Inherits from river.base.Base (not river.base.Estimator) to
avoid signature conflicts: our learn_one takes (x, treatment,
outcome, propensity) while River's Estimator expects (x, y).
Notes
Subclasses must implement learn_one and predict_one.
All constructor parameters must be stored as self.param_name
(matching the parameter name exactly) so that clone() and
_get_params() work correctly.
Non-constructor state (_n_seen, _ate_stats) is initialized
in each concrete __init__. It is intentionally NOT cloned —
clone() returns a fresh estimator with zero observations.
reset() re-initializes all state by calling __init__ again.
The predict_ci method returns a confidence interval for the ATE
(mean CATE), not for individual CATE predictions. This uses a normal
approximation via the central limit theorem applied to the running
pseudo-outcome variance.
Source code in onlinecml/base/base_estimator.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | |
n_seen
property
Number of observations processed so far.
smd
property
Current Standardized Mean Difference per covariate.
Returns None by default. Subclasses that maintain per-group statistics (e.g. IPW-based estimators) override this property.
Returns:
| Type | Description |
|---|---|
dict or None
|
Mapping from covariate name to (raw_smd, weighted_smd), or None if this estimator does not track balance diagnostics. |
learn_one(x, treatment, outcome, propensity=None)
abstractmethod
Process one observation and update the estimator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
dict
|
Feature dictionary for this observation. |
required |
treatment
|
int
|
Treatment indicator (0 = control, 1 = treated). |
required |
outcome
|
float
|
Observed outcome for this unit. |
required |
propensity
|
float or None
|
Known or logged propensity P(W=1|X). If None, the estimator will use its internal propensity model. |
None
|
Source code in onlinecml/base/base_estimator.py
predict_ate()
Return the current running ATE estimate.
Returns:
| Type | Description |
|---|---|
float
|
The current Average Treatment Effect estimate. Returns 0.0 before any observations have been processed. |
Source code in onlinecml/base/base_estimator.py
predict_ci(alpha=0.05)
Return a confidence interval for the ATE estimate.
Uses a normal approximation via the central limit theorem applied to the running variance of per-observation pseudo-outcomes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alpha
|
float
|
Significance level. Default 0.05 gives a 95% CI. |
0.05
|
Returns:
| Name | Type | Description |
|---|---|---|
lower |
float
|
Lower bound of the confidence interval. |
upper |
float
|
Upper bound of the confidence interval. |
Notes
Returns (-inf, inf) before at least 2 observations are seen.
The CI is for the ATE (mean CATE), not for individual CATE predictions.
Source code in onlinecml/base/base_estimator.py
predict_one(x)
abstractmethod
Predict the CATE for a single unit.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
dict
|
Feature dictionary for the unit. |
required |
Returns:
| Type | Description |
|---|---|
float
|
Estimated Conditional Average Treatment Effect for this unit. |
Source code in onlinecml/base/base_estimator.py
reset()
Reset the estimator to its initial (untrained) state.
Equivalent to creating a fresh instance with the same constructor arguments. All learned state is discarded.
Source code in onlinecml/base/base_estimator.py
BasePolicy
onlinecml.base.base_policy.BasePolicy
Bases: Base
Abstract base class for treatment exploration policies.
A policy decides which treatment to assign and with what probability, given a CATE score estimate and the current step count. Subclasses implement different exploration-exploitation trade-offs.
All policies follow River conventions: constructor parameters are stored as instance attributes with the same name, enabling clone() and _get_params() to work correctly.
Notes
This class intentionally does NOT inherit from River's bandit Policy because our interface operates on CATE scores (continuous real-valued estimates of causal effects) rather than bandit arm indices. The returned propensity is the probability of the chosen treatment under the policy, which is used for IPW correction downstream.
Source code in onlinecml/base/base_policy.py
choose(cate_score, step)
abstractmethod
Choose a treatment assignment given the current CATE estimate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cate_score
|
float
|
Current CATE estimate for the unit. Positive values suggest treatment is beneficial; negative values suggest control. |
required |
step
|
int
|
The current time step (used for decay schedules). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
treatment |
int
|
The chosen treatment assignment (0 or 1). |
propensity |
float
|
The probability of the chosen treatment under this policy. Used for IPW correction in downstream estimators. |
Source code in onlinecml/base/base_policy.py
reset()
update(reward)
Update policy state after observing a reward.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reward
|
float
|
The observed outcome after applying the chosen treatment. Default implementation is a no-op; override for adaptive policies. |
required |
Source code in onlinecml/base/base_policy.py
RunningStats
onlinecml.base.running_stats.RunningStats
Online mean and variance using Welford's single-pass algorithm.
Computes the sample mean, variance, and standard deviation of a stream of scalar values without storing the data.
Examples:
>>> stats = RunningStats()
>>> for x in [2.0, 4.0, 6.0]:
... stats.update(x)
>>> stats.mean
4.0
>>> stats.n
3
Source code in onlinecml/base/running_stats.py
mean
property
Current sample mean. Returns 0.0 before any observations.
n
property
Number of observations seen.
std
property
Current sample standard deviation. Returns 0.0 when n < 2.
variance
property
Current sample variance (ddof=1). Returns 0.0 when n < 2.
reset()
update(x)
Update statistics with a new observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
float
|
The new scalar value to incorporate. |
required |
Source code in onlinecml/base/running_stats.py
WeightedRunningStats
onlinecml.base.running_stats.WeightedRunningStats
Online weighted mean and variance using West's (1979) algorithm.
Computes the weighted mean and population-weighted variance of a stream of scalar values with associated importance weights.
Notes
Returns population-weighted variance (S / sum_w), not sample variance. This is appropriate for SMD computation where we normalize by pooled standard deviation, not for statistical inference.
References
West, D.H.D. (1979). Updating mean and variance estimates: an improved method. Communications of the ACM, 22(9), 532-535.
Examples:
>>> stats = WeightedRunningStats()
>>> stats.update(2.0, w=1.0)
>>> stats.update(4.0, w=2.0)
>>> stats.mean # (2*1 + 4*2) / 3 = 10/3
3.3333333333333335
Source code in onlinecml/base/running_stats.py
mean
property
Current weighted mean. Returns 0.0 before any observations.
std
property
Population-weighted standard deviation. Returns 0.0 when sum_w <= 0.
sum_weights
property
Sum of all weights seen so far.
variance
property
Population-weighted variance (S / sum_w). Returns 0.0 when sum_w <= 0.
reset()
update(x, w=1.0)
Update statistics with a new weighted observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
float
|
The new scalar value to incorporate. |
required |
w
|
float
|
Non-negative importance weight. Silently ignored if w <= 0. |
1.0
|