Adding a New Estimator
This guide walks you through adding a new causal estimator to OnlineCML. Follow these steps to ensure your estimator integrates cleanly with the rest of the library.
1. Choose the right base class
| Base class | Use when |
|---|---|
BaseOnlineEstimator |
Estimating ATE or CATE from (x, treatment, outcome) triples |
BasePolicy |
Deciding which treatment to assign |
river.base.Base |
Helper classes (propensity models, diagnostics) |
2. Create the file
One class per file. Place it in the appropriate subpackage:
onlinecml/
├── reweighting/ # IPW-style weighting estimators
├── metalearners/ # S/T/X/R-Learner style estimators
├── matching/ # Matching-based estimators
├── forests/ # Tree and forest estimators
└── policy/ # Exploration policies
3. Implement the class
All constructor parameters must be stored as self.param_name (required for
clone() and _get_params() to work). Non-constructor state (running
statistics, counters) must be initialised in __init__ but are not
constructor parameters.
from onlinecml.base.base_estimator import BaseOnlineEstimator
from onlinecml.base.running_stats import RunningStats
class MyEstimator(BaseOnlineEstimator):
"""One-line summary.
Longer description explaining the method.
Parameters
----------
alpha : float
Description. Default 1.0.
Examples
--------
>>> from onlinecml.datasets import LinearCausalStream
>>> m = MyEstimator()
>>> for x, w, y, _ in LinearCausalStream(n=100, seed=0):
... m.learn_one(x, w, y)
>>> isinstance(m.predict_ate(), float)
True
"""
def __init__(self, alpha: float = 1.0) -> None:
# Constructor params → stored as self.param
self.alpha = alpha
# Non-constructor state
self._n_seen: int = 0
self._ate_stats = RunningStats()
def learn_one(
self,
x: dict,
treatment: int,
outcome: float,
propensity: float | None = None,
) -> None:
"""Process one observation (predict-then-learn protocol)."""
# 1. Predict BEFORE updating any model
# 2. Compute pseudo-outcome
pseudo = ...
self._ate_stats.update(pseudo)
self._n_seen += 1
# 3. Update any sub-models
def predict_one(self, x: dict) -> float:
"""Return the CATE estimate for a single unit."""
return self._ate_stats.mean # or a model prediction
4. Register in __init__.py
Add the import and __all__ entry to the subpackage's __init__.py:
5. Write tests
Create tests/unit/test_my_estimator.py with at minimum:
class TestMyEstimator:
def test_n_seen_starts_zero(self): ...
def test_predict_ate_before_data_returns_zero(self): ...
def test_learn_increments_n_seen(self): ...
def test_predict_one_returns_float(self): ...
def test_reset_clears_state(self): ...
def test_clone_gives_fresh_model(self): ...
def test_ci_finite_after_data(self): ...
6. Key design rules
- Predict-then-learn: always predict before updating any model
reset()works:BaseOnlineEstimator.reset()usesclone()internals — don't override it unless necessaryclone()gives a fresh model:n_seen == 0after cloning- No pandas, no batches: all processing must work one observation at a time
- O(1) memory: avoid growing unbounded lists (use
deque(maxlen=...)for buffers) - Thread-safe (read-only):
predict_oneshould not modify state
7. Checklist before PR
- [ ] Docstring on class and all public methods (NumPy style)
- [ ] Type hints on all parameters and return values
- [ ]
learn_one/predict_one/reset/cloneall work - [ ] Unit tests cover the standard test set above
- [ ]
ruff checkpasses - [ ]
mypypasses - [ ]
pytest tests/ --covshows ≥ 90% total coverage