Skip to content

Forests

onlinecml.forests.causal_hoeffding_tree.CausalHoeffdingTree

Bases: BaseOnlineEstimator

Online causal tree with a CATE-variance split criterion.

Grows a binary decision tree one observation at a time using the Hoeffding bound to guarantee that splits are chosen with high probability from the same feature as a batch learner would choose, given enough data.

Improvements over a naive causal tree:

  • Multi-threshold split search: instead of a single running-mean threshold, evaluates 5 quantile-based candidates per feature and picks the best, improving split location accuracy.
  • Linear leaf models: each leaf maintains separate River LinearRegression models for the treated and control arms. predict_one returns mu1(x) - mu0(x) (individual CATE) rather than a flat leaf mean.
  • Doubly robust leaf CATE: the per-leaf ATE baseline used for split scoring is the running mean of the DR pseudo-outcome mu1 - mu0 + W(Y-mu1)/p - (1-W)(Y-mu0)/(1-p), correcting for within-leaf confounding.

Parameters:

Name Type Description Default
grace_period int

Minimum observations a leaf must collect before attempting a split. Default 200.

200
delta float

Confidence parameter for the Hoeffding bound. Default 1e-5.

1e-05
tau float

Tie-breaking threshold. Default 0.05.

0.05
max_depth int or None

Maximum tree depth. None = unlimited. Default 10.

10
min_arm_samples int

Minimum per-arm observations required per child for split scoring and for switching to linear-model predictions. Default 5.

5
mtry int or None

Number of features randomly considered at each split attempt. None = all features. Default None.

None
outcome_range float

Upper bound on |CATE| for calibrating the Hoeffding bound. Default 10.0.

10.0
clip_ps float

Propensity score clipping bounds [clip_ps, 1 - clip_ps] for DR correction within leaves. Default 0.1.

0.1
seed int or None

Random seed for the mtry RNG. Default None.

None
Notes

Split score (maximised):

.. math::

\text{score}(j) = \frac{n_L}{n}(\hat{\tau}_L - \hat{\tau})^2
                 + \frac{n_R}{n}(\hat{\tau}_R - \hat{\tau})^2

where :math:\hat{\tau} is the DR-corrected leaf CATE and :math:\hat{\tau}_k = \bar{Y}_{1,k} - \bar{Y}_{0,k} for child k.

References

Domingos, P. and Hulten, G. (2000). Mining high-speed data streams. KDD, 71-80.

Wager, S. and Athey, S. (2018). Estimation and inference of heterogeneous treatment effects using random forests. JASA, 113(523), 1228-1242.

Examples:

>>> from onlinecml.datasets import HeterogeneousCausalStream
>>> from onlinecml.forests import CausalHoeffdingTree
>>> tree = CausalHoeffdingTree(grace_period=50, delta=0.01, seed=42)
>>> for x, w, y, _ in HeterogeneousCausalStream(n=1000, seed=0):
...     tree.learn_one(x, w, y)
>>> isinstance(tree.predict_one({'x0': 1.0, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0}), float)
True
Source code in onlinecml/forests/causal_hoeffding_tree.py
class CausalHoeffdingTree(BaseOnlineEstimator):
    """Online causal tree with a CATE-variance split criterion.

    Grows a binary decision tree one observation at a time using the
    **Hoeffding bound** to guarantee that splits are chosen with high
    probability from the same feature as a batch learner would choose,
    given enough data.

    **Improvements over a naive causal tree:**

    - *Multi-threshold split search*: instead of a single running-mean
      threshold, evaluates 5 quantile-based candidates per feature and picks
      the best, improving split location accuracy.
    - *Linear leaf models*: each leaf maintains separate River
      ``LinearRegression`` models for the treated and control arms.
      ``predict_one`` returns ``mu1(x) - mu0(x)`` (individual CATE) rather
      than a flat leaf mean.
    - *Doubly robust leaf CATE*: the per-leaf ATE baseline used for split
      scoring is the running mean of the DR pseudo-outcome
      ``mu1 - mu0 + W(Y-mu1)/p - (1-W)(Y-mu0)/(1-p)``, correcting for
      within-leaf confounding.

    Parameters
    ----------
    grace_period : int
        Minimum observations a leaf must collect before attempting a split.
        Default 200.
    delta : float
        Confidence parameter for the Hoeffding bound. Default 1e-5.
    tau : float
        Tie-breaking threshold. Default 0.05.
    max_depth : int or None
        Maximum tree depth. ``None`` = unlimited. Default 10.
    min_arm_samples : int
        Minimum per-arm observations required per child for split scoring and
        for switching to linear-model predictions. Default 5.
    mtry : int or None
        Number of features randomly considered at each split attempt.
        ``None`` = all features. Default None.
    outcome_range : float
        Upper bound on ``|CATE|`` for calibrating the Hoeffding bound.
        Default 10.0.
    clip_ps : float
        Propensity score clipping bounds ``[clip_ps, 1 - clip_ps]`` for DR
        correction within leaves. Default 0.1.
    seed : int or None
        Random seed for the mtry RNG. Default None.

    Notes
    -----
    Split score (maximised):

    .. math::

        \\text{score}(j) = \\frac{n_L}{n}(\\hat{\\tau}_L - \\hat{\\tau})^2
                         + \\frac{n_R}{n}(\\hat{\\tau}_R - \\hat{\\tau})^2

    where :math:`\\hat{\\tau}` is the DR-corrected leaf CATE and
    :math:`\\hat{\\tau}_k = \\bar{Y}_{1,k} - \\bar{Y}_{0,k}` for child ``k``.

    References
    ----------
    Domingos, P. and Hulten, G. (2000). Mining high-speed data streams.
    KDD, 71-80.

    Wager, S. and Athey, S. (2018). Estimation and inference of heterogeneous
    treatment effects using random forests. JASA, 113(523), 1228-1242.

    Examples
    --------
    >>> from onlinecml.datasets import HeterogeneousCausalStream
    >>> from onlinecml.forests import CausalHoeffdingTree
    >>> tree = CausalHoeffdingTree(grace_period=50, delta=0.01, seed=42)
    >>> for x, w, y, _ in HeterogeneousCausalStream(n=1000, seed=0):
    ...     tree.learn_one(x, w, y)
    >>> isinstance(tree.predict_one({'x0': 1.0, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0}), float)
    True
    """

    def __init__(
        self,
        grace_period: int = 200,
        delta: float = 1e-5,
        tau: float = 0.05,
        max_depth: int | None = 10,
        min_arm_samples: int = 5,
        mtry: int | None = None,
        outcome_range: float = 10.0,
        clip_ps: float = 0.1,
        seed: int | None = None,
    ) -> None:
        self.grace_period    = grace_period
        self.delta           = delta
        self.tau             = tau
        self.max_depth       = max_depth
        self.min_arm_samples = min_arm_samples
        self.mtry            = mtry
        self.outcome_range   = outcome_range
        self.clip_ps         = clip_ps
        self.seed            = seed

        self._rng = random.Random(seed)
        self._root: _Node = _Node()
        self._n_seen: int = 0
        self._ate_stats: RunningStats = RunningStats()

        # Per-leaf per-feature split statistics: {id(node) → {feat → _FeatureSplitStats}}
        self._leaf_split_stats: dict[int, dict[str, _FeatureSplitStats]] = {}

    # ------------------------------------------------------------------
    # BaseOnlineEstimator interface
    # ------------------------------------------------------------------

    def learn_one(
        self,
        x: dict,
        treatment: int,
        outcome: float,
        propensity: float | None = None,
    ) -> None:
        """Process one observation and potentially grow the tree.

        Uses a predict-first-then-learn protocol: all three leaf models
        (treated, control, propensity) predict *before* being updated, so the
        DR pseudo-outcome is computed from out-of-sample predictions.

        Parameters
        ----------
        x : dict
            Covariate dictionary.
        treatment : int
            Treatment indicator (0 or 1).
        outcome : float
            Observed outcome.
        propensity : float or None
            If provided, uses this logged propensity instead of the leaf PS
            model for the DR correction.
        """
        self._n_seen += 1
        node, depth = self._route(x)

        # ── Predict-first: all models predict before any update ─────────────
        mu1 = node.treated_model.predict_one(x)
        mu0 = node.control_model.predict_one(x)

        if propensity is not None:
            p_hat = max(self.clip_ps, min(1.0 - self.clip_ps, propensity))
        else:
            raw = node.ps_model.predict_proba_one(x)
            p_hat = max(self.clip_ps, min(1.0 - self.clip_ps, raw.get(1, 0.5)))

        # ── DR pseudo-outcome ────────────────────────────────────────────────
        psi = (
            mu1 - mu0
            + treatment * (outcome - mu1) / p_hat
            - (1 - treatment) * (outcome - mu0) / (1.0 - p_hat)
        )

        # ── Update leaf-level statistics ─────────────────────────────────────
        node.stats.update(outcome, treatment)
        node.dr_stats.update(psi)
        node.n_since_split += 1

        # ── Learn-after: update models on current observation ────────────────
        if treatment == 1:
            node.treated_model.learn_one(x, outcome)
        else:
            node.control_model.learn_one(x, outcome)
        if propensity is None:
            node.ps_model.learn_one(x, treatment)

        # ── Update per-feature split statistics ──────────────────────────────
        leaf_id = id(node)
        if leaf_id not in self._leaf_split_stats:
            self._leaf_split_stats[leaf_id] = {}
        for feat, val in x.items():
            if feat not in self._leaf_split_stats[leaf_id]:
                self._leaf_split_stats[leaf_id][feat] = _FeatureSplitStats()
            self._leaf_split_stats[leaf_id][feat].update(val, outcome, treatment)

        # ── Update global ATE estimate (DR-corrected) ─────────────────────
        self._ate_stats.update(psi)

        # ── Attempt split ────────────────────────────────────────────────────
        if node.n_since_split >= self.grace_period:
            if self.max_depth is None or depth < self.max_depth:
                self._try_split(node, depth)

    def predict_one(self, x: dict) -> float:
        """Predict CATE for a single unit using the leaf's linear models.

        Parameters
        ----------
        x : dict
            Covariate dictionary.

        Returns
        -------
        float
            Estimated CATE: ``mu1(x) - mu0(x)`` from the leaf's linear models
            when enough per-arm data exists; DR-corrected mean otherwise.
        """
        node, _ = self._route(x)
        return node.predict_cate(x, self.min_arm_samples)

    # ------------------------------------------------------------------
    # Tree internals
    # ------------------------------------------------------------------

    def _route(self, x: dict) -> tuple[_Node, int]:
        """Route ``x`` to its leaf, returning ``(leaf_node, depth)``."""
        node = self._root
        depth = 0
        while not node.is_leaf:
            val = x.get(node.feature, 0.0)  # type: ignore[arg-type]
            node = node.left if val <= node.threshold else node.right  # type: ignore[assignment]
            depth += 1
        return node, depth

    def _try_split(self, node: _Node, depth: int) -> None:
        """Evaluate causal split candidates and split if the Hoeffding bound allows."""
        leaf_id    = id(node)
        feat_stats = self._leaf_split_stats.get(leaf_id, {})
        if not feat_stats:
            return

        # Use DR-corrected CATE as the global baseline τ
        cate_global = node.dr_stats.mean if node.dr_stats.n >= 2 else node.stats.cate

        n = node.n_since_split
        best_score   = float("-inf")
        best_feat    = None
        best_thresh  = 0.0
        second_score = float("-inf")

        # Optional mtry: randomly restrict candidate features
        candidates = list(feat_stats.keys())
        if self.mtry is not None and self.mtry < len(candidates):
            candidates = self._rng.sample(candidates, self.mtry)

        for feat in candidates:
            score, thresh = feat_stats[feat].best_split(cate_global, self.min_arm_samples)
            if score > best_score:
                second_score = best_score
                best_score   = score
                best_feat    = feat
                best_thresh  = thresh
            elif score > second_score:
                second_score = score

        if best_feat is None or best_score <= 0:
            return

        R = self.outcome_range
        hoeffding_bound = math.sqrt(R * R * math.log(1.0 / self.delta) / (2.0 * n))

        gap = best_score - max(second_score, 0.0)
        if gap > hoeffding_bound or hoeffding_bound < self.tau:
            self._split_node(node, best_feat, best_thresh, leaf_id)

    def _split_node(
        self,
        node: _Node,
        feature: str,
        threshold: float,
        leaf_id: int,
    ) -> None:
        """Convert a leaf into an internal node."""
        node.feature   = feature
        node.threshold = threshold
        node.left      = _Node()
        node.right     = _Node()
        node.n_since_split = 0
        self._leaf_split_stats.pop(leaf_id, None)

    # ------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------

    @property
    def n_nodes(self) -> int:
        """Total number of nodes (internal + leaf) in the tree."""
        return self._count_nodes(self._root)

    def _count_nodes(self, node: _Node | None) -> int:
        """Recursively count all nodes."""
        if node is None:
            return 0
        return 1 + self._count_nodes(node.left) + self._count_nodes(node.right)

    @property
    def n_leaves(self) -> int:
        """Number of leaf nodes."""
        return self._count_leaves(self._root)

    def _count_leaves(self, node: _Node | None) -> int:
        """Recursively count leaf nodes."""
        if node is None:
            return 0
        if node.is_leaf:
            return 1
        return self._count_leaves(node.left) + self._count_leaves(node.right)

