API Reference

Base Classes

class normix.exponential_family.ExponentialFamily[source]

Bases: Module

Abstract base class for exponential family distributions.

Concrete subclasses must implement:

_log_partition_from_theta, natural_params, sufficient_statistics, log_base_measure

abstractmethod natural_params()[source]

θ from stored classical parameters.

Return type:

Array

abstractmethod static sufficient_statistics(x)[source]

t(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

abstractmethod static log_base_measure(x)[source]

log h(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

log_partition()[source]

ψ(θ) at current parameters.

Return type:

Array

expectation_params(backend='jax')[source]

η = ∇ψ(θ).

Parameters:

backend ('jax' (default, JIT-able) or 'cpu' (numpy/scipy))

Return type:

Array

fisher_information(backend='jax')[source]

I(θ) = ∇²ψ(θ).

Parameters:

backend ('jax' (default, JIT-able) or 'cpu' (numpy/scipy))

Return type:

Array

log_prob(x)[source]

log p(x|θ) = log h(x) + θᵀt(x) − ψ(θ), single observation.

Parameters:

x (Array)

Return type:

Array

pdf(x)[source]

p(x|θ), single observation. Batch via jax.vmap.

Parameters:

x (Array)

Return type:

Array

mean()[source]

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

std()[source]

Std[X] = √Var[X].

Return type:

Array

cdf(x)[source]

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

rvs(n, seed=42)[source]

Sample n observations. Uses numpy/scipy (not JIT-able).

Parameters:
Return type:

ndarray

classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

ExponentialFamily

classmethod bregman_divergence(theta, eta)[source]

Bregman divergence ψ(θ) − θ·η (conjugate dual).

Minimising over θ yields ∇ψ(θ*) = η, i.e. the natural parameters corresponding to expectation parameters η.

Parameters:
Return type:

Array

classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='lbfgs', verbose=0)[source]

Construct from expectation parameters η by solving ∇ψ(θ) = η.

Minimises the Bregman divergence ψ(θ) − θ·η via solve_bregman. Subclasses can override for closed-form inverses.

Parameters:
  • backend ('jax' (default, JIT-able) or 'cpu' (scipy, more robust))

  • method ('lbfgs' (default), 'bfgs', or 'newton')

  • verbose (0 = silent, >= 1 = print solver summary)

  • eta (Array)

  • theta0 (Array | None)

  • maxiter (int)

  • tol (float)

Return type:

ExponentialFamily

classmethod fit_mle(X, *, theta0=None, maxiter=500, tol=1e-10, verbose=0)[source]

MLE via exponential family identity: η̂ = mean_i t(xᵢ).

Batches over X using jax.vmap, then calls from_expectation(η̂).

Parameters:
  • X ((n, ...) array of observations)

  • theta0 (optional initial natural parameters θ₀ for the η→θ solver)

  • maxiter (maximum iterations for the η→θ solver)

  • tol (convergence tolerance for the η→θ solver)

  • verbose (0 = silent, >= 1 = print solver summary)

Return type:

ExponentialFamily

fit(X, *, maxiter=500, tol=1e-10, verbose=0, **kwargs)[source]

Fit using self as initialization (warm start).

Computes η̂ = mean_i t(xᵢ) and solves from_expectation(η̂) using self.natural_params() as the initial theta0.

Parameters:
  • X ((n, ...) array of observations)

  • maxiter (maximum iterations for the η→θ solver)

  • tol (convergence tolerance for the η→θ solver)

  • verbose (0 = silent, >= 1 = print solver summary)

Return type:

ExponentialFamily

classmethod default_init(X)[source]

Moment-based initialisation from data.

Computes η̂ = mean_i t(xᵢ) and inverts to get an initial model. For distributions with closed-form from_expectation (Gamma, InverseGamma, InverseGaussian), this gives the MLE directly.

Parameters:

X (Array)

Return type:

ExponentialFamily

Univariate Distributions

Gamma

Gamma distribution as an exponential family.

PDF: p(x|α,β) = β^α/Γ(α) · x^{α-1} · exp(-βx), x > 0

Exponential family:

h(x) = 1 t(x) = [log x, x] θ = [α-1, -β] (θ₁ > -1, θ₂ < 0) ψ(θ) = log Γ(θ₁+1) − (θ₁+1) log(−θ₂) η = [ψ(α) − log β, α/β] (digamma, mean)

class normix.distributions.gamma.Gamma(alpha, beta)[source]

Bases: ExponentialFamily

Gamma(α, β) distribution — shape α > 0, rate β > 0.

Parameters:
alpha: Array
beta: Array
natural_params()[source]

θ from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]

t(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]

log h(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

cdf(x)[source]

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

rvs(n, seed=42)[source]

Sample n observations. Uses numpy/scipy (not JIT-able).

Parameters:
classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

Gamma

classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, **kwargs)[source]

