API Reference

Base Classes

class normix.exponential_family.ExponentialFamily[source]

Bases: Module

Abstract base class for exponential family distributions.

Concrete subclasses must implement:

_log_partition_from_theta, natural_params, sufficient_statistics, log_base_measure

abstractmethod natural_params()[source]

\(\theta\) from stored classical parameters.

Return type:

Array

abstractmethod static sufficient_statistics(x)[source]

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

abstractmethod static log_base_measure(x)[source]

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

log_partition()[source]

\(\psi(\theta)\) at current parameters.

Return type:

Array

expectation_params(backend='jax')[source]

\(\eta = \nabla\psi(\theta)\).

Parameters:

backend (str) – 'jax' (default, JIT-able) or 'cpu' (numpy/scipy).

Return type:

Array

fisher_information(backend='jax')[source]

\(I(\theta) = \nabla^2\psi(\theta)\).

Parameters:

backend (str) – 'jax' (default, JIT-able) or 'cpu' (numpy/scipy).

Return type:

Array

log_prob(x)[source]

\(\log p(x\mid\theta) = \log h(x) + \theta^\top t(x) - \psi(\theta)\), single observation.

Parameters:

x (Array)

Return type:

Array

pdf(x)[source]

p(x|θ), single observation. Batch via jax.vmap.

Parameters:

x (Array)

Return type:

Array

mean()[source]

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

std()[source]

\(\mathrm{Std}[X] = \sqrt{\mathrm{Var}[X]}\).

Return type:

Array

cdf(x)[source]

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

rvs(n, seed=42)[source]

Sample n observations via JAX PRNG (JIT-able).

Parameters:
Return type:

Array

squared_hellinger(other)[source]

Squared Hellinger distance \(H^2(p, q)\).

Default uses the general exponential-family formula via \(\psi\). Subclasses may override for numerically improved variants.

Parameters:

other (ExponentialFamily)

Return type:

Array

kl_divergence(other)[source]

KL divergence \(D_{\mathrm{KL}}(\mathrm{self} \| \mathrm{other})\).

Default uses the Bregman-divergence formula via \(\psi\) and \(\nabla\psi\). Subclasses may override.

Parameters:

other (ExponentialFamily)

Return type:

Array

classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

ExponentialFamily

classmethod bregman_divergence(theta, eta)[source]

Bregman divergence \(\psi(\theta) - \theta\cdot\eta\) (conjugate dual).

Minimising over \(\theta\) yields \(\nabla\psi(\theta^*) = \eta\), i.e. the natural parameters corresponding to expectation parameters \(\eta\).

Parameters:
Return type:

Array

classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='lbfgs', verbose=0)[source]

Construct from expectation parameters \(\eta\) by solving \(\nabla\psi(\theta) = \eta\).

Minimises the Bregman divergence \(\psi(\theta) - \theta\cdot\eta\) via solve_bregman. Subclasses can override for closed-form inverses.

Parameters:
  • backend (str) – 'jax' (default, JIT-able) or 'cpu' (scipy, more robust).

  • method (str) – 'lbfgs' (default), 'bfgs', or 'newton'.

  • verbose (int) – 0 = silent, >= 1 = print solver summary.

  • eta (Array)

  • theta0 (Array | None)

  • maxiter (int)

  • tol (float)

Return type:

ExponentialFamily

classmethod fit_mle(X, *, theta0=None, maxiter=500, tol=1e-10, verbose=0)[source]

MLE via exponential family identity: \(\hat\eta = \frac{1}{n}\sum_i t(x_i)\).

Batches over X using jax.vmap, then calls from_expectation(\(\hat\eta\)).

Parameters:
  • X (jax.Array) – (n, ...) array of observations.

  • theta0 (jax.Array, optional) – Initial natural parameters \(\theta_0\) for the \(\eta\to\theta\) solver.

  • maxiter (int) – Maximum iterations for the \(\eta\to\theta\) solver.

  • tol (float) – Convergence tolerance for the \(\eta\to\theta\) solver.

  • verbose (int) – 0 = silent, >= 1 = print solver summary.

Return type:

ExponentialFamily

fit(X, *, maxiter=500, tol=1e-10, verbose=0, **kwargs)[source]

Fit using self as initialization (warm start).

Computes \(\hat\eta = \frac{1}{n}\sum_i t(x_i)\) and solves from_expectation(\(\hat\eta\)) using self.natural_params() as the initial \(\theta_0\).

Parameters:
  • X (jax.Array) – (n, ...) array of observations.

  • maxiter (int) – Maximum iterations for the \(\eta\to\theta\) solver.

  • tol (float) – Convergence tolerance for the \(\eta\to\theta\) solver.

  • verbose (int) – 0 = silent, >= 1 = print solver summary.

Return type:

ExponentialFamily

classmethod default_init(X)[source]

Moment-based initialisation from data.

Computes \(\hat\eta = \frac{1}{n}\sum_i t(x_i)\) and inverts to get an initial model. For distributions with closed-form from_expectation (Gamma, InverseGamma, InverseGaussian), this gives the MLE directly.

Parameters:

X (Array)

Return type:

ExponentialFamily

Univariate Distributions

Gamma

Gamma distribution as an exponential family.

\[p(x \mid \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)} x^{\alpha-1} e^{-\beta x}, \quad x > 0\]

Exponential family structure:

\[h(x) = 1, \quad t(x) = [\log x,\; x]\]
\[\theta = [\alpha-1,\; -\beta], \quad \theta_1 > -1,\; \theta_2 < 0\]
\[\psi(\theta) = \log\Gamma(\theta_1+1) - (\theta_1+1)\log(-\theta_2)\]
\[\eta = [\psi(\alpha) - \log\beta,\; \alpha/\beta] \quad \text{(digamma, mean)}\]
class normix.distributions.gamma.Gamma(alpha, beta)[source]

Bases: ExponentialFamily

Gamma(\(\alpha\), \(\beta\)) distribution — shape \(\alpha > 0\), rate \(\beta > 0\).

Parameters:
alpha: Array
beta: Array
natural_params()[source]

\(\theta\) from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

cdf(x)[source]

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

rvs(n, seed=42)[source]

Sample n observations from \(\mathrm{Gamma}(\alpha, \beta)\) via JAX PRNG.

Parameters:
Return type:

Array

classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

Gamma

to_gig(*, boundary_eps=0.0)[source]

Exact embedding into the GIG family.

Gamma(\(\alpha\), \(\beta\)) is the \(b \to 0\) limit of GIG(\(p = \alpha,\; a = 2\beta,\; b\)). With boundary_eps = 0 the embedding stores b = 0 exactly; pass a small positive value to stay in the strict interior of GIG’s domain (matters only for downstream expectation_params calls on the lifted GIG).

Parameters:

boundary_eps (float)

classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, backend='jax', **kwargs)[source]

Closed-form \(\eta \to \theta\) via Newton on \(\psi(\alpha) - \log\alpha = \eta_1 - \log\eta_2\).

\(\eta = [E[\log X],\; E[X]]\)\(\alpha\) from digamma inversion, \(\beta = \alpha / \eta_2\).

Parameters:
  • backend (str) – 'jax' (default): lax.fori_loop Newton (JIT-compatible). 'cpu': scipy.special digamma/polygamma (no XLA tracing).

  • eta (Array)

  • maxiter (int)

  • tol (float)

Return type:

Gamma

Inverse Gamma

InverseGamma distribution as an exponential family.

\[p(x \mid \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)} x^{-\alpha-1} e^{-\beta/x}, \quad x > 0\]

Exponential family structure:

\[h(x) = 1, \quad t(x) = [-1/x,\; \log x]\]
\[\theta = [\beta,\; -(\alpha+1)], \quad \theta_1 > 0,\; \theta_2 < -1\]
\[\psi(\theta) = \log\Gamma(-\theta_2-1) - (-\theta_2-1)\log\theta_1 = \log\Gamma(\alpha) - \alpha\log\beta\]
\[\eta = [-\alpha/\beta,\; \log\beta - \psi(\alpha)] \quad (E[-1/X],\; E[\log X])\]
class normix.distributions.inverse_gamma.InverseGamma(alpha, beta)[source]

Bases: ExponentialFamily

InverseGamma(\(\alpha\), \(\beta\)) — shape \(\alpha > 0\), rate \(\beta > 0\).

Parameters:
alpha: Array
beta: Array
natural_params()[source]

\(\theta\) from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

cdf(x)[source]

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

rvs(n, seed=42)[source]

Sample n observations from \(\mathrm{InvGamma}(\alpha, \beta)\) via JAX PRNG.

