API Reference

Contents

API Reference#

Base Classes#

class normix.exponential_family.ExponentialFamily[source]#

Bases: Module

Abstract base class for exponential family distributions.

Concrete subclasses must implement:

_log_partition_from_theta, natural_params, sufficient_statistics, log_base_measure

abstractmethod natural_params()[source]#

\(\theta\) from stored classical parameters.

Return type:

Array

abstractmethod static sufficient_statistics(x)[source]#

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

abstractmethod static log_base_measure(x)[source]#

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

log_partition()[source]#

\(\psi(\theta)\) at current parameters.

Return type:

Array

expectation_params(backend='jax')[source]#

\(\eta = \nabla\psi(\theta)\).

Parameters:

backend (str) – 'jax' (default, JIT-able) or 'cpu' (numpy/scipy).

Return type:

Array

fisher_information(backend='jax')[source]#

\(I(\theta) = \nabla^2\psi(\theta)\).

Parameters:

backend (str) – 'jax' (default, JIT-able) or 'cpu' (numpy/scipy).

Return type:

Array

log_prob(x)[source]#

\(\log p(x\mid\theta) = \log h(x) + \theta^\top t(x) - \psi(\theta)\), single observation.

Parameters:

x (Array)

Return type:

Array

pdf(x)[source]#

p(x|θ), single observation. Batch via jax.vmap.

Parameters:

x (Array)

Return type:

Array

mean()[source]#

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]#

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

std()[source]#

\(\mathrm{Std}[X] = \sqrt{\mathrm{Var}[X]}\).

Return type:

Array

cdf(x)[source]#

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

rvs(n, seed=42)[source]#

Sample n observations via JAX PRNG (JIT-able).

Parameters:
Return type:

Array

squared_hellinger(other)[source]#

Squared Hellinger distance \(H^2(p, q)\).

Default uses the general exponential-family formula via \(\psi\). Subclasses may override for numerically improved variants.

Parameters:

other (ExponentialFamily)

Return type:

Array

kl_divergence(other)[source]#

KL divergence \(D_{\mathrm{KL}}(\mathrm{self} \| \mathrm{other})\).

Default uses the Bregman-divergence formula via \(\psi\) and \(\nabla\psi\). Subclasses may override.

Parameters:

other (ExponentialFamily)

Return type:

Array

classmethod from_natural(theta)[source]#

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

ExponentialFamily

classmethod bregman_divergence(theta, eta)[source]#

Bregman divergence \(\psi(\theta) - \theta\cdot\eta\) (conjugate dual).

Minimising over \(\theta\) yields \(\nabla\psi(\theta^*) = \eta\), i.e. the natural parameters corresponding to expectation parameters \(\eta\).

Parameters:
Return type:

Array

classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='lbfgs', verbose=0)[source]#

Construct from expectation parameters \(\eta\) by solving \(\nabla\psi(\theta) = \eta\).

Minimises the Bregman divergence \(\psi(\theta) - \theta\cdot\eta\) via solve_bregman. Subclasses can override for closed-form inverses.

Parameters:
  • backend (str) – 'jax' (default, JIT-able) or 'cpu' (scipy, more robust).

  • method (str) – 'lbfgs' (default), 'bfgs', or 'newton'.

  • verbose (int) – 0 = silent, >= 1 = print solver summary.

  • eta (Array)

  • theta0 (Array | None)

  • maxiter (int)

  • tol (float)

Return type:

ExponentialFamily

classmethod fit_mle(X, *, theta0=None, maxiter=500, tol=1e-10, verbose=0)[source]#

MLE via exponential family identity: \(\hat\eta = \frac{1}{n}\sum_i t(x_i)\).

Batches over X using jax.vmap, then calls from_expectation(\(\hat\eta\)).

Parameters:
  • X (jax.Array) – (n, ...) array of observations.

  • theta0 (jax.Array, optional) – Initial natural parameters \(\theta_0\) for the \(\eta\to\theta\) solver.

  • maxiter (int) – Maximum iterations for the \(\eta\to\theta\) solver.

  • tol (float) – Convergence tolerance for the \(\eta\to\theta\) solver.

  • verbose (int) – 0 = silent, >= 1 = print solver summary.

Return type:

ExponentialFamily

fit(X, *, maxiter=500, tol=1e-10, verbose=0, **kwargs)[source]#

Fit using self as initialization (warm start).

Computes \(\hat\eta = \frac{1}{n}\sum_i t(x_i)\) and solves from_expectation(\(\hat\eta\)) using self.natural_params() as the initial \(\theta_0\).

Parameters:
  • X (jax.Array) – (n, ...) array of observations.

  • maxiter (int) – Maximum iterations for the \(\eta\to\theta\) solver.

  • tol (float) – Convergence tolerance for the \(\eta\to\theta\) solver.

  • verbose (int) – 0 = silent, >= 1 = print solver summary.

Return type:

ExponentialFamily

classmethod default_init(X)[source]#

Moment-based initialisation from data.

Computes \(\hat\eta = \frac{1}{n}\sum_i t(x_i)\) and inverts to get an initial model. For distributions with closed-form from_expectation (Gamma, InverseGamma, InverseGaussian), this gives the MLE directly.

Parameters:

X (Array)

Return type:

ExponentialFamily

Univariate Distributions#

Gamma#

Gamma distribution as an exponential family.

\[p(x \mid \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)} x^{\alpha-1} e^{-\beta x}, \quad x > 0\]

Exponential family structure:

\[h(x) = 1, \quad t(x) = [\log x,\; x]\]
\[\theta = [\alpha-1,\; -\beta], \quad \theta_1 > -1,\; \theta_2 < 0\]
\[\psi(\theta) = \log\Gamma(\theta_1+1) - (\theta_1+1)\log(-\theta_2)\]
\[\eta = [\psi(\alpha) - \log\beta,\; \alpha/\beta] \quad \text{(digamma, mean)}\]
class normix.distributions.gamma.Gamma(alpha, beta)[source]#

Bases: ExponentialFamily

Gamma(\(\alpha\), \(\beta\)) distribution — shape \(\alpha > 0\), rate \(\beta > 0\).

Parameters:
alpha: Array#
beta: Array#
natural_params()[source]#

\(\theta\) from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]#

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]#

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]#

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]#

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

mode()[source]#

Mode \((\alpha - 1)/\beta\) for \(\alpha \ge 1\).

For \(\alpha < 1\) the density diverges at \(x=0\) and the formula returns a non-positive value; we clip to LOG_EPS so the return value is always a valid positive sample location.

Return type:

Array

cdf(x)[source]#

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

ppf(q)[source]#

Quantile function \(F^{-1}(q) = \mathrm{gammaincinv}(\alpha, q) / \beta\).

Parameters:

q (Array)

Return type:

Array

rvs(n, seed=42)[source]#

Sample n observations from \(\mathrm{Gamma}(\alpha, \beta)\) via JAX PRNG.

Parameters:
Return type:

Array

classmethod from_natural(theta)[source]#

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

Gamma

to_gig(*, boundary_eps=0.0)[source]#

Exact embedding into the GIG family.

Gamma(\(\alpha\), \(\beta\)) is the \(b \to 0\) limit of GIG(\(p = \alpha,\; a = 2\beta,\; b\)). With boundary_eps = 0 the embedding stores b = 0 exactly; pass a small positive value to stay in the strict interior of GIG’s domain (matters only for downstream expectation_params calls on the lifted GIG).

Parameters:

boundary_eps (float)

classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, backend='jax', alpha_min=None, **kwargs)[source]#

Closed-form \(\eta \to \theta\) via Newton on \(\psi(\alpha) - \log\alpha = \eta_1 - \log\eta_2\).

\(\eta = [E[\log X],\; E[X]]\)\(\alpha\) from digamma inversion, \(\beta = \alpha / \eta_2\).

Parameters:
  • backend (str) – 'jax' (default): lax.fori_loop Newton (JIT-compatible). 'cpu': scipy.special digamma/polygamma (no XLA tracing).

  • alpha_min (float, optional) – Lower bound on the fitted shape \(\alpha\). When set, the digamma-inversion result is clamped to jnp.maximum(alpha, alpha_min) before \(\beta = \alpha/\eta_2\) is formed, so the projected estimate stays self-consistent. This is the opt-in VG estimand control (the ghyp “fix-\(\lambda\)” analogue): it restricts the estimator to a region where the VG marginal likelihood is bounded. None (default) leaves the estimate unconstrained. JIT-safe (no Python branch on traced values).

  • eta (Array)

  • maxiter (int)

  • tol (float)

Return type:

Gamma

Inverse Gamma#

InverseGamma distribution as an exponential family.

\[p(x \mid \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)} x^{-\alpha-1} e^{-\beta/x}, \quad x > 0\]

Exponential family structure:

\[h(x) = 1, \quad t(x) = [-1/x,\; \log x]\]
\[\theta = [\beta,\; -(\alpha+1)], \quad \theta_1 > 0,\; \theta_2 < -1\]
\[\psi(\theta) = \log\Gamma(-\theta_2-1) - (-\theta_2-1)\log\theta_1 = \log\Gamma(\alpha) - \alpha\log\beta\]
\[\eta = [-\alpha/\beta,\; \log\beta - \psi(\alpha)] \quad (E[-1/X],\; E[\log X])\]
class normix.distributions.inverse_gamma.InverseGamma(alpha, beta)[source]#

Bases: ExponentialFamily

InverseGamma(\(\alpha\), \(\beta\)) — shape \(\alpha > 0\), rate \(\beta > 0\).

Parameters:
alpha: Array#
beta: Array#
natural_params()[source]#

\(\theta\) from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]#

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]#

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]#

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]#

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

mode()[source]#

Mode \(\beta / (\alpha + 1)\) (closed form, valid for all \(\alpha > 0\)).

Return type:

Array

cdf(x)[source]#

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

ppf(q)[source]#

Quantile function \(F^{-1}(q) = \beta / \mathrm{gammaincinv}(\alpha, 1-q)\).

Follows from \(F(x) = 1 - P(\alpha, \beta/x) = q\).

Parameters:

q (Array)

Return type:

Array

rvs(n, seed=42)[source]#

Sample n observations from \(\mathrm{InvGamma}(\alpha, \beta)\) via JAX PRNG.