Closed-form η → θ via Newton’s method on ψ(α) − log(α) = η₁ − log(η₂).

η = [E[log X], E[X]] → α from digamma inversion, β = α/η₂.

Parameters:
Return type:

Gamma

Inverse Gamma

InverseGamma distribution as an exponential family.

PDF: p(x|α,β) = β^α/Γ(α) · x^{-α-1} · exp(-β/x), x > 0

Exponential family:

h(x) = 1 t(x) = [-1/x, log x] θ = [β, -(α+1)] (θ₁ > 0, θ₂ < -1) ψ(θ) = log Γ(-θ₂-1) − (-θ₂-1) log(θ₁)

= log Γ(α) − α log β

η = [-α/β, log β − ψ(α)] (E[-1/X], E[log X])

class normix.distributions.inverse_gamma.InverseGamma(alpha, beta)[source]

Bases: ExponentialFamily

InverseGamma(α, β) — shape α > 0, rate β > 0.

Parameters:
alpha: Array
beta: Array
natural_params()[source]

θ from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]

t(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]

log h(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

cdf(x)[source]

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

rvs(n, seed=42)[source]

Sample n observations. Uses numpy/scipy (not JIT-able).

Parameters:
classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

InverseGamma

classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, **kwargs)[source]

η = [-α/β, log β − ψ(α)].

β = α/(-η₁), solve ψ(α) − log α = −η₂ − log(-η₁) via Newton.

Parameters:
Return type:

InverseGamma

Inverse Gaussian

Inverse Gaussian (Wald) distribution as an exponential family.

PDF: f(x|μ,λ) = √(λ/(2π)) · x^{-3/2} · exp(−λ(x−μ)²/(2μ²x)), x > 0

Exponential family:

h(x) = (2π)^{-1/2} · x^{-3/2} t(x) = [x, 1/x] θ = [−λ/(2μ²), −λ/2] (θ₁ < 0, θ₂ < 0) ψ(θ) = ½log(2π) − ½log(−2θ₂) + √((−2θ₁)(−2θ₂))

(the ½log(2π) is absorbed into log h(x) = −½log(2π) − 3/2 log x)

η = [E[X], E[1/X]] = [μ, 1/μ + 1/λ]

class normix.distributions.inverse_gaussian.InverseGaussian(mu, lam)[source]

Bases: ExponentialFamily

InverseGaussian(μ, λ) — mean μ > 0, shape λ > 0.

Parameters:
mu: Array
lam: Array
natural_params()[source]

θ from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]

t(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]

log h(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]

E[X]. Subclasses should override with analytical formulas.

Return type:

Array

var()[source]

Var[X]. Subclasses should override with analytical formulas.

Return type:

Array

cdf(x)[source]

CDF F(x). Subclasses should override with analytical formulas.

Parameters:

x (Array)

Return type:

Array

rvs(n, seed=42)[source]

Sample n observations. Uses numpy/scipy (not JIT-able).

Parameters:
Return type:

np.ndarray

classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

InverseGaussian

classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, **kwargs)[source]
Closed-form from η = [E[X], E[1/X]] = [μ, 1/μ + 1/λ]:

μ = η₁, 1/λ = η₂ - 1/η₁ → λ = 1/(η₂ - 1/η₁)

Parameters:
Return type:

InverseGaussian

Generalized Inverse Gaussian

Generalized Inverse Gaussian (GIG) distribution as an exponential family.

PDF: f(x|p,a,b) = (a/b)^{p/2} / (2 K_p(√(ab))) · x^{p-1} · exp(-(ax+b/x)/2)

Exponential family:

h(x) = 1 t(x) = [log x, 1/x, x] θ = [p-1, -b/2, -a/2] (θ₂ ≤ 0, θ₃ ≤ 0) ψ(θ) = log 2 + log K_p(√(ab)) + (p/2) log(b/a)

where p = θ₁+1, a = -2θ₃, b = -2θ₂

η = [E[log X], E[1/X], E[X]]

Special cases:

b→0, p>0: GIG → Gamma(p, a/2) a→0, p<0: GIG → InverseGamma(-p, b/2) p=-1/2: GIG → InverseGaussian

η→θ optimization uses η-rescaling to reduce Fisher condition number:

s = √(η₂/η₃), η̃ = (η₁+½log(η₂/η₃), √(η₂η₃), √(η₂η₃)) Solve η̃→θ̃ with symmetric GIG (ã=b̃), then unscale.

Log-Partition Triad Overrides

_log_partition_from_theta : JAX, uses log_kv(backend=’jax’) _grad_log_partition : inherits default jax.grad (custom JVP on log_kv) _hessian_log_partition : analytical 7-Bessel Hessian in θ-space _log_partition_cpu : numpy + log_kv(backend=’cpu’) _grad_log_partition_cpu : analytical Bessel ratios via scipy.kve _hessian_log_partition_cpu : central FD on _log_partition_cpu