Uses the relation: if \(X \sim \mathrm{Gamma}(\alpha, 1/\beta)\) then \(1/X \sim \mathrm{InvGamma}(\alpha, \beta)\).

Parameters:
Return type:

Array

classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

InverseGamma

to_gig(*, boundary_eps=0.0)[source]

Exact embedding into the GIG family.

InverseGamma(\(\alpha\), \(\beta\)) is the \(a \to 0\) limit of GIG(\(p = -\alpha,\; a,\; b = 2\beta\)). With boundary_eps = 0 the embedding stores a = 0 exactly; pass a small positive value to stay in the strict interior of GIG’s domain.

Parameters:

boundary_eps (float)

classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, backend='jax', **kwargs)[source]

\(\eta = [-\alpha/\beta,\; \log\beta - \psi(\alpha)]\).

\(\beta = \alpha / (-\eta_1)\); solve \(\psi(\alpha) - \log\alpha = -\eta_2 - \log(-\eta_1)\) via Newton.

Parameters:
  • backend (str) – 'jax' (default): lax.fori_loop Newton (JIT-compatible). 'cpu': scipy.special digamma/polygamma (no XLA tracing).

  • eta (Array)

  • maxiter (int)

  • tol (float)

Return type:

InverseGamma

Inverse Gaussian

Inverse Gaussian (Wald) distribution as an exponential family.

\[f(x \mid \mu, \lambda) = \sqrt{\frac{\lambda}{2\pi}}\, x^{-3/2} \exp\!\left(-\frac{\lambda(x-\mu)^2}{2\mu^2 x}\right), \quad x > 0\]

Exponential family structure:

\[h(x) = (2\pi)^{-1/2} x^{-3/2}, \quad t(x) = [x,\; 1/x]\]
\[\theta = \Bigl[-\tfrac{\lambda}{2\mu^2},\; -\tfrac{\lambda}{2}\Bigr], \quad \theta_1 < 0,\; \theta_2 < 0\]
\[ \begin{align}\begin{aligned}\psi(\theta) = -\tfrac{1}{2}\log(-2\theta_2) - \sqrt{(-2\theta_1)(-2\theta_2)}\\\bigl(\tfrac{1}{2}\log(2\pi)\ \text{is absorbed into}\ \log h(x) = -\tfrac{1}{2}\log(2\pi) - \tfrac{3}{2}\log x\bigr)\end{aligned}\end{align} \]
\[\eta = [E[X],\; E[1/X]] = [\mu,\; 1/\mu + 1/\lambda]\]
class normix.distributions.inverse_gaussian.InverseGaussian(mu, lam)[source]

Bases: ExponentialFamily

InverseGaussian(\(\mu\), \(\lambda\)) — mean \(\mu > 0\), shape \(\lambda > 0\).

Parameters:
mu: Array
lam: Array
natural_params()[source]

\(\theta\) from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

cdf(x)[source]

CDF of the Inverse Gaussian distribution (log-space stable).

\[F(x) = \Phi(t_1) + \exp\!\bigl(2\lambda/\mu + \log\Phi(-t_2)\bigr)\]

where \(t_1 = \sqrt{\lambda/x}\,(x/\mu - 1)\) and \(t_2 = \sqrt{\lambda/x}\,(x/\mu + 1)\). The second term uses log_ndtr to avoid overflow when \(\lambda/\mu\) is large.

Parameters:

x (Array)

Return type:

Array

rvs(n, seed=42)[source]

Sample n observations from \(\mathrm{InvGaussian}(\mu, \lambda)\) via JAX PRNG.

Uses the algorithm from Michael, Schucany & Haas (1976):

  1. \(\nu \sim \mathcal{N}(0,1)\), \(y = \nu^2\)

  2. \(x = \mu + \frac{\mu^2 y}{2\lambda} - \frac{\mu}{2\lambda}\sqrt{4\mu\lambda y + \mu^2 y^2}\)

  3. \(z \sim \mathrm{Uniform}(0,1)\); return \(x\) if \(z \le \mu/(\mu+x)\), else \(\mu^2/x\)

Uses jnp.where for vectorized branching over the full sample array.

Parameters:
Return type:

Array

classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

InverseGaussian

to_gig()[source]

Exact embedding into the GIG family.

InverseGaussian(\(\mu\), \(\lambda\)) is GIG with \(p = -1/2,\; a = \lambda/\mu^2,\; b = \lambda\). No boundary approximation: the embedding lands strictly inside GIG’s domain.

classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, **kwargs)[source]

Closed-form from \(\eta = [E[X],\; E[1/X]] = [\mu,\; 1/\mu + 1/\lambda]\):

\(\mu = \eta_1\), \(\lambda = 1/(\eta_2 - 1/\eta_1)\).

Parameters:
Return type:

InverseGaussian

Generalized Inverse Gaussian

Generalized Inverse Gaussian (GIG) distribution as an exponential family.

\[f(x \mid p, a, b) = \frac{(a/b)^{p/2}}{2 K_p(\sqrt{ab})} \, x^{p-1} \exp\!\left(-\frac{ax + b/x}{2}\right), \quad x > 0\]

Exponential family structure:

\[h(x) = 1, \quad t(x) = [\log x,\; 1/x,\; x]\]
\[\theta = [p-1,\; -b/2,\; -a/2], \quad \theta_2 \le 0,\; \theta_3 \le 0\]
\[\psi(\theta) = \log 2 + \log K_p(\sqrt{ab}) + \tfrac{p}{2}\log(b/a), \quad p = \theta_1+1,\; a = -2\theta_3,\; b = -2\theta_2\]
\[\eta = [E[\log X],\; E[1/X],\; E[X]]\]

Special cases:

  • \(b \to 0,\; p > 0\): GIG → \(\mathrm{Gamma}(p,\; a/2)\)

  • \(a \to 0,\; p < 0\): GIG → \(\mathrm{InvGamma}(-p,\; b/2)\)

  • \(p = -1/2\): GIG → InverseGaussian

η→θ rescaling (reduces Fisher condition number):

\[s = \sqrt{\eta_2/\eta_3}, \quad \tilde{\eta} = \bigl(\eta_1 + \tfrac{1}{2}\log(\eta_2/\eta_3),\; \sqrt{\eta_2\eta_3},\; \sqrt{\eta_2\eta_3}\bigr)\]

Solve \(\tilde{\eta} \to \tilde{\theta}\) with symmetric GIG (\(\tilde{a} = \tilde{b}\)), then unscale.

Log-Partition Triad Overrides:

  • _log_partition_from_theta : JAX, uses log_kv(backend='jax')

  • _grad_log_partition : analytical Bessel ratios (5 \(K_\nu\) calls)

  • _hessian_log_partition : analytical 7-Bessel Hessian in \(\theta\)-space

  • _log_partition_cpu : numpy + log_kv(backend='cpu')

  • _grad_log_partition_cpu : analytical Bessel ratios via scipy.kve

  • _hessian_log_partition_cpu: central FD on _log_partition_cpu

class normix.distributions.generalized_inverse_gaussian.GeneralizedInverseGaussian(p, a, b)[source]

Bases: ExponentialFamily

Generalized Inverse Gaussian distribution.

Stored: \(p\) (shape, any real), \(a > 0\), \(b > 0\).

Parameters:
p: Array
a: Array
b: Array
natural_params()[source]

\(\theta\) from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static expectation_params_batch(p, a, b, backend='jax')[source]

Vectorized η for arrays of (p, a, b), each shape (N,). Returns (N, 3) array where columns are [E_log_X, E_inv_X, E_X].

backend=’jax’ : vmap over scalar JAX grad backend=’cpu’ : vectorized scipy.kve (6 C-level array calls)

Parameters:

backend (str)

Return type:

Array

mean()[source]

\(E[X] = \eta_3\) from expectation parameters.

Return type:

Array

var()[source]

\(\mathrm{Var}[X] = \partial^2\psi/\partial\theta_3^2\) = Fisher information [2,2].

Return type:

Array

rvs(n, seed=42, method='devroye')[source]

Sample n observations from \(\mathrm{GIG}(p, a, b)\).

Parameters:
  • n (int) – Sample size.

  • seed (int) – Integer seed for JAX PRNG (or scipy random_state for 'scipy').

  • method (str) –

    Sampling algorithm:

    • 'devroye' (default) — Transformed density rejection (TDR) on \(\log(x)\), pure JAX, no Bessel functions.

    • 'pinv' — Numerical inverse CDF (CPU table build + JAX sampling), no Bessel. Best for large n with fixed parameters.

    • 'scipy'scipy.stats.geninvgauss (CPU, original fallback).

Return type:

Array

to_gamma()[source]

KL projection onto the Gamma family.