n_leaves property

Number of leaf nodes.

n_nodes property

Total number of nodes (internal + leaf) in the tree.

learn_one(x, treatment, outcome, propensity=None)

Process one observation and potentially grow the tree.

Uses a predict-first-then-learn protocol: all three leaf models (treated, control, propensity) predict before being updated, so the DR pseudo-outcome is computed from out-of-sample predictions.

Parameters:

Name Type Description Default
x dict

Covariate dictionary.

required
treatment int

Treatment indicator (0 or 1).

required
outcome float

Observed outcome.

required
propensity float or None

If provided, uses this logged propensity instead of the leaf PS model for the DR correction.

None
Source code in onlinecml/forests/causal_hoeffding_tree.py
def learn_one(
    self,
    x: dict,
    treatment: int,
    outcome: float,
    propensity: float | None = None,
) -> None:
    """Process one observation and potentially grow the tree.

    Uses a predict-first-then-learn protocol: all three leaf models
    (treated, control, propensity) predict *before* being updated, so the
    DR pseudo-outcome is computed from out-of-sample predictions.

    Parameters
    ----------
    x : dict
        Covariate dictionary.
    treatment : int
        Treatment indicator (0 or 1).
    outcome : float
        Observed outcome.
    propensity : float or None
        If provided, uses this logged propensity instead of the leaf PS
        model for the DR correction.
    """
    self._n_seen += 1
    node, depth = self._route(x)

    # ── Predict-first: all models predict before any update ─────────────
    mu1 = node.treated_model.predict_one(x)
    mu0 = node.control_model.predict_one(x)

    if propensity is not None:
        p_hat = max(self.clip_ps, min(1.0 - self.clip_ps, propensity))
    else:
        raw = node.ps_model.predict_proba_one(x)
        p_hat = max(self.clip_ps, min(1.0 - self.clip_ps, raw.get(1, 0.5)))

    # ── DR pseudo-outcome ────────────────────────────────────────────────
    psi = (
        mu1 - mu0
        + treatment * (outcome - mu1) / p_hat
        - (1 - treatment) * (outcome - mu0) / (1.0 - p_hat)
    )

    # ── Update leaf-level statistics ─────────────────────────────────────
    node.stats.update(outcome, treatment)
    node.dr_stats.update(psi)
    node.n_since_split += 1

    # ── Learn-after: update models on current observation ────────────────
    if treatment == 1:
        node.treated_model.learn_one(x, outcome)
    else:
        node.control_model.learn_one(x, outcome)
    if propensity is None:
        node.ps_model.learn_one(x, treatment)

    # ── Update per-feature split statistics ──────────────────────────────
    leaf_id = id(node)
    if leaf_id not in self._leaf_split_stats:
        self._leaf_split_stats[leaf_id] = {}
    for feat, val in x.items():
        if feat not in self._leaf_split_stats[leaf_id]:
            self._leaf_split_stats[leaf_id][feat] = _FeatureSplitStats()
        self._leaf_split_stats[leaf_id][feat].update(val, outcome, treatment)

    # ── Update global ATE estimate (DR-corrected) ─────────────────────
    self._ate_stats.update(psi)

    # ── Attempt split ────────────────────────────────────────────────────
    if node.n_since_split >= self.grace_period:
        if self.max_depth is None or depth < self.max_depth:
            self._try_split(node, depth)

