Skip to content

Diagnostics

onlinecml.diagnostics.ate_tracker.ATETracker

Tracks the running ATE estimate with online confidence intervals.

Maintains a running mean and variance of per-observation pseudo-outcomes and optionally records history for convergence plotting.

Parameters:

Name Type Description Default
log_every int

Append a history entry every log_every observations. Default 1 (log every observation). Set to a larger value for long streams.

1
warmup int

Number of initial pseudo-outcomes to skip when accumulating the ATE estimate. History is not recorded during warmup. Default 0.

0
forgetting_factor float

Controls how quickly old pseudo-outcomes are forgotten. 1.0 = cumulative Welford mean (no forgetting, default). Values < 1.0 (e.g. 0.95–0.99) switch to EWMA so the tracker adapts to concept drift. alpha = 1 - forgetting_factor.

1.0
Notes

Unlike BaseOnlineEstimator, this is a standalone diagnostic tool that users instantiate separately and feed pseudo-outcomes into. It is not tied to any specific estimation method.

When forgetting_factor < 1.0, the internal Welford state is replaced by an EWMA (EWMAStats). The CI formula remains mean ± z * sqrt(var/n) where var and n come from the EWMA estimates.

Examples:

>>> tracker = ATETracker(log_every=10)
>>> for pseudo_outcome in [1.5, 2.3, 1.8, 2.1]:
...     tracker.update(pseudo_outcome)
>>> abs(tracker.ate - 1.925) < 1e-10
True
Source code in onlinecml/diagnostics/ate_tracker.py
class ATETracker:
    """Tracks the running ATE estimate with online confidence intervals.

    Maintains a running mean and variance of per-observation pseudo-outcomes
    and optionally records history for convergence plotting.

    Parameters
    ----------
    log_every : int
        Append a history entry every ``log_every`` observations. Default 1
        (log every observation). Set to a larger value for long streams.
    warmup : int
        Number of initial pseudo-outcomes to skip when accumulating the ATE
        estimate. History is not recorded during warmup. Default 0.
    forgetting_factor : float
        Controls how quickly old pseudo-outcomes are forgotten.
        ``1.0`` = cumulative Welford mean (no forgetting, default).
        Values < 1.0 (e.g. 0.95–0.99) switch to EWMA so the tracker
        adapts to concept drift. ``alpha = 1 - forgetting_factor``.

    Notes
    -----
    Unlike ``BaseOnlineEstimator``, this is a standalone diagnostic tool
    that users instantiate separately and feed pseudo-outcomes into. It is
    not tied to any specific estimation method.

    When ``forgetting_factor < 1.0``, the internal Welford state is replaced
    by an EWMA (``EWMAStats``). The CI formula remains ``mean ± z * sqrt(var/n)``
    where ``var`` and ``n`` come from the EWMA estimates.

    Examples
    --------
    >>> tracker = ATETracker(log_every=10)
    >>> for pseudo_outcome in [1.5, 2.3, 1.8, 2.1]:
    ...     tracker.update(pseudo_outcome)
    >>> abs(tracker.ate - 1.925) < 1e-10
    True
    """

    def __init__(
        self,
        log_every: int = 1,
        warmup: int = 0,
        forgetting_factor: float = 1.0,
    ) -> None:
        self.log_every = log_every
        self.warmup = warmup
        self.forgetting_factor = forgetting_factor
        # Delegate statistics to the appropriate backend
        self._stats: RunningStats | EWMAStats = (
            EWMAStats(alpha=1.0 - forgetting_factor)
            if forgetting_factor < 1.0
            else RunningStats()
        )
        self._n_total: int = 0  # includes warmup observations
        self._history: list[tuple[int, float, float, float]] = []

    def update(self, pseudo_outcome: float) -> None:
        """Incorporate one pseudo-outcome into the running ATE estimate.

        Parameters
        ----------
        pseudo_outcome : float
            Per-observation pseudo-outcome (e.g. IPW score, DR score, or
            per-obs CATE estimate). The running mean of these values
            converges to the ATE under the relevant identification assumptions.
        """
        self._n_total += 1
        if self._n_total <= self.warmup:
            return

        self._stats.update(pseudo_outcome)

        if self._stats.n % self.log_every == 0:
            lo, hi = self.ci()
            self._history.append((self._n_total, self._stats.mean, lo, hi))

    def reset(self) -> None:
        """Reset all state to the initial (empty) condition."""
        self._stats.reset()
        self._n_total = 0
        self._history = []

    @property
    def ate(self) -> float:
        """Current ATE estimate (running mean of pseudo-outcomes).

        Returns 0.0 before any observations are seen (or during warmup).
        """
        return self._stats.mean

    @property
    def n(self) -> int:
        """Number of pseudo-outcomes processed (excluding warmup)."""
        return self._stats.n

    @property
    def history(self) -> list[tuple[int, float, float, float]]:
        """Recorded history as a list of ``(step, ate, ci_lower, ci_upper)`` tuples."""
        return list(self._history)

    def ci(self, alpha: float = 0.05) -> tuple[float, float]:
        """Return a confidence interval for the current ATE estimate.

        Uses a normal approximation via the central limit theorem.

        Parameters
        ----------
        alpha : float
            Significance level. Default 0.05 gives a 95% CI.

        Returns
        -------
        lower : float
            Lower bound of the confidence interval.
        upper : float
            Upper bound of the confidence interval.

        Notes
        -----
        Returns ``(-inf, inf)`` before at least 2 observations are seen.
        """
        n = self._stats.n
        if n < 2:
            return (float("-inf"), float("inf"))
        variance = self._stats.variance
        se = math.sqrt(variance / n)
        z = scipy.stats.norm.ppf(1.0 - alpha / 2.0)
        mean = self._stats.mean
        return (mean - z * se, mean + z * se)

    def convergence_width(self, alpha: float = 0.05) -> float:
        """Return the current confidence interval width.

        Useful as an early-stopping criterion: stop collecting data when
        the CI width falls below a target threshold.

        Parameters
        ----------
        alpha : float
            Significance level for the CI. Default 0.05.

        Returns
        -------
        float
            Width of the current CI (``upper - lower``). Returns ``inf``
            before 2 observations are seen.
        """
        lo, hi = self.ci(alpha)
        return hi - lo

    def plot(self, ax: "matplotlib.axes.Axes | None" = None) -> "matplotlib.axes.Axes":
        """Plot the ATE convergence curve with shaded confidence band.

        Parameters
        ----------
        ax : matplotlib.axes.Axes or None
            Axes to plot on. If None, creates a new figure.

        Returns
        -------
        matplotlib.axes.Axes
            The axes with the convergence plot.
        """
        import matplotlib.pyplot as plt

        if ax is None:
            _, ax = plt.subplots()

        if not self._history:
            return ax

        steps = [h[0] for h in self._history]
        ates = [h[1] for h in self._history]
        lows = [h[2] for h in self._history]
        highs = [h[3] for h in self._history]

        ax.plot(steps, ates, label="ATE estimate", color="steelblue")
        ax.fill_between(steps, lows, highs, alpha=0.2, color="steelblue", label="95% CI")
        ax.axhline(self._stats.mean, linestyle="--", color="gray", linewidth=0.8, label="Current ATE")
        ax.set_xlabel("Observations")
        ax.set_ylabel("ATE estimate")
        ax.set_title("ATE Convergence")
        ax.legend()
        return ax