Minimises \(D_{\mathrm{KL}}(\mathrm{GIG}\,\|\,q)\) over \(q \in \mathrm{Gamma}\) by matching the Gamma sufficient statistics under the source GIG: \(E_{q^*}[\log X] = \eta_1,\; E_{q^*}[X] = \eta_3\). Solved via Gamma.from_expectation(). See docs/tech_notes/distribution_conversions.md for the derivation.

to_inverse_gamma()[source]

KL projection onto the InverseGamma family.

Matches \(E[-1/X] = -\eta_2,\; E[\log X] = \eta_1\).

to_inverse_gaussian()[source]

KL projection onto the InverseGaussian family.

Matches \(E[X] = \eta_3,\; E[1/X] = \eta_2\). The closed form \(\lambda = 1/(\eta_2 - 1/\eta_3)\) is well-defined whenever \(\eta_2 > 1/\eta_3\) (Jensen, true for every non-degenerate GIG); InverseGaussian.from_expectation() clamps near degenerate inputs.

classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

GeneralizedInverseGaussian

classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='newton', verbose=0)[source]

\(\eta \to \theta\) via \(\eta\)-rescaling + optimization.

Rescaling makes the Fisher matrix symmetric (\(\tilde{a} = \tilde{b}\)), reducing condition number by up to \(10^{30}\) for extreme \(a/b\) ratios.

Parameters:
  • theta0 (jax.Array, optional) – Warm-start point (required for JAX solvers; if None, uses multi-start CPU solver with Gamma/InvGamma/InvGauss seeds).

  • backend (str) – 'jax' (default, JIT-able) or 'cpu' (scipy, more robust).

  • method (str) – 'newton', 'lbfgs', or 'bfgs'.

  • eta (Array)

  • maxiter (int)

  • tol (float)

  • verbose (int)

Return type:

GeneralizedInverseGaussian

normix.distributions.generalized_inverse_gaussian.GIG

alias of GeneralizedInverseGaussian

Multivariate Distributions

Multivariate Normal distribution.

Stored: mu (d,), L_Sigma (d×d) lower-triangular Cholesky of \(\Sigma\). All linear algebra via L_Sigma — never form \(\Sigma^{-1}\) explicitly.

Exponential family structure

\[t(x) = [x,\; \operatorname{vec}(xx^\top)], \quad \theta = [\Sigma^{-1}\mu,\; -\tfrac{1}{2}\operatorname{vec}(\Sigma^{-1})], \quad \log h(x) = 0\]
\[\psi(\theta) = \tfrac{1}{2}\mu^\top\Sigma^{-1}\mu - \tfrac{1}{2}\log|\Sigma^{-1}| + \tfrac{d}{2}\log(2\pi)\]

where \(\operatorname{vec}\) uses row-major order (numpy.ndarray.ravel()).

All parametrization conversions are analytical (closed-form):

  • classical \(\leftrightarrow\) natural: natural_params / from_natural

  • natural \(\to\) expectation: _grad_log_partition (analytical override)

  • expectation \(\to\) classical: from_expectation

No Bregman solver is ever invoked. fit_mle computes \(\hat\eta = n^{-1}\sum_i t(x_i)\) and calls from_expectation (closed-form). log_prob overrides the inherited EF formula with a direct Cholesky computation for efficiency.

class normix.distributions.normal.MultivariateNormal(mu, L_Sigma)[source]

Bases: ExponentialFamily

Multivariate Normal distribution as an exponential family.

Parameters:
  • mu (jax.Array) – (d,) mean vector.

  • L_Sigma (jax.Array) – (d, d) lower-triangular Cholesky factor of \(\Sigma\).

Notes

Natural parameters: \(\theta = [\Sigma^{-1}\mu,\; -\tfrac{1}{2}\operatorname{vec}(\Sigma^{-1})]\). Sufficient statistics: \(t(x) = [x,\; \operatorname{vec}(xx^\top)]\). Log-partition: \(\psi(\theta) = \tfrac{1}{2}\mu^\top\Sigma^{-1}\mu - \tfrac{1}{2}\log|\Sigma^{-1}| + \tfrac{d}{2}\log(2\pi)\). Log base measure: \(\log h(x) = 0\).

mu: Array
L_Sigma: Array
natural_params()[source]

\(\theta = [\Sigma^{-1}\mu,\; -\tfrac{1}{2}\operatorname{vec}(\Sigma^{-1})]\)

Return type:

Array

static sufficient_statistics(x)[source]

\(t(x) = [x,\; \operatorname{vec}(xx^\top)]\).

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]

\(\log h(x) = 0\) (base measure is Lebesgue).

Parameters:

x (Array)

Return type:

Array

classmethod from_classical(mu, sigma)[source]

Construct from mean μ and covariance matrix Σ.

Return type:

MultivariateNormal

classmethod from_expectation(eta, **_kwargs)[source]

Closed-form inversion \(\eta \to \theta \to (\mu, L_\Sigma)\).

\(\eta = [E[X],\; \operatorname{vec}(E[XX^\top])]\), so

\[\mu = \eta_1, \qquad \Sigma = \operatorname{reshape}(\eta_2, d, d) - \mu\mu^\top\]
Parameters:
  • eta (jax.Array) – Expectation parameters of shape (d + d²,).

  • **_kwargs – Ignored (accepts backend, theta0, etc. for API compatibility).

Return type:

MultivariateNormal

classmethod from_natural(theta)[source]

Construct from natural parameters \(\theta = [\theta_1, \theta_2]\).

Recovers \(\Lambda = -2\,\mathrm{reshape}(\theta_2, d, d)\), then \(\mu = \Lambda^{-1}\theta_1\) and \(L_\Sigma = \mathrm{chol}(\Lambda^{-1})\).

Parameters:

theta (Array)

Return type:

MultivariateNormal

log_prob(x)[source]

\(\log f(x) = -\tfrac{d}{2}\log(2\pi) - \tfrac{1}{2}\log|\Sigma| - \tfrac{1}{2}\|L_\Sigma^{-1}(x-\mu)\|^2\).

Overrides the inherited EF formula for numerical efficiency (Cholesky-direct).

Parameters:

x (Array)

Return type:

Array

mean()[source]

\(E[X] = \mu\).

Return type:

Array

cov()[source]

\(\mathrm{Cov}[X] = \Sigma = L_\Sigma L_\Sigma^\top\).

Return type:

Array

rvs(n, seed=42)[source]

Draw n i.i.d. samples via JAX PRNG.

Returns:

Shape (n, d).

Return type:

jax.Array

Parameters:
sample(key, shape=())[source]

Draw samples via an explicit JAX key. Legacy API — prefer rvs.

Parameters:
Return type:

Array

property d: int

Dimensionality.

property dim: int

Dimensionality (alias for d).

property sigma: Array

Covariance matrix \(\Sigma = L_\Sigma L_\Sigma^\top\) (alias for cov()).

Mixture Base Classes

JointNormalMixture

JointNormalMixture — abstract exponential family for normal variance-mean mixtures.

Joint distribution \(f(x, y)\):

\[X \mid Y \sim \mathcal{N}(\mu + \gamma y,\; \Sigma y), \quad Y \sim \text{subordinator (GIG, Gamma, InvGamma, InvGaussian)}\]

Sufficient statistics:

\[t(x, y) = [\log y,\; 1/y,\; y,\; x,\; x/y,\; \mathrm{vec}(xx^\top/y)]\]

Natural parameters:

\[\theta_1 = p_{\mathrm{sub}} - 1 - d/2, \quad \theta_2 = -(b_{\mathrm{sub}}/2 + \tfrac{1}{2}\mu^\top\Sigma^{-1}\mu) < 0\]
\[\theta_3 = -(a_{\mathrm{sub}}/2 + \tfrac{1}{2}\gamma^\top\Sigma^{-1}\gamma) < 0, \quad \theta_4 = \Sigma^{-1}\gamma, \quad \theta_5 = \Sigma^{-1}\mu, \quad \theta_6 = -\tfrac{1}{2}\mathrm{vec}(\Sigma^{-1})\]

For a GIG subordinator, \(\theta_2,\theta_3\) combine \(-b/2,-a/2\) with the normal quadratic forms so that \(\theta^{\top} t\) matches \(-(a y + b/y)/2\) from \(f_Y\) plus the \(y\)-dependent terms from \(f_{X\mid Y}\).

Log-partition:

\[\psi = \psi_{\mathrm{sub}}(p, a, b) + \tfrac{1}{2}\log|\Sigma| + \mu^\top\Sigma^{-1}\gamma\]

Expectation parameters (EM E-step quantities):

