Skip to content

Base

Core abstract classes and statistical utilities.

BaseOnlineEstimator

onlinecml.base.base_estimator.BaseOnlineEstimator

Bases: Base

Abstract base class for all online causal estimators.

Every estimator in OnlineCML inherits from this class. It provides the standard interface for online causal inference: processing one observation at a time, estimating the Average Treatment Effect (ATE), and reporting confidence intervals.

Inherits from river.base.Base (not river.base.Estimator) to avoid signature conflicts: our learn_one takes (x, treatment, outcome, propensity) while River's Estimator expects (x, y).

Notes

Subclasses must implement learn_one and predict_one.

All constructor parameters must be stored as self.param_name (matching the parameter name exactly) so that clone() and _get_params() work correctly.

Non-constructor state (_n_seen, _ate_stats) is initialized in each concrete __init__. It is intentionally NOT cloned — clone() returns a fresh estimator with zero observations. reset() re-initializes all state by calling __init__ again.

The predict_ci method returns a confidence interval for the ATE (mean CATE), not for individual CATE predictions. This uses a normal approximation via the central limit theorem applied to the running pseudo-outcome variance.

Source code in onlinecml/base/base_estimator.py
class BaseOnlineEstimator(Base):
    """Abstract base class for all online causal estimators.

    Every estimator in OnlineCML inherits from this class. It provides
    the standard interface for online causal inference: processing one
    observation at a time, estimating the Average Treatment Effect (ATE),
    and reporting confidence intervals.

    Inherits from ``river.base.Base`` (not ``river.base.Estimator``) to
    avoid signature conflicts: our ``learn_one`` takes ``(x, treatment,
    outcome, propensity)`` while River's Estimator expects ``(x, y)``.

    Notes
    -----
    Subclasses must implement ``learn_one`` and ``predict_one``.

    All constructor parameters must be stored as ``self.param_name``
    (matching the parameter name exactly) so that ``clone()`` and
    ``_get_params()`` work correctly.

    Non-constructor state (``_n_seen``, ``_ate_stats``) is initialized
    in each concrete ``__init__``. It is intentionally NOT cloned —
    ``clone()`` returns a fresh estimator with zero observations.
    ``reset()`` re-initializes all state by calling ``__init__`` again.

    The ``predict_ci`` method returns a confidence interval for the ATE
    (mean CATE), not for individual CATE predictions. This uses a normal
    approximation via the central limit theorem applied to the running
    pseudo-outcome variance.
    """

    @abc.abstractmethod
    def learn_one(
        self,
        x: dict,
        treatment: int,
        outcome: float,
        propensity: float | None = None,
    ) -> None:
        """Process one observation and update the estimator.

        Parameters
        ----------
        x : dict
            Feature dictionary for this observation.
        treatment : int
            Treatment indicator (0 = control, 1 = treated).
        outcome : float
            Observed outcome for this unit.
        propensity : float or None
            Known or logged propensity P(W=1|X). If None, the estimator
            will use its internal propensity model.
        """

    @abc.abstractmethod
    def predict_one(self, x: dict) -> float:
        """Predict the CATE for a single unit.

        Parameters
        ----------
        x : dict
            Feature dictionary for the unit.

        Returns
        -------
        float
            Estimated Conditional Average Treatment Effect for this unit.
        """

    def predict_ate(self) -> float:
        """Return the current running ATE estimate.

        Returns
        -------
        float
            The current Average Treatment Effect estimate. Returns 0.0
            before any observations have been processed.
        """
        return self._ate_stats.mean

    def predict_ci(self, alpha: float = 0.05) -> tuple[float, float]:
        """Return a confidence interval for the ATE estimate.

        Uses a normal approximation via the central limit theorem applied
        to the running variance of per-observation pseudo-outcomes.

        Parameters
        ----------
        alpha : float
            Significance level. Default 0.05 gives a 95% CI.

        Returns
        -------
        lower : float
            Lower bound of the confidence interval.
        upper : float
            Upper bound of the confidence interval.

        Notes
        -----
        Returns ``(-inf, inf)`` before at least 2 observations are seen.
        The CI is for the ATE (mean CATE), not for individual CATE predictions.
        """
        n = self._ate_stats.n
        if n < 2:
            return (float("-inf"), float("inf"))
        ate = self._ate_stats.mean
        se = (self._ate_stats.variance / n) ** 0.5
        z = scipy.stats.norm.ppf(1.0 - alpha / 2.0)
        return (ate - z * se, ate + z * se)

    def reset(self) -> None:
        """Reset the estimator to its initial (untrained) state.

        Equivalent to creating a fresh instance with the same constructor
        arguments. All learned state is discarded.
        """
        fresh = self.clone()
        self.__dict__.update(fresh.__dict__)

    @property
    def n_seen(self) -> int:
        """Number of observations processed so far."""
        return self._n_seen

    @property
    def smd(self) -> dict | None:
        """Current Standardized Mean Difference per covariate.

        Returns None by default. Subclasses that maintain per-group
        statistics (e.g. IPW-based estimators) override this property.

        Returns
        -------
        dict or None
            Mapping from covariate name to (raw_smd, weighted_smd), or
            None if this estimator does not track balance diagnostics.
        """
        return None