Uses the relation: if \(X \sim \mathrm{Gamma}(\alpha, 1/\beta)\) then \(1/X \sim \mathrm{InvGamma}(\alpha, \beta)\).

Parameters:
Return type:

Array

classmethod from_natural(theta)[source]#

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

InverseGamma

to_gig(*, boundary_eps=0.0)[source]#

Exact embedding into the GIG family.

InverseGamma(\(\alpha\), \(\beta\)) is the \(a \to 0\) limit of GIG(\(p = -\alpha,\; a,\; b = 2\beta\)). With boundary_eps = 0 the embedding stores a = 0 exactly; pass a small positive value to stay in the strict interior of GIG’s domain.

Parameters:

boundary_eps (float)

classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, backend='jax', **kwargs)[source]#

\(\eta = [-\alpha/\beta,\; \log\beta - \psi(\alpha)]\).

\(\beta = \alpha / (-\eta_1)\); solve \(\psi(\alpha) - \log\alpha = -\eta_2 - \log(-\eta_1)\) via Newton.

Parameters:
  • backend (str) – 'jax' (default): lax.fori_loop Newton (JIT-compatible). 'cpu': scipy.special digamma/polygamma (no XLA tracing).

  • eta (Array)

  • maxiter (int)

  • tol (float)

Return type:

InverseGamma

Inverse Gaussian#

Inverse Gaussian (Wald) distribution as an exponential family.

\[f(x \mid \mu, \lambda) = \sqrt{\frac{\lambda}{2\pi}}\, x^{-3/2} \exp\!\left(-\frac{\lambda(x-\mu)^2}{2\mu^2 x}\right), \quad x > 0\]

Exponential family structure:

\[h(x) = (2\pi)^{-1/2} x^{-3/2}, \quad t(x) = [x,\; 1/x]\]
\[\theta = \Bigl[-\tfrac{\lambda}{2\mu^2},\; -\tfrac{\lambda}{2}\Bigr], \quad \theta_1 < 0,\; \theta_2 < 0\]
\[ \begin{align}\begin{aligned}\psi(\theta) = -\tfrac{1}{2}\log(-2\theta_2) - \sqrt{(-2\theta_1)(-2\theta_2)}\\\bigl(\tfrac{1}{2}\log(2\pi)\ \text{is absorbed into}\ \log h(x) = -\tfrac{1}{2}\log(2\pi) - \tfrac{3}{2}\log x\bigr)\end{aligned}\end{align} \]
\[\eta = [E[X],\; E[1/X]] = [\mu,\; 1/\mu + 1/\lambda]\]
class normix.distributions.inverse_gaussian.InverseGaussian(mu, lam)[source]#

Bases: ExponentialFamily

InverseGaussian(\(\mu\), \(\lambda\)) — mean \(\mu > 0\), shape \(\lambda > 0\).

Parameters:
mu: Array#
lam: Array#
natural_params()[source]#

\(\theta\) from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]#

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]#

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]#

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]#

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

mode()[source]#

Mode \(\mu\bigl(\sqrt{1 + 9\mu^2/(4\lambda^2)} - 3\mu/(2\lambda)\bigr)\).

Closed-form maximiser of the IG density on \((0, \infty)\).

Return type:

Array

cdf(x)[source]#

CDF of the Inverse Gaussian distribution (log-space stable).

\[F(x) = \Phi(t_1) + \exp\!\bigl(2\lambda/\mu + \log\Phi(-t_2)\bigr)\]

where \(t_1 = \sqrt{\lambda/x}\,(x/\mu - 1)\) and \(t_2 = \sqrt{\lambda/x}\,(x/\mu + 1)\). The second term uses log_ndtr to avoid overflow when \(\lambda/\mu\) is large.

Parameters:

x (Array)

Return type:

Array

ppf(q)[source]#

Quantile function via a PINV table built from log_prob().

Trapezoidal-CDF lookup on \(w = \log x\), seeded at \(\log\) mode().

Parameters:

q (Array)

Return type:

Array

rvs(n, seed=42)[source]#

Sample n observations from \(\mathrm{InvGaussian}(\mu, \lambda)\) via JAX PRNG.

Uses the algorithm from Michael, Schucany & Haas (1976):

  1. \(\nu \sim \mathcal{N}(0,1)\), \(y = \nu^2\)

  2. \(x = \mu + \frac{\mu^2 y}{2\lambda} - \frac{\mu}{2\lambda}\sqrt{4\mu\lambda y + \mu^2 y^2}\)

  3. \(z \sim \mathrm{Uniform}(0,1)\); return \(x\) if \(z \le \mu/(\mu+x)\), else \(\mu^2/x\)

Uses jnp.where for vectorized branching over the full sample array.

Parameters:
Return type:

Array

classmethod from_natural(theta)[source]#

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

InverseGaussian

to_gig()[source]#

Exact embedding into the GIG family.

InverseGaussian(\(\mu\), \(\lambda\)) is GIG with \(p = -1/2,\; a = \lambda/\mu^2,\; b = \lambda\). No boundary approximation: the embedding lands strictly inside GIG’s domain.

classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, **kwargs)[source]#

Closed-form from \(\eta = [E[X],\; E[1/X]] = [\mu,\; 1/\mu + 1/\lambda]\):

\(\mu = \eta_1\), \(\lambda = 1/(\eta_2 - 1/\eta_1)\).

Parameters:
Return type:

InverseGaussian

Generalized Inverse Gaussian#

Generalized Inverse Gaussian (GIG) distribution as an exponential family.

\[f(x \mid p, a, b) = \frac{(a/b)^{p/2}}{2 K_p(\sqrt{ab})} \, x^{p-1} \exp\!\left(-\frac{ax + b/x}{2}\right), \quad x > 0\]

Exponential family structure:

\[h(x) = 1, \quad t(x) = [\log x,\; 1/x,\; x]\]
\[\theta = [p-1,\; -b/2,\; -a/2], \quad \theta_2 \le 0,\; \theta_3 \le 0\]
\[\psi(\theta) = \log 2 + \log K_p(\sqrt{ab}) + \tfrac{p}{2}\log(b/a), \quad p = \theta_1+1,\; a = -2\theta_3,\; b = -2\theta_2\]
\[\eta = [E[\log X],\; E[1/X],\; E[X]]\]

Special cases:

  • \(b \to 0,\; p > 0\): GIG → \(\mathrm{Gamma}(p,\; a/2)\)

  • \(a \to 0,\; p < 0\): GIG → \(\mathrm{InvGamma}(-p,\; b/2)\)

  • \(p = -1/2\): GIG → InverseGaussian

η→θ rescaling (reduces Fisher condition number):

\[s = \sqrt{\eta_2/\eta_3}, \quad \tilde{\eta} = \bigl(\eta_1 + \tfrac{1}{2}\log(\eta_2/\eta_3),\; \sqrt{\eta_2\eta_3},\; \sqrt{\eta_2\eta_3}\bigr)\]

Solve \(\tilde{\eta} \to \tilde{\theta}\) with symmetric GIG (\(\tilde{a} = \tilde{b}\)), then unscale.

Log-Partition Triad Overrides:

  • _log_partition_from_theta : JAX, uses log_kv(backend='jax')

  • _grad_log_partition : analytical Bessel ratios (5 \(K_\nu\) calls)

  • _hessian_log_partition : analytical 7-Bessel Hessian in \(\theta\)-space

  • _log_partition_cpu : numpy + log_kv(backend='cpu')

  • _grad_log_partition_cpu : analytical Bessel ratios via scipy.kve

  • _hessian_log_partition_cpu: central FD on _log_partition_cpu

class normix.distributions.generalized_inverse_gaussian.GeneralizedInverseGaussian(p, a, b)[source]#

Bases: ExponentialFamily

Generalized Inverse Gaussian distribution.

Stored: \(p\) (shape, any real), \(a > 0\), \(b > 0\).

Parameters:
p: Array#
a: Array#
b: Array#
to_gig()[source]#

Identity embedding into the GIG family.

GIG is already in GIG coordinates, so this returns self. It exists so that every subordinator family exposes a uniform to_gig() for the shared prior-to-posterior conjugacy map in the EM E-step (see _posterior_gig_params()).

Return type:

GeneralizedInverseGaussian

natural_params()[source]#

\(\theta\) from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]#

\(t(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]#

\(\log h(x)\) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static expectation_params_batch(p, a, b, backend='jax')[source]#

Vectorized η for arrays of (p, a, b), each shape (N,). Returns (N, 3) array where columns are [E_log_X, E_inv_X, E_X].

backend=’jax’ : vmap over scalar JAX grad backend=’cpu’ : vectorized scipy.kve (6 C-level array calls)

Parameters:

backend (str)

Return type:

Array

mean()[source]#

\(E[X] = \eta_3\) from expectation parameters.

Return type:

Array

var()[source]#

\(\mathrm{Var}[X] = \partial^2\psi/\partial\theta_3^2\) = Fisher information [2,2].

Return type:

Array

mode()[source]#

Interior mode \(\bigl((p-1) + \sqrt{(p-1)^2 + ab}\bigr) / a\).

Closed-form positive critical point of the log-density. For \(p \ge 1\) this is the unique global maximum on \((0,\infty)\); for \(p < 1\) the density diverges at 0 and this returns the interior local maximum.

Return type:

Array

cdf(x)[source]#

CDF \(F(x) = P(X \le x)\).

Trapezoidal CDF on a \(w = \log x\) grid built from log_prob(); seeded at \(\log\) mode(). In the degenerate regimes (\(\sqrt{ab} <\) GIG_DEGEN_THRESHOLD) delegates to the limiting Gamma / InverseGamma CDF for accuracy.

Parameters:

x (Array)

Return type:

Array

ppf(q)[source]#

Quantile (inverse CDF) \(F^{-1}(q)\) via the PINV table.

Parameters:

q (Array)

Return type:

Array

rvs(n, seed=42, method='devroye')[source]#

Sample n observations from \(\mathrm{GIG}(p, a, b)\).

Parameters:
  • n (int) – Sample size.

  • seed (int) – Integer seed for JAX PRNG (or scipy random_state for 'scipy').

  • method (str) –

    Sampling algorithm:

    • 'devroye' (default) — Transformed density rejection (TDR) on \(\log(x)\), pure JAX, no Bessel functions.

    • 'pinv' — Numerical inverse CDF (CPU table build + JAX sampling), no Bessel. Best for large n with fixed parameters.

    • 'scipy'scipy.stats.geninvgauss (CPU, original fallback).

Return type:

Array

to_gamma()[source]#

KL projection onto the Gamma family.

Minimises \(D_{\mathrm{KL}}(\mathrm{GIG}\,\|\,q)\) over \(q \in \mathrm{Gamma}\) by matching the Gamma sufficient statistics under the source GIG: \(E_{q^*}[\log X] = \eta_1,\; E_{q^*}[X] = \eta_3\). Solved via Gamma.from_expectation().

to_inverse_gamma()[source]#

KL projection onto the InverseGamma family.

Matches \(E[-1/X] = -\eta_2,\; E[\log X] = \eta_1\).

to_inverse_gaussian()[source]#

KL projection onto the InverseGaussian family.

Matches \(E[X] = \eta_3,\; E[1/X] = \eta_2\). The closed form \(\lambda = 1/(\eta_2 - 1/\eta_3)\) is well-defined whenever \(\eta_2 > 1/\eta_3\) (Jensen, true for every non-degenerate GIG); InverseGaussian.from_expectation() clamps near degenerate inputs.

classmethod from_natural(theta)[source]#

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

GeneralizedInverseGaussian

classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='newton', verbose=0)[source]#