ate property

Current ATE estimate (running mean of pseudo-outcomes).

Returns 0.0 before any observations are seen (or during warmup).

history property

Recorded history as a list of (step, ate, ci_lower, ci_upper) tuples.

n property

Number of pseudo-outcomes processed (excluding warmup).

ci(alpha=0.05)

Return a confidence interval for the current ATE estimate.

Uses a normal approximation via the central limit theorem.

Parameters:

Name Type Description Default
alpha float

Significance level. Default 0.05 gives a 95% CI.

0.05

Returns:

Name Type Description
lower float

Lower bound of the confidence interval.

upper float

Upper bound of the confidence interval.

Notes

Returns (-inf, inf) before at least 2 observations are seen.

Source code in onlinecml/diagnostics/ate_tracker.py
def ci(self, alpha: float = 0.05) -> tuple[float, float]:
    """Return a confidence interval for the current ATE estimate.

    Uses a normal approximation via the central limit theorem.

    Parameters
    ----------
    alpha : float
        Significance level. Default 0.05 gives a 95% CI.

    Returns
    -------
    lower : float
        Lower bound of the confidence interval.
    upper : float
        Upper bound of the confidence interval.

    Notes
    -----
    Returns ``(-inf, inf)`` before at least 2 observations are seen.
    """
    n = self._stats.n
    if n < 2:
        return (float("-inf"), float("inf"))
    variance = self._stats.variance
    se = math.sqrt(variance / n)
    z = scipy.stats.norm.ppf(1.0 - alpha / 2.0)
    mean = self._stats.mean
    return (mean - z * se, mean + z * se)

convergence_width(alpha=0.05)

Return the current confidence interval width.

Useful as an early-stopping criterion: stop collecting data when the CI width falls below a target threshold.

Parameters:

Name Type Description Default
alpha float

Significance level for the CI. Default 0.05.

0.05

Returns:

Type Description
float

Width of the current CI (upper - lower). Returns inf before 2 observations are seen.

Source code in onlinecml/diagnostics/ate_tracker.py
def convergence_width(self, alpha: float = 0.05) -> float:
    """Return the current confidence interval width.

    Useful as an early-stopping criterion: stop collecting data when
    the CI width falls below a target threshold.

    Parameters
    ----------
    alpha : float
        Significance level for the CI. Default 0.05.

    Returns
    -------
    float
        Width of the current CI (``upper - lower``). Returns ``inf``
        before 2 observations are seen.
    """
    lo, hi = self.ci(alpha)
    return hi - lo

plot(ax=None)

Plot the ATE convergence curve with shaded confidence band.

Parameters:

Name Type Description Default
ax Axes or None

Axes to plot on. If None, creates a new figure.

None

Returns:

Type Description
Axes

The axes with the convergence plot.

Source code in onlinecml/diagnostics/ate_tracker.py
def plot(self, ax: "matplotlib.axes.Axes | None" = None) -> "matplotlib.axes.Axes":
    """Plot the ATE convergence curve with shaded confidence band.

    Parameters
    ----------
    ax : matplotlib.axes.Axes or None
        Axes to plot on. If None, creates a new figure.

    Returns
    -------
    matplotlib.axes.Axes
        The axes with the convergence plot.
    """
    import matplotlib.pyplot as plt

    if ax is None:
        _, ax = plt.subplots()

    if not self._history:
        return ax

    steps = [h[0] for h in self._history]
    ates = [h[1] for h in self._history]
    lows = [h[2] for h in self._history]
    highs = [h[3] for h in self._history]

    ax.plot(steps, ates, label="ATE estimate", color="steelblue")
    ax.fill_between(steps, lows, highs, alpha=0.2, color="steelblue", label="95% CI")
    ax.axhline(self._stats.mean, linestyle="--", color="gray", linewidth=0.8, label="Current ATE")
    ax.set_xlabel("Observations")
    ax.set_ylabel("ATE estimate")
    ax.set_title("ATE Convergence")
    ax.legend()
    return ax

