Exponential Family Core#
Scope. Why every distribution lives behind one log-partition function, how the triad pattern keeps gradient/Hessian and JAX/CPU evaluation consistent, and how these primitives plug into the η→θ Bregman solver.
Where things live. Module hierarchy and the full triad table are in the API Reference. This file records the rationale.
1. One Function, Three Parametrizations#
Each distribution is described by a single convex log-partition \(\psi(\theta)\). Everything else — densities, expectation parameters, Fisher information, Bregman inversion — is derived from it. Three parametrizations form a triangle:
classical (μ, σ², α, β, …) ←→ natural θ ←→ expectation η = ∇ψ(θ)
via from_classical / natural_params ↑
from_expectation
from_natural, natural_params, and from_expectation are the only three
constructors a user ever needs. from_expectation(η) runs the universal
Bregman solver min_θ [ψ(θ) − θ·η] from fitting/solvers.py.
Why it matters: distributions never need to know about parameter conversions other than their own classical mapping. The conversion graph collapses into derivatives and Bregman minimisation of a single function.
2. The Log-Partition Triad#
Three functions × two backends:
JAX (JIT-able) CPU (numpy/scipy)
log-partition _log_partition_from_theta _log_partition_cpu
gradient _grad_log_partition _grad_log_partition_cpu
Hessian _hessian_log_partition _hessian_log_partition_cpu
Tier |
Default |
Override when |
|---|---|---|
1: |
abstract |
always (it is the distribution) |
2: |
|
analytical formula avoids recompilation or saves Bessel calls |
3: |
numpy wrappers around the JAX version |
distribution calls |
2.1 Why @classmethod for the triad#
Tier 2 needs to differentiate the subclass’s _log_partition_from_theta,
not the parent’s. @classmethod gives cls, which routes to the right
implementation. @staticmethod cannot — it would always call the parent.
cls is not traced by JAX, so this is JIT-safe.
2.2 Why the CPU tier exists#
scipy.special.kve is fast for vectorised Bessel evaluation in the
EM hot path; a vmap over JAX’s lax.cond-dispatched log_kv triggers
separate kernel launches per regime check and is much slower. Distributions
that don’t call log_kv (Gamma, InverseGamma, InverseGaussian) inherit
the default np.asarray(jax_version(...)) wrappers — they pay nothing for
the CPU tier.
3. Bregman Solver Interface#
fitting/solvers.py exposes solve_bregman(f, η, θ₀, ...) for any convex
\(f\) — not just log-partitions. The interface uses θ-space only for
gradient/Hessian; bounds are handled by the solver via
_setup_reparam (φ ↔ θ). Distributions never see the reparametrization.
solve_bregman(f, eta, theta0, *, backend, method, bounds,
grad_fn, hess_fn, max_steps, tol, verbose)
Decision |
Choice |
Why |
|---|---|---|
|
separate |
Distributions don’t know about φ-space; solver applies chain rule via |
Newton implementation |
hand-rolled |
No JAX library provides a Newton minimizer that accepts a user-supplied Hessian (Optimistix Newton is root-finding only, JAXopt has no Newton) |
Multi-start |
|
Orthogonal wrapper, not baked into solver name |
Result type |
|
Survives |
Cached JIT |
module-level |
The GIG warm-start hot path otherwise retraces per call |
backend × method matrix:
backend |
method |
bounds |
What runs |
|---|---|---|---|
|
|
reparam |
hand-rolled |
|
|
native (LBFGSB) |
|
|
|
reparam |
|
|
|
reparam |
|
|
|
native |
|
|
|
none |
|
GIG warm-start hot path: backend='cpu', method='lbfgs' (scipy’s L-BFGS-B
scipy.kve) avoids GPU dispatch on this 3-D scalar problem.
4. Pre-1.0 Decisions Recorded Here#
4.1 D3 — MultivariateNormal as ExponentialFamily#
Promoted from a plain eqx.Module to a full ExponentialFamily. EF
structure:
Component |
Expression |
|---|---|
\(t(x)\) |
\([x,\;\operatorname{vec}(xx^\top)]\), shape \((d + d^2,)\) |
\(\theta\) |
\([\Sigma^{-1}\mu,\;-\tfrac12\operatorname{vec}(\Sigma^{-1})]\) |
\(\log h(x)\) |
\(0\) |
\(\psi(\theta)\) |
\(\tfrac12\theta_1^\top\Lambda^{-1}\theta_1 - \tfrac12\log|\Lambda| + \tfrac d2\log(2\pi)\), \(\Lambda = -2\,\mathrm{reshape}(\theta_2)\) |
vec uses row-major (ravel()) throughout. All conversions are
analytical (Tier 2 override) — no Bregman solver is ever invoked for
MVN. _log_partition_from_theta uses Cholesky of \(\Lambda\) for numerical
stability; log_prob overrides the inherited EF formula with a direct
Cholesky path (more efficient).
4.2 D4 — Keep jaxopt for now#
JAXopt is unmaintained (last release 0.8.3) and emits a DeprecationWarning
on import. We keep it: it is the only pure-JAX library with a native box-
constrained quasi-Newton (LBFGSB). Migrating to optax.scale_by_lbfgs
loses the convergence loop; optimistix.LBFGS lacks box constraints.
The deprecation warning is suppressed in normix/__init__.py. Migration
recipe (when JAXopt breaks): wrap optax.scale_by_lbfgs in
jax.lax.while_loop (~100 lines), then drop jaxopt.
4.3 Constraints handling#
jnp.maximum(x, LOG_EPS) (clamp), not paramax. Reasons:
The reparametrization we need is 8 lines, fully understandable.
EM does not need gradients through the constraints (it parameterises θ in the constrained space, not log-space).
No extra dependency.
jnp.where is preferred over lax.cond whenever possible: it is
vmap-compatible without changing the trace, and the clamping prevents
NaN gradients without branch divergence.
4.4 Module-level functions are forbidden#
Distribution behaviour lives on the class as @classmethod or
@staticmethod. No _helper(self.alpha, ...) module-level functions
that close over attributes — they leak the class API into module globals
and are hard to override.
5. Cross-References#
η→θ optimization for GIG: Solvers and Bessel Functions.
Theory: GIG distribution, EM algorithm.