\(\eta \to \theta\) via \(\eta\)-rescaling + optimization.

Rescaling makes the Fisher matrix symmetric (\(\tilde{a} = \tilde{b}\)), reducing condition number by up to \(10^{30}\) for extreme \(a/b\) ratios.

Parameters:
  • theta0 (jax.Array, optional) – Warm-start point (required for JAX solvers; if None, uses multi-start CPU solver with Gamma/InvGamma/InvGauss seeds).

  • backend (str) – 'jax' (default, JIT-able) or 'cpu' (scipy, more robust).

  • method (str) – 'newton', 'lbfgs', or 'bfgs'.

  • eta (Array)

  • maxiter (int)

  • tol (float)

  • verbose (int)

Return type:

GeneralizedInverseGaussian

normix.distributions.generalized_inverse_gaussian.GIG#

alias of GeneralizedInverseGaussian

Multivariate Distributions#

Multivariate Normal distribution.

Stored: mu (d,), L_Sigma (d×d) lower-triangular Cholesky of \(\Sigma\). All linear algebra via L_Sigma — never form \(\Sigma^{-1}\) explicitly.

Exponential family structure#

\[t(x) = [x,\; \operatorname{vec}(xx^\top)], \quad \theta = [\Sigma^{-1}\mu,\; -\tfrac{1}{2}\operatorname{vec}(\Sigma^{-1})], \quad \log h(x) = 0\]
\[\psi(\theta) = \tfrac{1}{2}\mu^\top\Sigma^{-1}\mu - \tfrac{1}{2}\log|\Sigma^{-1}| + \tfrac{d}{2}\log(2\pi)\]

where \(\operatorname{vec}\) uses row-major order (numpy.ndarray.ravel()).

All parametrization conversions are analytical (closed-form):

  • classical \(\leftrightarrow\) natural: natural_params / from_natural

  • natural \(\to\) expectation: _grad_log_partition (analytical override)

  • expectation \(\to\) classical: from_expectation

No Bregman solver is ever invoked. fit_mle computes \(\hat\eta = n^{-1}\sum_i t(x_i)\) and calls from_expectation (closed-form). log_prob overrides the inherited EF formula with a direct Cholesky computation for efficiency.

class normix.distributions.normal.MultivariateNormal(mu, L_Sigma)[source]#

Bases: ExponentialFamily

Multivariate Normal distribution as an exponential family.

Parameters:
  • mu (jax.Array) – (d,) mean vector.

  • L_Sigma (jax.Array) – (d, d) lower-triangular Cholesky factor of \(\Sigma\).

Notes

Natural parameters: \(\theta = [\Sigma^{-1}\mu,\; -\tfrac{1}{2}\operatorname{vec}(\Sigma^{-1})]\). Sufficient statistics: \(t(x) = [x,\; \operatorname{vec}(xx^\top)]\). Log-partition: \(\psi(\theta) = \tfrac{1}{2}\mu^\top\Sigma^{-1}\mu - \tfrac{1}{2}\log|\Sigma^{-1}| + \tfrac{d}{2}\log(2\pi)\). Log base measure: \(\log h(x) = 0\).

mu: Array#
L_Sigma: Array#
natural_params()[source]#

\(\theta = [\Sigma^{-1}\mu,\; -\tfrac{1}{2}\operatorname{vec}(\Sigma^{-1})]\)

Return type:

Array

static sufficient_statistics(x)[source]#

\(t(x) = [x,\; \operatorname{vec}(xx^\top)]\).

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]#

\(\log h(x) = 0\) (base measure is Lebesgue).

Parameters:

x (Array)

Return type:

Array

classmethod from_classical(mu, sigma)[source]#

Construct from mean μ and covariance matrix Σ.

Return type:

MultivariateNormal

classmethod from_expectation(eta, **_kwargs)[source]#

Closed-form inversion \(\eta \to \theta \to (\mu, L_\Sigma)\).

\(\eta = [E[X],\; \operatorname{vec}(E[XX^\top])]\), so

\[\mu = \eta_1, \qquad \Sigma = \operatorname{reshape}(\eta_2, d, d) - \mu\mu^\top\]
Parameters:
  • eta (jax.Array) – Expectation parameters of shape (d + d²,).

  • **_kwargs – Ignored (accepts backend, theta0, etc. for API compatibility).

Return type:

MultivariateNormal

classmethod from_natural(theta)[source]#

Construct from natural parameters \(\theta = [\theta_1, \theta_2]\).

Recovers \(\Lambda = -2\,\mathrm{reshape}(\theta_2, d, d)\), then \(\mu = \Lambda^{-1}\theta_1\) and \(L_\Sigma = \mathrm{chol}(\Lambda^{-1})\).

Parameters:

theta (Array)

Return type:

MultivariateNormal

log_prob(x)[source]#

\(\log f(x) = -\tfrac{d}{2}\log(2\pi) - \tfrac{1}{2}\log|\Sigma| - \tfrac{1}{2}\|L_\Sigma^{-1}(x-\mu)\|^2\).

Overrides the inherited EF formula for numerical efficiency (Cholesky-direct).

Parameters:

x (Array)

Return type:

Array

mean()[source]#

\(E[X] = \mu\).

Return type:

Array

cov()[source]#

\(\mathrm{Cov}[X] = \Sigma = L_\Sigma L_\Sigma^\top\).

Return type:

Array

rvs(n, seed=42)[source]#

Draw n i.i.d. samples via JAX PRNG.

Returns:

Shape (n, d).

Return type:

jax.Array

Parameters:
sample(key, shape=())[source]#

Draw samples via an explicit JAX key. Legacy API — prefer rvs.

Parameters:
Return type:

Array

property d: int#

Dimensionality.

property dim: int#

Dimensionality (alias for d).

property sigma: Array#

Covariance matrix \(\Sigma = L_\Sigma L_\Sigma^\top\) (alias for cov()).

Mixture Base Classes#

JointNormalMixture#

JointNormalMixture — abstract exponential family for normal variance-mean mixtures.

Joint distribution \(f(x, y)\):

\[X \mid Y \sim \mathcal{N}(\mu + \gamma y,\; \Sigma y), \quad Y \sim \text{subordinator (GIG, Gamma, InvGamma, InvGaussian)}\]

Sufficient statistics:

\[t(x, y) = [\log y,\; 1/y,\; y,\; x,\; x/y,\; \mathrm{vec}(xx^\top/y)]\]

Natural parameters:

\[\theta_1 = p_{\mathrm{sub}} - 1 - d/2, \quad \theta_2 = -(b_{\mathrm{sub}}/2 + \tfrac{1}{2}\mu^\top\Sigma^{-1}\mu) < 0\]
\[\theta_3 = -(a_{\mathrm{sub}}/2 + \tfrac{1}{2}\gamma^\top\Sigma^{-1}\gamma) < 0, \quad \theta_4 = \Sigma^{-1}\gamma, \quad \theta_5 = \Sigma^{-1}\mu, \quad \theta_6 = -\tfrac{1}{2}\mathrm{vec}(\Sigma^{-1})\]

For a GIG subordinator, \(\theta_2,\theta_3\) combine \(-b/2,-a/2\) with the normal quadratic forms so that \(\theta^{\top} t\) matches \(-(a y + b/y)/2\) from \(f_Y\) plus the \(y\)-dependent terms from \(f_{X\mid Y}\).

Log-partition:

\[\psi = \psi_{\mathrm{sub}}(p, a, b) + \tfrac{1}{2}\log|\Sigma| + \mu^\top\Sigma^{-1}\gamma\]

Expectation parameters (EM E-step quantities):

\[\eta_1 = E[\log Y], \quad \eta_2 = E[1/Y], \quad \eta_3 = E[Y]\]
\[\eta_4 = E[X] = \mu + \gamma E[Y], \quad \eta_5 = E[X/Y] = \mu E[1/Y] + \gamma\]
\[\eta_6 = E[XX^\top/Y] = \Sigma + \mu\mu^\top E[1/Y] + \gamma\gamma^\top E[Y] + \mu\gamma^\top + \gamma\mu^\top\]

EM M-step closed-form (let \(D = 1 - E[1/Y] \cdot E[Y]\)):

\[\mu = \frac{E[X] - E[Y] E[X/Y]}{D}, \quad \gamma = \frac{E[X/Y] - E[1/Y] E[X]}{D}\]
\[\Sigma = E[XX^\top/Y] - E[X/Y]\mu^\top - \mu E[X/Y]^\top + E[1/Y]\mu\mu^\top - E[Y]\gamma\gamma^\top\]
class normix.mixtures.joint.JointNormalMixture(mu, gamma, L_Sigma)[source]#

Bases: ExponentialFamily

Abstract joint distribution \(f(x, y)\) for normal variance-mean mixtures.

Stored: mu (d,), gamma (d,), L_Sigma (d×d lower Cholesky of \(\Sigma\)). Subordinator parameters defined by concrete subclasses.

Parameters:
mu: Array#
gamma: Array#
L_Sigma: Array#
abstractmethod subordinator()[source]#

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