\[\eta_1 = E[\log Y], \quad \eta_2 = E[1/Y], \quad \eta_3 = E[Y]\]
\[\eta_4 = E[X] = \mu + \gamma E[Y], \quad \eta_5 = E[X/Y] = \mu E[1/Y] + \gamma\]
\[\eta_6 = E[XX^\top/Y] = \Sigma + \mu\mu^\top E[1/Y] + \gamma\gamma^\top E[Y] + \mu\gamma^\top + \gamma\mu^\top\]

EM M-step closed-form (let \(D = 1 - E[1/Y] \cdot E[Y]\)):

\[\mu = \frac{E[X] - E[Y] E[X/Y]}{D}, \quad \gamma = \frac{E[X/Y] - E[1/Y] E[X]}{D}\]
\[\Sigma = E[XX^\top/Y] - E[X/Y]\mu^\top - \mu E[X/Y]^\top + E[1/Y]\mu\mu^\top - E[Y]\gamma\gamma^\top\]
class normix.mixtures.joint.JointNormalMixture(mu, gamma, L_Sigma)[source]

Bases: ExponentialFamily

Abstract joint distribution \(f(x, y)\) for normal variance-mean mixtures.

Stored: mu (d,), gamma (d,), L_Sigma (d×d lower Cholesky of \(\Sigma\)). Subordinator parameters defined by concrete subclasses.

Parameters:
mu: Array
gamma: Array
L_Sigma: Array
abstractmethod subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

property d: int
sigma()[source]

Covariance matrix \(\Sigma = L_\Sigma L_\Sigma^\top\).

Return type:

Array

log_det_sigma()[source]

\(\log|\Sigma| = 2\sum_i \log L_{ii}\), via Cholesky diagonal.

Return type:

Array

rvs(n, seed=42)[source]

Sample \((X, Y)\) from the joint distribution via JAX PRNG.

Returns:

  • X (jax.Array) – Shape (n, d).

  • Y (jax.Array) – Shape (n,).

Parameters:
Return type:

Tuple[Array, Array]

log_prob_joint(x, y)[source]

\(\log f(x, y) = \log f(x\mid y) + \log f_Y(y)\).

\[\log f(x\mid y) = -\tfrac{d}{2}\log(2\pi) - \tfrac{1}{2}\log|\Sigma| - \tfrac{d}{2}\log y - \frac{1}{2y}\|L^{-1}(x-\mu)\|^2 + \gamma^\top\Sigma^{-1}(x-\mu) - \tfrac{y}{2}\gamma^\top\Sigma^{-1}\gamma\]

\(\log f_Y(y)\) from subordinator.

Parameters:
Return type:

Array

conditional_expectations(x)[source]

Compute \(E[g(Y)\mid X=x]\) for EM E-step.

The posterior \(Y\mid X\) follows a GIG-like distribution:

\[p_{\mathrm{post}} = p_{\mathrm{eff}} - d/2, \quad a_{\mathrm{post}} = a_{\mathrm{eff}} + \gamma^\top\Sigma^{-1}\gamma, \quad b_{\mathrm{post}} = b_{\mathrm{eff}} + (x-\mu)^\top\Sigma^{-1}(x-\mu)\]

Returns dict with keys: E_log_Y, E_inv_Y, E_Y.

Parameters:

x (Array)

Return type:

Dict[str, Array]

static sufficient_statistics(xy)[source]

\(t(x,y) = [\log y,\; 1/y,\; y,\; x,\; x/y,\; \mathrm{vec}(xx^\top/y)]\).

Input: flat vector \([x_1,\ldots,x_d, y]\).

Parameters:

xy (Array)

Return type:

Array

static log_base_measure(xy)[source]

\(\log h(x)\) for a single unbatched observation.

Parameters:

xy (Array)

Return type:

Array

classmethod from_expectation(eta, **kwargs)[source]

Construct from expectation parameters \(\eta\).

Two input forms are supported:

  • NormalMixtureEta — the natural η pytree of the joint normal-variance-mean mixture; uses the closed-form M-step (\(\mu, \gamma, \Sigma\) analytical; subordinator via _subordinator_from_eta()).

  • flat jax.Array — the generic Bregman solver inherited from ExponentialFamily.

The pytree path is the canonical η→θ map for these distributions: it is exact (no Bregman iterations on the normal block) and uses the subordinator’s own from_expectation (closed-form for Gamma / InverseGamma / InverseGaussian; numerical for GIG).

Parameters:
  • eta (NormalMixtureEta or jax.Array) – Expectation parameters.

  • **kwargs – For the pytree path: forwarded to _subordinator_from_eta() (e.g. backend, method, maxiter, theta0 for warm-starting GIG). For the flat-array path: forwarded to the parent solver.

Return type:

JointNormalMixture

NormalMixture

Marginal mixture base classes.

MarginalMixture is the abstract interface fitters and the divergences module depend on. NormalMixture is the full-covariance implementation, owning a JointNormalMixture. A factor-analysis implementation lives in the sibling class FactorNormalMixture (see docs/design/mixtures.md § 6).

NormalMixture provides:

  • log_prob(x) — closed-form marginal log-density

  • e_step(X)jax.vmap() over conditional expectations

  • m_step(X, expectations) — returns new NormalMixture

  • fit(X, ...) — convenience EM fitting with multi-start

class normix.mixtures.marginal.MarginalMixture[source]

Bases: Module

Abstract interface for marginal mixtures used by the EM fitters.

Concrete subclasses pick the storage of the Gaussian dispersion (full Cholesky in NormalMixture, low-rank-plus-diagonal in FactorNormalMixture) and the type of the EM expectation pytree (NormalMixtureEta vs. FactorMixtureStats).

The fitter depends only on this contract; it does not know which storage form a model uses.

abstractmethod log_prob(x)[source]

Marginal \(\log f(x)\) for a single observation.

Parameters:

x (Array)

Return type:

Array

pdf(x)[source]

Marginal \(f(x)\) for a single observation.

Parameters:

x (Array)

Return type:

Array

marginal_log_likelihood(X)[source]

Mean log-likelihood over a dataset.

Parameters:

X (Array)

Return type:

Array

abstractmethod e_step(X, *, backend='jax')[source]

E-step: aggregated expectation parameters for the batch.

Parameters:
Return type:

Any

abstractmethod m_step(eta, **kwargs)[source]

Full M-step: updates all parameters; returns a new model.

Parameters:

eta (Any)

Return type:

MarginalMixture

abstractmethod m_step_normal(eta)[source]

M-step for normal parameters only (MCECM cycle 1).

Parameters:

eta (Any)

Return type:

MarginalMixture

abstractmethod m_step_subordinator(eta, **kwargs)[source]

M-step for subordinator parameters only (MCECM cycle 2).

Parameters:

eta (Any)

Return type:

MarginalMixture

abstractmethod compute_eta_from_model()[source]

Reconstruct the expectation pytree from the model’s own parameters.

Return type:

Any

abstractmethod em_convergence_params()[source]

Pytree whose leaf-wise change measures EM convergence.

Subordinator parameters are intentionally excluded (their solver has its own tolerance, and including them inflates iteration counts). For full-covariance models this is (mu, gamma, L_Sigma); for factor-analysis models it is (mu, gamma, F F^T + D) to sidestep the rotational gauge.

Return type:

Any

fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton')[source]

Fit using self as initialisation. Returns EMResult.

Parameters:
Return type:

EMResult

class normix.mixtures.marginal.NormalMixture(joint)[source]

Bases: MarginalMixture

Marginal \(f(x) = \int_0^\infty f(x,y)\,dy\) for a normal variance-mean mixture.

Not an exponential family. Owns a JointNormalMixture (which is). The classical parameters \((\mu, \gamma, \Sigma, \text{subordinator})\) are forwarded as read-only properties; use replace() to obtain a new model with updated parameters (modules are immutable).

property joint: JointNormalMixture
property d: int
property mu: Array

\(\mu\) — location parameter (forwarded from the joint).

property gamma: Array

\(\gamma\) — skewness parameter (forwarded from the joint).

property L_Sigma: Array

Lower Cholesky factor of \(\Sigma\) (forwarded from the joint).

sigma()[source]

Dispersion \(\Sigma = L_\Sigma L_\Sigma^\top\) (forwarded from the joint).

Distinct from cov(), which returns the marginal covariance \(E[Y]\,\Sigma + \mathrm{Var}[Y]\,\gamma\gamma^\top\).

Return type:

Array

log_det_sigma()[source]

\(\log|\Sigma|\) (forwarded from the joint).

Return type:

Array

mean()[source]

\(E[X] = \mu + \gamma E[Y]\).

Return type:

Array

cov()[source]

\(\mathrm{Cov}[X] = E[Y]\,\Sigma + \mathrm{Var}[Y]\,\gamma\gamma^\top\).