n_seen property

Number of observations processed so far.

smd property

Current Standardized Mean Difference per covariate.

Returns None by default. Subclasses that maintain per-group statistics (e.g. IPW-based estimators) override this property.

Returns:

Type Description
dict or None

Mapping from covariate name to (raw_smd, weighted_smd), or None if this estimator does not track balance diagnostics.

learn_one(x, treatment, outcome, propensity=None) abstractmethod

Process one observation and update the estimator.

Parameters:

Name Type Description Default
x dict

Feature dictionary for this observation.

required
treatment int

Treatment indicator (0 = control, 1 = treated).

required
outcome float

Observed outcome for this unit.

required
propensity float or None

Known or logged propensity P(W=1|X). If None, the estimator will use its internal propensity model.

None
Source code in onlinecml/base/base_estimator.py
@abc.abstractmethod
def learn_one(
    self,
    x: dict,
    treatment: int,
    outcome: float,
    propensity: float | None = None,
) -> None:
    """Process one observation and update the estimator.

    Parameters
    ----------
    x : dict
        Feature dictionary for this observation.
    treatment : int
        Treatment indicator (0 = control, 1 = treated).
    outcome : float
        Observed outcome for this unit.
    propensity : float or None
        Known or logged propensity P(W=1|X). If None, the estimator
        will use its internal propensity model.
    """

predict_ate()

Return the current running ATE estimate.

Returns:

Type Description
float

The current Average Treatment Effect estimate. Returns 0.0 before any observations have been processed.

Source code in onlinecml/base/base_estimator.py
def predict_ate(self) -> float:
    """Return the current running ATE estimate.

    Returns
    -------
    float
        The current Average Treatment Effect estimate. Returns 0.0
        before any observations have been processed.
    """
    return self._ate_stats.mean

predict_ci(alpha=0.05)

Return a confidence interval for the ATE estimate.

Uses a normal approximation via the central limit theorem applied to the running variance of per-observation pseudo-outcomes.

Parameters:

Name Type Description Default
alpha float

Significance level. Default 0.05 gives a 95% CI.

0.05

Returns:

Name Type Description
lower float

Lower bound of the confidence interval.

upper float

Upper bound of the confidence interval.

Notes

Returns (-inf, inf) before at least 2 observations are seen. The CI is for the ATE (mean CATE), not for individual CATE predictions.

Source code in onlinecml/base/base_estimator.py
def predict_ci(self, alpha: float = 0.05) -> tuple[float, float]:
    """Return a confidence interval for the ATE estimate.

    Uses a normal approximation via the central limit theorem applied
    to the running variance of per-observation pseudo-outcomes.

    Parameters
    ----------
    alpha : float
        Significance level. Default 0.05 gives a 95% CI.

    Returns
    -------
    lower : float
        Lower bound of the confidence interval.
    upper : float
        Upper bound of the confidence interval.

    Notes
    -----
    Returns ``(-inf, inf)`` before at least 2 observations are seen.
    The CI is for the ATE (mean CATE), not for individual CATE predictions.
    """
    n = self._ate_stats.n
    if n < 2:
        return (float("-inf"), float("inf"))
    ate = self._ate_stats.mean
    se = (self._ate_stats.variance / n) ** 0.5
    z = scipy.stats.norm.ppf(1.0 - alpha / 2.0)
    return (ate - z * se, ate + z * se)