predict_one(x)

Predict CATE for a single unit using the leaf's linear models.

Parameters:

Name Type Description Default
x dict

Covariate dictionary.

required

Returns:

Type Description
float

Estimated CATE: mu1(x) - mu0(x) from the leaf's linear models when enough per-arm data exists; DR-corrected mean otherwise.

Source code in onlinecml/forests/causal_hoeffding_tree.py
def predict_one(self, x: dict) -> float:
    """Predict CATE for a single unit using the leaf's linear models.

    Parameters
    ----------
    x : dict
        Covariate dictionary.

    Returns
    -------
    float
        Estimated CATE: ``mu1(x) - mu0(x)`` from the leaf's linear models
        when enough per-arm data exists; DR-corrected mean otherwise.
    """
    node, _ = self._route(x)
    return node.predict_cate(x, self.min_arm_samples)

onlinecml.forests.online_causal_forest.OnlineCausalForest

Bases: BaseOnlineEstimator

Ensemble of CausalHoeffdingTrees for online CATE estimation.

Grows n_trees independent CausalHoeffdingTree instances in parallel. Each tree receives a random subsample of each observation (Poisson bootstrap, Oza 2001). The forest CATE prediction is the mean of all tree predictions.

Each tree is monitored for concept drift via an ADWIN detector on its normalised prediction signal. On drift detection the affected tree is reset and starts growing from scratch, while the remaining trees continue uninterrupted.