property d: int#
sigma()[source]#

Covariance matrix \(\Sigma = L_\Sigma L_\Sigma^\top\).

Return type:

Array

log_det_sigma()[source]#

\(\log|\Sigma| = 2\sum_i \log L_{ii}\), via Cholesky diagonal.

Return type:

Array

rvs(n, seed=42)[source]#

Sample \((X, Y)\) from the joint distribution via JAX PRNG.

Returns:

  • X (jax.Array) – Shape (n, d).

  • Y (jax.Array) – Shape (n,).

Parameters:
Return type:

Tuple[Array, Array]

log_prob_joint(x, y)[source]#

\(\log f(x, y) = \log f(x\mid y) + \log f_Y(y)\).

\[\log f(x\mid y) = -\tfrac{d}{2}\log(2\pi) - \tfrac{1}{2}\log|\Sigma| - \tfrac{d}{2}\log y - \frac{1}{2y}\|L^{-1}(x-\mu)\|^2 + \gamma^\top\Sigma^{-1}(x-\mu) - \tfrac{y}{2}\gamma^\top\Sigma^{-1}\gamma\]

\(\log f_Y(y)\) from subordinator.

Parameters:
Return type:

Array

conditional_expectations(x)[source]#

Compute \(E[g(Y)\mid X=x]\) for the EM E-step.

The posterior \(Y\mid X=x\) is

\[\mathrm{GIG}\!\left(p - \tfrac{d}{2},\; a + \gamma^\top\Sigma^{-1}\gamma,\; b + (x-\mu)^\top\Sigma^{-1}(x-\mu)\right),\]

with the family-specific prior \((p, a, b)\) resolved by _posterior_gig_params(). Returns a dict with keys E_log_Y, E_inv_Y, E_Y.

Parameters:

x (Array)

Return type:

Dict[str, Array]

static sufficient_statistics(xy)[source]#

\(t(x,y) = [\log y,\; 1/y,\; y,\; x,\; x/y,\; \mathrm{vec}(xx^\top/y)]\).

Input: flat vector \([x_1,\ldots,x_d, y]\).

Parameters:

xy (Array)

Return type:

Array

static log_base_measure(xy)[source]#

\(\log h(x)\) for a single unbatched observation.

Parameters:

xy (Array)

Return type:

Array

classmethod from_expectation(eta, **kwargs)[source]#

Construct from expectation parameters \(\eta\).

Two input forms are supported:

  • NormalMixtureEta — the natural η pytree of the joint normal-variance-mean mixture; uses the closed-form M-step (\(\mu, \gamma, \Sigma\) analytical; subordinator via _subordinator_from_eta()).

  • flat jax.Array — the generic Bregman solver inherited from ExponentialFamily.

The pytree path is the canonical η→θ map for these distributions: it is exact (no Bregman iterations on the normal block) and uses the subordinator’s own from_expectation (closed-form for Gamma / InverseGamma / InverseGaussian; numerical for GIG).

Parameters:
  • eta (NormalMixtureEta or jax.Array) – Expectation parameters.

  • **kwargs – For the pytree path: forwarded to _subordinator_from_eta() (e.g. backend, method, maxiter, theta0 for warm-starting GIG). For the flat-array path: forwarded to the parent solver.

Return type:

JointNormalMixture

NormalMixture#

Marginal mixture base classes.

MarginalMixture is the abstract interface fitters and the divergences module depend on. NormalMixture is the full-covariance implementation, owning a JointNormalMixture. A factor-analysis implementation lives in the sibling class FactorNormalMixture (see docs/design/mixtures.md § 6).

NormalMixture provides:

  • log_prob(x) — closed-form marginal log-density

  • e_step(X)jax.vmap() over conditional expectations

  • m_step(X, expectations) — returns new NormalMixture

  • fit(X, ...) — convenience EM fitting with multi-start

class normix.mixtures.marginal.MarginalMixture[source]#

Bases: Module

Abstract interface for marginal mixtures used by the EM fitters.

Concrete subclasses pick the storage of the Gaussian dispersion (full Cholesky in NormalMixture, low-rank-plus-diagonal in FactorNormalMixture) and the type of the EM expectation pytree (NormalMixtureEta vs. FactorMixtureStats).

The fitter depends only on this contract; it does not know which storage form a model uses.

abstractmethod log_prob(x)[source]#

Marginal \(\log f(x)\) for a single observation.

Parameters:

x (Array)

Return type:

Array

pdf(x)[source]#

Marginal \(f(x)\) for a single observation.

Parameters:

x (Array)

Return type:

Array

marginal_log_likelihood(X)[source]#

Mean log-likelihood over a dataset.

Parameters:

X (Array)

Return type:

Array

abstractmethod e_step(X, *, backend='jax')[source]#

E-step: aggregated expectation parameters for the batch.

Parameters:
Return type:

Any

abstractmethod m_step(eta, **kwargs)[source]#

Full M-step: updates all parameters; returns a new model.

Parameters:

eta (Any)

Return type:

MarginalMixture

abstractmethod m_step_normal(eta)[source]#

M-step for normal parameters only (MCECM cycle 1).

Parameters:

eta (Any)

Return type:

MarginalMixture

abstractmethod m_step_subordinator(eta, **kwargs)[source]#

M-step for subordinator parameters only (MCECM cycle 2).

Parameters:

eta (Any)

Return type:

MarginalMixture

abstractmethod compute_eta_from_model()[source]#

Reconstruct the expectation pytree from the model’s own parameters.

Return type:

Any

abstractmethod em_convergence_params()[source]#

Pytree whose leaf-wise change measures EM convergence.

Subordinator parameters are intentionally excluded (their solver has its own tolerance, and including them inflates iteration counts). For full-covariance models this is (mu, gamma, L_Sigma); for factor-analysis models it is (mu, gamma, F F^T + D) to sidestep the rotational gauge.

Return type:

Any

fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton', alpha_min=None)[source]#

Fit using self as initialisation. Returns EMResult.

Parameters:
  • alpha_min (float or {'density', 'inverse_moment'}, optional) – Opt-in lower bound on the Gamma subordinator shape \(\alpha\) (Variance Gamma only — the unique family whose marginal likelihood is unbounded at \(x=\mu\)). Restricts the estimator to a region where that degeneracy cannot occur; the ghyp “fix-\(\lambda\)” analogue. None (default) leaves \(\alpha\) unconstrained. A float is used directly; the \(d\)-aware sentinels resolve to \(d/2 + \varepsilon\) ('density' — marginal density bounded) or \(d/2 + 1 + \varepsilon\) ('inverse_moment'\(E[1/Y\mid x]\) also bounded), with \(\varepsilon=\) ALPHA_MIN_MARGIN. Has no effect on NInvG / NIG / GH (their prior \(b>0\) keeps the likelihood bounded for every \(\alpha\)).

  • X (jax.Array)

  • algorithm (str)

  • verbose (int)

  • max_iter (int)

  • tol (float)

  • regularization (str)

  • e_step_backend (str)

  • m_step_backend (str)

  • m_step_method (str)

Return type:

EMResult

class normix.mixtures.marginal.NormalMixture(joint)[source]#

Bases: MarginalMixture

Marginal \(f(x) = \int_0^\infty f(x,y)\,dy\) for a normal variance-mean mixture.

Not an exponential family. Owns a JointNormalMixture (which is). The classical parameters \((\mu, \gamma, \Sigma, \text{subordinator})\) are forwarded as read-only properties; use replace() to obtain a new model with updated parameters (modules are immutable).

property joint: JointNormalMixture#
property d: int#
property mu: Array#

\(\mu\) — location parameter (forwarded from the joint).

property gamma: Array#

\(\gamma\) — skewness parameter (forwarded from the joint).

property L_Sigma: Array#

Lower Cholesky factor of \(\Sigma\) (forwarded from the joint).

sigma()[source]#

Dispersion \(\Sigma = L_\Sigma L_\Sigma^\top\) (forwarded from the joint).

Distinct from cov(), which returns the marginal covariance \(E[Y]\,\Sigma + \mathrm{Var}[Y]\,\gamma\gamma^\top\).

Return type:

Array

log_det_sigma()[source]#

\(\log|\Sigma|\) (forwarded from the joint).

Return type:

Array

project(w)[source]#

Return the univariate normal mixture \(w^\top X\) as a Univariate* instance.

Parameters:

w (Array)

Return type:

_UnivariateNormalMixtureMixin

mean()[source]#

\(E[X] = \mu + \gamma E[Y]\).

Return type:

Array

cov()[source]#

\(\mathrm{Cov}[X] = E[Y]\,\Sigma + \mathrm{Var}[Y]\,\gamma\gamma^\top\).

Return type:

Array

rvs(n, seed=42)[source]#

Sample X from the marginal distribution.

Parameters:
Return type:

Array

squared_hellinger(other)[source]#

Squared Hellinger distance via joint distributions (upper bound on marginal).

Parameters:

other (NormalMixture)

Return type:

Array

kl_divergence(other)[source]#

KL divergence via joint distributions.

Parameters:

other (NormalMixture)

Return type:

Array

compute_eta_from_model()[source]#

Reconstruct \(\eta\) from the model’s own parameters.

Uses the marginal expectations of the joint sufficient statistics:

\[\eta_4 = \mu + \gamma\,E[Y], \quad \eta_5 = \mu\,E[1/Y] + \gamma, \quad \eta_6 = \Sigma + \mu\mu^\top E[1/Y] + \gamma\gamma^\top E[Y] + \mu\gamma^\top + \gamma\mu^\top\]
Return type:

NormalMixtureEta

e_step(X, backend='jax')[source]#

Full E-step: subordinator conditionals + batch aggregation.

Returns a NormalMixtureEta with the six aggregated expectation parameters.

Parameters:
  • X ((n, d) data array)

  • backend (str) – 'jax' (default): jax.vmap over conditional_expectations. 'cpu': quad forms in JAX + GIG Bessel on CPU.

Return type:

NormalMixtureEta

classmethod from_expectation(eta, **kwargs)[source]#

Construct from expectation parameters \(\eta\).

Wraps JointNormalMixture.from_expectation(), which performs the exact closed-form M-step on the normal block and the subordinator’s from_expectation on the subordinator block.

This is the canonical η→model map: any prior or shrinkage target \(\eta_0\) can be inspected as a concrete model via cls.from_expectation(eta_0).sigma() etc.