Return type:

Array

rvs(n, seed=42)[source]

Sample X from the marginal distribution.

Parameters:
Return type:

Array

squared_hellinger(other)[source]

Squared Hellinger distance via joint distributions (upper bound on marginal).

Parameters:

other (NormalMixture)

Return type:

Array

kl_divergence(other)[source]

KL divergence via joint distributions.

Parameters:

other (NormalMixture)

Return type:

Array

compute_eta_from_model()[source]

Reconstruct \(\eta\) from the model’s own parameters.

Uses the marginal expectations of the joint sufficient statistics:

\[\eta_4 = \mu + \gamma\,E[Y], \quad \eta_5 = \mu\,E[1/Y] + \gamma, \quad \eta_6 = \Sigma + \mu\mu^\top E[1/Y] + \gamma\gamma^\top E[Y] + \mu\gamma^\top + \gamma\mu^\top\]
Return type:

NormalMixtureEta

e_step(X, backend='jax')[source]

Full E-step: subordinator conditionals + batch aggregation.

Returns a NormalMixtureEta with the six aggregated expectation parameters.

Parameters:
  • X ((n, d) data array)

  • backend (str) – 'jax' (default): jax.vmap over conditional_expectations. 'cpu': quad forms in JAX + GIG Bessel on CPU.

Return type:

NormalMixtureEta

classmethod from_expectation(eta, **kwargs)[source]

Construct from expectation parameters \(\eta\).

Wraps JointNormalMixture.from_expectation(), which performs the exact closed-form M-step on the normal block and the subordinator’s from_expectation on the subordinator block.

This is the canonical η→model map: any prior or shrinkage target \(\eta_0\) can be inspected as a concrete model via cls.from_expectation(eta_0).sigma() etc.

Parameters:
  • eta (NormalMixtureEta) – Aggregated expectation parameters.

  • **kwargs – Forwarded to JointNormalMixture.from_expectation() (e.g. backend, method, maxiter, theta0).

Return type:

NormalMixture

m_step(eta, **kwargs)[source]

Full M-step: update normal params + subordinator from \(\eta\).

Equivalent to type(self).from_expectation(eta, **kwargs); self is only used to dispatch on the subclass. Subclasses with iterative subordinator solvers (e.g. GeneralizedHyperbolic) may override to inject a warm-start \(\theta_0\).

Parameters:

eta (NormalMixtureEta)

Return type:

NormalMixture

m_step_normal(eta)[source]

M-step for normal parameters only (MCECM Cycle 1).

Updates \(\mu, \gamma, \Sigma\); subordinator unchanged.

Parameters:

eta (NormalMixtureEta)

Return type:

NormalMixture

m_step_subordinator(eta, **kwargs)[source]

M-step for the subordinator only (MCECM Cycle 2).

Reads the subordinator-relevant fields of eta; normal parameters are read from self._joint and copied unchanged. Subclasses with iterative solvers may override to add warm-start or sanity-check fallbacks.

Parameters:

eta (NormalMixtureEta)

Return type:

NormalMixture

replace(**updates)[source]

Return a new model with selected parameters replaced.

Accepts any subset of:

  • normal parameters: mu, gamma, L_Sigma;

  • dispersion alias sigma — converted to L_Sigma via Cholesky (mutually exclusive with L_Sigma);

  • subordinator parameters declared by _subordinator_keys() (e.g. alpha, beta for VG / NInvG, mu_ig, lam for NIG, p, a, b for GH).

The actual storage lives in joint; this method does an immutable update via equinox.tree_at().

Examples

>>> vg2 = vg.replace(mu=new_mu)                    # change μ
>>> vg3 = vg.replace(alpha=2.5, beta=0.5)          # change subordinator
>>> vg4 = vg.replace(sigma=sigma2 * jnp.eye(d))    # set Σ via covariance
Return type:

NormalMixture

regularize_det_sigma(target_log_det=0.0)[source]

Rescale to enforce \(\log|\Sigma| = \mathrm{target\_log\_det}\).

Picks \(s = \exp((\log|\Sigma| - \tau)/d)\) and applies _rescale(). The default target_log_det = 0 recovers the \(|\Sigma| = 1\) convention; passing the log-determinant of an initial reference Σ implements the det_sigma_x family.

Parameters:

target_log_det (float)

Return type:

NormalMixture

regularize_det_sigma_one()[source]

Enforce \(|\Sigma| = 1\). Alias for regularize_det_sigma() with target_log_det = 0.

Return type:

NormalMixture

regularize_a_eq_b()[source]

Rescale subordinator so that \(a = b = \sqrt{ab}\) for GIG-parameterised families.

Default implementation is a no-op; override in subclasses with both \(a, b > 0\) (currently GH and NIG; VG and NInvG have a degenerate a=0 or b=0 and the default no-op is the right behaviour).

Return type:

NormalMixture

em_convergence_params()[source]

Pytree whose leaf-wise change measures EM convergence.

Returns (mu, gamma, L_Sigma). Subordinator parameters (p, a, b) are excluded — their solver has its own tolerance and including them inflates iteration counts.

classmethod default_init(X)[source]

Moment-based initialisation from data.

Returns a model with:

mu = sample mean gamma = zeros Sigma = empirical covariance (regularized) subordinator = distribution-specific defaults

Useful as a starting point for model.fit(X).

Parameters:

X (Array)

Return type:

NormalMixture

Mixture Distributions

Variance Gamma

Variance Gamma (VG) distribution.

Special case of GH with GIG → Gamma subordinator (\(b \to 0\), \(p > 0\)). \(Y \sim \mathrm{Gamma}(\alpha, \beta)\), i.e. GIG(\(p = \alpha\), \(a = 2\beta\), \(b \to 0\)).

Stored: \(\mu\), \(\gamma\), \(L_\Sigma\) (Cholesky of \(\Sigma\)), \(\alpha\) (shape), \(\beta\) (rate) of Gamma.

class normix.distributions.variance_gamma.JointVarianceGamma(mu, gamma, L_Sigma, alpha, beta)[source]

Bases: JointNormalMixture

Joint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{Gamma}(\alpha, \beta)\).

GIG limit: \(p = \alpha\), \(a = 2\beta\), \(b \to 0\).

Parameters:
alpha: Array
beta: Array
subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]

\(\theta = [\alpha-1-d/2,\; -\tfrac{1}{2}\mu^\top\Lambda\mu,\; -(\beta+\tfrac{1}{2}\gamma^\top\Lambda\gamma),\; \Lambda\gamma,\; \Lambda\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Lambda)]\)

(Gamma subordinator: \(p=\alpha\), \(a=2\beta\), \(b\to 0\)).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]
classmethod from_natural(theta)[source]

Recover classical parameters from \(\theta\).

\(\alpha = \theta_1 + 1 + d/2\), \(\beta = -\theta_3 - \gamma_{\mathrm{quad}}\).

Parameters:

theta (Array)

Return type:

JointVarianceGamma

to_joint_generalized_hyperbolic(*, boundary_eps=0.0)[source]

Exact embedding into JointGeneralizedHyperbolic.

Lifts the Gamma subordinator to GIG via Gamma.to_gig() and keeps the Normal block (\(\mu, \gamma, L_\Sigma\)) unchanged.

Parameters:

boundary_eps (float)

class normix.distributions.variance_gamma.VarianceGamma(joint)[source]

Bases: NormalMixture

Marginal Variance Gamma distribution f(x).

Parameters:

joint (JointVarianceGamma)

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]
Return type:

VarianceGamma

log_prob(x)[source]

Marginal VG log-density (own formula, no GH delegation).

\[f(x) \propto \left(\frac{q}{2c}\right)^{\nu/2} K_\nu\!\left(\sqrt{2qc}\right) \exp(\gamma^\top\Sigma^{-1}(x-\mu))\]

where \(\nu = \alpha - d/2\), \(c = \beta + \tfrac{1}{2}\gamma^\top\Lambda\gamma\), \(q = (x-\mu)^\top\Lambda(x-\mu)\).

Parameters:

x (Array)

Return type:

Array

property alpha: Array

\(\alpha\) — Gamma shape (forwarded from the joint).

property beta: Array

\(\beta\) — Gamma rate (forwarded from the joint).

fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]

Fit VG using EM or MCECM. Defaults to CPU E-step (faster than JAX vmap for the degenerate-GIG posterior arising from the Gamma subordinator).

to_generalized_hyperbolic(*, boundary_eps=0.0)[source]

Exact embedding into the GeneralizedHyperbolic family.

Parameters:

boundary_eps (float)

Normal Inverse Gamma

Normal-Inverse Gamma (NInvG) distribution.

