Solvers and Bessel Functions#

Scope. Why the Bregman solver is decoupled from ExponentialFamily, why Bessel evaluation has two backends and four numerical regimes, and why the EM hot path runs on a CPU/GPU hybrid.

Where things live. The backend × method matrix is in Exponential Family Core § 3. This file owns the deeper rationale.


1. Bregman Solver (fitting/solvers.py)#

The η→θ inversion is

\[ \theta_* = \arg\min_\theta\,[\,\psi(\theta) - \theta\cdot\eta\,]. \]

This problem is convex in \(\theta\) for any convex \(\psi\). The solver takes \(\psi\) as a generic f callable, not a log-partition method:

solve_bregman(f, eta, theta0, *, backend, method, bounds,
              grad_fn, hess_fn, max_steps, tol, verbose) -> BregmanResult

Decision

Choice

Rationale

Generic f (vs log_partition_fn)

generic

Bregman works for any convex function; the solver shouldn’t know about EFs

grad_fn + hess_fn separate

separate, both θ-space

Solver applies \(\theta \leftrightarrow \phi\) chain rule via jax.jacobian(to_theta); distributions never touch reparametrization

Result type

BregmanResult

Survives lax.scan (loose-typed Any scalars where needed)

Multi-start

orthogonal solve_bregman_multistart

Not baked into solver names; vmap for JAX, Python for for CPU

1.1 Bounds: reparam vs native#

Bound

Transform \(\theta \to \phi\)

Inverse \(\phi \to \theta\)

\((-\infty, 0)\)

\(\phi = \log(-\theta)\)

\(\theta = -\exp(\phi)\)

\((0, +\infty)\)

\(\phi = \log(\theta)\)

\(\theta = \exp(\phi)\)

\((\ell, h)\)

\(\phi = \mathrm{logit}((\theta-\ell)/(h-\ell))\)

\(\theta = \ell + (h-\ell)\sigma(\phi)\)

\((-\infty, +\infty)\)

\(\phi = \theta\)

\(\theta = \phi\)

backend='cpu' passes bounds directly to scipy.optimize.minimize (native L-BFGS-B box constraints). jaxopt.LBFGSB also supports bounds natively. Other JAX backends reparameterise.

1.2 Newton: hand-rolled, JIT-cached#

No JAX library provides a Newton minimizer that accepts a user-supplied Hessian:

Library

Newton

Custom Hessian

Box constraints

Optimistix

root-finding only

no

no

JAXopt

none

n/a

yes (LBFGSB only)

Optax

none

n/a

n/a

So we ship a hand-rolled Newton via lax.while_loop (true early stopping). For repeated warm-started solves on the same shape (the GIG EM hot path), make_jit_newton_solver(f, grad_fn, hess_fn, bounds) builds a @jax.jit-decorated specialised solve whose XLA cache survives across calls — critical, otherwise per-call retracing dominated GH EM time.

1.3 BregmanResult and lax.scan#

@dataclass(frozen=True)
class BregmanResult:
    theta: jax.Array
    fun:        Any   # may be JAX scalar (under scan) or Python float
    grad_norm:  Any
    num_steps:  int
    converged:  Any   # bool / 0-d JAX bool
    elapsed_time: float = 0.0

Loose Any typing is deliberate: forcing Python float/bool would raise ConcretizationTypeError when the result flows through lax.scan. verbose is threaded into the solver for printed diagnostics.


2. GIG η→θ#

The GIG Fisher information can be ill-conditioned (condition number up to \(10^{30}\)) when \(a \ll b\) or \(a \gg b\). Vanilla L-BFGS-B fails without rescaling.

2.1 η-rescaling#

Before optimization:

\[ s = \sqrt{\eta_2/\eta_3},\qquad \tilde\eta = \bigl(\eta_1 + \tfrac12\log(\eta_2/\eta_3),\,\sqrt{\eta_2\eta_3},\,\sqrt{\eta_2\eta_3}\bigr). \]

The rescaled GIG has \(\tilde a = \tilde b = \sqrt{ab}\) and a symmetric Fisher matrix. After solving for \(\tilde\theta\):

\[ \theta = (\tilde\theta_1,\;\tilde\theta_2/s,\;s\cdot\tilde\theta_3). \]

2.2 Solver choice in EM#

Default: backend='cpu', method='lbfgs'scipy.optimize.minimize with scipy.special.kve. This avoids GPU kernel dispatch overhead on a 3-D scalar problem.

For the warm-started Newton path (backend='jax', method='newton'), the cached _gig_jax_newton_jit keeps a single XLA executable across all warm-started solves.

When theta0 is not provided, GeneralizedInverseGaussian.from_expectation runs solve_bregman_multistart on the η-rescaled problem, with seeds from the Gamma / InverseGamma / InverseGaussian special cases.


3. Bessel Functions#

log_kv(v, z, backend='jax'|'cpu') is the unified entry point in normix/utils/bessel.py.

3.1 Pure-JAX backend (default)#