class normix.distributions.generalized_inverse_gaussian.GeneralizedInverseGaussian(p, a, b)[source]

Bases: ExponentialFamily

Generalized Inverse Gaussian distribution.

Stored: p (shape, any real), a > 0, b > 0.

Parameters:
p: Array
a: Array
b: Array
natural_params()[source]

θ from stored classical parameters.

Return type:

Array

static sufficient_statistics(x)[source]

t(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static log_base_measure(x)[source]

log h(x) for a single unbatched observation.

Parameters:

x (Array)

Return type:

Array

static expectation_params_batch(p, a, b, backend='jax')[source]

Vectorized η for arrays of (p, a, b), each shape (N,). Returns (N, 3) array where columns are [E_log_X, E_inv_X, E_X].

backend=’jax’ : vmap over scalar JAX grad backend=’cpu’ : vectorized scipy.kve (6 C-level array calls)

Parameters:

backend (str)

Return type:

Array

mean()[source]

E[X] = η₃ from expectation parameters.

Return type:

Array

var()[source]

Var[X] = ∂²ψ/∂θ₃² = Fisher information [2,2].

Return type:

Array

rvs(n, seed=42)[source]

Sample n observations. Uses numpy/scipy (not JIT-able).

Parameters:
Return type:

ndarray

classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

GeneralizedInverseGaussian

classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='newton', verbose=0)[source]

η → θ via η-rescaling + optimization.

Rescaling makes the Fisher matrix symmetric (ã = b̃), reducing condition number by up to \(10^{30}\) for extreme a/b ratios.

Parameters:
  • theta0 (warm-start point (required for JAX solvers; if None, uses) – multi-start CPU solver with Gamma/InvGamma/InvGauss seeds)

  • backend ('jax' (default, JIT-able) or 'cpu' (scipy, no JAX dispatch))

  • method ('newton', 'lbfgs', 'bfgs')

  • eta (Array)

  • maxiter (int)

  • tol (float)

  • verbose (int)

Return type:

GeneralizedInverseGaussian

normix.distributions.generalized_inverse_gaussian.GIG

alias of GeneralizedInverseGaussian

Multivariate Distributions

Multivariate Normal distribution.

Stored: mu (d,), L_Sigma (d,d) lower-triangular Cholesky of Σ. All linear algebra via L_Sigma — never form Σ⁻¹ explicitly.

PDF: f(x) = (2π)^{-d/2} |Σ|^{-1/2} exp(-½(x-μ)ᵀΣ⁻¹(x-μ))

Exponential family:

\[t(x) = [x, \operatorname{vec}(xx^\top)],\quad \theta = [\Sigma^{-1}\mu,\; -\tfrac{1}{2}\operatorname{vec}(\Sigma^{-1})],\quad \psi(\theta) = \tfrac{1}{2}\mu^\top\Sigma^{-1}\mu + \tfrac{1}{2}\log|\Sigma| + \tfrac{d}{2}\log(2\pi)\]

In practice we use the Cholesky log_prob formula directly.

class normix.distributions.normal.MultivariateNormal(mu, L_Sigma)[source]

Bases: Module

Multivariate Normal with Cholesky parametrization.

Parameters:
  • mu ((d,) array)

  • L_Sigma ((d,d) lower-triangular Cholesky factor of Σ)

mu: Array
L_Sigma: Array
classmethod from_classical(mu, sigma)[source]

Construct from mean μ and covariance matrix Σ.

Return type:

MultivariateNormal

log_prob(x)[source]

log f(x) for a single observation x of shape (d,).

Parameters:

x (Array)

Return type:

Array

sample(key, shape=())[source]

Draw samples of shape (*shape, d).

Parameters:
Return type:

Array

property dim: int
property sigma: Array

Covariance matrix Σ = L_Sigma L_Sigmaᵀ.

Mixture Base Classes

JointNormalMixture

JointNormalMixture — abstract exponential family for normal variance-mean mixtures.

Joint distribution f(x,y):

X|Y ~ N(μ + γy, Σy) Y ~ subordinator (GIG, Gamma, InverseGamma, InverseGaussian)

Sufficient statistics:

t(x,y) = [log y, 1/y, y, x, x/y, vec(xxᵀ/y)]

Natural parameters:

θ₁ = p_sub - 1 - d/2 (GIG p; scalar, depends on subordinator) θ₂ = -(b_sub + ½μᵀΣ⁻¹μ) < 0 θ₃ = -(a_sub + ½γᵀΣ⁻¹γ) < 0 θ₄ = Σ⁻¹γ (d-vector) θ₅ = Σ⁻¹μ (d-vector) θ₆ = -½vec(Σ⁻¹) (d²-vector)

Log partition:

ψ = ψ_sub(p, a, b) + ½log|Σ| + μᵀΣ⁻¹γ

Expectation parameters (EM E-step quantities):