Parameters:
  • eta (NormalMixtureEta) – Aggregated expectation parameters.

  • **kwargs – Forwarded to JointNormalMixture.from_expectation() (e.g. backend, method, maxiter, theta0).

Return type:

NormalMixture

m_step(eta, **kwargs)[source]#

Full M-step: update normal params + subordinator from \(\eta\).

Equivalent to type(self).from_expectation(eta, **kwargs); self is only used to dispatch on the subclass. Subclasses with iterative subordinator solvers (e.g. GeneralizedHyperbolic) may override to inject a warm-start \(\theta_0\).

Parameters:

eta (NormalMixtureEta)

Return type:

NormalMixture

m_step_normal(eta)[source]#

M-step for normal parameters only (MCECM Cycle 1).

Updates \(\mu, \gamma, \Sigma\); subordinator unchanged.

Parameters:

eta (NormalMixtureEta)

Return type:

NormalMixture

m_step_subordinator(eta, **kwargs)[source]#

M-step for the subordinator only (MCECM Cycle 2).

Reads the subordinator-relevant fields of eta; normal parameters are read from self._joint and copied unchanged. Subclasses with iterative solvers may override to add warm-start or sanity-check fallbacks.

Parameters:

eta (NormalMixtureEta)

Return type:

NormalMixture

replace(**updates)[source]#

Return a new model with selected parameters replaced.

Accepts any subset of:

  • normal parameters: mu, gamma, L_Sigma;

  • dispersion alias sigma — converted to L_Sigma via Cholesky (mutually exclusive with L_Sigma);

  • subordinator parameters declared by _subordinator_keys() (e.g. alpha, beta for VG / NInvG, mu_ig, lam for NIG, p, a, b for GH).

The actual storage lives in joint; this method does an immutable update via equinox.tree_at().

Examples

>>> vg2 = vg.replace(mu=new_mu)                    # change μ
>>> vg3 = vg.replace(alpha=2.5, beta=0.5)          # change subordinator
>>> vg4 = vg.replace(sigma=sigma2 * jnp.eye(d))    # set Σ via covariance
Return type:

NormalMixture

regularize_det_sigma(target_log_det=0.0)[source]#

Rescale to enforce \(\log|\Sigma| = \mathrm{target\_log\_det}\).

Picks \(s = \exp((\log|\Sigma| - \tau)/d)\) and applies _rescale(). The default target_log_det = 0 recovers the \(|\Sigma| = 1\) convention; passing the log-determinant of an initial reference Σ implements the det_sigma_x family.

Parameters:

target_log_det (float)

Return type:

NormalMixture

regularize_det_sigma_one()[source]#

Enforce \(|\Sigma| = 1\). Alias for regularize_det_sigma() with target_log_det = 0.

Return type:

NormalMixture

regularize_a_eq_b()[source]#

Rescale subordinator so that \(a = b = \sqrt{ab}\) for GIG-parameterised families.

Default implementation is a no-op; override in subclasses with both \(a, b > 0\) (currently GH and NIG; VG and NInvG have a degenerate a=0 or b=0 and the default no-op is the right behaviour).

Return type:

NormalMixture

em_convergence_params()[source]#

Pytree whose leaf-wise change measures EM convergence.

Returns (mu, gamma, L_Sigma). Subordinator parameters (p, a, b) are excluded — their solver has its own tolerance and including them inflates iteration counts.

classmethod default_init(X)[source]#

Moment-based initialisation from data.

Returns a model with:

mu = sample mean gamma = zeros Sigma = empirical covariance (regularized) subordinator = distribution-specific defaults

Useful as a starting point for model.fit(X).

Parameters:

X (Array)

Return type:

NormalMixture

Mixture Distributions#

Variance Gamma#

Variance Gamma (VG) distribution.

Special case of GH with GIG → Gamma subordinator (\(b \to 0\), \(p > 0\)). \(Y \sim \mathrm{Gamma}(\alpha, \beta)\), i.e. GIG(\(p = \alpha\), \(a = 2\beta\), \(b \to 0\)).

Stored: \(\mu\), \(\gamma\), \(L_\Sigma\) (Cholesky of \(\Sigma\)), \(\alpha\) (shape), \(\beta\) (rate) of Gamma.

class normix.distributions.variance_gamma.JointVarianceGamma(mu, gamma, L_Sigma, alpha, beta)[source]#

Bases: JointNormalMixture

Joint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{Gamma}(\alpha, \beta)\).

GIG limit: \(p = \alpha\), \(a = 2\beta\), \(b \to 0\).

Parameters:
alpha: Array#
beta: Array#
subordinator()[source]#

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]#

\(\theta = [\alpha-1-d/2,\; -\tfrac{1}{2}\mu^\top\Lambda\mu,\; -(\beta+\tfrac{1}{2}\gamma^\top\Lambda\gamma),\; \Lambda\gamma,\; \Lambda\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Lambda)]\)

(Gamma subordinator: \(p=\alpha\), \(a=2\beta\), \(b\to 0\)).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]#
classmethod from_natural(theta)[source]#

Recover classical parameters from \(\theta\).

\(\alpha = \theta_1 + 1 + d/2\), \(\beta = -\theta_3 - \gamma_{\mathrm{quad}}\).

Parameters:

theta (Array)

Return type:

JointVarianceGamma

to_joint_generalized_hyperbolic(*, boundary_eps=0.0)[source]#

Exact embedding into JointGeneralizedHyperbolic.

Lifts the Gamma subordinator to GIG via Gamma.to_gig() and keeps the Normal block (\(\mu, \gamma, L_\Sigma\)) unchanged.

Parameters:

boundary_eps (float)

class normix.distributions.variance_gamma.VarianceGamma(joint)[source]#

Bases: NormalMixture

Marginal Variance Gamma distribution f(x).

Parameters:

joint (JointVarianceGamma)

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]#
Return type:

VarianceGamma

log_prob(x)[source]#

Marginal VG log-density (own formula, no GH delegation).

\[f(x) \propto \left(\frac{q}{2c}\right)^{\nu/2} K_\nu\!\left(\sqrt{2qc}\right) \exp(\gamma^\top\Sigma^{-1}(x-\mu))\]

where \(\nu = \alpha - d/2\), \(c = \beta + \tfrac{1}{2}\gamma^\top\Lambda\gamma\), \(q = (x-\mu)^\top\Lambda(x-\mu)\).

Parameters:

x (Array)

Return type:

Array

property alpha: Array#

\(\alpha\) — Gamma shape (forwarded from the joint).

property beta: Array#

\(\beta\) — Gamma rate (forwarded from the joint).

fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton', alpha_min=None)[source]#

Fit VG using EM or MCECM. Defaults to CPU E-step (faster than JAX vmap for the degenerate-GIG posterior arising from the Gamma subordinator).

alpha_min (float or 'density' / 'inverse_moment') is the opt-in lower bound on the Gamma shape \(\alpha\) that keeps the VG marginal likelihood bounded; see fit().

to_generalized_hyperbolic(*, boundary_eps=0.0)[source]#

Exact embedding into the GeneralizedHyperbolic family.

Parameters:

boundary_eps (float)

class normix.distributions.variance_gamma.UnivariateVarianceGamma(joint)[source]#

Bases: _UnivariateNormalMixtureMixin, VarianceGamma

Univariate (d=1) Variance Gamma distribution.

Sibling of VarianceGamma for 1-D problems: exposes scalar mean/var/std, (n,)-shaped rvs, and cdf/ppf backed by a PINV table over the marginal log-density. EM, fit, replace, and regularisation are inherited from VarianceGamma.

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]#

Build from scalar or 1-D classical parameters.

mu, gamma may be scalars or (1,); sigma may be a scalar variance, (1,), or (1, 1).

Return type:

UnivariateVarianceGamma

class normix.distributions.variance_gamma.FactorVarianceGamma(mu, gamma, F, D, *, alpha, beta)[source]#

Bases: FactorNormalMixture

Factor-analysis Variance Gamma: \(Y \sim \mathrm{Gamma}(\alpha, \beta)\), \(\Sigma = F F^\top + \mathrm{diag}(D)\).

GIG limit of the subordinator: \(p = \alpha\), \(a = 2\beta\), \(b \to 0\).

Parameters:
classmethod from_classical(*, mu, gamma, F, D, alpha, beta)[source]#
Return type:

FactorVarianceGamma

property alpha: Array#
property beta: Array#
log_prob(x)[source]#

Marginal VG log-density evaluated with Woodbury Σ-solve.

Parameters:

x (Array)

Return type:

Array

Normal Inverse Gamma#

Normal-Inverse Gamma (NInvG) distribution.

Special case of GH with GIG → InverseGamma subordinator (\(a \to 0\), \(p < 0\)). \(Y \sim \mathrm{InvGamma}(\alpha, \beta)\), i.e. GIG(\(p = -\alpha\), \(a \to 0\), \(b = 2\beta\)).

Stored: \(\mu\), \(\gamma\), \(L_\Sigma\) (Cholesky of \(\Sigma\)), \(\alpha\) (shape), \(\beta\) (rate) of InverseGamma.

class normix.distributions.normal_inverse_gamma.JointNormalInverseGamma(mu, gamma, L_Sigma, alpha, beta)[source]#

Bases: JointNormalMixture

Joint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{InvGamma}(\alpha, \beta)\).

GIG limit: \(p = -\alpha\), \(a \to 0\), \(b = 2\beta\).

Parameters:
alpha: Array#
beta: Array#
subordinator()[source]#

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]#

\(\theta = [-(\alpha+1)-d/2,\; -(\beta+\tfrac{1}{2}\mu^\top\Lambda\mu),\; -\tfrac{1}{2}\gamma^\top\Lambda\gamma,\; \Lambda\gamma,\; \Lambda\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Lambda)]\)

(InverseGamma subordinator: \(p=-\alpha\), \(a\to 0\), \(b=2\beta\)).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]#
classmethod from_natural(theta)[source]#

Recover classical parameters from \(\theta\).

\(\alpha = -(\theta_1 + d/2) - 1\), \(\beta = -\theta_2 - \mu_{\mathrm{quad}}\).

Parameters:

theta (Array)

Return type:

JointNormalInverseGamma

to_joint_generalized_hyperbolic(*, boundary_eps=0.0)[source]#