Parameters:

Name Type Description Default
n_trees int

Number of trees in the ensemble. Default 10.

10
grace_period int

Grace period passed to each CausalHoeffdingTree. Default 200.

200
delta float

Hoeffding confidence parameter for each tree. Default 1e-5.

1e-05
tau float

Tie-breaking threshold for each tree. Default 0.05.

0.05
max_depth int or None

Maximum tree depth. Default 10.

10
subsample_rate float

Expected number of times each tree sees each observation (Poisson bootstrap lambda). 1.0 = standard online bagging. Default 1.0.

1.0
mtry int or None

Number of features randomly considered at each split attempt per tree. None = all features. int(sqrt(p)) is a common choice when many features are informative. Default None.

None
min_arm_samples int

Passed to each tree. Default 5.

5
outcome_range float

Passed to each tree. Upper bound on |CATE| for calibrating the Hoeffding bound. Default 10.0.

10.0
clip_ps float

Propensity clipping bounds for DR correction within leaves. Default 0.1.

0.1
drift_detection bool

If True, attach an ADWIN detector to each tree and reset trees on drift. Default True.

True
seed int or None

Random seed for the subsampling RNG.

None
Notes

Online bagging (Oza 2001): each incoming observation is presented to tree k exactly Poisson(subsample_rate) times.

Drift detection follows the ARF approach: each tree's prediction is normalised to [0, 1] using the running mean ± 3σ window and fed to ADWIN. When ADWIN raises an alarm, the tree and its detector are reset.

References

Oza, N.C. (2001). Online bagging and boosting. Proc. American Statistical Association, 229-234.

Gomes, H.M. et al. (2017). Adaptive random forests for evolving data stream classification. Machine Learning, 106(9), 1469-1495.

Examples:

>>> from onlinecml.datasets import LinearCausalStream
>>> from onlinecml.forests import OnlineCausalForest
>>> forest = OnlineCausalForest(n_trees=5, grace_period=50, seed=0)
>>> for x, w, y, _ in LinearCausalStream(n=500, seed=0):
...     forest.learn_one(x, w, y)
>>> isinstance(forest.predict_one({'x0': 0.5, 'x1': -0.3, 'x2': 0.0, 'x3': 0.1, 'x4': -0.2}), float)
True
Source code in onlinecml/forests/online_causal_forest.py
class OnlineCausalForest(BaseOnlineEstimator):
    """Ensemble of CausalHoeffdingTrees for online CATE estimation.

    Grows ``n_trees`` independent ``CausalHoeffdingTree`` instances in parallel.
    Each tree receives a random subsample of each observation (Poisson bootstrap,
    Oza 2001).  The forest CATE prediction is the mean of all tree predictions.

    Each tree is monitored for concept drift via an ADWIN detector on its
    normalised prediction signal.  On drift detection the affected tree is
    reset and starts growing from scratch, while the remaining trees continue
    uninterrupted.

    Parameters
    ----------
    n_trees : int
        Number of trees in the ensemble. Default 10.
    grace_period : int
        Grace period passed to each ``CausalHoeffdingTree``. Default 200.
    delta : float
        Hoeffding confidence parameter for each tree. Default 1e-5.
    tau : float
        Tie-breaking threshold for each tree. Default 0.05.
    max_depth : int or None
        Maximum tree depth. Default 10.
    subsample_rate : float
        Expected number of times each tree sees each observation (Poisson
        bootstrap ``lambda``). 1.0 = standard online bagging. Default 1.0.
    mtry : int or None
        Number of features randomly considered at each split attempt per tree.
        ``None`` = all features. ``int(sqrt(p))`` is a common choice when many
        features are informative. Default None.
    min_arm_samples : int
        Passed to each tree. Default 5.
    outcome_range : float
        Passed to each tree. Upper bound on ``|CATE|`` for calibrating the
        Hoeffding bound. Default 10.0.
    clip_ps : float
        Propensity clipping bounds for DR correction within leaves. Default 0.1.
    drift_detection : bool
        If ``True``, attach an ADWIN detector to each tree and reset trees on
        drift. Default True.
    seed : int or None
        Random seed for the subsampling RNG.

    Notes
    -----
    Online bagging (Oza 2001): each incoming observation is presented to tree
    ``k`` exactly ``Poisson(subsample_rate)`` times.

    Drift detection follows the ARF approach: each tree's prediction is
    normalised to ``[0, 1]`` using the running ``mean ± 3σ`` window and fed to
    ADWIN.  When ADWIN raises an alarm, the tree and its detector are reset.

    References
    ----------
    Oza, N.C. (2001). Online bagging and boosting. Proc. American Statistical
    Association, 229-234.

    Gomes, H.M. et al. (2017). Adaptive random forests for evolving data stream
    classification. Machine Learning, 106(9), 1469-1495.

    Examples
    --------
    >>> from onlinecml.datasets import LinearCausalStream
    >>> from onlinecml.forests import OnlineCausalForest
    >>> forest = OnlineCausalForest(n_trees=5, grace_period=50, seed=0)
    >>> for x, w, y, _ in LinearCausalStream(n=500, seed=0):
    ...     forest.learn_one(x, w, y)
    >>> isinstance(forest.predict_one({'x0': 0.5, 'x1': -0.3, 'x2': 0.0, 'x3': 0.1, 'x4': -0.2}), float)
    True
    """

    def __init__(
        self,
        n_trees: int = 10,
        grace_period: int = 200,
        delta: float = 1e-5,
        tau: float = 0.05,
        max_depth: int | None = 10,
        subsample_rate: float = 1.0,
        mtry: int | None = None,
        min_arm_samples: int = 5,
        outcome_range: float = 10.0,
        clip_ps: float = 0.1,
        drift_detection: bool = True,
        seed: int | None = None,
    ) -> None:
        self.n_trees         = n_trees
        self.grace_period    = grace_period
        self.delta           = delta
        self.tau             = tau
        self.max_depth       = max_depth
        self.subsample_rate  = subsample_rate
        self.mtry            = mtry
        self.min_arm_samples = min_arm_samples
        self.outcome_range   = outcome_range
        self.clip_ps         = clip_ps
        self.drift_detection = drift_detection
        self.seed            = seed

        self._rng = random.Random(seed)
        self._trees: list[CausalHoeffdingTree] = [
            self._new_tree(i) for i in range(n_trees)
        ]
        self._n_seen: int = 0
        self._ate_stats: RunningStats = RunningStats()

        # Per-tree drift monitoring
        self._drift_detectors: list[ADWIN] | None = (
            [ADWIN() for _ in range(n_trees)] if drift_detection else None
        )
        self._pred_stats: list[RunningStats] | None = (
            [RunningStats() for _ in range(n_trees)] if drift_detection else None
        )

    # ------------------------------------------------------------------
    # BaseOnlineEstimator interface
    # ------------------------------------------------------------------

    def learn_one(
        self,
        x: dict,
        treatment: int,
        outcome: float,
        propensity: float | None = None,
    ) -> None:
        """Process one observation, updating all trees via online bagging.

        After each tree update, optionally checks for concept drift using its
        ADWIN detector and resets the tree if drift is detected.

        Parameters
        ----------
        x : dict
            Covariate dictionary.
        treatment : int
            Treatment indicator (0 or 1).
        outcome : float
            Observed outcome.
        propensity : float or None
            If provided, passed to each tree's DR correction.
        """
        self._n_seen += 1

        for i, tree in enumerate(self._trees):
            k = self._poisson(self.subsample_rate)
            for _ in range(k):
                tree.learn_one(x, treatment, outcome, propensity)

            # ── Drift detection ─────────────────────────────────────────────
            if self._drift_detectors is not None:
                pred  = tree.predict_one(x)
                stats = self._pred_stats[i]  # type: ignore[index]
                stats.update(pred)
                if stats.n > 1:
                    sigma      = stats.std or 1e-8
                    normalised = (pred - stats.mean) / (6.0 * sigma) + 0.5
                    normalised = max(0.0, min(1.0, normalised))
                    if self._drift_detectors[i].update(normalised):
                        self._trees[i]            = self._new_tree(i)
                        self._drift_detectors[i]  = ADWIN()
                        self._pred_stats[i]       = RunningStats()  # type: ignore[index]

    def predict_one(self, x: dict) -> float:
        """Predict CATE as the mean across all tree predictions.

        Parameters
        ----------
        x : dict
            Covariate dictionary.

        Returns
        -------
        float
            Mean CATE across all trees. Returns ``0.0`` when untrained.
        """
        preds = [t.predict_one(x) for t in self._trees]
        return sum(preds) / len(preds) if preds else 0.0

    def predict_ate(self) -> float:
        """Return the current ATE estimate as the mean of tree DR-corrected ATEs.

        Each tree maintains a running mean of its own DR pseudo-outcomes.
        The forest ATE is their simple average, which converges to the true
        ATE without cold-start bias from untrained linear leaf models.

        Returns
        -------
        float
            Mean ATE across all trees. Returns ``0.0`` before any data.
        """
        ates = [t.predict_ate() for t in self._trees]
        return sum(ates) / len(ates) if ates else 0.0

    def predict_ci(self, alpha: float = 0.05) -> tuple[float, float]:
        """Return a confidence interval for the ATE as the mean of tree CIs.

        Parameters
        ----------
        alpha : float
            Significance level. Default 0.05 gives a 95% CI.

        Returns
        -------
        lower : float
            Mean lower bound across all tree CIs.
        upper : float
            Mean upper bound across all tree CIs.
        """
        cis = [t.predict_ci(alpha) for t in self._trees]
        lowers = [lo for lo, _ in cis]
        uppers = [hi for _, hi in cis]
        n = len(cis)
        return (sum(lowers) / n, sum(uppers) / n) if n else (float("-inf"), float("inf"))

    # ------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------

    @property
    def n_nodes(self) -> list[int]:
        """Number of nodes in each tree."""
        return [t.n_nodes for t in self._trees]

    @property
    def n_leaves(self) -> list[int]:
        """Number of leaf nodes in each tree."""
        return [t.n_leaves for t in self._trees]

    # ------------------------------------------------------------------
    # Internals
    # ------------------------------------------------------------------

    def _new_tree(self, index: int) -> CausalHoeffdingTree:
        """Instantiate a fresh ``CausalHoeffdingTree`` for position ``index``.

        Parameters
        ----------
        index : int
            Tree index, used to derive a per-tree seed from ``self.seed``.

        Returns
        -------
        CausalHoeffdingTree
            A freshly initialised tree with this forest's hyperparameters.
        """
        return CausalHoeffdingTree(
            grace_period    = self.grace_period,
            delta           = self.delta,
            tau             = self.tau,
            max_depth       = self.max_depth,
            min_arm_samples = self.min_arm_samples,
            mtry            = self.mtry,
            outcome_range   = self.outcome_range,
            clip_ps         = self.clip_ps,
            seed            = (self.seed + index) if self.seed is not None else None,
        )

    def _poisson(self, lam: float) -> int:
        """Draw from Poisson(lam) using Knuth's algorithm (lam ≤ 20).

        Parameters
        ----------
        lam : float
            Poisson rate parameter.

        Returns
        -------
        int
            A non-negative integer sample from Poisson(lam).
        """
        if lam <= 0:
            return 0
        L = math.exp(-lam)
        k = 0
        p = 1.0
        while p > L:
            k += 1
            p *= self._rng.random()
        return k - 1

    def reset(self) -> None:
        """Reset the forest to its initial (untrained) state."""
        fresh = self.clone()
        self.__dict__.update(fresh.__dict__)