η₁ = E[log Y] η₂ = E[1/Y] η₃ = E[Y] η₄ = E[X] = μ + γ E[Y] η₅ = E[X/Y] = μ E[1/Y] + γ η₆ = E[XXᵀ/Y] = Σ + μμᵀ E[1/Y] + γγᵀ E[Y] + μγᵀ + γμᵀ

EM M-step closed-form (from η):

Let D = 1 - E[1/Y]·E[Y] μ = (E[X] - E[Y]·E[X/Y]) / D γ = (E[X/Y] - E[1/Y]·E[X]) / D Σ = E[XXᵀ/Y] - E[X/Y]μᵀ - μE[X/Y]ᵀ + E[1/Y]μμᵀ - E[Y]γγᵀ

class normix.mixtures.joint.JointNormalMixture(mu, gamma, L_Sigma)[source]

Bases: ExponentialFamily

Abstract joint distribution f(x, y) for normal variance-mean mixtures.

Stored: mu (d,), gamma (d,), L_Sigma (d,d lower Cholesky of Σ) Subordinator parameters defined by concrete subclasses.

Parameters:
mu: Array
gamma: Array
L_Sigma: Array
abstractmethod subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

property d: int
sigma()[source]

Covariance matrix Σ = L_Sigma L_Sigmaᵀ.

Return type:

Array

log_det_sigma()[source]

log|Σ| = 2 Σᵢ log Lᵢᵢ, via Cholesky diagonal.

Return type:

Array

rvs(n, seed=42)[source]

Sample (X, Y) from the joint distribution.

Returns:

  • X ((n, d) array)

  • Y ((n,) array)

Parameters:
Return type:

Tuple[ndarray, ndarray]

log_prob_joint(x, y)[source]

log f(x, y) = log f(x|y) + log f_Y(y).

log f(x|y) = -d/2 log(2π) - ½ log|Σy| - ½(x-μ-γy)ᵀ(Σy)⁻¹(x-μ-γy)
= -d/2 log(2π) - ½ log|Σ| - d/2 log y
  • 1/(2y) ‖L⁻¹(x-μ)‖² + γᵀΣ⁻¹(x-μ) - y/2 γᵀΣ⁻¹γ

log f_Y(y) from subordinator.

Parameters:
Return type:

Array

conditional_expectations(x)[source]

Compute E[g(Y)|X=x] for EM E-step.

The posterior Y|X follows a GIG-like distribution with parameters:

p_post = p_eff - d/2 a_post = a_eff + γᵀΣ⁻¹γ b_post = b_eff + (x-μ)ᵀΣ⁻¹(x-μ)

Returns dict with keys: E_log_Y, E_inv_Y, E_Y. These are then used in the M-step.

Parameters:

x (Array)

Return type:

Dict[str, Array]

static sufficient_statistics(xy)[source]

t(x,y) = [log y, 1/y, y, x, x/y, vec(xxᵀ/y)] Input: flat vector [x…, y] where x is d-dimensional.

Parameters:

xy (Array)

Return type:

Array

static log_base_measure(xy)[source]

log h(x) for a single unbatched observation.

Parameters:

xy (Array)

Return type:

Array

NormalMixture

NormalMixture — marginal f(x) = ∫ f(x,y) dy.

Owns a JointNormalMixture. Provides:

log_prob(x) — closed-form marginal log-density e_step(X) — jax.vmap over conditional_expectations m_step(X, expectations) — returns new NormalMixture fit(X, …) — convenience EM fitting with multi-start

class normix.mixtures.marginal.NormalMixture(joint)[source]

Bases: Module

Marginal f(x) = ∫₀^∞ f(x,y) dy.

Not an exponential family. Owns a JointNormalMixture (which is).

property joint
property d: int
log_prob(x)[source]

Marginal log f(x). Subclasses provide closed-form; default raises.

Parameters:

x (Array)

Return type:

Array

pdf(x)[source]

Marginal f(x), single observation.

Parameters:

x (Array)

Return type:

Array

mean()[source]

E[X] = μ + γ E[Y].

Return type:

Array

cov()[source]

Cov[X] = E[Y] Σ + Var[Y] γγᵀ.

Return type:

Array

rvs(n, seed=42)[source]

Sample X from the marginal distribution.

Parameters:
Return type:

ndarray

e_step(X, backend='jax')[source]

E-step: compute conditional expectations E[g(Y)|X=xᵢ] for each i.

Returns dict of arrays with shape (n, …) for each expectation.

backend=’jax’ (default): jax.vmap over conditional_expectations.

JIT-able, differentiable.

backend=’cpu’: quad forms in JAX (vmapped) + GIG Bessel on CPU.

Faster for large N; not JIT-able.

Parameters:
Return type:

Dict[str, Array]

m_step(X, expectations, **kwargs)[source]

M-step: update model parameters from sufficient statistics.

Returns a NEW NormalMixture with updated parameters. Subclasses must override _m_step_subordinator.