Exact embedding into JointGeneralizedHyperbolic.

Lifts the InverseGamma subordinator to GIG via InverseGamma.to_gig() and keeps the Normal block unchanged.

Parameters:

boundary_eps (float)

class normix.distributions.normal_inverse_gamma.NormalInverseGamma(joint)[source]#

Bases: NormalMixture

Marginal Normal-Inverse Gamma distribution f(x).

Parameters:

joint (JointNormalInverseGamma)

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]#
log_prob(x)[source]#

Marginal NInvG log-density (own formula, no GH delegation).

GIG params: \(p=-\alpha\), \(a=\gamma^\top\Lambda\gamma\), \(b=2\beta+Q(x)\). The normalising integral is \(2(b/a)^{p/2} K_p(\sqrt{ab})\).

Parameters:

x (Array)

Return type:

Array

property alpha: Array#

\(\alpha\) — InverseGamma shape (forwarded from the joint).

property beta: Array#

\(\beta\) — InverseGamma rate (forwarded from the joint).

fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]#

Fit NInvG using EM or MCECM. Defaults to CPU E-step (faster than JAX vmap for the degenerate-GIG posterior arising from the InverseGamma subordinator).

to_generalized_hyperbolic(*, boundary_eps=0.0)[source]#

Exact embedding into the GeneralizedHyperbolic family.

Parameters:

boundary_eps (float)

class normix.distributions.normal_inverse_gamma.UnivariateNormalInverseGamma(joint)[source]#

Bases: _UnivariateNormalMixtureMixin, NormalInverseGamma

Univariate (d=1) Normal-Inverse-Gamma distribution.

Sibling of NormalInverseGamma for 1-D problems; see UnivariateVarianceGamma for the contract.

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]#
Return type:

UnivariateNormalInverseGamma

class normix.distributions.normal_inverse_gamma.FactorNormalInverseGamma(mu, gamma, F, D, *, alpha, beta)[source]#

Bases: FactorNormalMixture

Factor-analysis Normal-Inverse-Gamma: \(Y \sim \mathrm{InvGamma}(\alpha, \beta)\), \(\Sigma = F F^\top + \mathrm{diag}(D)\).

GIG limit: \(p = -\alpha\), \(a \to 0\), \(b = 2\beta\).

Parameters:
classmethod from_classical(*, mu, gamma, F, D, alpha, beta)[source]#
Return type:

FactorNormalInverseGamma

property alpha: Array#
property beta: Array#
log_prob(x)[source]#

Marginal \(\log f(x)\) for a single observation.

Parameters:

x (Array)

Return type:

Array

Normal Inverse Gaussian#

Normal-Inverse Gaussian (NIG) distribution.

Special case of GH with GIG → InverseGaussian subordinator (\(p = -1/2\)). \(Y \sim \mathrm{InvGaussian}(\mu_{IG}, \lambda)\), i.e. GIG(\(p = -1/2\), \(a = \lambda/\mu_{IG}^2\), \(b = \lambda\)).

Stored: \(\mu\), \(\gamma\), \(L_\Sigma\) (Cholesky of \(\Sigma\)), \(\mu_{IG}\) (IG mean), \(\lambda\) (IG shape).

class normix.distributions.normal_inverse_gaussian.JointNormalInverseGaussian(mu, gamma, L_Sigma, mu_ig, lam)[source]#

Bases: JointNormalMixture

Joint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{InvGaussian}(\mu_{IG}, \lambda)\).

Stored: \(\mu_{IG}\) (IG mean) and \(\lambda\) (IG shape) directly. GIG params: \(p = -1/2\), \(a = \lambda/\mu_{IG}^2\), \(b = \lambda\).

Parameters:
mu_ig: Array#
lam: Array#
subordinator()[source]#

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]#

\(\theta = [-3/2-d/2,\; -(\lambda/2+\tfrac{1}{2}\mu^\top\Lambda\mu),\; -(\lambda/(2\mu_{IG}^2)+\tfrac{1}{2}\gamma^\top\Lambda\gamma),\; \Lambda\gamma,\; \Lambda\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Lambda)]\)

where \(p=-1/2\), \(a=\lambda/\mu_{IG}^2\), \(b=\lambda\), aligned with GIG natural parameters on \([\log y,\,1/y,\,y]\).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, mu_ig, lam)[source]#
classmethod from_natural(theta)[source]#

Recover classical parameters from \(\theta\).

From \(b = -2\theta_2 - 2\mu_{\mathrm{quad}} = \lambda\) and \(a = -2\theta_3 - 2\gamma_{\mathrm{quad}} = \lambda/\mu_{IG}^2\), so \(\mu_{IG} = \sqrt{b/a}\).

Parameters:

theta (Array)

Return type:

JointNormalInverseGaussian

to_joint_generalized_hyperbolic()[source]#

Exact embedding into JointGeneralizedHyperbolic.

Lifts the InverseGaussian subordinator to GIG via InverseGaussian.to_gig() (no boundary approximation) and keeps the Normal block unchanged.

class normix.distributions.normal_inverse_gaussian.NormalInverseGaussian(joint)[source]#

Bases: NormalMixture

Marginal Normal-Inverse Gaussian distribution f(x).

Parameters:

joint (JointNormalInverseGaussian)

classmethod from_classical(*, mu, gamma, sigma, mu_ig, lam)[source]#
log_prob(x)[source]#

Marginal NIG log-density.

Uses \(K_{-1/2}(z) = \sqrt{\pi/(2z)}\,e^{-z}\) for the normalisation, leaving only one \(\log K_\nu\) call at order \(\nu = -1/2 - d/2\).

Parameters:

x (Array)

Return type:

Array

property mu_ig: Array#

\(\mu_{IG}\) — InverseGaussian mean (forwarded from the joint).

property lam: Array#

\(\lambda\) — InverseGaussian shape (forwarded from the joint).

regularize_a_eq_b()[source]#

Rescale so \(a = b = \sqrt{ab}\).

For NIG, \(a = \lambda/\mu_{IG}^2,\;b = \lambda\), so \(s = \sqrt{a/b} = 1/\mu_{IG}\). After rescaling \(\mu_{IG}' = 1\), i.e. the InverseGaussian has unit mean.

Return type:

NormalInverseGaussian

to_generalized_hyperbolic()[source]#

Exact embedding into the GeneralizedHyperbolic family.

No boundary approximation: NIG sits in the strict interior of GH’s parameter space (\(p = -1/2,\; a = \lambda/\mu_{IG}^2,\; b = \lambda\)).

class normix.distributions.normal_inverse_gaussian.UnivariateNormalInverseGaussian(joint)[source]#

Bases: _UnivariateNormalMixtureMixin, NormalInverseGaussian

Univariate (d=1) Normal-Inverse-Gaussian distribution.

Sibling of NormalInverseGaussian for 1-D problems; see UnivariateVarianceGamma for the contract.

classmethod from_classical(*, mu, gamma, sigma, mu_ig, lam)[source]#
Return type:

UnivariateNormalInverseGaussian

class normix.distributions.normal_inverse_gaussian.FactorNormalInverseGaussian(mu, gamma, F, D, *, mu_ig, lam)[source]#

Bases: FactorNormalMixture

Factor-analysis Normal-Inverse-Gaussian: \(Y \sim \mathrm{InvGaussian}(\mu_{IG}, \lambda)\), \(\Sigma = F F^\top + \mathrm{diag}(D)\).

GIG params: \(p = -1/2\), \(a = \lambda/\mu_{IG}^2\), \(b = \lambda\).

Parameters:
classmethod from_classical(*, mu, gamma, F, D, mu_ig, lam)[source]#
Return type:

FactorNormalInverseGaussian

property mu_ig: Array#
property lam: Array#
log_prob(x)[source]#

Marginal \(\log f(x)\) for a single observation.

Parameters:

x (Array)

Return type:

Array

regularize_a_eq_b()[source]#

Rescale so \(a = b\). For NIG (\(a = \lambda/\mu_{IG}^2\), \(b = \lambda\)), this means \(\mu_{IG} \to 1\).

Return type:

FactorNormalInverseGaussian

Generalized Hyperbolic#

Generalized Hyperbolic (GH) distribution.

Joint: \(X \mid Y \sim \mathcal{N}(\mu + \gamma y, \Sigma y)\), \(Y \sim \mathrm{GIG}(p, a, b)\).

Marginal: \(\mathrm{GH}(\mu, \gamma, \Sigma, p, a, b)\).

Marginal log-density (closed form via Bessel functions):

Let \(Q(x) = (x-\mu)^\top \Sigma^{-1}(x-\mu)\), \(A = a + \gamma^\top \Sigma^{-1} \gamma\).

\[\log f(x) = -\tfrac{d}{2}\log(2\pi) - \tfrac{1}{2}\log|\Sigma| + \tfrac{p}{2}(\log a - \log b) - \log K_p(\sqrt{ab}) + \tfrac{d/2-p}{2}\log\frac{A}{Q(x)+b} + \log K_{p-d/2}\!\left(\sqrt{A(Q(x)+b)}\right) + \gamma^\top \Sigma^{-1}(x - \mu)\]

Posterior: \(Y \mid X = x \sim \mathrm{GIG}(p - d/2,\; a + \gamma^\top\Sigma^{-1}\gamma,\; b + (x-\mu)^\top\Sigma^{-1}(x-\mu))\).

class normix.distributions.generalized_hyperbolic.JointGeneralizedHyperbolic(mu, gamma, L_Sigma, p, a, b)[source]#

Bases: JointNormalMixture

Joint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{GIG}(p, a, b)\).

Stored: mu, gamma, L_Sigma (from JointNormalMixture) + p, a, b (GIG parameters).

Parameters:
p: Array#
a: Array#
b: Array#
subordinator()[source]#

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]#

\(\theta = [p-1-d/2,\; -(b/2+\tfrac{1}{2}\mu^\top\Sigma^{-1}\mu),\; -(a/2+\tfrac{1}{2}\gamma^\top\Sigma^{-1}\gamma),\; \Sigma^{-1}\gamma,\; \Sigma^{-1}\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Sigma^{-1})]\)

The scalar coefficients on sufficient statistics \(1/y\) and \(y\) match the GIG convention \(\theta_{\mathrm{GIG}} = [p-1,\,-b/2,\,-a/2]\) on \(t_Y = [\log y,\,1/y,\,y]\).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, p, a, b)[source]#