n_leaves property

Number of leaf nodes in each tree.

n_nodes property

Number of nodes in each tree.

learn_one(x, treatment, outcome, propensity=None)

Process one observation, updating all trees via online bagging.

After each tree update, optionally checks for concept drift using its ADWIN detector and resets the tree if drift is detected.

Parameters:

Name Type Description Default
x dict

Covariate dictionary.

required
treatment int

Treatment indicator (0 or 1).

required
outcome float

Observed outcome.

required
propensity float or None

If provided, passed to each tree's DR correction.

None
Source code in onlinecml/forests/online_causal_forest.py
def learn_one(
    self,
    x: dict,
    treatment: int,
    outcome: float,
    propensity: float | None = None,
) -> None:
    """Process one observation, updating all trees via online bagging.

    After each tree update, optionally checks for concept drift using its
    ADWIN detector and resets the tree if drift is detected.

    Parameters
    ----------
    x : dict
        Covariate dictionary.
    treatment : int
        Treatment indicator (0 or 1).
    outcome : float
        Observed outcome.
    propensity : float or None
        If provided, passed to each tree's DR correction.
    """
    self._n_seen += 1

    for i, tree in enumerate(self._trees):
        k = self._poisson(self.subsample_rate)
        for _ in range(k):
            tree.learn_one(x, treatment, outcome, propensity)

        # ── Drift detection ─────────────────────────────────────────────
        if self._drift_detectors is not None:
            pred  = tree.predict_one(x)
            stats = self._pred_stats[i]  # type: ignore[index]
            stats.update(pred)
            if stats.n > 1:
                sigma      = stats.std or 1e-8
                normalised = (pred - stats.mean) / (6.0 * sigma) + 0.5
                normalised = max(0.0, min(1.0, normalised))
                if self._drift_detectors[i].update(normalised):
                    self._trees[i]            = self._new_tree(i)
                    self._drift_detectors[i]  = ADWIN()
                    self._pred_stats[i]       = RunningStats()  # type: ignore[index]