Parameters:
Return type:

NormalMixture

regularize_det_sigma_one()[source]

Enforce |Σ| = 1 by rescaling.

Σ → Σ/s, γ → γ/s, subordinator params scaled via _scale_subordinator. s = det(Σ)^{1/d}.

Return type:

NormalMixture

marginal_log_likelihood(X)[source]

Mean log-likelihood over dataset.

Parameters:

X (Array)

Return type:

Array

fit(X, *, verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton')[source]

Fit using self as initialization. Returns EMResult.

Parameters:
  • X ((n, d) data array)

  • verbose (0 = silent, 1 = summary, 2 = per-iteration table)

  • max_iter (EM convergence parameters)

  • tol (EM convergence parameters)

  • regularization ('det_sigma_one' or 'none')

  • e_step_backend ('jax' or 'cpu')

  • m_step_backend ('jax' or 'cpu')

  • m_step_method ('newton', 'lbfgs', or 'bfgs')

Return type:

EMResult with .model, .log_likelihoods, .param_changes, etc.

classmethod default_init(X)[source]

Moment-based initialisation from data.

Returns a model with:

mu = sample mean gamma = zeros Sigma = empirical covariance (regularized) subordinator = distribution-specific defaults

Useful as a starting point for model.fit(X).

Parameters:

X (Array)

Return type:

NormalMixture

Mixture Distributions

Variance Gamma

Variance Gamma (VG) distribution.

Special case of GH with GIG → Gamma subordinator (b → 0, p > 0). Y ~ Gamma(α, β), i.e. GIG(p = α, a = 2β, b → 0).

Stored: μ, γ, L_Σ (Cholesky of Σ), α (shape), β (rate) of Gamma.

class normix.distributions.variance_gamma.JointVarianceGamma(mu, gamma, L_Sigma, alpha, beta)[source]

Bases: JointNormalMixture

Joint f(x,y): X|Y ~ N(μ+γy, Σy), Y ~ Gamma(α, β).

GIG limit: p = α, a = 2β, b → 0.

Parameters:
alpha: Array
beta: Array
subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]

θ = [α-1-d/2, -½μᵀΛμ, -(β+½γᵀΛγ), Λγ, Λμ, -½vec(Λ)] (Gamma subordinator: p=α, a=2β, b→0).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]
classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

JointVarianceGamma

class normix.distributions.variance_gamma.VarianceGamma(joint)[source]

Bases: NormalMixture

Marginal Variance Gamma distribution f(x).

Parameters:

joint (JointVarianceGamma)

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]
Return type:

VarianceGamma

log_prob(x)[source]

Marginal VG log-density (own formula, no GH delegation).

f(x) = C * (q/(2c))^{nu/2} * K_nu(sqrt(2qc)) * exp(linear) where nu=alpha-d/2, c=beta+½γᵀΛγ, q=(x-μ)ᵀΛ(x-μ).

Parameters:

x (Array)

Return type:

Array

fit(X, *, verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]

Fit VG using EM. Defaults to CPU E-step (faster than JAX vmap for the degenerate-GIG posterior arising from the Gamma subordinator).

Normal Inverse Gamma

Normal-Inverse Gamma (NInvG) distribution.

Special case of GH with GIG → InverseGamma subordinator (a → 0, p < 0). Y ~ InverseGamma(α, β), i.e. GIG(p = −α, a → 0, b = 2β).

Stored: μ, γ, L_Σ (Cholesky of Σ), α (shape), β (rate) of InverseGamma.

class normix.distributions.normal_inverse_gamma.JointNormalInverseGamma(mu, gamma, L_Sigma, alpha, beta)[source]

Bases: JointNormalMixture

Joint f(x,y): X|Y ~ N(μ+γy, Σy), Y ~ InverseGamma(α, β).

GIG limit: p = −α, a → 0, b = 2β.

Parameters:
alpha: Array
beta: Array
subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]

θ = [-(α+1)-d/2, -(β+½μᵀΛμ), -½γᵀΛγ, Λγ, Λμ, -½vec(Λ)] (InverseGamma subordinator: p=-α, a→0, b=2β).

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]
classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

class normix.distributions.normal_inverse_gamma.NormalInverseGamma(joint)[source]

Bases: NormalMixture

Marginal Normal-Inverse Gamma distribution f(x).

Parameters:

joint (JointNormalInverseGamma)

classmethod from_classical(*, mu, gamma, sigma, alpha, beta)[source]
log_prob(x)[source]

Marginal NInvG log-density (own formula, no GH delegation).

GIG params: p=-α, a=γᵀΛγ, b=2β+Q(x). The normalising integral is 2(b/a)^{p/2} K_p(sqrt(ab)). For the symmetric case (γ≈0, a→0), uses Γ-function closed form.

Parameters:

x (Array)

Return type:

Array

fit(X, *, verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]