reset()

Reset all state to the initial (empty) condition.

Source code in onlinecml/diagnostics/ate_tracker.py
def reset(self) -> None:
    """Reset all state to the initial (empty) condition."""
    self._stats.reset()
    self._n_total = 0
    self._history = []

update(pseudo_outcome)

Incorporate one pseudo-outcome into the running ATE estimate.

Parameters:

Name Type Description Default
pseudo_outcome float

Per-observation pseudo-outcome (e.g. IPW score, DR score, or per-obs CATE estimate). The running mean of these values converges to the ATE under the relevant identification assumptions.

required
Source code in onlinecml/diagnostics/ate_tracker.py
def update(self, pseudo_outcome: float) -> None:
    """Incorporate one pseudo-outcome into the running ATE estimate.

    Parameters
    ----------
    pseudo_outcome : float
        Per-observation pseudo-outcome (e.g. IPW score, DR score, or
        per-obs CATE estimate). The running mean of these values
        converges to the ATE under the relevant identification assumptions.
    """
    self._n_total += 1
    if self._n_total <= self.warmup:
        return

    self._stats.update(pseudo_outcome)

    if self._stats.n % self.log_every == 0:
        lo, hi = self.ci()
        self._history.append((self._n_total, self._stats.mean, lo, hi))

onlinecml.diagnostics.smd.OnlineSMD

Bases: Base

Online covariate balance diagnostics via Standardized Mean Difference.

Tracks the raw and IPW-weighted SMD for a set of covariates, updated one observation at a time. Used to monitor whether treatment and control groups are comparable in covariate distributions.

Parameters:

Name Type Description Default
covariates list of str

Names of covariates to track. Must match keys in the feature dicts passed to update.

required
Notes

SMD for a covariate is defined as:

.. math::

\text{SMD} = \frac{\bar{X}_T - \bar{X}_C}{\sqrt{(s_T^2 + s_C^2) / 2}}

where s^2 is the sample variance within each group. Returns 0.0 when either group has fewer than 2 observations.

Raw SMD uses unweighted RunningStats; weighted SMD uses WeightedRunningStats with West's (1979) algorithm (population variance). The weighted SMD is used by is_balanced.

This class does NOT inherit from BaseOnlineEstimator — it is a standalone diagnostic tool.

Examples:

>>> smd = OnlineSMD(covariates=["age", "income"])
>>> smd.update({"age": 30, "income": 50000}, treatment=1, weight=1.2)
>>> smd.update({"age": 45, "income": 70000}, treatment=0, weight=0.8)
>>> report = smd.report()
>>> "age" in report
True
Source code in onlinecml/diagnostics/smd.py
class OnlineSMD(Base):
    """Online covariate balance diagnostics via Standardized Mean Difference.

    Tracks the raw and IPW-weighted SMD for a set of covariates, updated
    one observation at a time. Used to monitor whether treatment and control
    groups are comparable in covariate distributions.

    Parameters
    ----------
    covariates : list of str
        Names of covariates to track. Must match keys in the feature dicts
        passed to ``update``.

    Notes
    -----
    SMD for a covariate is defined as:

    .. math::

        \\text{SMD} = \\frac{\\bar{X}_T - \\bar{X}_C}{\\sqrt{(s_T^2 + s_C^2) / 2}}

    where ``s^2`` is the sample variance within each group. Returns 0.0
    when either group has fewer than 2 observations.

    Raw SMD uses unweighted ``RunningStats``; weighted SMD uses
    ``WeightedRunningStats`` with West's (1979) algorithm (population
    variance). The weighted SMD is used by ``is_balanced``.

    This class does NOT inherit from ``BaseOnlineEstimator`` — it is a
    standalone diagnostic tool.

    Examples
    --------
    >>> smd = OnlineSMD(covariates=["age", "income"])
    >>> smd.update({"age": 30, "income": 50000}, treatment=1, weight=1.2)
    >>> smd.update({"age": 45, "income": 70000}, treatment=0, weight=0.8)
    >>> report = smd.report()
    >>> "age" in report
    True
    """

    def __init__(self, covariates: list[str]) -> None:
        self.covariates = covariates
        # Lazily initialized on first update call
        self._stats: dict[str, dict] = {}

    def _init_covariate(self, cov: str) -> None:
        """Initialize tracking stats for a covariate.

        Parameters
        ----------
        cov : str
            Covariate name to initialize.
        """
        self._stats[cov] = {
            "raw_treated": RunningStats(),
            "raw_control": RunningStats(),
            "weighted_treated": WeightedRunningStats(),
            "weighted_control": WeightedRunningStats(),
        }

    def update(self, x: dict, treatment: int, weight: float = 1.0) -> None:
        """Update balance statistics with one observation.

        Parameters
        ----------
        x : dict
            Feature dictionary. Missing covariates default to 0.0.
        treatment : int
            Treatment indicator (0 = control, 1 = treated).
        weight : float
            Importance weight for this observation (e.g. IPW weight).
            Default 1.0 (unweighted).
        """
        for cov in self.covariates:
            if cov not in self._stats:
                self._init_covariate(cov)
            val = float(x.get(cov, 0.0))
            s = self._stats[cov]
            if treatment == 1:
                s["raw_treated"].update(val)
                s["weighted_treated"].update(val, w=weight)
            else:
                s["raw_control"].update(val)
                s["weighted_control"].update(val, w=weight)

    @staticmethod
    def _compute_smd(stats_t: RunningStats, stats_c: RunningStats) -> float:
        """Compute SMD between two groups using their running statistics.

        Parameters
        ----------
        stats_t : RunningStats
            Running stats for the treated group.
        stats_c : RunningStats
            Running stats for the control group.

        Returns
        -------
        float
            SMD value. Returns 0.0 if either group has fewer than 2 obs.
        """
        if stats_t.n < 2 or stats_c.n < 2:
            return 0.0
        pooled_var = (stats_t.variance + stats_c.variance) / 2.0
        if pooled_var <= 0.0:
            return 0.0
        return (stats_t.mean - stats_c.mean) / math.sqrt(pooled_var)

    @staticmethod
    def _compute_weighted_smd(
        stats_t: WeightedRunningStats, stats_c: WeightedRunningStats
    ) -> float:
        """Compute weighted SMD between two groups.

        Parameters
        ----------
        stats_t : WeightedRunningStats
            Weighted running stats for the treated group.
        stats_c : WeightedRunningStats
            Weighted running stats for the control group.

        Returns
        -------
        float
            Weighted SMD value. Returns 0.0 if either group has no weight mass.
        """
        if stats_t.sum_weights <= 0.0 or stats_c.sum_weights <= 0.0:
            return 0.0
        pooled_var = (stats_t.variance + stats_c.variance) / 2.0
        if pooled_var <= 0.0:
            return 0.0
        return (stats_t.mean - stats_c.mean) / math.sqrt(pooled_var)

    def report(self) -> dict[str, tuple[float, float]]:
        """Return the current SMD for each tracked covariate.

        Returns
        -------
        dict
            Mapping from covariate name to ``(raw_smd, weighted_smd)``.
            Covariates with insufficient data return ``(0.0, 0.0)``.
        """
        result = {}
        for cov in self.covariates:
            if cov not in self._stats:
                result[cov] = (0.0, 0.0)
            else:
                s = self._stats[cov]
                raw = self._compute_smd(s["raw_treated"], s["raw_control"])
                weighted = self._compute_weighted_smd(
                    s["weighted_treated"], s["weighted_control"]
                )
                result[cov] = (raw, weighted)
        return result

    def is_balanced(self, thr: float = 0.1) -> bool:
        """Check whether all covariates are balanced after weighting.

        Parameters
        ----------
        thr : float
            Maximum absolute weighted SMD threshold. Default 0.1
            (the conventional "well-balanced" threshold).

        Returns
        -------
        bool
            True if all covariates have ``|weighted_smd| < thr``.
        """
        return all(abs(smd_val) < thr for _, smd_val in self.report().values())

is_balanced(thr=0.1)

Check whether all covariates are balanced after weighting.

Parameters:

Name Type Description Default
thr float

Maximum absolute weighted SMD threshold. Default 0.1 (the conventional "well-balanced" threshold).

0.1

Returns:

Type Description
bool

True if all covariates have |weighted_smd| < thr.

Source code in onlinecml/diagnostics/smd.py
def is_balanced(self, thr: float = 0.1) -> bool:
    """Check whether all covariates are balanced after weighting.

    Parameters
    ----------
    thr : float
        Maximum absolute weighted SMD threshold. Default 0.1
        (the conventional "well-balanced" threshold).

    Returns
    -------
    bool
        True if all covariates have ``|weighted_smd| < thr``.
    """
    return all(abs(smd_val) < thr for _, smd_val in self.report().values())

report()

Return the current SMD for each tracked covariate.

Returns:

Type Description
dict

Mapping from covariate name to (raw_smd, weighted_smd). Covariates with insufficient data return (0.0, 0.0).

Source code in onlinecml/diagnostics/smd.py
def report(self) -> dict[str, tuple[float, float]]:
    """Return the current SMD for each tracked covariate.

    Returns
    -------
    dict
        Mapping from covariate name to ``(raw_smd, weighted_smd)``.
        Covariates with insufficient data return ``(0.0, 0.0)``.
    """
    result = {}
    for cov in self.covariates:
        if cov not in self._stats:
            result[cov] = (0.0, 0.0)
        else:
            s = self._stats[cov]
            raw = self._compute_smd(s["raw_treated"], s["raw_control"])
            weighted = self._compute_weighted_smd(
                s["weighted_treated"], s["weighted_control"]
            )
            result[cov] = (raw, weighted)
    return result

update(x, treatment, weight=1.0)

Update balance statistics with one observation.

Parameters:

Name Type Description Default
x dict

Feature dictionary. Missing covariates default to 0.0.

required
treatment int

Treatment indicator (0 = control, 1 = treated).

required
weight float

Importance weight for this observation (e.g. IPW weight). Default 1.0 (unweighted).

1.0
Source code in onlinecml/diagnostics/smd.py
def update(self, x: dict, treatment: int, weight: float = 1.0) -> None:
    """Update balance statistics with one observation.

    Parameters
    ----------
    x : dict
        Feature dictionary. Missing covariates default to 0.0.
    treatment : int
        Treatment indicator (0 = control, 1 = treated).
    weight : float
        Importance weight for this observation (e.g. IPW weight).
        Default 1.0 (unweighted).
    """
    for cov in self.covariates:
        if cov not in self._stats:
            self._init_covariate(cov)
        val = float(x.get(cov, 0.0))
        s = self._stats[cov]
        if treatment == 1:
            s["raw_treated"].update(val)
            s["weighted_treated"].update(val, w=weight)
        else:
            s["raw_control"].update(val)
            s["weighted_control"].update(val, w=weight)