Construct from classical parameters.

classmethod from_natural(theta)[source]#

Recover classical parameters from \(\theta\).

\(p = \theta_1 + 1 + d/2\), \(b = -2\theta_2 - 2\mu_{\mathrm{quad}}\), \(a = -2\theta_3 - 2\gamma_{\mathrm{quad}}\).

Parameters:

theta (Array)

Return type:

JointGeneralizedHyperbolic

to_joint_variance_gamma()[source]#

KL projection onto JointVarianceGamma (Gamma subordinator).

Return type:

JointVarianceGamma

to_joint_normal_inverse_gamma()[source]#

KL projection onto JointNormalInverseGamma.

Return type:

JointNormalInverseGamma

to_joint_normal_inverse_gaussian()[source]#

KL projection onto JointNormalInverseGaussian.

Return type:

JointNormalInverseGaussian

class normix.distributions.generalized_hyperbolic.GeneralizedHyperbolic(joint)[source]#

Bases: NormalMixture

Marginal Generalized Hyperbolic distribution \(f(x)\).

Stores a JointGeneralizedHyperbolic. Provides:

  • log_prob(x) — closed-form Bessel expression

  • e_step, m_step — for EM fitting

  • fit(X, ...) — convenience fitting method

Parameters:

joint (JointGeneralizedHyperbolic)

classmethod from_classical(*, mu, gamma, sigma, p, a, b)[source]#
Return type:

GeneralizedHyperbolic

log_prob(x)[source]#

Marginal \(\log f(x)\).

\[f(x) \propto \left(\frac{A}{Q(x)+b}\right)^{(d/2-p)/2} K_{p-d/2}\!\left(\sqrt{A(Q(x)+b)}\right) \exp\!\left(\gamma^\top\Sigma^{-1}(x-\mu)\right)\]

where \(Q(x) = (x-\mu)^\top\Sigma^{-1}(x-\mu)\), \(A = a + \gamma^\top\Sigma^{-1}\gamma\).

Parameters:

x (Array)

Return type:

Array

property p: Array#

\(p\) — GIG order (forwarded from the joint).

property a: Array#

\(a\) — GIG concentration (forwarded from the joint).

property b: Array#

\(b\) — GIG concentration (forwarded from the joint).

m_step(eta, **kwargs)[source]#

Full M-step with warm-started + sanity-checked subordinator solve.

Overrides the base cls.from_expectation(eta) because the GIG solver benefits substantially from warm-starting at the current \(\theta\) and from a fall-back when the solver wanders into unsane regions; both are managed by m_step_subordinator().

Return type:

GeneralizedHyperbolic

m_step_subordinator(eta, **kwargs)[source]#

M-step for the subordinator only (MCECM Cycle 2).

Reads the subordinator-relevant fields of eta; normal parameters are read from self._joint and copied unchanged. Subclasses with iterative solvers may override to add warm-start or sanity-check fallbacks.

Return type:

GeneralizedHyperbolic

regularize_a_eq_b()[source]#

Rescale so \(a = b = \sqrt{ab}\) (orbit invariant).

Picks \(s = \sqrt{a/b}\). Idempotent: applying twice leaves the model unchanged.

Return type:

GeneralizedHyperbolic

to_variance_gamma()[source]#

KL projection onto the VarianceGamma family.

Return type:

VarianceGamma

to_normal_inverse_gamma()[source]#

KL projection onto the NormalInverseGamma family.

Return type:

NormalInverseGamma

to_normal_inverse_gaussian()[source]#

KL projection onto the NormalInverseGaussian family.

Return type:

NormalInverseGaussian

classmethod default_init(X)[source]#

Warm-start from the best of NIG / VG / NInvG sub-model fits.

Runs 5 EM iterations (JAX backend, Newton method) for each special case, converts each to GH parametrisation, and selects the candidate with the highest marginal log-likelihood. A moment-based fallback (\(p=1, a=1, b=1\)) is included as a fourth candidate.

Fully JAX-native: no try/except, no Python branching on data values.

Parameters:

X (Array)

Return type:

GeneralizedHyperbolic

fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='det_sigma_one', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]#

Fit GH distribution using EM or MCECM.

Defaults to CPU backends and det_sigma_one regularization (GH has scale non-identifiability requiring |Sigma| = 1).

class normix.distributions.generalized_hyperbolic.UnivariateGeneralizedHyperbolic(joint)[source]#

Bases: _UnivariateNormalMixtureMixin, GeneralizedHyperbolic

Univariate (d=1) Generalized Hyperbolic distribution.

Sibling of GeneralizedHyperbolic for 1-D problems; see UnivariateVarianceGamma for the contract.

classmethod from_classical(*, mu, gamma, sigma, p, a, b)[source]#
Return type:

UnivariateGeneralizedHyperbolic

class normix.distributions.generalized_hyperbolic.FactorGeneralizedHyperbolic(mu, gamma, F, D, *, p, a, b)[source]#

Bases: FactorNormalMixture

Factor-analysis Generalized Hyperbolic: \(Y \sim \mathrm{GIG}(p, a, b)\), \(\Sigma = F F^\top + \mathrm{diag}(D)\).

The GIG subordinator’s M-step uses warm-started numerical optimisation with the same sanity-check fall-back as the standard GeneralizedHyperbolic.

Parameters:
classmethod from_classical(*, mu, gamma, F, D, p, a, b)[source]#
Return type:

FactorGeneralizedHyperbolic

property p: Array#
property a: Array#
property b: Array#
log_prob(x)[source]#

Marginal \(\log f(x)\) for a single observation.

Parameters:

x (Array)

Return type:

Array

m_step(eta, **kwargs)[source]#

Full M-step with warm-started, sanity-checked GIG solve.

Mirrors GeneralizedHyperbolic.m_step(): do the closed-form factor M-step first, then run the subordinator update with warm-start and fall-back.

Return type:

FactorGeneralizedHyperbolic

m_step_subordinator(eta, **kwargs)[source]#

M-step for the subordinator only (MCECM cycle 2).

Reads the subordinator-relevant fields of eta; \((\mu, \gamma, F, D)\) are read from self and copied unchanged. Subclasses with iterative solvers may override to add a warm-start or sanity-check fallback.

Return type:

FactorGeneralizedHyperbolic

regularize_a_eq_b()[source]#

Rescale so \(a = b = \sqrt{ab}\) (orbit invariant).

Mirrors GeneralizedHyperbolic.regularize_a_eq_b(). Picks \(s = \sqrt{a/b}\).

Return type:

FactorGeneralizedHyperbolic

classmethod default_init(X, *, r=1)[source]#

Warm-start from the best of FactorNIG / FactorVG / FactorNInvG fits.

Mirrors GeneralizedHyperbolic.default_init(): runs a few EM iterations of each special-case factor family, converts each subordinator to the GIG embedding, and selects the candidate with the highest marginal log-likelihood. A moment-based fallback (\(p = 1, a = 1, b = 1\)) is included as a fourth candidate.

Fully JAX-native: no try/except, no Python branching on data values.

Parameters:
Return type:

FactorGeneralizedHyperbolic

Fitting#

EM Fitters#

EM fitters for normix distributions.

Model knows math, fitter knows iteration.

BatchEMFitter — standard batch EM with dual-loop architecture:

lax.scan (JIT-able) or Python for-loop (CPU-compatible)

IncrementalEMFitter — online / mini-batch EM with pluggable eta update rules

class normix.fitting.em.EMResult(model, log_likelihoods, param_changes, n_iter, converged, elapsed_time, diverged=False)[source]#

Bases: object

Result of an EM fitting procedure.

converged and diverged are separate because batch EM has three outcomes: tolerance met (converged=True), a non-finite iterate was reverted (diverged=True), or max_iter was exhausted with neither (both False). They are not opposites — converged=False covers both divergence and a finite but not-yet-converged stop.