Fit NInvG using EM. Defaults to CPU E-step (faster than JAX vmap for the degenerate-GIG posterior arising from the InverseGamma subordinator).

Normal Inverse Gaussian

Normal-Inverse Gaussian (NIG) distribution.

Special case of GH with GIG → InverseGaussian subordinator (p = −½). Y ~ InverseGaussian(μ_IG, λ), i.e. GIG(p = −½, a = λ/μ_IG², b = λ).

Stored: μ, γ, L_Σ (Cholesky of Σ), μ_IG (IG mean), λ (IG shape).

class normix.distributions.normal_inverse_gaussian.JointNormalInverseGaussian(mu, gamma, L_Sigma, mu_ig, lam)[source]

Bases: JointNormalMixture

Joint f(x,y): X|Y ~ N(μ+γy, Σy), Y ~ InverseGaussian(μ_IG, λ).

Stored: μ_IG (IG mean) and λ (IG shape) directly. GIG params: p = −½, a = λ/μ_IG², b = λ.

Parameters:
mu_ig: Array
lam: Array
subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]

θ = [-3/2-d/2, -(b+½μᵀΛμ), -(a+½γᵀΛγ), Λγ, Λμ, -½vec(Λ)] where p=-½, a=λ/μ_IG², b=λ.

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, mu_ig, lam)[source]
classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

class normix.distributions.normal_inverse_gaussian.NormalInverseGaussian(joint)[source]

Bases: NormalMixture

Marginal Normal-Inverse Gaussian distribution f(x).

Parameters:

joint (JointNormalInverseGaussian)

classmethod from_classical(*, mu, gamma, sigma, mu_ig, lam)[source]
log_prob(x)[source]

Marginal NIG log-density.

Uses K_{-1/2}(z) = sqrt(pi/(2z)) exp(-z) for the normalisation, leaving only one log_kv call at order nu = -1/2 - d/2.

Parameters:

x (Array)

Return type:

Array

Generalized Hyperbolic

Generalized Hyperbolic (GH) distribution.

Joint: X|Y ~ N(μ+γy, Σy), Y ~ GIG(p, a, b) Marginal: GH(μ, γ, Σ, p, a, b)

Marginal log-density (closed form via Bessel functions):

Let Q(x) = (x-μ)ᵀΣ⁻¹(x-μ), A = a + γᵀΣ⁻¹γ log f(x) = const

  • (p - d/2) log(Q(x) + b) / 2 - p/2 log A

  • log K_{p-d/2}(√(A(Q(x)+b)))

  • log K_p(√(ab))

  • (x-μ)ᵀΣ⁻¹γ · … (skewness term)

Full formula:
f(x) = C(p,a,b,Σ) · (A/(Q(x)+b))^{(p-d/2)/2}

· K_{p-d/2}(√(A(Q(x)+b))) · exp(γᵀΣ⁻¹(x-μ))

C = (2π)^{-d/2} |Σ|^{-1/2} (a/b)^{p/2} / (2 K_p(√(ab))) · A^{(d/2-p)/2} · …

(see below for precise formula)

Posterior Y|X = x ~ GIG(p - d/2, a + γᵀΣ⁻¹γ, b + (x-μ)ᵀΣ⁻¹(x-μ)) → conditional expectations computed from GIG.

class normix.distributions.generalized_hyperbolic.JointGeneralizedHyperbolic(mu, gamma, L_Sigma, p, a, b)[source]

Bases: JointNormalMixture

Joint f(x,y): X|Y~N(μ+γy, Σy), Y~GIG(p,a,b).

Stored: mu, gamma, L (from JointNormalMixture) + p, a, b (GIG parameters).

Parameters:
p: Array
a: Array
b: Array
subordinator()[source]

Return the fitted subordinator distribution.

Return type:

ExponentialFamily

natural_params()[source]

θ = [p-1-d/2, -(b+½μᵀΣ⁻¹μ), -(a+½γᵀΣ⁻¹γ), Σ⁻¹γ, Σ⁻¹μ, -½vec(Σ⁻¹)]

Return type:

Array

classmethod from_classical(*, mu, gamma, sigma, p, a, b)[source]

Construct from classical parameters.

classmethod from_natural(theta)[source]

Construct from natural parameters θ. Subclasses must override.

Parameters:

theta (Array)

Return type:

JointGeneralizedHyperbolic

class normix.distributions.generalized_hyperbolic.GeneralizedHyperbolic(joint)[source]

Bases: NormalMixture

Marginal Generalized Hyperbolic distribution f(x).

Stores a JointGeneralizedHyperbolic. Provides:

log_prob(x) — closed-form Bessel expression e_step, m_step — for EM fitting fit(X, …) — convenience class method

Parameters:

joint (JointGeneralizedHyperbolic)

classmethod from_classical(*, mu, gamma, sigma, p, a, b)[source]
Return type:

GeneralizedHyperbolic