onlinecml.diagnostics.live_love_plot.LiveLovePlot

Real-time Love Plot for monitoring covariate balance online.

Displays raw and weighted Standardized Mean Differences (SMD) for a set of covariates. Updates the plot every update_every steps. A vertical reference line at |SMD| = 0.1 marks the conventional "well-balanced" threshold.

Parameters:

Name Type Description Default
covariates list of str

Names of covariates to display (in order).

required
update_every int

Redraw the plot every update_every calls to update. Default 100.

100
balance_threshold float

Reference line position. Default 0.1.

0.1
Notes

This class wraps OnlineSMD internally. Users can either pass feature dicts directly to update or maintain an external OnlineSMD instance and call render with its report.

The plot is only rendered when matplotlib is available. If matplotlib is not installed, calls to update and render are no-ops.

Examples:

>>> plot = LiveLovePlot(covariates=["age", "income"], update_every=50)
>>> plot.update({"age": 30, "income": 50000}, treatment=1, weight=1.2)
>>> ax = plot.render()
Source code in onlinecml/diagnostics/live_love_plot.py
class LiveLovePlot:
    """Real-time Love Plot for monitoring covariate balance online.

    Displays raw and weighted Standardized Mean Differences (SMD) for a
    set of covariates. Updates the plot every ``update_every`` steps.
    A vertical reference line at ``|SMD| = 0.1`` marks the conventional
    "well-balanced" threshold.

    Parameters
    ----------
    covariates : list of str
        Names of covariates to display (in order).
    update_every : int
        Redraw the plot every ``update_every`` calls to ``update``.
        Default 100.
    balance_threshold : float
        Reference line position. Default 0.1.

    Notes
    -----
    This class wraps ``OnlineSMD`` internally. Users can either pass
    feature dicts directly to ``update`` or maintain an external
    ``OnlineSMD`` instance and call ``render`` with its report.

    The plot is only rendered when ``matplotlib`` is available. If
    ``matplotlib`` is not installed, calls to ``update`` and ``render``
    are no-ops.

    Examples
    --------
    >>> plot = LiveLovePlot(covariates=["age", "income"], update_every=50)
    >>> plot.update({"age": 30, "income": 50000}, treatment=1, weight=1.2)
    >>> ax = plot.render()
    """

    def __init__(
        self,
        covariates: list[str],
        update_every: int = 100,
        balance_threshold: float = 0.1,
    ) -> None:
        self.covariates = covariates
        self.update_every = update_every
        self.balance_threshold = balance_threshold
        self._n: int = 0
        self._fig = None
        self._ax = None
        # Internal SMD tracker
        from onlinecml.diagnostics.smd import OnlineSMD
        self._smd = OnlineSMD(covariates=covariates)

    def update(self, x: dict, treatment: int, weight: float = 1.0) -> None:
        """Update the covariate balance statistics.

        Parameters
        ----------
        x : dict
            Feature dictionary for this observation.
        treatment : int
            Treatment indicator (0 or 1).
        weight : float
            Importance weight for this observation. Default 1.0.
        """
        self._smd.update(x, treatment, weight=weight)
        self._n += 1
        if self._n % self.update_every == 0:
            self.render()

    def render(self, ax: "matplotlib.axes.Axes | None" = None) -> "matplotlib.axes.Axes | None":
        """Render the Love Plot from current SMD data.

        Parameters
        ----------
        ax : matplotlib.axes.Axes or None
            Axes to render on. If None, creates or reuses internal axes.

        Returns
        -------
        matplotlib.axes.Axes or None
            The rendered axes, or None if matplotlib is unavailable.
        """
        try:
            import matplotlib.pyplot as plt
        except ImportError:
            return None

        report = self._smd.report()
        if not report:
            return None

        if ax is None:
            if self._fig is None:
                self._fig, self._ax = plt.subplots(figsize=(8, max(3, len(self.covariates) * 0.5)))
            ax = self._ax

        ax.cla()
        covs = list(report.keys())
        raw_smds = [report[c][0] for c in covs]
        weighted_smds = [report[c][1] for c in covs]
        y_pos = list(range(len(covs)))

        ax.scatter(raw_smds, y_pos, marker="o", label="Raw SMD", color="steelblue", zorder=3)
        ax.scatter(
            weighted_smds, y_pos, marker="^", label="Weighted SMD", color="darkorange", zorder=3
        )
        ax.axvline(
            self.balance_threshold, linestyle="--", color="red", linewidth=0.8, label="Threshold"
        )
        ax.axvline(
            -self.balance_threshold, linestyle="--", color="red", linewidth=0.8
        )
        ax.axvline(0, linestyle="-", color="black", linewidth=0.5)
        ax.set_yticks(y_pos)
        ax.set_yticklabels(covs)
        ax.set_xlabel("Standardized Mean Difference")
        ax.set_title(f"Love Plot (n={self._n})")
        ax.legend(loc="lower right")

        if self._fig is not None:
            self._fig.tight_layout()
            plt.pause(0.001)

        return ax

    def save(self, path: str) -> None:
        """Save the current plot to a file.

        Parameters
        ----------
        path : str
            File path (e.g. ``'balance.png'``).
        """
        if self._fig is not None:
            self._fig.savefig(path, bbox_inches="tight")

