EM Framework#
Scope. Why model and fitter are separate, what the η-update rule abstraction buys, how the
Shrinkagecombinator generalises penalised EM, and what the four covariance-regularisation modes do.Where things live. Public API is in the API Reference. Parameter facade and
from_expectationdispatch are in Mixture Architecture. Solver internals are in Solvers and Bessel Functions.
1. Model Knows Math, Fitter Knows Iteration#
Following the GMMX style: distributions implement E/M-step math and
return immutable models; fitters implement iteration, convergence, and
diagnostics. Fitters are plain Python classes, not eqx.Module
(they carry no JAX state — moving them onto pytrees would buy nothing
and make their attributes traceable).
class BatchEMFitter: # full-dataset EM with convergence monitoring
class IncrementalEMFitter: # mini-batch / online with fixed budget
Both return EMResult:
Field |
Always set |
Notes |
|---|---|---|
|
yes |
the fitted pytree |
|
yes |
max relative L2 change in |
|
yes |
iterations actually run |
|
yes (Batch) / |
|
|
optional |
|
|
optional (verbose ≥ 1) |
per-iteration LL trace |
|
yes |
wall clock |
2. EM Steps on Marginal Mixtures#
e_step(X, *, backend='jax'|'cpu') -> NormalMixtureEta | FactorMixtureStats
m_step(eta, **kw) -> MarginalMixture # full update
m_step_normal(eta) -> MarginalMixture # MCECM cycle 1
m_step_subordinator(eta, **kw) -> MarginalMixture # MCECM cycle 2
compute_eta_from_model() -> stats pytree # incremental warm-start
em_convergence_params() -> pytree # convergence hook
e_step returns aggregated expectation parameters as an eqx.Module
pytree, not raw per-observation dicts. The first six fields of
FactorMixtureStats are identical to NormalMixtureEta, so shrinkage
targets and rule weights port across the two families (see
docs/design/mixtures.md §7).
2.1 Convergence on a pytree#
em_convergence_params() returns a pytree whose leaf-wise change
defines convergence:
Marginal |
Returns |
|---|---|
|
|
|
|
Subordinator parameters \((p, a, b)\) are excluded — their solver has
its own tolerance and including them inflates iteration counts.
Returning \(\Sigma\) from FactorNormalMixture (rather than \((F, D)\))
sidesteps the \(r \times r\) orthogonal gauge of \(F\) — the Σ-recovery
test in tests/test_factor_mixture.py passes without an
orthogonalisation step.
_param_change(new, old) takes max relative L2 change across leaves,
clamping the denominator at _PARAM_EPS.
3. η-Update Rules: Two Layers#
Online and shrinkage updates compose. The current abstraction has two layers so future ML-style predictors can plug in without an API revision.
class EtaUpdateRule(eqx.Module):
"""Most general: η_t = rule(η_{t-1}, η̂)."""
def initial_state(self) -> dict: ...
def __call__(self, eta_prev, eta_new, step, batch_size, state)
-> tuple[EtaLike, dict]: ...
class AffineRule(EtaUpdateRule):
"""Specialisation: η_t = a + b·η_{t-1} + c·η̂."""
def weights(self, step, batch_size, state)
-> tuple[Optional[EtaLike], Weight, Weight, dict]: ...
__call__ (rather than predict / apply) is the Equinox idiom:
a module is its forward pass, like eqx.nn.Linear(x).
Rule (concrete) |
Layer |
Stored |
\((a, b, c)\) |
|---|---|---|---|
|
affine |
— |
\((0, 0, 1)\) |
|
affine |
|
\((0, 1-w, w)\) |
|
affine |
|
\(c = 1/(\tau_0 + t)\) |
|
affine |
(state-only |
\(n/(n+m), m/(n+m)\) |
|
affine |
|
as-is |
|
non-affine combinator |
|
derived (see §4) |
3.1 Generalised affine_combine#
affine_combine(eta_prev, eta_new, b, c, a=None) accepts three weight
forms via jax.tree.map:
Form |
Math |
Use case |
|---|---|---|
Scalar (or 0-d |
\(b\cdot I_n\) broadcast to every leaf |
EWMA, Robbins–Monro, uniform shrinkage |
Stats-shape pytree |
block-diagonal: leaf-wise multiply |
per-field / per-element shrinkage |
Callable |
arbitrary linear operator |
user-supplied; e.g. wrap |
We do not ship a flat \(n \times n\) matrix form. That would force
the public API to commit to a flatten order across \((s_1,\dots,s_6)\),
leaking the contract into every user script. Users who want a true
\(n \times n\) linear operator wrap an eqx.nn.Linear inside a custom
rule and use jax.flatten_util.ravel_pytree locally.
3.2 State and trainability#
eqx.Module is immutable. Anything that genuinely changes per iteration
(cumulative sample count, RNN hidden state, momentum buffers) lives in
state, threaded by the fitter:
state = rule.initial_state()
for step in range(...):
eta_t, state = rule(eta_prev, eta_new, step, batch_size, state)
Pure rules round-trip an empty state = {} at zero cost. Every JAX-array
leaf of an eqx.Module is automatically a parameter visible to
jax.grad / optax. The architecture leaves the door open for
meta-learning step-size schedules end-to-end (the lax.scan fast path
satisfies the no-Python-branches requirement); meta-learning is not
a goal of any current implementation phase.
4. Penalised EM via the Shrinkage Combinator#
4.1 Theory recap (η-affine derivation)#
For an exponential family with log-partition \(\psi\), the penalised M-step solves
which is identical to ordinary MLE on the shrunk expectation
This generalises in two directions:
Per-field \(\tau\). Replace the scalar by a pytree \(\tau \in \eta\)-shape. Setting \(\tau\) non-zero only on the \(E[X X^\top/Y]\) leaf shrinks only \(\Sigma_{k+1}\).
Composition with running rules. Replace \(\hat\eta_k\) by \(\mathrm{base}(\eta_{k-1},\hat\eta_k)\) — Robbins–Monro, EWMA, sample-weighted, etc. Each leaf of η is updated by an independent affine map;
affine_combinealready handlestree.map-broadcast weights.
4.2 The combinator#
class Shrinkage(EtaUpdateRule):
"""η_t = (τ/(1+τ)) ⊙ η_0 + (1/(1+τ)) ⊙ base(η_{t-1}, η̂).
base : EtaUpdateRule (e.g. IdentityUpdate, RobbinsMonroUpdate)
eta0 : NormalMixtureEta or FactorMixtureStats (the prior)
tau : scalar | stats-shape pytree
"""
base: EtaUpdateRule
eta0: EtaLike
tau: Union[jax.Array, EtaLike]
def initial_state(self):
return self.base.initial_state()
Shrinkage lives one level above AffineRule: even though its action
on η is affine, its base may be an arbitrary EtaUpdateRule (including
non-affine predictors).
The combinator detects per-field τ via type(tau) is type(eta0), so the
same code works for both NormalMixtureEta and FactorMixtureStats.
4.3 Why a combinator (not ShrunkX copies)#
Composing shrinkage with a running rule by hand requires getting the
algebra right twice. The combinator does it in one place. The four
hypothetical ShrunkRobbinsMonro / ShrunkEWMA / ShrunkSampleWeighted
/ ShrunkIdentity classes collapse into one Shrinkage(base, eta0, tau)
with no behaviour lost.
4.4 Usage patterns#
Use case |
Construction |
|---|---|
Batch EM, uniform shrinkage |
|
Batch EM, Σ-only shrinkage |
|
Robbins–Monro online + shrinkage |
|
EWMA + shrinkage |
|
Sample-weighted + shrinkage |
|
The current ShrinkageUpdate class was removed (pre-1.0 rename); users
now construct Shrinkage(IdentityUpdate(), eta0, tau) explicitly.
When eta_update is set, BatchEMFitter switches off its lax.scan
fast path and uses the Python-loop path. The combinator wraps an
arbitrary base rule so we cannot make blanket JIT-friendliness
assumptions at the fitter level.
4.5 Shrinkage targets#
normix/fitting/shrinkage_targets.py provides four constructors. Each
returns a complete six-field NormalMixtureEta, even when the user
intends to shrink only one statistic — keeping the public contract
“η₀ is always a full prior” simple.
Builder |
Effect |
|---|---|
|
prior = current model’s η (L2 trust-region in η-space) |
|
\(\Sigma_0 = \sigma^2 I_d\) |
|
\(\Sigma_0 = \mathrm{diag}(s^2)\) |
|
arbitrary user-supplied PSD matrix |
The dispersion-substitution variants reuse the model’s
\((\mu, \gamma, p, a, b)\) to fill the other five fields, so the result
is still a coherent prior expectation parameter — the user’s per-field
tau then decides which fields are actually shrunk. Inverting a target
is type(model).from_expectation(eta0_isotropic(model, σ²)).
4.6 Choosing \(\tau\) and \(\Sigma_0\)#
Practical guidance (the scalar τ has the interpretation “the prior contributes the same weight as \(n_{\text{prior}} = \tau\,n\) pseudo-observations”):
Regime |
Suggested \(\tau\) |
Notes |
|---|---|---|
\(n \gg d\) |
\(0\) – \(0.05\) |
Shrinkage rarely helps |
\(n \approx d\) |
\(0.1\) – \(1\) |
Sample \(\Sigma\) ill-conditioned |
\(n < d\) |
\(\geq 1\) |
Required for invertibility |
Online streaming |
match base rule’s horizon |
E.g. \(\tau \sim 1/T\) |
Cross-validation on held-out \(\log f(x_{\text{test}})\) is the default;
running fits across a τ-grid maps to jax.vmap over the rule’s leaves.
5. Covariance Regularisations#
After each M-step the fitter optionally rescales the model. The
regularisation family is enumerated by
BatchEMFitter._REGULARIZATIONS:
Mode |
What it enforces |
Implementation |
|---|---|---|
|
identity |
— |
|
\(|\Sigma| = 1\) (the original GH convention) |
|
|
\(\log|\Sigma| = \log|\Sigma_0|\), the initial model’s log-determinant |
|
|
\(a = b = \sqrt{ab}\) on the GIG subordinator (orbit invariant) |
|
All four are reparametrisations: the joint density is unchanged. They move the model along the orbit \(Y \to s\,Y\), \(\Sigma \to \Sigma/s\), \(\gamma \to \gamma/s\), with the subordinator absorbing the scale.
5.1 Why three Σ-targeting modes#
Mode |
Use when |
|---|---|
|
EM should leave the scale alone (e.g. when downstream code reads \(\Sigma\) directly) |
|
Comparing across distributions where only the orbit matters; classical GH convention |
|
Running multiple distributions on the same data and you want their displayed \((a, b, \gamma)\) on a comparable scale (e.g. the SP500 study compares VG / NIG / NInvG / GH side by side; only |
5.2 Why 'a_eq_b' matters separately#
GH’s \((p, a, b, \gamma)\) has a one-parameter orbit \(s\) that rescales
\(a, b, \gamma\). The most useful canonical representative sets
\(a = b = \sqrt{ab}\) — the orbit-invariant pair \((a\cdot b)\) becomes
visible directly. NIG, where \(a = \lambda/\mu_{IG}^2\) and \(b = \lambda\),
has the same orbit; the canonical representative there is \(\mu_{IG} = 1\).
VG (Gamma subordinator, \(a = 0\)) and NInvG (InverseGamma, \(b = 0\)) are
already on a degenerate orbit — 'a_eq_b' is a no-op for those.
5.3 The _rescale / _build_rescaled pattern#
Each marginal owns the linear-algebra side of the rescale; the
subordinator-side is delegated to a per-subclass _build_rescaled:
class NormalMixture(MarginalMixture):
def _rescale(self, scale):
# Σ → Σ/s, γ → γ/s, then defer subordinator to subclass
L_new = self._joint.L_Sigma / jnp.sqrt(scale)
gamma_new = self._joint.gamma / scale
return self._build_rescaled(self._joint.mu, gamma_new, L_new, scale)
def regularize_det_sigma(self, target_log_det=0.0):
s = jnp.exp((self._joint.log_det_sigma() - target_log_det) / d)
return self._rescale(s)
def regularize_det_sigma_one(self):
return self.regularize_det_sigma(0.0)
def regularize_a_eq_b(self):
# default no-op; overridden by GH and NIG
return self
Per-subclass _build_rescaled knows how the subordinator absorbs
\(Y \to s\,Y\):
Subordinator |
Stored |
\(Y \to s\,Y\) rule |
|---|---|---|
Gamma(\(\alpha,\beta\)) (VG) |
|
\(\beta \to \beta/s\), \(\alpha\) unchanged |
InverseGamma(\(\alpha,\beta\)) (NInvG) |
|
\(\beta \to \beta\cdot s\), \(\alpha\) unchanged |
InverseGaussian(\(\mu_{IG},\lambda\)) (NIG) |
|
\(\mu_{IG} \to s\mu_{IG}\), \(\lambda \to s\lambda\) |
GIG(\(p, a, b\)) (GH) |
|
\(a \to a/s\), \(b \to b\cdot s\), \(p\) unchanged |
FactorNormalMixture follows the same pattern: \(F \to F/\sqrt{s}\),
\(D \to D/s\), plus the same subordinator rules.
The lesson learned during Phase 4: NIG’s _build_rescaled was originally
mu_ig/scale, lam/scale, which is wrong — the correct rule is
mu_ig·scale, lam·scale. The orbit invariance test in
tests/test_regularizations.py would have caught it (and now does).
6. Loop Dispatch (Batch)#
use_scan = (
algorithm == 'em'
and e_step_backend == 'jax'
and m_step_backend == 'jax'
and verbose <= 1
and eta_update is None
)
Otherwise → Python for loop (CPU backends, verbose tables, or any
eta_update).
IncrementalEMFitter runs lax.scan over minibatch steps when both
backends are 'jax' and verbose == 0; inner_iter > 1 nests
lax.fori_loop. RNG keys are pre-stacked (_materialize_incremental_subkeys).
7. Cross-References#
Solvers (η→θ): Solvers and Bessel Functions.
Theory: EM algorithm, Shrinkage, Factor analysis.