Parameters:
model: Any#
log_likelihoods: Array | None#
param_changes: Array#
n_iter: int#
converged: bool | None#
elapsed_time: float#
diverged: bool = False#
class normix.fitting.em.BatchEMFitter(*, algorithm='em', max_iter=200, tol=0.001, verbose=0, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton', eta_update=None, track_ll=False, m_step_kwargs=None)[source]#

Bases: object

Batch EM / MCECM algorithm with dual-loop architecture.

EM (default): E-step → M-step (all params) → regularize.

MCECM: E-step → M-step (normal params only) → regularize → E-step → M-step (subordinator only).

Convergence is measured by relative parameter change in the normal parameters (mu, gamma, L_Sigma), excluding subordinator (GIG) parameters.

Loop selection is automatic:
  • lax.scan when both backends are ‘jax’, verbose <= 1, algorithm=’em’, and no eta_update rule

  • Python for-loop otherwise

Parameters:
  • algorithm (str) – ‘em’ (default) or ‘mcecm’.

  • max_iter (int) – Maximum number of iterations.

  • tol (float) – Convergence tolerance on max relative parameter change.

  • verbose (int) – 0 = silent, 1 = summary, 2 = per-iteration table.

  • regularization (str) –

    Strategy applied after each M-step:

    • 'none' — no regularization.

    • 'det_sigma_one' — enforce \(|\Sigma| = 1\) (the original GH convention; equivalent to regularize_det_sigma(target_log_det=0)).

    • 'det_sigma_x' — enforce \(|\Sigma| = |\Sigma_0|\) where \(\Sigma_0\) is the dispersion of the initial model passed to fit(). Useful when comparing GH / FactorGH parameters against VG / NIG / NInvG, which leave \(|\Sigma|\) at the empirical scale.

    • 'a_eq_b' — rescale the GIG subordinator so \(a = b = \sqrt{ab}\). Trivial no-op for VG / NInvG / MultivariateNormal.

  • e_step_backend (str) – ‘jax’ (default) or ‘cpu’.

  • m_step_backend (str) – ‘jax’ or ‘cpu’ (default, faster for GIG).

  • m_step_method (str) – ‘newton’ (default), ‘lbfgs’, or ‘bfgs’.

  • eta_update (EtaUpdateRule or None) – Optional eta combination rule (e.g. Shrinkage(IdentityUpdate(), eta0, tau)). When set, the E-step output is transformed before the M-step.

  • track_ll (bool) – When True, record per-iteration marginal log-likelihood in EMResult.log_likelihoods without requiring verbose >= 1.

  • m_step_kwargs (dict or None) – Extra keyword arguments forwarded verbatim to every m_step / m_step_subordinator call (in addition to backend and method). Used to thread estimand controls such as the VG alpha_min shape bound down to the subordinator’s from_expectation. Values must be static (e.g. Python floats) to stay compatible with the lax.scan path. None = no extras.

fit(model, X)[source]#

Run batch EM or MCECM. Auto-selects lax.scan or Python loop.

Parameters:
  • model (NormalMixture subclass (used as initial parameters))

  • X ((n, d) data array)

Return type:

EMResult with fitted model, convergence diagnostics, and timing.

class normix.fitting.em.IncrementalEMFitter(*, eta_update=None, batch_size=256, max_steps=200, inner_iter=1, verbose=0, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton')[source]#

Bases: object

Incremental EM with pluggable eta update rules.

Replaces OnlineEMFitter and MiniBatchEMFitter. Processes data in random mini-batches, applies an EtaUpdateRule to combine the running \(\eta\) with each batch estimate, then M-steps on the combined \(\eta\).

Parameters:
  • eta_update (EtaUpdateRule) – How to combine running η with each batch estimate.

  • batch_size (int) – Observations per batch.

  • max_steps (int) – Number of batches to process (total budget).

  • inner_iter (int) – 1 = online (default); >1 = fine-tuning on each batch.

  • verbose (int) –

    0 = silent; 1 = periodic summary.

    Scan path (JAX backends only):

    verbose must be 0 so diagnostics do not rely on Python side effects each step.

  • regularization (str) – Same options as BatchEMFitter'none' | 'det_sigma_one' | 'det_sigma_x' | 'a_eq_b'.

  • e_step_backend (str) – Passed through to e_step / m_step.

  • m_step_backend (str) – Passed through to e_step / m_step.

  • m_step_method (str) – Passed through to e_step / m_step.

fit(model, X, *, key)[source]#

Run incremental EM. Returns EMResult.

Parameters:
Return type:

EMResult

Solvers#

Bregman divergence solvers.

Minimises f(θ) − θ·η over θ, where f is any convex function (e.g. the log-partition ψ for an exponential family). At the minimum ∇f(θ*) = η.

Public API#

solve_bregman single starting point solve_bregman_multistart multiple starting points (vmap for JAX Newton;

for-loop for quasi-Newton and CPU)

bregman_objective utility: f(θ) − θ·η make_jit_newton_solver build a stable @jax.jit Newton solve specialised

to a fixed (f, grad_fn, hess_fn, bounds) — repeated calls with matching shapes/dtypes hit the XLA cache, avoiding the per-call re-tracing that solve_bregman incurs from fresh closures.

Backends × methods#

backend=’jax’, method=’newton’ custom lax.scan Newton, autodiff or analytical Hessian backend=’jax’, method=’lbfgs’ jaxopt LBFGSB (bounds native) or LBFGS (reparam) backend=’jax’, method=’bfgs’ jaxopt BFGS with reparameterization for bounds backend=’cpu’, method=’lbfgs’ scipy L-BFGS-B backend=’cpu’, method=’bfgs’ scipy BFGS backend=’cpu’, method=’newton’ scipy trust-exact with Hessian

Gradient / Hessian sources#

grad_fnθ → ∇f(θ). For backend=’cpu’: pure CPU (numpy) gradient.

For backend=’jax’, method=’newton’: JAX-traceable ∇ψ(θ). If None with backend=’cpu’: hybrid — jax.grad compiled → NumPy callbacks.

hess_fnθ → ∇²f(θ). Required for method=’newton’.

For backend=’cpu’: pure CPU (numpy) Hessian. For backend=’jax’, method=’newton’: JAX-traceable ∇²ψ(θ). If None with method=’newton’: jax.hessian of the full objective.

Both grad_fn and hess_fn operate in theta-space only. The solver handles all reparameterization internally via the chain rule.

class normix.fitting.solvers.BregmanResult(theta, fun, grad_norm, num_steps, converged, elapsed_time=0.0)[source]#

Bases: object

Result of a Bregman divergence minimization.

Scalar fields accept both Python and JAX types so the result can live inside a lax.scan carry without concretization errors.

Parameters:
theta: Array#
fun: Any#
grad_norm: Any#
num_steps: int#
converged: Any#
elapsed_time: float = 0.0#
normix.fitting.solvers.bregman_objective(theta, eta, f)[source]#

f(θ) − θ·η — convex dual whose minimum gives ∇f(θ*) = η.

Parameters:
Return type:

Array

normix.fitting.solvers.solve_bregman(f, eta, theta0, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]#

Minimise f(θ) − θ·η over θ.

Parameters:
  • f (convex function θ → scalar (e.g. log-partition ψ))

  • eta (target vector (e.g. expectation parameters η))

  • theta0 (initial guess)

  • backend ('jax' (JIT-able) or 'cpu' (scipy, not JIT-able))

  • method ('lbfgs', 'bfgs', or 'newton')

  • bounds ((lower, upper) pair of JAX arrays, each shape (d,);) – None → unconstrained. For backend=’jax’: enforced via reparameterization. For backend=’cpu’: converted to scipy format internally.

  • max_steps (iteration budget)

  • tol (convergence tolerance on ‖∇f(θ) − η‖∞)

  • grad_fn (θ → ∇f(θ).) – For backend=’cpu’: must accept and return numpy arrays. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with backend=’cpu’: falls back to jax.grad (hybrid mode).

  • hess_fn (θ → ∇²f(θ).) – Required for method=’newton’. For backend=’cpu’: must accept numpy arrays and return numpy array. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with method=’newton’: jax.hessian of the full objective is used.

  • verbose (int) – 0 = silent, >= 1 = print summary after solve.

Return type:

BregmanResult

normix.fitting.solvers.solve_bregman_multistart(f, eta, theta0_batch, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]#

Run solve_bregman from multiple starting points; return the best result.

Parameters:
  • theta0_batch ((K, dim) jax.Array for backend='jax', method='newton') – (parallel via vmap); list of arrays otherwise (sequential for-loop).

  • verbose (int) – 0 = silent, >= 1 = print summary.

  • f (Callable[[Array], Array])

  • eta (Array)

  • backend (str)

  • method (str)

  • bounds (Tuple[Array, Array] | None)

  • max_steps (int)

  • tol (float)

  • grad_fn (Callable | None)

  • hess_fn (Callable | None)

Return type:

BregmanResult

normix.fitting.solvers.make_jit_newton_solver(f, grad_fn, hess_fn, bounds=None)[source]#

Build a @jax.jit-decorated Newton solver specialised to one problem.

The returned callable has signature solve(eta, theta0, max_steps=20, tol=1e-10) -> (theta_opt, fun, grad_norm, converged) where max_steps is a static argument (required by lax.scan).

All distribution-level inputs (f, grad_fn, hess_fn, bounds) are baked into the closure at construction time. Repeated calls with the same array shapes and dtypes therefore reuse the compiled XLA executable.

Use this in EM hot paths where solve_bregman would otherwise build a fresh Python closure on every call and force JAX to re-trace the same Newton kernel on each iteration.

Parameters:
  • f (convex objective ψ(θ) → scalar (must be JAX-traceable).)

  • grad_fn (∇ψ(θ) (d,).)

  • hess_fn (∇²ψ(θ) (d, d).)

  • bounds ((lower, upper) of shape (d,) each, or None.)

Returns:

Jit-compiled Newton solver. Returns a 4-tuple of JAX arrays (theta, fun, grad_norm, converged); wrap in BregmanResult externally if needed.

Return type:

Callable

Utilities#

Bessel Functions#

JAX-compatible log modified Bessel function of the second kind.

log_kv(v, z) = log K_v(z), fully pure-JAX with zero scipy callbacks.

Regime-specific methods, selected via lax.cond (only one branch executes):
  1. Hankel asymptotic (DLMF 10.40.2) — large z

  2. Olver uniform expansion (DLMF 10.41.4) — large v

  3. Small-z leading asymptotic (DLMF 10.30.2) — z → 0

  4. Gauss-Legendre quadrature (Takekawa 2022) — moderate z, v

Using lax.cond (not jnp.where) means only the selected branch executes at runtime. lax.cond requires scalar conditions, so the core scalar function _log_kv_scalar is vmapped over array inputs.

Custom JVP for full autodiff:
  • ∂/∂z : exact recurrence K’_v = −(K_{v−1}+K_{v+1})/2

  • ∂/∂v : central FD on log_kv itself (ε = BESSEL_EPS_V)

backend=’jax’ (default): pure-JAX, JIT-able, differentiable. backend=’cpu’ : scipy.special.kve, fully vectorized numpy.

Not JIT-able. Fast for EM hot path.

normix.utils.bessel.log_kv(v, z, backend='jax')[source]#

\(\log K_v(z)\) — log modified Bessel function of the second kind.

Parameters:
  • v (scalar or array) – Order (any real; \(K_v = K_{-v}\)).

  • z (scalar or array) – Argument (must be > 0).

  • backend (str, optional) –

    'jax' (default) or 'cpu'.

    • 'jax': pure-JAX, lax.cond regime selection, custom JVP. JIT-able, differentiable. Default for log_prob, pdf, etc.

    • 'cpu': scipy.special.kve, fully vectorised NumPy. Not JIT-able. Fast for EM hot path.

Returns:

Same broadcast shape as (v, z).

Return type:

jax.Array

Examples

Evaluate at a single point (JAX backend, JIT-able):

>>> import jax.numpy as jnp
>>> from normix import log_kv
>>> float(log_kv(v=0.5, z=1.0))
-0.112...

CPU backend (uses scipy, faster for EM hot-paths):

>>> float(log_kv(v=0.5, z=1.0, backend='cpu'))
-0.112...

Symmetry \(K_v(z) = K_{-v}(z)\):

>>> abs(float(log_kv(0.5, 2.0)) - float(log_kv(-0.5, 2.0))) < 1e-10
True

Differentiable via JAX:

>>> import jax
>>> dlogkv_dz = jax.grad(lambda z: log_kv(0.5, z))(jnp.array(1.0))
>>> float(dlogkv_dz) < 0   # K_v decreases with z
True

Constants#

Shared numerical constants for normix.