predict_one(x) abstractmethod

Predict the CATE for a single unit.

Parameters:

Name Type Description Default
x dict

Feature dictionary for the unit.

required

Returns:

Type Description
float

Estimated Conditional Average Treatment Effect for this unit.

Source code in onlinecml/base/base_estimator.py
@abc.abstractmethod
def predict_one(self, x: dict) -> float:
    """Predict the CATE for a single unit.

    Parameters
    ----------
    x : dict
        Feature dictionary for the unit.

    Returns
    -------
    float
        Estimated Conditional Average Treatment Effect for this unit.
    """

reset()

Reset the estimator to its initial (untrained) state.

Equivalent to creating a fresh instance with the same constructor arguments. All learned state is discarded.

Source code in onlinecml/base/base_estimator.py
def reset(self) -> None:
    """Reset the estimator to its initial (untrained) state.

    Equivalent to creating a fresh instance with the same constructor
    arguments. All learned state is discarded.
    """
    fresh = self.clone()
    self.__dict__.update(fresh.__dict__)

BasePolicy

onlinecml.base.base_policy.BasePolicy

Bases: Base

Abstract base class for treatment exploration policies.

A policy decides which treatment to assign and with what probability, given a CATE score estimate and the current step count. Subclasses implement different exploration-exploitation trade-offs.

All policies follow River conventions: constructor parameters are stored as instance attributes with the same name, enabling clone() and _get_params() to work correctly.

Notes

This class intentionally does NOT inherit from River's bandit Policy because our interface operates on CATE scores (continuous real-valued estimates of causal effects) rather than bandit arm indices. The returned propensity is the probability of the chosen treatment under the policy, which is used for IPW correction downstream.

Source code in onlinecml/base/base_policy.py
class BasePolicy(Base):
    """Abstract base class for treatment exploration policies.

    A policy decides which treatment to assign and with what probability,
    given a CATE score estimate and the current step count. Subclasses
    implement different exploration-exploitation trade-offs.

    All policies follow River conventions: constructor parameters are
    stored as instance attributes with the same name, enabling clone()
    and _get_params() to work correctly.

    Notes
    -----
    This class intentionally does NOT inherit from River's bandit Policy
    because our interface operates on CATE scores (continuous real-valued
    estimates of causal effects) rather than bandit arm indices. The
    returned propensity is the probability of the chosen treatment under
    the policy, which is used for IPW correction downstream.
    """

    @abc.abstractmethod
    def choose(self, cate_score: float, step: int) -> tuple[int, float]:
        """Choose a treatment assignment given the current CATE estimate.

        Parameters
        ----------
        cate_score : float
            Current CATE estimate for the unit. Positive values suggest
            treatment is beneficial; negative values suggest control.
        step : int
            The current time step (used for decay schedules).

        Returns
        -------
        treatment : int
            The chosen treatment assignment (0 or 1).
        propensity : float
            The probability of the chosen treatment under this policy.
            Used for IPW correction in downstream estimators.
        """

    def update(self, reward: float) -> None:
        """Update policy state after observing a reward.

        Parameters
        ----------
        reward : float
            The observed outcome after applying the chosen treatment.
            Default implementation is a no-op; override for adaptive policies.
        """

    def reset(self) -> None:
        """Reset the policy to its initial (untrained) state."""
        self.__init__(**self._get_params())  # type: ignore[misc]

choose(cate_score, step) abstractmethod

Choose a treatment assignment given the current CATE estimate.

Parameters:

Name Type Description Default
cate_score float

Current CATE estimate for the unit. Positive values suggest treatment is beneficial; negative values suggest control.

required
step int

The current time step (used for decay schedules).

required

Returns:

Name Type Description
treatment int

The chosen treatment assignment (0 or 1).

propensity float

The probability of the chosen treatment under this policy. Used for IPW correction in downstream estimators.