Special case of GH with GIG → InverseGamma subordinator (\(a \to 0\), \(p < 0\)). \(Y \sim \mathrm{InvGamma}(\alpha, \beta)\), i.e. GIG(\(p = -\alpha\), \(a \to 0\), \(b = 2\beta\)).

Stored: \(\mu\), \(\gamma\), \(L_\Sigma\) (Cholesky of \(\Sigma\)), \(\alpha\) (shape), \(\beta\) (rate) of InverseGamma.

class normix.distributions.normal_inverse_gamma.JointNormalInverseGamma(mu, gamma, L_Sigma, alpha, beta)[source]

Bases: JointNormalMixture

Joint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{InvGamma}(\alpha, \beta)\).

GIG limit: \(p = -\alpha\), \(a \to 0\), \(b = 2\beta\).

Parameters:
alpha: Array
beta: Array
subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]

\(\theta = [-(\alpha+1)-d/2,\; -(\beta+\tfrac{1}{2}\mu^\top\Lambda\mu),\; -\tfrac{1}{2}\gamma^\top\Lambda\gamma,\; \Lambda\gamma,\; \Lambda\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Lambda)]\)

(InverseGamma subordinator: \(p=-\alpha\), \(a\to 0\), \(b=2\beta\)).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]
classmethod from_natural(theta)[source]

Recover classical parameters from \(\theta\).

\(\alpha = -(\theta_1 + d/2) - 1\), \(\beta = -\theta_2 - \mu_{\mathrm{quad}}\).

Parameters:

theta (Array)

Return type:

JointNormalInverseGamma

to_joint_generalized_hyperbolic(*, boundary_eps=0.0)[source]

Exact embedding into JointGeneralizedHyperbolic.

Lifts the InverseGamma subordinator to GIG via InverseGamma.to_gig() and keeps the Normal block unchanged.

Parameters:

boundary_eps (float)

class normix.distributions.normal_inverse_gamma.NormalInverseGamma(joint)[source]

Bases: NormalMixture

Marginal Normal-Inverse Gamma distribution f(x).

Parameters:

joint (JointNormalInverseGamma)

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]
log_prob(x)[source]

Marginal NInvG log-density (own formula, no GH delegation).

GIG params: \(p=-\alpha\), \(a=\gamma^\top\Lambda\gamma\), \(b=2\beta+Q(x)\). The normalising integral is \(2(b/a)^{p/2} K_p(\sqrt{ab})\).

Parameters:

x (Array)

Return type:

Array

property alpha: Array

\(\alpha\) — InverseGamma shape (forwarded from the joint).

property beta: Array

\(\beta\) — InverseGamma rate (forwarded from the joint).

fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]

Fit NInvG using EM or MCECM. Defaults to CPU E-step (faster than JAX vmap for the degenerate-GIG posterior arising from the InverseGamma subordinator).

to_generalized_hyperbolic(*, boundary_eps=0.0)[source]

Exact embedding into the GeneralizedHyperbolic family.

Parameters:

boundary_eps (float)

Normal Inverse Gaussian

Normal-Inverse Gaussian (NIG) distribution.

Special case of GH with GIG → InverseGaussian subordinator (\(p = -1/2\)). \(Y \sim \mathrm{InvGaussian}(\mu_{IG}, \lambda)\), i.e. GIG(\(p = -1/2\), \(a = \lambda/\mu_{IG}^2\), \(b = \lambda\)).

Stored: \(\mu\), \(\gamma\), \(L_\Sigma\) (Cholesky of \(\Sigma\)), \(\mu_{IG}\) (IG mean), \(\lambda\) (IG shape).

class normix.distributions.normal_inverse_gaussian.JointNormalInverseGaussian(mu, gamma, L_Sigma, mu_ig, lam)[source]

Bases: JointNormalMixture

Joint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{InvGaussian}(\mu_{IG}, \lambda)\).

Stored: \(\mu_{IG}\) (IG mean) and \(\lambda\) (IG shape) directly. GIG params: \(p = -1/2\), \(a = \lambda/\mu_{IG}^2\), \(b = \lambda\).

Parameters:
mu_ig: Array
lam: Array
subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]

\(\theta = [-3/2-d/2,\; -(\lambda/2+\tfrac{1}{2}\mu^\top\Lambda\mu),\; -(\lambda/(2\mu_{IG}^2)+\tfrac{1}{2}\gamma^\top\Lambda\gamma),\; \Lambda\gamma,\; \Lambda\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Lambda)]\)

where \(p=-1/2\), \(a=\lambda/\mu_{IG}^2\), \(b=\lambda\), aligned with GIG natural parameters on \([\log y,\,1/y,\,y]\).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, mu_ig, lam)[source]
classmethod from_natural(theta)[source]

Recover classical parameters from \(\theta\).

From \(b = -2\theta_2 - 2\mu_{\mathrm{quad}} = \lambda\) and \(a = -2\theta_3 - 2\gamma_{\mathrm{quad}} = \lambda/\mu_{IG}^2\), so \(\mu_{IG} = \sqrt{b/a}\).

Parameters:

theta (Array)

Return type:

JointNormalInverseGaussian

to_joint_generalized_hyperbolic()[source]

Exact embedding into JointGeneralizedHyperbolic.

Lifts the InverseGaussian subordinator to GIG via InverseGaussian.to_gig() (no boundary approximation) and keeps the Normal block unchanged.

class normix.distributions.normal_inverse_gaussian.NormalInverseGaussian(joint)[source]

Bases: NormalMixture

Marginal Normal-Inverse Gaussian distribution f(x).

Parameters:

joint (JointNormalInverseGaussian)

classmethod from_classical(*, mu, gamma, sigma, mu_ig, lam)[source]
log_prob(x)[source]

Marginal NIG log-density.

Uses \(K_{-1/2}(z) = \sqrt{\pi/(2z)}\,e^{-z}\) for the normalisation, leaving only one \(\log K_\nu\) call at order \(\nu = -1/2 - d/2\).

Parameters:

x (Array)

Return type:

Array

property mu_ig: Array

\(\mu_{IG}\) — InverseGaussian mean (forwarded from the joint).

property lam: Array

\(\lambda\) — InverseGaussian shape (forwarded from the joint).

regularize_a_eq_b()[source]

Rescale so \(a = b = \sqrt{ab}\).

For NIG, \(a = \lambda/\mu_{IG}^2,\;b = \lambda\), so \(s = \sqrt{a/b} = 1/\mu_{IG}\). After rescaling \(\mu_{IG}' = 1\), i.e. the InverseGaussian has unit mean.

Return type:

NormalInverseGaussian

to_generalized_hyperbolic()[source]

Exact embedding into the GeneralizedHyperbolic family.

No boundary approximation: NIG sits in the strict interior of GH’s parameter space (\(p = -1/2,\; a = \lambda/\mu_{IG}^2,\; b = \lambda\)).

Generalized Hyperbolic

Generalized Hyperbolic (GH) distribution.

Joint: \(X \mid Y \sim \mathcal{N}(\mu + \gamma y, \Sigma y)\), \(Y \sim \mathrm{GIG}(p, a, b)\).

Marginal: \(\mathrm{GH}(\mu, \gamma, \Sigma, p, a, b)\).

Marginal log-density (closed form via Bessel functions):

Let \(Q(x) = (x-\mu)^\top \Sigma^{-1}(x-\mu)\), \(A = a + \gamma^\top \Sigma^{-1} \gamma\).

\[\log f(x) = -\tfrac{d}{2}\log(2\pi) - \tfrac{1}{2}\log|\Sigma| + \tfrac{p}{2}(\log a - \log b) - \log K_p(\sqrt{ab}) + \tfrac{d/2-p}{2}\log\frac{A}{Q(x)+b} + \log K_{p-d/2}\!\left(\sqrt{A(Q(x)+b)}\right) + \gamma^\top \Sigma^{-1}(x - \mu)\]

Posterior: \(Y \mid X = x \sim \mathrm{GIG}(p - d/2,\; a + \gamma^\top\Sigma^{-1}\gamma,\; b + (x-\mu)^\top\Sigma^{-1}(x-\mu))\).

class normix.distributions.generalized_hyperbolic.JointGeneralizedHyperbolic(mu, gamma, L_Sigma, p, a, b)[source]

Bases: JointNormalMixture

Joint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{GIG}(p, a, b)\).

Stored: mu, gamma, L_Sigma (from JointNormalMixture) + p, a, b (GIG parameters).

Parameters:
p: Array
a: Array
b: Array
subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]

