Solvers and Bessel Functions#
Scope. Why the Bregman solver is decoupled from
ExponentialFamily, why Bessel evaluation has two backends and four numerical regimes, and why the EM hot path runs on a CPU/GPU hybrid.Where things live. The
backend × methodmatrix is in Exponential Family Core § 3. This file owns the deeper rationale.
1. Bregman Solver (fitting/solvers.py)#
The η→θ inversion is
This problem is convex in \(\theta\) for any convex \(\psi\). The solver
takes \(\psi\) as a generic f callable, not a log-partition method:
solve_bregman(f, eta, theta0, *, backend, method, bounds,
grad_fn, hess_fn, max_steps, tol, verbose) -> BregmanResult
Decision |
Choice |
Rationale |
|---|---|---|
Generic |
generic |
Bregman works for any convex function; the solver shouldn’t know about EFs |
|
separate, both θ-space |
Solver applies \(\theta \leftrightarrow \phi\) chain rule via |
Result type |
|
Survives |
Multi-start |
orthogonal |
Not baked into solver names; |
1.1 Bounds: reparam vs native#
Bound |
Transform \(\theta \to \phi\) |
Inverse \(\phi \to \theta\) |
|---|---|---|
\((-\infty, 0)\) |
\(\phi = \log(-\theta)\) |
\(\theta = -\exp(\phi)\) |
\((0, +\infty)\) |
\(\phi = \log(\theta)\) |
\(\theta = \exp(\phi)\) |
\((\ell, h)\) |
\(\phi = \mathrm{logit}((\theta-\ell)/(h-\ell))\) |
\(\theta = \ell + (h-\ell)\sigma(\phi)\) |
\((-\infty, +\infty)\) |
\(\phi = \theta\) |
\(\theta = \phi\) |
backend='cpu' passes bounds directly to scipy.optimize.minimize
(native L-BFGS-B box constraints). jaxopt.LBFGSB also supports bounds
natively. Other JAX backends reparameterise.
1.2 Newton: hand-rolled, JIT-cached#
No JAX library provides a Newton minimizer that accepts a user-supplied Hessian:
Library |
Newton |
Custom Hessian |
Box constraints |
|---|---|---|---|
Optimistix |
root-finding only |
no |
no |
JAXopt |
none |
n/a |
yes (LBFGSB only) |
Optax |
none |
n/a |
n/a |
So we ship a hand-rolled Newton via lax.while_loop (true early
stopping). For repeated warm-started solves on the same shape (the GIG
EM hot path), make_jit_newton_solver(f, grad_fn, hess_fn, bounds)
builds a @jax.jit-decorated specialised solve whose XLA cache survives
across calls — critical, otherwise per-call retracing dominated GH
EM time.
1.3 BregmanResult and lax.scan#
@dataclass(frozen=True)
class BregmanResult:
theta: jax.Array
fun: Any # may be JAX scalar (under scan) or Python float
grad_norm: Any
num_steps: int
converged: Any # bool / 0-d JAX bool
elapsed_time: float = 0.0
Loose Any typing is deliberate: forcing Python float/bool would
raise ConcretizationTypeError when the result flows through
lax.scan. verbose is threaded into the solver for printed
diagnostics.
2. GIG η→θ#
The GIG Fisher information can be ill-conditioned (condition number up to \(10^{30}\)) when \(a \ll b\) or \(a \gg b\). Vanilla L-BFGS-B fails without rescaling.
2.1 η-rescaling#
Before optimization:
The rescaled GIG has \(\tilde a = \tilde b = \sqrt{ab}\) and a symmetric Fisher matrix. After solving for \(\tilde\theta\):
2.2 Solver choice in EM#
Default: backend='cpu', method='lbfgs' — scipy.optimize.minimize
with scipy.special.kve. This avoids GPU kernel dispatch overhead on a
3-D scalar problem.
For the warm-started Newton path (backend='jax', method='newton'),
the cached _gig_jax_newton_jit keeps a single XLA executable across
all warm-started solves.
When theta0 is not provided, GeneralizedInverseGaussian.from_expectation
runs solve_bregman_multistart on the η-rescaled problem, with seeds
from the Gamma / InverseGamma / InverseGaussian special cases.
3. Bessel Functions#
log_kv(v, z, backend='jax'|'cpu') is the unified entry point in
normix/utils/bessel.py.
3.1 Pure-JAX backend (default)#
Four-regime dispatch via lax.cond (only the selected branch executes
at runtime):
Regime |
Trigger |
Method |
|---|---|---|
Hankel |
\(z > \max(25, v^2/4)\) |
DLMF 10.40.2 asymptotic |
Olver |
\(|v| > 25\), not Hankel |
DLMF 10.41.3-4 uniform expansion |
Small-\(z\) |
\(z < 10^{-6}\), \(|v| > 0.5\) |
leading asymptotic |
Quadrature |
otherwise |
64-point Gauss–Legendre (Takekawa 2022) |
Custom JVP via @jax.custom_jvp:
\(\partial/\partial z\): exact recurrence \(K'_\nu = -(K_{\nu-1} + K_{\nu+1})/2\).
\(\partial/\partial v\): central FD with \(\varepsilon = 10^{-5}\).
3.2 CPU backend (EM hot path)#
scipy.special.kve, fully vectorised NumPy. Not JIT-able. For large
\(N\) a single kve C-call per element beats vmapping JAX’s
lax.cond-dispatched implementation, which causes separate kernel
launches per condition check.
3.3 Why backend is a Python-level string#
Resolved before JAX tracing begins. backend='jax' keeps the code
traceable; backend='cpu' runs eagerly — appropriate because EM loops
are already Python for loops at the CPU end.
3.4 CPU triad for Bessel-dependent distributions#
Design rule: any distribution that calls log_kv must override the
Tier 3 CPU classmethods so the CPU solver path
(solve_bregman(backend='cpu')) avoids JAX dispatch entirely. The
three classmethods are _log_partition_cpu, _grad_log_partition_cpu,
_hessian_log_partition_cpu — all numpy in / numpy out.
Distributions that don’t call log_kv (Gamma, InverseGamma,
InverseGaussian) inherit the default wrappers. They pay nothing.
4. CPU/GPU Hybrid Backend#
EM timing on 468 stocks, 2552 observations (GH distribution):
Phase |
JAX (GPU) |
CPU hybrid |
Speedup |
|---|---|---|---|
E-step |
~1.1 s |
~0.07 s |
~15× |
M-step (GIG solve) |
~5–7 s |
~0.01 s |
~500× |
Hybrid strategy:
Quad forms (\(L_\Sigma^{-1}(x-\mu)\) etc.) stay in JAX (d-dimensional, GPU-friendly).
log_kvcalls and GIG optimization move to CPU (backend='cpu').
NormalMixture.e_step(X, backend='cpu') is the hybrid path:
Quad forms (
L⁻¹(x−μ),‖z‖²,‖w‖²) stay in JAXvmap(GPU-friendly).Bessel calls go to CPU via
GIG.expectation_params_batch(backend='cpu')._posterior_gig_params(z2, w2)lives on eachJointNormalMixturesubclass.
Default fitter settings reflect the hot path:
e_step_backend='jax', m_step_backend='cpu', m_step_method='newton'.
5. Random Variate Generation#
PINV (Polynomial-Interpolation-based Numerical Inversion) in
utils/rvs.py is pure JAX and works for any univariate log-kernel — no
normalising constant needed:
build_pinv_table(log_kernel, mode, *, x_of_w, n_grid, tail_eps)builds a quantile table in JAX. Tail bisection vialax.fori_loop, trapezoidal CDF viajnp.cumsum.rvs_pinv(key, u_grid, x_grid, n)samples viajnp.interp(GPU-friendly, vectorised).
Distributions on \((0,\infty)\) supply
log_kernel(w) = log_prob(exp(w)) + w and seed the table at
jnp.log(self.mode()). Closed-form mode() lives on the distribution
itself (Gamma, InverseGamma, InverseGaussian, GIG).
InverseGaussian.ppf and both GIG.cdf / GIG.ppf inline a single
build_pinv_table call — log_prob is the only kernel.
GIG-specific sampling is inlined in
distributions/generalized_inverse_gaussian.py:
_gig_rvs_devroye(key, p, a, b, n)— TDR on \(w = \log x\), batch-parallel (nowhile_loop)._gig_rvs_pinv(key, u_grid, x_grid, n)— alias ofrvs_pinvused byGIG.rvs(method='pinv').
Neither method evaluates the Bessel normalising constant.
Quantile Functions (cdf, ppf)#
Gamma.ppfandInverseGamma.ppfinvert the regularised incomplete gamma vianormix.utils.gammaincinv— a pure-JAX Newton iteration onjax.scipy.special.gammaincwith a Wilson–Hilferty seed. This is the JAX analogue ofscipy.special.gammaincinv.InverseGaussian.ppf,GIG.cdf,GIG.ppfbuild a PINV table fromlog_prob(above).Univariate
Normal-mixture marginals (UnivariateVarianceGamma,UnivariateNormalInverseGamma,UnivariateNormalInverseGaussian,UnivariateGeneralizedHyperbolic) use the same generic PINV machinery withlog_kernel(w) = self.log_prob(jnp.atleast_1d(w)), seeded atself.mean()(no closed-form mode for Bessel mixtures).
6. Cross-References#
Triad design: Exponential Family Core.
Theory: GIG distribution, EM algorithm.