Source code in onlinecml/base/base_policy.py
@abc.abstractmethod
def choose(self, cate_score: float, step: int) -> tuple[int, float]:
    """Choose a treatment assignment given the current CATE estimate.

    Parameters
    ----------
    cate_score : float
        Current CATE estimate for the unit. Positive values suggest
        treatment is beneficial; negative values suggest control.
    step : int
        The current time step (used for decay schedules).

    Returns
    -------
    treatment : int
        The chosen treatment assignment (0 or 1).
    propensity : float
        The probability of the chosen treatment under this policy.
        Used for IPW correction in downstream estimators.
    """

reset()

Reset the policy to its initial (untrained) state.

Source code in onlinecml/base/base_policy.py
def reset(self) -> None:
    """Reset the policy to its initial (untrained) state."""
    self.__init__(**self._get_params())  # type: ignore[misc]

update(reward)

Update policy state after observing a reward.

Parameters:

Name Type Description Default
reward float

The observed outcome after applying the chosen treatment. Default implementation is a no-op; override for adaptive policies.

required
Source code in onlinecml/base/base_policy.py
def update(self, reward: float) -> None:
    """Update policy state after observing a reward.

    Parameters
    ----------
    reward : float
        The observed outcome after applying the chosen treatment.
        Default implementation is a no-op; override for adaptive policies.
    """

RunningStats

onlinecml.base.running_stats.RunningStats

Online mean and variance using Welford's single-pass algorithm.

Computes the sample mean, variance, and standard deviation of a stream of scalar values without storing the data.

Examples:

>>> stats = RunningStats()
>>> for x in [2.0, 4.0, 6.0]:
...     stats.update(x)
>>> stats.mean
4.0
>>> stats.n
3
Source code in onlinecml/base/running_stats.py
class RunningStats:
    """Online mean and variance using Welford's single-pass algorithm.

    Computes the sample mean, variance, and standard deviation of a
    stream of scalar values without storing the data.

    Examples
    --------
    >>> stats = RunningStats()
    >>> for x in [2.0, 4.0, 6.0]:
    ...     stats.update(x)
    >>> stats.mean
    4.0
    >>> stats.n
    3
    """

    def __init__(self) -> None:
        self._n: int = 0
        self._mean: float = 0.0
        self._M2: float = 0.0

    def update(self, x: float) -> None:
        """Update statistics with a new observation.

        Parameters
        ----------
        x : float
            The new scalar value to incorporate.
        """
        self._n += 1
        delta = x - self._mean
        self._mean += delta / self._n
        delta2 = x - self._mean
        self._M2 += delta * delta2

    def reset(self) -> None:
        """Reset all state to the initial (empty) condition."""
        self._n = 0
        self._mean = 0.0
        self._M2 = 0.0

    @property
    def n(self) -> int:
        """Number of observations seen."""
        return self._n

    @property
    def mean(self) -> float:
        """Current sample mean. Returns 0.0 before any observations."""
        return self._mean

    @property
    def variance(self) -> float:
        """Current sample variance (ddof=1). Returns 0.0 when n < 2."""
        if self._n < 2:
            return 0.0
        return self._M2 / (self._n - 1)

    @property
    def std(self) -> float:
        """Current sample standard deviation. Returns 0.0 when n < 2."""
        return math.sqrt(self.variance)

mean property

Current sample mean. Returns 0.0 before any observations.

n property

Number of observations seen.

std property

Current sample standard deviation. Returns 0.0 when n < 2.

variance property

Current sample variance (ddof=1). Returns 0.0 when n < 2.

reset()

Reset all state to the initial (empty) condition.

Source code in onlinecml/base/running_stats.py
def reset(self) -> None:
    """Reset all state to the initial (empty) condition."""
    self._n = 0
    self._mean = 0.0
    self._M2 = 0.0

update(x)

Update statistics with a new observation.

Parameters:

Name Type Description Default
x float

The new scalar value to incorporate.

required
Source code in onlinecml/base/running_stats.py
def update(self, x: float) -> None:
    """Update statistics with a new observation.

    Parameters
    ----------
    x : float
        The new scalar value to incorporate.
    """
    self._n += 1
    delta = x - self._mean
    self._mean += delta / self._n
    delta2 = x - self._mean
    self._M2 += delta * delta2

WeightedRunningStats

onlinecml.base.running_stats.WeightedRunningStats

Online weighted mean and variance using West's (1979) algorithm.