predict_ate()

Return the current ATE estimate as the mean of tree DR-corrected ATEs.

Each tree maintains a running mean of its own DR pseudo-outcomes. The forest ATE is their simple average, which converges to the true ATE without cold-start bias from untrained linear leaf models.

Returns:

Type Description
float

Mean ATE across all trees. Returns 0.0 before any data.

Source code in onlinecml/forests/online_causal_forest.py
def predict_ate(self) -> float:
    """Return the current ATE estimate as the mean of tree DR-corrected ATEs.

    Each tree maintains a running mean of its own DR pseudo-outcomes.
    The forest ATE is their simple average, which converges to the true
    ATE without cold-start bias from untrained linear leaf models.

    Returns
    -------
    float
        Mean ATE across all trees. Returns ``0.0`` before any data.
    """
    ates = [t.predict_ate() for t in self._trees]
    return sum(ates) / len(ates) if ates else 0.0

predict_ci(alpha=0.05)

Return a confidence interval for the ATE as the mean of tree CIs.

Parameters:

Name Type Description Default
alpha float

Significance level. Default 0.05 gives a 95% CI.

0.05

Returns:

Name Type Description
lower float

Mean lower bound across all tree CIs.

upper float

Mean upper bound across all tree CIs.

Source code in onlinecml/forests/online_causal_forest.py
def predict_ci(self, alpha: float = 0.05) -> tuple[float, float]:
    """Return a confidence interval for the ATE as the mean of tree CIs.

    Parameters
    ----------
    alpha : float
        Significance level. Default 0.05 gives a 95% CI.

    Returns
    -------
    lower : float
        Mean lower bound across all tree CIs.
    upper : float
        Mean upper bound across all tree CIs.
    """
    cis = [t.predict_ci(alpha) for t in self._trees]
    lowers = [lo for lo, _ in cis]
    uppers = [hi for _, hi in cis]
    n = len(cis)
    return (sum(lowers) / n, sum(uppers) / n) if n else (float("-inf"), float("inf"))

predict_one(x)

Predict CATE as the mean across all tree predictions.

Parameters:

Name Type Description Default
x dict

Covariate dictionary.

required

Returns:

Type Description
float

Mean CATE across all trees. Returns 0.0 when untrained.

Source code in onlinecml/forests/online_causal_forest.py
def predict_one(self, x: dict) -> float:
    """Predict CATE as the mean across all tree predictions.

    Parameters
    ----------
    x : dict
        Covariate dictionary.

    Returns
    -------
    float
        Mean CATE across all trees. Returns ``0.0`` when untrained.
    """
    preds = [t.predict_one(x) for t in self._trees]
    return sum(preds) / len(preds) if preds else 0.0

reset()

Reset the forest to its initial (untrained) state.

Source code in onlinecml/forests/online_causal_forest.py
def reset(self) -> None:
    """Reset the forest to its initial (untrained) state."""
    fresh = self.clone()
    self.__dict__.update(fresh.__dict__)