Four-regime dispatch via lax.cond (only the selected branch executes at runtime):

Regime

Trigger

Method

Hankel

\(z > \max(25, v^2/4)\)

DLMF 10.40.2 asymptotic

Olver

\(|v| > 25\), not Hankel

DLMF 10.41.3-4 uniform expansion

Small-\(z\)

\(z < 10^{-6}\), \(|v| > 0.5\)

leading asymptotic

Quadrature

otherwise

64-point Gauss–Legendre (Takekawa 2022)

Custom JVP via @jax.custom_jvp:

  • \(\partial/\partial z\): exact recurrence \(K'_\nu = -(K_{\nu-1} + K_{\nu+1})/2\).

  • \(\partial/\partial v\): central FD with \(\varepsilon = 10^{-5}\).

3.2 CPU backend (EM hot path)#

scipy.special.kve, fully vectorised NumPy. Not JIT-able. For large \(N\) a single kve C-call per element beats vmapping JAX’s lax.cond-dispatched implementation, which causes separate kernel launches per condition check.

3.3 Why backend is a Python-level string#

Resolved before JAX tracing begins. backend='jax' keeps the code traceable; backend='cpu' runs eagerly — appropriate because EM loops are already Python for loops at the CPU end.

3.4 CPU triad for Bessel-dependent distributions#

Design rule: any distribution that calls log_kv must override the Tier 3 CPU classmethods so the CPU solver path (solve_bregman(backend='cpu')) avoids JAX dispatch entirely. The three classmethods are _log_partition_cpu, _grad_log_partition_cpu, _hessian_log_partition_cpu — all numpy in / numpy out.

Distributions that don’t call log_kv (Gamma, InverseGamma, InverseGaussian) inherit the default wrappers. They pay nothing.


4. CPU/GPU Hybrid Backend#

EM timing on 468 stocks, 2552 observations (GH distribution):

Phase

JAX (GPU)

CPU hybrid

Speedup

E-step

~1.1 s

~0.07 s

~15×

M-step (GIG solve)

~5–7 s

~0.01 s

~500×

Hybrid strategy:

  • Quad forms (\(L_\Sigma^{-1}(x-\mu)\) etc.) stay in JAX (d-dimensional, GPU-friendly).

  • log_kv calls and GIG optimization move to CPU (backend='cpu').

NormalMixture.e_step(X, backend='cpu') is the hybrid path:

  • Quad forms (L⁻¹(x−μ), ‖z‖², ‖w‖²) stay in JAX vmap (GPU-friendly).

  • Bessel calls go to CPU via GIG.expectation_params_batch(backend='cpu').

  • _posterior_gig_params(z2, w2) lives on each JointNormalMixture subclass.

Default fitter settings reflect the hot path: e_step_backend='jax', m_step_backend='cpu', m_step_method='newton'.


5. Random Variate Generation#

PINV (Polynomial-Interpolation-based Numerical Inversion) in utils/rvs.py is pure JAX and works for any univariate log-kernel — no normalising constant needed:

  • build_pinv_table(log_kernel, mode, *, x_of_w, n_grid, tail_eps) builds a quantile table in JAX. Tail bisection via lax.fori_loop, trapezoidal CDF via jnp.cumsum.

  • rvs_pinv(key, u_grid, x_grid, n) samples via jnp.interp (GPU-friendly, vectorised).

Distributions on \((0,\infty)\) supply log_kernel(w) = log_prob(exp(w)) + w and seed the table at jnp.log(self.mode()). Closed-form mode() lives on the distribution itself (Gamma, InverseGamma, InverseGaussian, GIG). InverseGaussian.ppf and both GIG.cdf / GIG.ppf inline a single build_pinv_table call — log_prob is the only kernel.

GIG-specific sampling is inlined in distributions/generalized_inverse_gaussian.py:

  • _gig_rvs_devroye(key, p, a, b, n) — TDR on \(w = \log x\), batch-parallel (no while_loop).

  • _gig_rvs_pinv(key, u_grid, x_grid, n) — alias of rvs_pinv used by GIG.rvs(method='pinv').

Neither method evaluates the Bessel normalising constant.

Quantile Functions (cdf, ppf)#

  • Gamma.ppf and InverseGamma.ppf invert the regularised incomplete gamma via normix.utils.gammaincinv — a pure-JAX Newton iteration on jax.scipy.special.gammainc with a Wilson–Hilferty seed. This is the JAX analogue of scipy.special.gammaincinv.

  • InverseGaussian.ppf, GIG.cdf, GIG.ppf build a PINV table from log_prob (above).

  • Univariate Normal-mixture marginals (UnivariateVarianceGamma, UnivariateNormalInverseGamma, UnivariateNormalInverseGaussian, UnivariateGeneralizedHyperbolic) use the same generic PINV machinery with log_kernel(w) = self.log_prob(jnp.atleast_1d(w)), seeded at self.mean() (no closed-form mode for Bessel mixtures).


6. Cross-References#