Computes the weighted mean and population-weighted variance of a stream of scalar values with associated importance weights.

Notes

Returns population-weighted variance (S / sum_w), not sample variance. This is appropriate for SMD computation where we normalize by pooled standard deviation, not for statistical inference.

References

West, D.H.D. (1979). Updating mean and variance estimates: an improved method. Communications of the ACM, 22(9), 532-535.

Examples:

>>> stats = WeightedRunningStats()
>>> stats.update(2.0, w=1.0)
>>> stats.update(4.0, w=2.0)
>>> stats.mean  # (2*1 + 4*2) / 3 = 10/3
3.3333333333333335
Source code in onlinecml/base/running_stats.py
class WeightedRunningStats:
    """Online weighted mean and variance using West's (1979) algorithm.

    Computes the weighted mean and population-weighted variance of a
    stream of scalar values with associated importance weights.

    Notes
    -----
    Returns population-weighted variance (S / sum_w), not sample variance.
    This is appropriate for SMD computation where we normalize by pooled
    standard deviation, not for statistical inference.

    References
    ----------
    West, D.H.D. (1979). Updating mean and variance estimates: an improved
    method. Communications of the ACM, 22(9), 532-535.

    Examples
    --------
    >>> stats = WeightedRunningStats()
    >>> stats.update(2.0, w=1.0)
    >>> stats.update(4.0, w=2.0)
    >>> stats.mean  # (2*1 + 4*2) / 3 = 10/3
    3.3333333333333335
    """

    def __init__(self) -> None:
        self._sum_w: float = 0.0
        self._mean: float = 0.0
        self._S: float = 0.0  # weighted sum of squared deviations

    def update(self, x: float, w: float = 1.0) -> None:
        """Update statistics with a new weighted observation.

        Parameters
        ----------
        x : float
            The new scalar value to incorporate.
        w : float
            Non-negative importance weight. Silently ignored if w <= 0.
        """
        if w <= 0.0:
            return
        self._sum_w += w
        mean_old = self._mean
        self._mean += (w / self._sum_w) * (x - mean_old)
        self._S += w * (x - mean_old) * (x - self._mean)

    def reset(self) -> None:
        """Reset all state to the initial (empty) condition."""
        self._sum_w = 0.0
        self._mean = 0.0
        self._S = 0.0

    @property
    def sum_weights(self) -> float:
        """Sum of all weights seen so far."""
        return self._sum_w

    @property
    def mean(self) -> float:
        """Current weighted mean. Returns 0.0 before any observations."""
        return self._mean

    @property
    def variance(self) -> float:
        """Population-weighted variance (S / sum_w). Returns 0.0 when sum_w <= 0."""
        if self._sum_w <= 0.0:
            return 0.0
        return self._S / self._sum_w

    @property
    def std(self) -> float:
        """Population-weighted standard deviation. Returns 0.0 when sum_w <= 0."""
        return math.sqrt(self.variance)

mean property

Current weighted mean. Returns 0.0 before any observations.

std property

Population-weighted standard deviation. Returns 0.0 when sum_w <= 0.

sum_weights property

Sum of all weights seen so far.

variance property

Population-weighted variance (S / sum_w). Returns 0.0 when sum_w <= 0.

reset()

Reset all state to the initial (empty) condition.

Source code in onlinecml/base/running_stats.py
def reset(self) -> None:
    """Reset all state to the initial (empty) condition."""
    self._sum_w = 0.0
    self._mean = 0.0
    self._S = 0.0

update(x, w=1.0)

Update statistics with a new weighted observation.

Parameters:

Name Type Description Default
x float

The new scalar value to incorporate.

required
w float

Non-negative importance weight. Silently ignored if w <= 0.

1.0
Source code in onlinecml/base/running_stats.py
def update(self, x: float, w: float = 1.0) -> None:
    """Update statistics with a new weighted observation.

    Parameters
    ----------
    x : float
        The new scalar value to incorporate.
    w : float
        Non-negative importance weight. Silently ignored if w <= 0.
    """
    if w <= 0.0:
        return
    self._sum_w += w
    mean_old = self._mean
    self._mean += (w / self._sum_w) * (x - mean_old)
    self._S += w * (x - mean_old) * (x - self._mean)