\(\theta = [p-1-d/2,\; -(b/2+\tfrac{1}{2}\mu^\top\Sigma^{-1}\mu),\; -(a/2+\tfrac{1}{2}\gamma^\top\Sigma^{-1}\gamma),\; \Sigma^{-1}\gamma,\; \Sigma^{-1}\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Sigma^{-1})]\)

The scalar coefficients on sufficient statistics \(1/y\) and \(y\) match the GIG convention \(\theta_{\mathrm{GIG}} = [p-1,\,-b/2,\,-a/2]\) on \(t_Y = [\log y,\,1/y,\,y]\).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, p, a, b)[source]

Construct from classical parameters.

classmethod from_natural(theta)[source]

Recover classical parameters from \(\theta\).

\(p = \theta_1 + 1 + d/2\), \(b = -2\theta_2 - 2\mu_{\mathrm{quad}}\), \(a = -2\theta_3 - 2\gamma_{\mathrm{quad}}\).

Parameters:

theta (Array)

Return type:

JointGeneralizedHyperbolic

to_joint_variance_gamma()[source]

KL projection onto JointVarianceGamma (Gamma subordinator).

Return type:

JointVarianceGamma

to_joint_normal_inverse_gamma()[source]

KL projection onto JointNormalInverseGamma.

Return type:

JointNormalInverseGamma

to_joint_normal_inverse_gaussian()[source]

KL projection onto JointNormalInverseGaussian.

Return type:

JointNormalInverseGaussian

class normix.distributions.generalized_hyperbolic.GeneralizedHyperbolic(joint)[source]

Bases: NormalMixture

Marginal Generalized Hyperbolic distribution \(f(x)\).

Stores a JointGeneralizedHyperbolic. Provides:

  • log_prob(x) — closed-form Bessel expression

  • e_step, m_step — for EM fitting

  • fit(X, ...) — convenience fitting method

Parameters:

joint (JointGeneralizedHyperbolic)

classmethod from_classical(*, mu, gamma, sigma, p, a, b)[source]
Return type:

GeneralizedHyperbolic

log_prob(x)[source]

Marginal \(\log f(x)\).

\[f(x) \propto \left(\frac{A}{Q(x)+b}\right)^{(d/2-p)/2} K_{p-d/2}\!\left(\sqrt{A(Q(x)+b)}\right) \exp\!\left(\gamma^\top\Sigma^{-1}(x-\mu)\right)\]

where \(Q(x) = (x-\mu)^\top\Sigma^{-1}(x-\mu)\), \(A = a + \gamma^\top\Sigma^{-1}\gamma\).

Parameters:

x (Array)

Return type:

Array

property p: Array

\(p\) — GIG order (forwarded from the joint).

property a: Array

\(a\) — GIG concentration (forwarded from the joint).

property b: Array

\(b\) — GIG concentration (forwarded from the joint).

m_step(eta, **kwargs)[source]

Full M-step with warm-started + sanity-checked subordinator solve.

Overrides the base cls.from_expectation(eta) because the GIG solver benefits substantially from warm-starting at the current \(\theta\) and from a fall-back when the solver wanders into unsane regions; both are managed by m_step_subordinator().

Return type:

GeneralizedHyperbolic

m_step_subordinator(eta, **kwargs)[source]

M-step for the subordinator only (MCECM Cycle 2).

Reads the subordinator-relevant fields of eta; normal parameters are read from self._joint and copied unchanged. Subclasses with iterative solvers may override to add warm-start or sanity-check fallbacks.

Return type:

GeneralizedHyperbolic

regularize_a_eq_b()[source]

Rescale so \(a = b = \sqrt{ab}\) (orbit invariant).

Picks \(s = \sqrt{a/b}\). Idempotent: applying twice leaves the model unchanged.

Return type:

GeneralizedHyperbolic

to_variance_gamma()[source]

KL projection onto the VarianceGamma family.

Return type:

VarianceGamma

to_normal_inverse_gamma()[source]

KL projection onto the NormalInverseGamma family.

Return type:

NormalInverseGamma

to_normal_inverse_gaussian()[source]

KL projection onto the NormalInverseGaussian family.

Return type:

NormalInverseGaussian

classmethod default_init(X)[source]

Warm-start from the best of NIG / VG / NInvG sub-model fits.

Runs 5 EM iterations (JAX backend, Newton method) for each special case, converts each to GH parametrisation, and selects the candidate with the highest marginal log-likelihood. A moment-based fallback (\(p=1, a=1, b=1\)) is included as a fourth candidate.

Fully JAX-native: no try/except, no Python branching on data values.

Parameters:

X (Array)

Return type:

GeneralizedHyperbolic

fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='det_sigma_one', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]

Fit GH distribution using EM or MCECM.

Defaults to CPU backends and det_sigma_one regularization (GH has scale non-identifiability requiring |Sigma| = 1).

Fitting

EM Fitters

EM fitters for normix distributions.

Model knows math, fitter knows iteration.

BatchEMFitter — standard batch EM with dual-loop architecture:

lax.scan (JIT-able) or Python for-loop (CPU-compatible)

IncrementalEMFitter — online / mini-batch EM with pluggable eta update rules

class normix.fitting.em.EMResult(model, log_likelihoods, param_changes, n_iter, converged, elapsed_time)[source]

Bases: object

Result of an EM fitting procedure.

