Incremental (mini-batch) EM#

When data arrives in a stream, or is too large to sweep each iteration, IncrementalEMFitter updates the model from mini-batches. Each step computes the batch expectation parameters \(\hat\eta\), then blends them into a running estimate \(\eta_t\) through an \(\eta\)-update rule before the M-step. The choice of rule controls the bias/variance and the forgetting behaviour of the online estimate.

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np

from normix import (
    NormalInverseGaussian,
    IdentityUpdate, RobbinsMonroUpdate, SampleWeightedUpdate,
    EWMAUpdate, AffineUpdate, Shrinkage, eta0_from_model,
)
from normix.fitting.em import IncrementalEMFitter
from normix.utils.plotting import set_theme

set_theme()
np.set_printoptions(precision=4, suppress=True)

Setup#

true = NormalInverseGaussian.from_classical(
    mu=jnp.array([0.0, 0.0]),
    gamma=jnp.array([0.4, -0.3]),
    sigma=jnp.array([[1.0, 0.3], [0.3, 1.0]]),
    mu_ig=1.0, lam=1.5)
X = true.rvs(20_000, seed=0)
init = NormalInverseGaussian.default_init(X)
key = jax.random.PRNGKey(0)

The six \(\eta\)-update rules#

A rule maps the previous estimate \(\eta_{t-1}\) and the batch estimate \(\hat\eta\) to the new \(\eta_t\). normix ships six:

Rule

Update \(\eta_t\)

IdentityUpdate

\(\hat\eta\) (no memory)

RobbinsMonroUpdate(tau0)

step-size \(\propto 1/(t + \tau_0)\)

SampleWeightedUpdate

weight by cumulative sample count

EWMAUpdate(w)

exponential moving average, weight \(w\)

AffineUpdate(a, b, c)

\(a + b\,\eta_{t-1} + c\,\hat\eta\)

Shrinkage(base, eta0, tau)

wrap a base rule, shrink toward \(\eta_0\)

rules = {
    "Identity": IdentityUpdate(),
    "RobbinsMonro": RobbinsMonroUpdate(tau0=10.0),
    "SampleWeighted": SampleWeightedUpdate(),
    "EWMA(0.1)": EWMAUpdate(w=0.1),
    "Affine(½,½)": AffineUpdate(b=0.5, c=0.5),
    "Shrinkage": Shrinkage(IdentityUpdate(), eta0_from_model(init), tau=0.3),
}

target = float(true.marginal_log_likelihood(X))
print(f"target mean log-likelihood (true model): {target:.4f}\n")

finals = {}
for name, rule in rules.items():
    fitter = IncrementalEMFitter(
        batch_size=512, max_steps=60, eta_update=rule,
        e_step_backend="cpu", m_step_backend="cpu")
    res = fitter.fit(init, X, key=key)
    finals[name] = float(res.model.marginal_log_likelihood(X))
    print(f"{name:16s} final mean log-lik = {finals[name]:.4f}")
target mean log-likelihood (true model): -2.7653
Identity         final mean log-lik = -2.7660
RobbinsMonro     final mean log-lik = -2.7770
SampleWeighted   final mean log-lik = -2.7679
EWMA(0.1)        final mean log-lik = -2.7675
Affine(½,½)      final mean log-lik = -2.7665
Shrinkage        final mean log-lik = -2.7779
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
names = list(finals)
ax.barh(names, [finals[n] for n in names], color="#2D5A8A")
ax.axvline(target, color="0.4", ls="--", lw=1.2, label="true-model LL")
ax.set_xlabel("final mean log-likelihood")
ax.set_xlim(min(finals.values()) - 0.01, target + 0.005)
ax.set_title("Mini-batch EM: $\\eta$-update rules after 60 steps")
ax.legend()
plt.show()
../../_images/101d72a5a93dd55fa422ce0d6e5985cc671611901ea2dcaf15a4751dfbf33725.png

All rules climb to within a fraction of a nat of the true-model likelihood after just 60 mini-batches. They differ mainly in how they get there — the averaging rules (SampleWeighted, EWMA, Affine) damp the per-step noise, while Identity and the decaying RobbinsMonro step are more volatile.

Following a single trajectory#

With verbose=1 the fitter records the log-likelihood at diagnostic checkpoints, which we can plot to see the mini-batch ascent of two contrasting rules:

fig, ax = plt.subplots()
for name, rule in [("Identity", IdentityUpdate()),
                   ("EWMA(0.1)", EWMAUpdate(w=0.1))]:
    res = IncrementalEMFitter(
        batch_size=512, max_steps=60, eta_update=rule, verbose=1,
        e_step_backend="cpu", m_step_backend="cpu").fit(init, X, key=key)
    ll = np.asarray(res.log_likelihoods)
    ax.plot(np.linspace(0, 60, len(ll)), ll, marker="o", ms=3, label=name)
ax.axhline(target, color="0.4", ls="--", lw=1.2, label="true-model LL")
ax.set_xlabel("mini-batch step"); ax.set_ylabel("mean log-likelihood")
ax.set_title("Incremental EM trajectories")
ax.legend()
plt.show()
EM [incremental] NormalInverseGaussian: rule=IdentityUpdate, batch_size=512, max_steps=60, inner_iter=1
  step    6/60  LL=-2.770447  |Δparams|=3.6950e-01
  step   12/60  LL=-2.768975  |Δparams|=3.6683e-01
  step   18/60  LL=-2.766060  |Δparams|=5.4291e-01
  step   24/60  LL=-2.767411  |Δparams|=7.0969e-01
  step   30/60  LL=-2.768273  |Δparams|=2.6339e-01
  step   36/60  LL=-2.766347  |Δparams|=8.0837e-01
  step   42/60  LL=-2.768970  |Δparams|=6.6540e-01
  step   48/60  LL=-2.767970  |Δparams|=4.0360e-01
  step   54/60  LL=-2.769208  |Δparams|=5.3523e-01
  step   60/60  LL=-2.766001  |Δparams|=2.6125e+00
  Done (3.96s), final LL=-2.766001
EM [incremental] NormalInverseGaussian: rule=EWMAUpdate, batch_size=512, max_steps=60, inner_iter=1
  step    6/60  LL=-2.794093  |Δparams|=2.0370e-01
  step   12/60  LL=-2.784090  |Δparams|=1.1905e-01
  step   18/60  LL=-2.778310  |Δparams|=4.6644e-02
  step   24/60  LL=-2.773891  |Δparams|=3.8607e-02
  step   30/60  LL=-2.771556  |Δparams|=2.2443e-02
  step   36/60  LL=-2.770150  |Δparams|=3.9477e-02
  step   42/60  LL=-2.768800  |Δparams|=2.7101e-02
  step   48/60  LL=-2.768355  |Δparams|=2.6078e-02
  step   54/60  LL=-2.768068  |Δparams|=3.9422e-02
  step   60/60  LL=-2.767476  |Δparams|=4.1363e-02
  Done (3.86s), final LL=-2.767476
../../_images/7dd9d0b4262998e3171b0c4a03b7dcd27486697bfbff161e3a1d3ff0bdcaa7d0.png

The Identity rule is noisier because it discards all history; EWMA smooths the estimate across batches.

Shrinkage toward a target#

Shrinkage wraps any base rule and pulls the estimate toward a fixed \(\eta_0\) — a regularizer for small batches or noisy streams. The targets module builds sensible \(\eta_0\) values:

  • eta0_from_model(model) — the model’s own current expectation parameters.

  • eta0_isotropic(model, sigma2) — isotropic covariance target.

  • eta0_diagonal(model, diag) — diagonal covariance target.

  • eta0_with_sigma(model, Sigma0) — explicit covariance target.

from normix import eta0_isotropic

rule = Shrinkage(RobbinsMonroUpdate(tau0=10.0), eta0_isotropic(init, 1.0), tau=0.5)
res = IncrementalEMFitter(
    batch_size=256, max_steps=60, eta_update=rule,
    e_step_backend="cpu", m_step_backend="cpu").fit(init, X, key=key)
print("shrinkage-to-isotropic final mean log-lik:",
      float(res.model.marginal_log_likelihood(X)))
shrinkage-to-isotropic final mean log-lik: -2.8295220213379735

Takeaways#

  • IncrementalEMFitter updates from mini-batches; fit(model, X, key=...) needs a PRNG key for batch sampling.

  • Six \(\eta\)-update rules trade off memory vs responsiveness; averaging rules reduce variance, Identity reacts fastest.

  • Shrinkage + the eta0_* targets regularize the online estimate toward a chosen structure.

Next: Initialization and multi-start looks at where the EM loop starts and how to make fits robust to local optima.