log_prob(x)[source]

Marginal log f(x).

f(x) ∝ ((Q(x)+b)/A)^{(p-d/2)/2} · K_{p-d/2}(√(A(Q(x)+b))) · exp(γᵀΣ⁻¹(x-μ))

where Q(x) = (x-μ)ᵀΣ⁻¹(x-μ), A = a + γᵀΣ⁻¹γ.

Full normalizing constant:
C = (2π)^{-d/2} |Σ|^{-1/2} · (A/b)^{p/2} · A^{-d/4} · b^{d/4}

/ K_p(√(ab)) · (some power)

Parameters:

x (Array)

Return type:

Array

classmethod default_init(X)[source]

Warm-start from the best of NIG / VG / NInvG sub-model fits.

Runs 5 EM iterations for each special case, converts the winner to GH parametrisation, and uses it as the starting point. Falls back to moment-based default if all sub-models fail.

Parameters:

X (Array)

Return type:

GeneralizedHyperbolic

fit(X, *, verbose=0, max_iter=200, tol=0.001, regularization='det_sigma_one', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]

Fit GH distribution using EM.

Defaults to CPU backends and det_sigma_one regularization (GH has scale non-identifiability requiring |Sigma| = 1).

Fitting

EM Fitters

EM fitters for normix distributions.

Model knows math, fitter knows iteration.

BatchEMFitter — standard batch EM with dual-loop architecture:

lax.scan (JIT-able) or Python for-loop (CPU-compatible)

OnlineEMFitter — online EM, one sample at a time MiniBatchEMFitter — mini-batch EM, Robbins-Monro averaging

class normix.fitting.em.EMResult(model, log_likelihoods, param_changes, n_iter, converged, elapsed_time)[source]

Bases: object

Result of an EM fitting procedure.