Parameters:
model: Any
log_likelihoods: Array | None
param_changes: Array
n_iter: int
converged: bool | None
elapsed_time: float
class normix.fitting.em.BatchEMFitter(*, algorithm='em', max_iter=200, tol=0.001, verbose=0, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton', eta_update=None)[source]

Bases: object

Batch EM / MCECM algorithm with dual-loop architecture.

EM (default): E-step → M-step (all params) → regularize.

MCECM: E-step → M-step (normal params only) → regularize → E-step → M-step (subordinator only).

Convergence is measured by relative parameter change in the normal parameters (mu, gamma, L_Sigma), excluding subordinator (GIG) parameters.

Loop selection is automatic:
  • lax.scan when both backends are ‘jax’, verbose <= 1, algorithm=’em’, and no eta_update rule

  • Python for-loop otherwise

Parameters:
  • algorithm (str) – ‘em’ (default) or ‘mcecm’.

  • max_iter (int) – Maximum number of iterations.

  • tol (float) – Convergence tolerance on max relative parameter change.

  • verbose (int) – 0 = silent, 1 = summary, 2 = per-iteration table.

  • regularization (str) –

    Strategy applied after each M-step:

    • 'none' — no regularization.

    • 'det_sigma_one' — enforce \(|\Sigma| = 1\) (the original GH convention; equivalent to regularize_det_sigma(target_log_det=0)).

    • 'det_sigma_x' — enforce \(|\Sigma| = |\Sigma_0|\) where \(\Sigma_0\) is the dispersion of the initial model passed to fit(). Useful when comparing GH / FactorGH parameters against VG / NIG / NInvG, which leave \(|\Sigma|\) at the empirical scale.

    • 'a_eq_b' — rescale the GIG subordinator so \(a = b = \sqrt{ab}\). Trivial no-op for VG / NInvG / MultivariateNormal.

  • e_step_backend (str) – ‘jax’ (default) or ‘cpu’.

  • m_step_backend (str) – ‘jax’ or ‘cpu’ (default, faster for GIG).

  • m_step_method (str) – ‘newton’ (default), ‘lbfgs’, or ‘bfgs’.

  • eta_update (EtaUpdateRule or None) – Optional eta combination rule (e.g. Shrinkage(IdentityUpdate(), eta0, tau)). When set, the E-step output is transformed before the M-step.

fit(model, X)[source]

Run batch EM or MCECM. Auto-selects lax.scan or Python loop.

Parameters:
  • model (NormalMixture subclass (used as initial parameters))

  • X ((n, d) data array)

Return type:

EMResult with fitted model, convergence diagnostics, and timing.

class normix.fitting.em.IncrementalEMFitter(*, eta_update=None, batch_size=256, max_steps=200, inner_iter=1, verbose=0, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton')[source]

Bases: object

Incremental EM with pluggable eta update rules.

Replaces OnlineEMFitter and MiniBatchEMFitter. Processes data in random mini-batches, applies an EtaUpdateRule to combine the running \(\eta\) with each batch estimate, then M-steps on the combined \(\eta\).

Parameters:
  • eta_update (EtaUpdateRule) – How to combine running η with each batch estimate.

  • batch_size (int) – Observations per batch.

  • max_steps (int) – Number of batches to process (total budget).

  • inner_iter (int) – 1 = online (default); >1 = fine-tuning on each batch.

  • verbose (int) –

    0 = silent; 1 = periodic summary.

    Scan path (JAX backends only):

    verbose must be 0 so diagnostics do not rely on Python side effects each step.

  • regularization (str) – Same options as BatchEMFitter'none' | 'det_sigma_one' | 'det_sigma_x' | 'a_eq_b'.

  • e_step_backend (str) – Passed through to e_step / m_step.

  • m_step_backend (str) – Passed through to e_step / m_step.

  • m_step_method (str) – Passed through to e_step / m_step.

fit(model, X, *, key)[source]

Run incremental EM. Returns EMResult.

Parameters:
Return type:

EMResult

Solvers

Bregman divergence solvers.

Minimises f(θ) − θ·η over θ, where f is any convex function (e.g. the log-partition ψ for an exponential family). At the minimum ∇f(θ*) = η.

Public API

solve_bregman single starting point solve_bregman_multistart multiple starting points (vmap for JAX Newton;

for-loop for quasi-Newton and CPU)

bregman_objective utility: f(θ) − θ·η make_jit_newton_solver build a stable @jax.jit Newton solve specialised

to a fixed (f, grad_fn, hess_fn, bounds) — repeated calls with matching shapes/dtypes hit the XLA cache, avoiding the per-call re-tracing that solve_bregman incurs from fresh closures.

Backends × methods

backend=’jax’, method=’newton’ custom lax.scan Newton, autodiff or analytical Hessian backend=’jax’, method=’lbfgs’ jaxopt LBFGSB (bounds native) or LBFGS (reparam) backend=’jax’, method=’bfgs’ jaxopt BFGS with reparameterization for bounds backend=’cpu’, method=’lbfgs’ scipy L-BFGS-B backend=’cpu’, method=’bfgs’ scipy BFGS backend=’cpu’, method=’newton’ scipy trust-exact with Hessian

Gradient / Hessian sources

grad_fnθ → ∇f(θ). For backend=’cpu’: pure CPU (numpy) gradient.

For backend=’jax’, method=’newton’: JAX-traceable ∇ψ(θ). If None with backend=’cpu’: hybrid — jax.grad compiled → NumPy callbacks.

hess_fnθ → ∇²f(θ). Required for method=’newton’.

For backend=’cpu’: pure CPU (numpy) Hessian. For backend=’jax’, method=’newton’: JAX-traceable ∇²ψ(θ). If None with method=’newton’: jax.hessian of the full objective.

Both grad_fn and hess_fn operate in theta-space only. The solver handles all reparameterization internally via the chain rule.

class normix.fitting.solvers.BregmanResult(theta, fun, grad_norm, num_steps, converged, elapsed_time=0.0)[source]

Bases: object

Result of a Bregman divergence minimization.

Scalar fields accept both Python and JAX types so the result can live inside a lax.scan carry without concretization errors.

Parameters:
theta: Array
fun: Any
grad_norm: Any
num_steps: int
converged: Any
elapsed_time: float = 0.0
normix.fitting.solvers.bregman_objective(theta, eta, f)[source]

f(θ) − θ·η — convex dual whose minimum gives ∇f(θ*) = η.

Parameters:
Return type:

Array

normix.fitting.solvers.solve_bregman(f, eta, theta0, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]

Minimise f(θ) − θ·η over θ.

Parameters:
  • f (convex function θ → scalar (e.g. log-partition ψ))

  • eta (target vector (e.g. expectation parameters η))

  • theta0 (initial guess)

  • backend ('jax' (JIT-able) or 'cpu' (scipy, not JIT-able))

  • method ('lbfgs', 'bfgs', or 'newton')

  • bounds ((lower, upper) pair of JAX arrays, each shape (d,);) – None → unconstrained. For backend=’jax’: enforced via reparameterization. For backend=’cpu’: converted to scipy format internally.

  • max_steps (iteration budget)

  • tol (convergence tolerance on ‖∇f(θ) − η‖∞)

  • grad_fn (θ → ∇f(θ).) – For backend=’cpu’: must accept and return numpy arrays. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with backend=’cpu’: falls back to jax.grad (hybrid mode).

  • hess_fn (θ → ∇²f(θ).) – Required for method=’newton’. For backend=’cpu’: must accept numpy arrays and return numpy array. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with method=’newton’: jax.hessian of the full objective is used.

  • verbose (int) – 0 = silent, >= 1 = print summary after solve.

Return type:

BregmanResult

normix.fitting.solvers.solve_bregman_multistart(f, eta, theta0_batch, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]

Run solve_bregman from multiple starting points; return the best result.

Parameters:
  • theta0_batch ((K, dim) jax.Array for backend='jax', method='newton') – (parallel via vmap); list of arrays otherwise (sequential for-loop).

  • verbose (int) – 0 = silent, >= 1 = print summary.

  • f (Callable[[Array], Array])

  • eta (Array)

  • backend (str)

  • method (str)

  • bounds (Tuple[Array, Array] | None)

  • max_steps (int)

  • tol (float)

  • grad_fn (Callable | None)

  • hess_fn (Callable | None)

Return type:

BregmanResult

normix.fitting.solvers.make_jit_newton_solver(f, grad_fn, hess_fn, bounds=None)[source]

Build a @jax.jit-decorated Newton solver specialised to one problem.

The returned callable has signature solve(eta, theta0, max_steps=20, tol=1e-10) -> (theta_opt, fun, grad_norm, converged) where max_steps is a static argument (required by lax.scan).

All distribution-level inputs (f, grad_fn, hess_fn, bounds) are baked into the closure at construction time. Repeated calls with the same array shapes and dtypes therefore reuse the compiled XLA executable.

Use this in EM hot paths where solve_bregman would otherwise build a fresh Python closure on every call and force JAX to re-trace the same Newton kernel on each iteration.

Parameters:
  • f (convex objective ψ(θ) → scalar (must be JAX-traceable).)

  • grad_fn (∇ψ(θ) (d,).)

  • hess_fn (∇²ψ(θ) (d, d).)

  • bounds ((lower, upper) of shape (d,) each, or None.)

Returns:

Jit-compiled Newton solver. Returns a 4-tuple of JAX arrays (theta, fun, grad_norm, converged); wrap in BregmanResult externally if needed.

Return type:

Callable

Utilities

Bessel Functions

JAX-compatible log modified Bessel function of the second kind.

log_kv(v, z) = log K_v(z), fully pure-JAX with zero scipy callbacks.

Regime-specific methods, selected via lax.cond (only one branch executes):
  1. Hankel asymptotic (DLMF 10.40.2) — large z

  2. Olver uniform expansion (DLMF 10.41.4) — large v

  3. Small-z leading asymptotic (DLMF 10.30.2) — z → 0

  4. Gauss-Legendre quadrature (Takekawa 2022) — moderate z, v

Using lax.cond (not jnp.where) means only the selected branch executes at runtime. lax.cond requires scalar conditions, so the core scalar function _log_kv_scalar is vmapped over array inputs.

Custom JVP for full autodiff:
  • ∂/∂z : exact recurrence K’_v = −(K_{v−1}+K_{v+1})/2

  • ∂/∂v : central FD on log_kv itself (ε = BESSEL_EPS_V)

backend=’jax’ (default): pure-JAX, JIT-able, differentiable. backend=’cpu’ : scipy.special.kve, fully vectorized numpy.

Not JIT-able. Fast for EM hot path.

normix.utils.bessel.log_kv(v, z, backend='jax')[source]

\(\log K_v(z)\) — log modified Bessel function of the second kind.

Parameters:
  • v (scalar or array) – Order (any real; \(K_v = K_{-v}\)).

  • z (scalar or array) – Argument (must be > 0).

  • backend (str, optional) –

    'jax' (default) or 'cpu'.

    • 'jax': pure-JAX, lax.cond regime selection, custom JVP. JIT-able, differentiable. Default for log_prob, pdf, etc.

    • 'cpu': scipy.special.kve, fully vectorised NumPy. Not JIT-able. Fast for EM hot path.

Returns:

Same broadcast shape as (v, z).

Return type:

jax.Array

Examples

Evaluate at a single point (JAX backend, JIT-able):

>>> import jax.numpy as jnp
>>> from normix import log_kv
>>> float(log_kv(v=0.5, z=1.0))
-0.112...

CPU backend (uses scipy, faster for EM hot-paths):

>>> float(log_kv(v=0.5, z=1.0, backend='cpu'))
-0.112...

Symmetry \(K_v(z) = K_{-v}(z)\):

>>> abs(float(log_kv(0.5, 2.0)) - float(log_kv(-0.5, 2.0))) < 1e-10
True

Differentiable via JAX:

>>> import jax
>>> dlogkv_dz = jax.grad(lambda z: log_kv(0.5, z))(jnp.array(1.0))
>>> float(dlogkv_dz) < 0   # K_v decreases with z
True

Constants

Shared numerical constants for normix.