render(ax=None)

Render the Love Plot from current SMD data.

Parameters:

Name Type Description Default
ax Axes or None

Axes to render on. If None, creates or reuses internal axes.

None

Returns:

Type Description
Axes or None

The rendered axes, or None if matplotlib is unavailable.

Source code in onlinecml/diagnostics/live_love_plot.py
def render(self, ax: "matplotlib.axes.Axes | None" = None) -> "matplotlib.axes.Axes | None":
    """Render the Love Plot from current SMD data.

    Parameters
    ----------
    ax : matplotlib.axes.Axes or None
        Axes to render on. If None, creates or reuses internal axes.

    Returns
    -------
    matplotlib.axes.Axes or None
        The rendered axes, or None if matplotlib is unavailable.
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        return None

    report = self._smd.report()
    if not report:
        return None

    if ax is None:
        if self._fig is None:
            self._fig, self._ax = plt.subplots(figsize=(8, max(3, len(self.covariates) * 0.5)))
        ax = self._ax

    ax.cla()
    covs = list(report.keys())
    raw_smds = [report[c][0] for c in covs]
    weighted_smds = [report[c][1] for c in covs]
    y_pos = list(range(len(covs)))

    ax.scatter(raw_smds, y_pos, marker="o", label="Raw SMD", color="steelblue", zorder=3)
    ax.scatter(
        weighted_smds, y_pos, marker="^", label="Weighted SMD", color="darkorange", zorder=3
    )
    ax.axvline(
        self.balance_threshold, linestyle="--", color="red", linewidth=0.8, label="Threshold"
    )
    ax.axvline(
        -self.balance_threshold, linestyle="--", color="red", linewidth=0.8
    )
    ax.axvline(0, linestyle="-", color="black", linewidth=0.5)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(covs)
    ax.set_xlabel("Standardized Mean Difference")
    ax.set_title(f"Love Plot (n={self._n})")
    ax.legend(loc="lower right")

    if self._fig is not None:
        self._fig.tight_layout()
        plt.pause(0.001)

    return ax

save(path)

Save the current plot to a file.

Parameters:

Name Type Description Default
path str

File path (e.g. 'balance.png').

required
Source code in onlinecml/diagnostics/live_love_plot.py
def save(self, path: str) -> None:
    """Save the current plot to a file.

    Parameters
    ----------
    path : str
        File path (e.g. ``'balance.png'``).
    """
    if self._fig is not None:
        self._fig.savefig(path, bbox_inches="tight")

update(x, treatment, weight=1.0)

Update the covariate balance statistics.

Parameters:

Name Type Description Default
x dict

Feature dictionary for this observation.

required
treatment int

Treatment indicator (0 or 1).

required
weight float

Importance weight for this observation. Default 1.0.

1.0
Source code in onlinecml/diagnostics/live_love_plot.py
def update(self, x: dict, treatment: int, weight: float = 1.0) -> None:
    """Update the covariate balance statistics.

    Parameters
    ----------
    x : dict
        Feature dictionary for this observation.
    treatment : int
        Treatment indicator (0 or 1).
    weight : float
        Importance weight for this observation. Default 1.0.
    """
    self._smd.update(x, treatment, weight=weight)
    self._n += 1
    if self._n % self.update_every == 0:
        self.render()

onlinecml.diagnostics.overlap_checker.OverlapChecker

Monitors propensity score distributions for positivity violations.

Tracks the distribution of predicted propensity scores per treatment arm and raises warnings when extreme scores are detected. Reports the proportion of units in the common support region.

Parameters:

Name Type Description Default
ps_min float

Lower positivity threshold. PS values below this are flagged. Default 0.05.

0.05
ps_max float

Upper positivity threshold. PS values above this are flagged. Default 0.95.

0.95
Notes

A unit is in common support if its propensity score satisfies ps_min < p < ps_max. Units outside this region have unreliable causal estimates.

The report() method returns a summary including: - Mean PS per arm - Proportion flagged per arm - Overall common support rate

Examples:

>>> checker = OverlapChecker(ps_min=0.05, ps_max=0.95)
>>> checker.update(propensity=0.3, treatment=1)
>>> checker.update(propensity=0.02, treatment=0)
>>> checker.report()['n_flagged']
1
Source code in onlinecml/diagnostics/overlap_checker.py
class OverlapChecker:
    """Monitors propensity score distributions for positivity violations.

    Tracks the distribution of predicted propensity scores per treatment
    arm and raises warnings when extreme scores are detected. Reports
    the proportion of units in the common support region.

    Parameters
    ----------
    ps_min : float
        Lower positivity threshold. PS values below this are flagged.
        Default 0.05.
    ps_max : float
        Upper positivity threshold. PS values above this are flagged.
        Default 0.95.

    Notes
    -----
    A unit is in *common support* if its propensity score satisfies
    ``ps_min < p < ps_max``. Units outside this region have unreliable
    causal estimates.

    The ``report()`` method returns a summary including:
    - Mean PS per arm
    - Proportion flagged per arm
    - Overall common support rate

    Examples
    --------
    >>> checker = OverlapChecker(ps_min=0.05, ps_max=0.95)
    >>> checker.update(propensity=0.3, treatment=1)
    >>> checker.update(propensity=0.02, treatment=0)
    >>> checker.report()['n_flagged']
    1
    """

    def __init__(self, ps_min: float = 0.05, ps_max: float = 0.95) -> None:
        self.ps_min = ps_min
        self.ps_max = ps_max
        self._ps_stats = [RunningStats(), RunningStats()]   # [control, treated]
        self._n_flagged: int = 0
        self._n_total: int = 0

    def update(self, propensity: float, treatment: int) -> None:
        """Record a propensity score observation.

        Parameters
        ----------
        propensity : float
            Predicted propensity score ``P(W=1|X)`` for this unit.
        treatment : int
            Observed treatment indicator (0 or 1). Used to track
            per-arm PS distributions.
        """
        arm = int(bool(treatment))
        self._ps_stats[arm].update(propensity)
        self._n_total += 1
        if propensity < self.ps_min or propensity > self.ps_max:
            self._n_flagged += 1

    def report(self) -> dict:
        """Return a summary of the propensity score distribution.

        Returns
        -------
        dict
            Keys:

            - ``'n_total'`` — total observations seen
            - ``'n_flagged'`` — observations outside ``[ps_min, ps_max]``
            - ``'flag_rate'`` — proportion flagged
            - ``'common_support_rate'`` — ``1 - flag_rate``
            - ``'mean_ps_treated'`` — mean PS in the treated arm
            - ``'mean_ps_control'`` — mean PS in the control arm
        """
        flag_rate = self._n_flagged / self._n_total if self._n_total > 0 else 0.0
        return {
            "n_total": self._n_total,
            "n_flagged": self._n_flagged,
            "flag_rate": flag_rate,
            "common_support_rate": 1.0 - flag_rate,
            "mean_ps_treated": self._ps_stats[1].mean,
            "mean_ps_control": self._ps_stats[0].mean,
        }

    def is_overlap_adequate(self, max_flag_rate: float = 0.05) -> bool:
        """Return True if the positivity violation rate is acceptable.

        Parameters
        ----------
        max_flag_rate : float
            Maximum tolerable fraction of flagged units. Default 0.05.

        Returns
        -------
        bool
            True if fewer than ``max_flag_rate`` of units are flagged.
        """
        r = self.report()
        return r["flag_rate"] <= max_flag_rate

    def reset(self) -> None:
        """Reset all statistics."""
        self._ps_stats = [RunningStats(), RunningStats()]
        self._n_flagged = 0
        self._n_total = 0

is_overlap_adequate(max_flag_rate=0.05)

Return True if the positivity violation rate is acceptable.

Parameters:

Name Type Description Default
max_flag_rate float

Maximum tolerable fraction of flagged units. Default 0.05.

0.05

Returns:

Type Description
bool

True if fewer than max_flag_rate of units are flagged.

Source code in onlinecml/diagnostics/overlap_checker.py
def is_overlap_adequate(self, max_flag_rate: float = 0.05) -> bool:
    """Return True if the positivity violation rate is acceptable.

    Parameters
    ----------
    max_flag_rate : float
        Maximum tolerable fraction of flagged units. Default 0.05.

    Returns
    -------
    bool
        True if fewer than ``max_flag_rate`` of units are flagged.
    """
    r = self.report()
    return r["flag_rate"] <= max_flag_rate

report()

Return a summary of the propensity score distribution.

Returns:

Type Description
dict

Keys:

  • 'n_total' — total observations seen
  • 'n_flagged' — observations outside [ps_min, ps_max]
  • 'flag_rate' — proportion flagged
  • 'common_support_rate'1 - flag_rate
  • 'mean_ps_treated' — mean PS in the treated arm
  • 'mean_ps_control' — mean PS in the control arm
Source code in onlinecml/diagnostics/overlap_checker.py
def report(self) -> dict:
    """Return a summary of the propensity score distribution.

    Returns
    -------
    dict
        Keys:

        - ``'n_total'`` — total observations seen
        - ``'n_flagged'`` — observations outside ``[ps_min, ps_max]``
        - ``'flag_rate'`` — proportion flagged
        - ``'common_support_rate'`` — ``1 - flag_rate``
        - ``'mean_ps_treated'`` — mean PS in the treated arm
        - ``'mean_ps_control'`` — mean PS in the control arm
    """
    flag_rate = self._n_flagged / self._n_total if self._n_total > 0 else 0.0
    return {
        "n_total": self._n_total,
        "n_flagged": self._n_flagged,
        "flag_rate": flag_rate,
        "common_support_rate": 1.0 - flag_rate,
        "mean_ps_treated": self._ps_stats[1].mean,
        "mean_ps_control": self._ps_stats[0].mean,
    }

reset()

Reset all statistics.

Source code in onlinecml/diagnostics/overlap_checker.py
def reset(self) -> None:
    """Reset all statistics."""
    self._ps_stats = [RunningStats(), RunningStats()]
    self._n_flagged = 0
    self._n_total = 0

update(propensity, treatment)

Record a propensity score observation.

Parameters:

Name Type Description Default
propensity float

Predicted propensity score P(W=1|X) for this unit.

required
treatment int

Observed treatment indicator (0 or 1). Used to track per-arm PS distributions.

required
Source code in onlinecml/diagnostics/overlap_checker.py
def update(self, propensity: float, treatment: int) -> None:
    """Record a propensity score observation.

    Parameters
    ----------
    propensity : float
        Predicted propensity score ``P(W=1|X)`` for this unit.
    treatment : int
        Observed treatment indicator (0 or 1). Used to track
        per-arm PS distributions.
    """
    arm = int(bool(treatment))
    self._ps_stats[arm].update(propensity)
    self._n_total += 1
    if propensity < self.ps_min or propensity > self.ps_max:
        self._n_flagged += 1

onlinecml.diagnostics.concept_drift_monitor.ConceptDriftMonitor

Monitors the ATE estimate stream for structural breaks (concept drift).

Wraps River's ADWIN (Adaptive Windowing) drift detector to identify changes in the distribution of per-observation pseudo-outcomes, which signal a shift in the underlying treatment effect.

Parameters:

Name Type Description Default
delta float

ADWIN confidence parameter. Smaller values reduce false alarm rate at the cost of slower detection. Default 0.002.

0.002
Notes

ADWIN maintains an adaptive window over a data stream and raises a drift signal when the mean of the earlier and later sub-windows differ by more than the statistical threshold.

When drift is detected, drift_detected returns True and n_drifts increments. The estimator being monitored should be reset after drift is detected (the monitor does not do this automatically — it only signals).

References

Bifet, A. and Gavalda, R. (2007). Learning from time-changing data with adaptive windowing. Proceedings of the 7th SIAM International Conference on Data Mining, 443-448.

Examples:

>>> monitor = ConceptDriftMonitor(delta=0.002)
>>> for pseudo_outcome in [1.0, 1.1, 0.9, 1.0, 5.0, 5.1, 4.9]:
...     monitor.update(pseudo_outcome)
>>> monitor.n_drifts >= 0
True
Source code in onlinecml/diagnostics/concept_drift_monitor.py
class ConceptDriftMonitor:
    """Monitors the ATE estimate stream for structural breaks (concept drift).

    Wraps River's ADWIN (Adaptive Windowing) drift detector to identify
    changes in the distribution of per-observation pseudo-outcomes, which
    signal a shift in the underlying treatment effect.

    Parameters
    ----------
    delta : float
        ADWIN confidence parameter. Smaller values reduce false alarm
        rate at the cost of slower detection. Default 0.002.

    Notes
    -----
    ADWIN maintains an adaptive window over a data stream and raises a
    drift signal when the mean of the earlier and later sub-windows
    differ by more than the statistical threshold.

    When drift is detected, ``drift_detected`` returns True and
    ``n_drifts`` increments. The estimator being monitored should be
    reset after drift is detected (the monitor does not do this
    automatically — it only signals).

    References
    ----------
    Bifet, A. and Gavalda, R. (2007). Learning from time-changing data
    with adaptive windowing. Proceedings of the 7th SIAM International
    Conference on Data Mining, 443-448.

    Examples
    --------
    >>> monitor = ConceptDriftMonitor(delta=0.002)
    >>> for pseudo_outcome in [1.0, 1.1, 0.9, 1.0, 5.0, 5.1, 4.9]:
    ...     monitor.update(pseudo_outcome)
    >>> monitor.n_drifts >= 0
    True
    """

    def __init__(self, delta: float = 0.002) -> None:
        self.delta = delta
        self._detector = ADWIN(delta=delta)
        self._n_drifts: int = 0
        self._drift_detected: bool = False
        self._n_seen: int = 0

    def update(self, pseudo_outcome: float) -> None:
        """Feed one pseudo-outcome to the drift detector.

        Parameters
        ----------
        pseudo_outcome : float
            Per-observation pseudo-outcome (e.g. IPW score or CATE estimate).
            Drift in this stream indicates a shift in the treatment effect.
        """
        self._detector.update(pseudo_outcome)
        self._n_seen += 1
        if self._detector.drift_detected:
            self._n_drifts += 1
            self._drift_detected = True
        else:
            self._drift_detected = False

    @property
    def drift_detected(self) -> bool:
        """True if drift was detected on the most recent ``update`` call."""
        return self._drift_detected

    @property
    def n_drifts(self) -> int:
        """Total number of drift events detected since initialization."""
        return self._n_drifts

    @property
    def n_seen(self) -> int:
        """Total number of observations processed."""
        return self._n_seen

    def reset(self) -> None:
        """Reset the detector and drift counters."""
        self._detector = ADWIN(delta=self.delta)
        self._n_drifts = 0
        self._drift_detected = False
        self._n_seen = 0

drift_detected property

True if drift was detected on the most recent update call.

n_drifts property

Total number of drift events detected since initialization.

n_seen property

Total number of observations processed.

reset()

Reset the detector and drift counters.

Source code in onlinecml/diagnostics/concept_drift_monitor.py
def reset(self) -> None:
    """Reset the detector and drift counters."""
    self._detector = ADWIN(delta=self.delta)
    self._n_drifts = 0
    self._drift_detected = False
    self._n_seen = 0

update(pseudo_outcome)

Feed one pseudo-outcome to the drift detector.

Parameters:

Name Type Description Default
pseudo_outcome float

Per-observation pseudo-outcome (e.g. IPW score or CATE estimate). Drift in this stream indicates a shift in the treatment effect.

required
Source code in onlinecml/diagnostics/concept_drift_monitor.py
def update(self, pseudo_outcome: float) -> None:
    """Feed one pseudo-outcome to the drift detector.

    Parameters
    ----------
    pseudo_outcome : float
        Per-observation pseudo-outcome (e.g. IPW score or CATE estimate).
        Drift in this stream indicates a shift in the treatment effect.
    """
    self._detector.update(pseudo_outcome)
    self._n_seen += 1
    if self._detector.drift_detected:
        self._n_drifts += 1
        self._drift_detected = True
    else:
        self._drift_detected = False