Parameters:
model: Any
log_likelihoods: Array | None
param_changes: Array
n_iter: int
converged: bool
elapsed_time: float
class normix.fitting.em.BatchEMFitter(*, max_iter=200, tol=0.001, verbose=0, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton')[source]

Bases: object

Batch EM algorithm with dual-loop architecture.

Runs E-step (all data) -> M-step -> regularize until convergence. Convergence is measured by relative parameter change in the normal parameters (mu, gamma, L_Sigma), excluding subordinator (GIG) parameters.

Loop selection is automatic:
  • lax.scan when both backends are ‘jax’ and verbose <= 1

  • Python for-loop otherwise (CPU backends, or verbose >= 2)

Parameters:
  • max_iter (int) – Maximum number of EM iterations.

  • tol (float) – Convergence tolerance on max relative parameter change.

  • verbose (int) – 0 = silent, 1 = summary, 2 = per-iteration table.

  • regularization (str) – Regularization strategy (‘det_sigma_one’ or ‘none’).

  • e_step_backend (str) – ‘jax’ (default) or ‘cpu’.

  • m_step_backend (str) – ‘jax’ or ‘cpu’ (default, faster for GIG).

  • m_step_method (str) – ‘newton’ (default), ‘lbfgs’, or ‘bfgs’.

fit(model, X)[source]

Run batch EM. Auto-selects lax.scan or Python loop.

Parameters:
  • model (NormalMixture subclass (used as initial parameters))

  • X ((n, d) data array)

Return type:

EMResult with fitted model, convergence diagnostics, and timing.

class normix.fitting.em.OnlineEMFitter(*, tau0=10.0, max_epochs=5, verbose=0, regularization='none')[source]

Bases: object

Online EM algorithm (Robbins-Monro stochastic approximation).

Updates running sufficient statistics with step size 1/(tau0 + t). One epoch = one full pass through data in random order.

Parameters:
fit(model, X, *, key)[source]

Online EM. Returns EMResult.

Parameters:
Return type:

EMResult

class normix.fitting.em.MiniBatchEMFitter(*, batch_size=256, max_iter=200, tol=0.001, tau0=10.0, verbose=0, regularization='none')[source]

Bases: object

Mini-batch EM with Robbins-Monro averaging of sufficient statistics.

Parameters:
fit(model, X, *, key)[source]

Mini-batch EM. Returns EMResult.

Parameters:
Return type:

EMResult

Solvers

Bregman divergence solvers.

Minimises f(θ) − θ·η over θ, where f is any convex function (e.g. the log-partition ψ for an exponential family). At the minimum ∇f(θ*) = η.

Public API

solve_bregman single starting point solve_bregman_multistart multiple starting points (vmap for JAX Newton;

for-loop for quasi-Newton and CPU)

bregman_objective utility: f(θ) − θ·η

Backends × methods

backend=’jax’, method=’newton’ custom lax.scan Newton, autodiff or analytical Hessian backend=’jax’, method=’lbfgs’ jaxopt LBFGSB (bounds native) or LBFGS (reparam) backend=’jax’, method=’bfgs’ jaxopt BFGS with reparameterization for bounds backend=’cpu’, method=’lbfgs’ scipy L-BFGS-B backend=’cpu’, method=’bfgs’ scipy BFGS backend=’cpu’, method=’newton’ scipy trust-exact with Hessian

Gradient / Hessian sources

grad_fnθ → ∇f(θ). For backend=’cpu’: pure CPU (numpy) gradient.

For backend=’jax’, method=’newton’: JAX-traceable ∇ψ(θ). If None with backend=’cpu’: hybrid — jax.grad compiled → NumPy callbacks.

hess_fnθ → ∇²f(θ). Required for method=’newton’.

For backend=’cpu’: pure CPU (numpy) Hessian. For backend=’jax’, method=’newton’: JAX-traceable ∇²ψ(θ). If None with method=’newton’: jax.hessian of the full objective.

Both grad_fn and hess_fn operate in theta-space only. The solver handles all reparameterization internally via the chain rule.

class normix.fitting.solvers.BregmanResult(theta, fun, grad_norm, num_steps, converged, elapsed_time=0.0)[source]

Bases: object

Result of a Bregman divergence minimization.

Scalar fields accept both Python and JAX types so the result can live inside a lax.scan carry without concretization errors.

Parameters:
theta: Array
fun: Any
grad_norm: Any
num_steps: int
converged: Any
elapsed_time: float = 0.0
normix.fitting.solvers.bregman_objective(theta, eta, f)[source]

f(θ) − θ·η — convex dual whose minimum gives ∇f(θ*) = η.

Parameters:
Return type:

Array

normix.fitting.solvers.solve_bregman(f, eta, theta0, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]

Minimise f(θ) − θ·η over θ.

Parameters:
  • f (convex function θ → scalar (e.g. log-partition ψ))

  • eta (target vector (e.g. expectation parameters η))

  • theta0 (initial guess)

  • backend ('jax' (JIT-able) or 'cpu' (scipy, not JIT-able))

  • method ('lbfgs', 'bfgs', or 'newton')

  • bounds (list of (lo, hi) per dimension; None → unconstrained.) – For backend=’jax’: enforced via reparameterization (except method=’lbfgs’ which uses jaxopt.LBFGSB natively). For backend=’cpu’: passed directly as scipy bounds.

  • max_steps (iteration budget)

  • tol (convergence tolerance on ‖∇f(θ) − η‖∞)

  • grad_fn (θ → ∇f(θ).) – For backend=’cpu’: must accept and return numpy arrays. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with backend=’cpu’: falls back to jax.grad (hybrid mode).

  • hess_fn (θ → ∇²f(θ).) – Required for method=’newton’. For backend=’cpu’: must accept numpy arrays and return numpy array. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with method=’newton’: jax.hessian of the full objective is used.

  • verbose (int) – 0 = silent, >= 1 = print summary after solve.

Return type:

BregmanResult

normix.fitting.solvers.solve_bregman_multistart(f, eta, theta0_batch, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]

Run solve_bregman from multiple starting points; return the best result.

Parameters:
Return type:

BregmanResult

Utilities

Bessel Functions

JAX-compatible log modified Bessel function of the second kind.

log_kv(v, z) = log K_v(z), fully pure-JAX with zero scipy callbacks.

Regime-specific methods, selected via lax.cond (only one branch executes):
  1. Hankel asymptotic (DLMF 10.40.2) — large z

  2. Olver uniform expansion (DLMF 10.41.4) — large v

  3. Small-z leading asymptotic (DLMF 10.30.2) — z → 0

  4. Gauss-Legendre quadrature (Takekawa 2022) — moderate z, v

Using lax.cond (not jnp.where) means only the selected branch executes at runtime. lax.cond requires scalar conditions, so the core scalar function _log_kv_scalar is vmapped over array inputs.

Custom JVP for full autodiff:
  • ∂/∂z : exact recurrence K’_v = −(K_{v−1}+K_{v+1})/2

  • ∂/∂v : central FD on log_kv itself (ε = BESSEL_EPS_V)

backend=’jax’ (default): pure-JAX, JIT-able, differentiable. backend=’cpu’ : scipy.special.kve, fully vectorized numpy.

Not JIT-able. Fast for EM hot path.

normix.utils.bessel.log_kv(v, z, backend='jax')[source]

log K_v(z) — log modified Bessel function of the second kind.

Parameters:
  • v (scalar or array — order (any real; K_v = K_{-v}))

  • z (scalar or array — argument (must be > 0))

  • backend ('jax' (default) or 'cpu') –

    ‘jax’pure-JAX, lax.cond regime selection, custom JVP.

    JIT-able, differentiable. Default for log_prob, pdf, etc.

    ’cpu’scipy.special.kve, fully vectorized numpy.

    Not JIT-able. Fast for EM hot path.

Return type:

scalar or array of same broadcast shape as (v, z).

Constants

Shared numerical constants for normix.