Mixture Architecture#
Scope. Why normal-mixture distributions are split into a Joint and a Marginal class, why the joint is itself an
ExponentialFamily, and how the factor-analysis family slots in as a sibling without forcing the joint hierarchy to lie about its EF signature.Where things live. Class diagram and storage table are in the API Reference. EM details are in EM Framework.
1. Two Classes, Not One#
The Generalized Hyperbolic family is a normal variance–mean mixture:
JointNormalMixture(ExponentialFamily) f(x, y) — IS an exponential family
↑ (closed-form natural / sufficient
JointVarianceGamma, JointNIG, … statistics; exact M-step)
MarginalMixture(eqx.Module) abstract; fitter contract
↑ (Bessel-required marginal density)
NormalMixture, FactorNormalMixture f(x) = ∫ f(x,y) dy — NOT an EF
↑ on its own
VarianceGamma, GH, FactorGH, …
The joint is exponential — we exploit that for closed-form M-steps. The marginal needs numerical integration (Bessel functions) — it cannot be an EF. Mixing the two roles into one class breaks both stories.
2. D2 — Joint Classes Are Public#
JointVarianceGamma, JointNormalInverseGamma,
JointNormalInverseGaussian, JointGeneralizedHyperbolic, and the
abstract JointNormalMixture are first-class public ExponentialFamily
objects: exported from normix, documented alongside marginals, intended
for direct use where the joint law \(f(x,y)\) matters (simulation,
complete-data MLE, divergences, custom EM variants).
2.1 Observation vector convention#
sufficient_statistics, log_base_measure, and inherited log_prob /
pdf take a single flat array xy = jnp.concatenate([x, y]), with
\(x\) of shape (d,) and scalar \(y > 0\) last. This matches the
sufficient-statistic block \([\log y,\,1/y,\,y,\,\ldots]\) used internally.
For readability, log_prob_joint(x, y) and rvs(n, seed) -> (X, Y)
remain the preferred entry points when \(x\) and \(y\) are already separate.
2.2 GIG sign convention#
For GIG-based joints, natural parameters \(\theta_2,\theta_3\) must align with the GIG convention \(\theta_{\mathrm{GIG}} = [p-1,\,-b/2,\,-a/2]\) on \([\log y,\,1/y,\,y]\): scalar coefficients \(-(b/2 + \cdots)\) and \(-(a/2 + \cdots)\) on \(1/y\) and \(y\), not \(-b\) and \(-a\). Gamma and InverseGamma joints already match this limit; GH and NIG joints follow the same pattern.
2.3 from_natural for joints#
JointGeneralizedHyperbolic inverts the full joint family directly.
JointVarianceGamma, JointNormalInverseGamma, and
JointNormalInverseGaussian validate that theta lies on the
constrained subfamily before reconstructing classical parameters.
3. Marginal Mixture API#
MarginalMixture (in normix/mixtures/marginal.py) is the abstract
contract that fitters and downstream code depend on:
class MarginalMixture(eqx.Module):
# distribution surface
def log_prob(self, x: jax.Array) -> jax.Array: ...
def pdf(self, x: jax.Array) -> jax.Array: ...
def mean(self) -> jax.Array: ...
def cov(self) -> jax.Array: ...
def rvs(self, n: int, seed: int = 42) -> jax.Array: ...
def marginal_log_likelihood(self, X: jax.Array) -> jax.Array: ...
# EM hooks (stats type chosen by subclass)
def e_step(self, X, *, backend='jax'): ...
def m_step(self, eta, **kw) -> "MarginalMixture": ...
def m_step_normal(self, eta) -> "MarginalMixture": ...
def m_step_subordinator(self, eta, **kw) -> "MarginalMixture": ...
def compute_eta_from_model(self): ...
def em_convergence_params(self): ...
# convenience
def fit(self, X, **kw) -> "EMResult": ...
NormalMixture (full Σ) and FactorNormalMixture are the two
implementations. Fitters depend only on this ABC.
4. Parameter Facade on NormalMixture#
The marginal density is parameterised by the same classical tuple
\((\mu,\gamma,\Sigma,\text{subordinator})\) as the joint. Rather than
exposing this through model._joint.*, the marginal forwards them:
Forwarded |
Read |
Write (via |
|---|---|---|
|
property |
|
|
method |
|
|
method |
— |
Subordinator fields |
per-subclass property |
|
replace(**updates) validates keys against
_NORMAL_KEYS ∪ _subordinator_keys(), treats sigma as a write-only
alias for L_Sigma, and uses eqx.tree_at for an immutable update. The
storage stays in _joint; the facade is forwarders + immutable update —
no duplicated state.
vg2 = vg.replace(mu=new_mu)
vg3 = vg.replace(alpha=2.5, beta=0.5)
vg4 = vg.replace(sigma=sigma2 * jnp.eye(d))
gh2 = gh.replace(p=1.0, a=3.0, b=4.0)
The marginal mean() and cov() (=\(E[Y]\Sigma + \mathrm{Var}[Y]\gamma\gamma^\top\))
remain distinct from mu and sigma(): those are the conditional
Gaussian’s parameters, not the marginal’s moments.
5. from_expectation as the Canonical η→Model Map#
The closed-form M-step is exactly the conjugate-dual map η → θ on the joint exponential family. Both layers expose it:
JointNormalMixture.from_expectation(eta: NormalMixtureEta, **kw) -> JointNormalMixture
NormalMixture.from_expectation(eta: NormalMixtureEta, **kw) -> NormalMixture
The joint method dispatches on isinstance(eta, NormalMixtureEta):
pytree path — closed form (
_mstep_normal_paramsfor \(\mu,\gamma,\Sigma\);_subordinator_from_etafor the subordinator).flat
jax.Array— falls back to the inherited Bregman solver. Kept for API symmetry; not the recommended route for high-dimensional joint EFs.
Per-subclass plumbing collapses to two hooks:
Method |
Purpose |
|---|---|
|
Fit the subordinator from \((E[\log Y], E[1/Y], E[Y])\). |
|
One |
The instance m_step is a thin wrapper. Only GeneralizedHyperbolic
overrides m_step (and keeps a sanity-check m_step_subordinator):
the GIG solver needs a warm-start theta0 and a fall-back when it
wanders out of the sane region. VG / NInvG / NIG drop their per-subclass
m_step_subordinator overrides.
Inverting any prior or shrinkage target back to a model is a one-liner:
sigma_recovered = VarianceGamma.from_expectation(eta0_iso).sigma()
joint_recovered = JointVarianceGamma.from_expectation(eta0_iso)
6. Factor Analysis Family#
FactorNormalMixture is a sibling of NormalMixture, not a
subclass: it stores (μ, γ, F, D, subordinator) directly with
\(\Sigma = F F^\top + \mathrm{diag}(D)\). FactorVarianceGamma,
FactorNormalInverseGamma, FactorNormalInverseGaussian, and
FactorGeneralizedHyperbolic follow.
6.2 Woodbury everywhere#
All Σ-related linear algebra goes through Woodbury at \(O(d r^2 + r^3)\), never forming a dense \(d \times d\) solve:
def _M(self): return I_r + Fᵀ D⁻¹ F # (r, r)
def _solve(self, x): return D⁻¹ x − D⁻¹ F · M⁻¹ · Fᵀ D⁻¹ x # Σ⁻¹ x
def _quad_form(self,x): return x · _solve(x)
def _log_det_sigma(self): return Σ log D + slogdet(M)
def _beta(self): # β = Fᵀ Σ⁻¹, used per E-step pass
_beta is computed once per E-step pass, not per observation.
6.3 D positivity and F gauge#
D = jnp.maximum(D, D_FLOOR)after each M-step (D_FLOOR = 1e-8inutils/constants.py).Fis identifiable only up to a right \(r \times r\) orthogonal rotation, so(μ, γ, F, D)would never converge in norm. The convergence hookem_convergence_paramsreturns(μ, γ, Σ = F F^\top + \mathrm{diag}(D))— invariant to the rotation.
6.4 default_init for FactorGH#
Mirrors the standard GeneralizedHyperbolic.default_init: short EM
on FactorNIG, FactorVG, FactorNInvG (each converted into the
GIG embedding), plus a moment-based fallback (p=1, a=1, b=1). Picks
the candidate with the highest marginal log-likelihood.
JAX-native: no try/except, no Python branching on data values.
6.5 Why not a DispersionModel ABC yet#
A previous draft proposed DispersionModel with FullDispersion(L_Σ)
and FactorDispersion(F, D) implementations. We defer this:
Only two storage variants exist today.
An interface designed for two implementations adds indirection without observable benefit.
The third variant’s needs (SVD? banded? low-rank-plus-diagonal?) are unknown — committing to an interface now is the worst time.
When at least three variants are needed, the starting sketch is:
class DispersionModel(eqx.Module):
def solve(self, x): ...
def solve_matrix(self, X): ...
def quad_form(self, x): ...
def log_det(self): ...
def sigma(self): ...
def sample_noise(self, key, shape=()): ...
Until then, JointNormalMixture._quad_forms and
FactorNormalMixture._solve carry the linear algebra inline.
7. Sufficient Statistics in Theory Order#
Both stats classes use descriptive field names in theory order
(matching docs/theory/shrinkage.rst and docs/theory/factor_analysis.rst):
class NormalMixtureEta(eqx.Module):
E_inv_Y: # s_1 = E[Y⁻¹], scalar
E_Y: # s_2 = E[Y], scalar
E_log_Y: # s_3 = E[log Y], scalar
E_X: # s_4 = E[X], (d,)
E_X_inv_Y: # s_5 = E[X / Y], (d,)
E_XXT_inv_Y: # s_6 = E[X X^T/Y], (d, d)
class FactorMixtureStats(eqx.Module):
# fields s_1 … s_6 identical to NormalMixtureEta
E_XZT_inv_sqrtY: # s_7 = E[X Z^T Y^{-1/2}], (d, r)
E_Z_inv_sqrtY: # s_8 = E[Z Y^{-1/2}], (r,)
E_Z_sqrtY: # s_9 = E[Z Y^{1/2}], (r,)
E_ZZT: # s_10 = E[Z Z^T], (r, r)
Sharing the first six fields means shrinkage targets, weights, and tests written for the standard family port directly to the factor family.
8. Cross-References#
EM machinery: EM Framework.
Theory: GH distribution, Factor analysis, Shrinkage.