Architecture

This page describes the package structure and class hierarchy of normix.

Package Layout

normix/
├── __init__.py                  # Public API, enables float64
├── exponential_family.py        # ExponentialFamily(eqx.Module) base class
├── distributions/
│   ├── gamma.py                 # Gamma(α, β)
│   ├── inverse_gamma.py         # InverseGamma(α, β)
│   ├── inverse_gaussian.py      # InverseGaussian(μ, λ)
│   ├── generalized_inverse_gaussian.py  # GIG(p, a, b)
│   ├── normal.py                # MultivariateNormal(μ, L_Σ)
│   ├── variance_gamma.py        # VarianceGamma / JointVarianceGamma
│   ├── normal_inverse_gamma.py  # NormalInverseGamma / JointNormalInverseGamma
│   ├── normal_inverse_gaussian.py  # NormalInverseGaussian / JointNormalInverseGaussian
│   └── generalized_hyperbolic.py   # GeneralizedHyperbolic / JointGeneralizedHyperbolic
├── mixtures/
│   ├── joint.py                 # JointNormalMixture(ExponentialFamily)
│   └── marginal.py              # NormalMixture (owns a JointNormalMixture)
├── fitting/
│   ├── em.py                    # EMResult; Batch / Online / MiniBatch EM fitters
│   └── solvers.py               # Bregman divergence solvers (η→θ)
└── utils/
    ├── bessel.py                # log_kv with custom JVP
    ├── constants.py             # Shared numerical constants
    ├── plotting.py              # Notebook plotting helpers
    └── validation.py            # EM validation helpers

Class Hierarchy

eqx.Module
├── ExponentialFamily (abstract)
│   ├── Gamma
│   ├── InverseGamma
│   ├── InverseGaussian
│   ├── GeneralizedInverseGaussian (alias: GIG)
│   ├── MultivariateNormal
│   └── JointNormalMixture (abstract)
│       ├── JointVarianceGamma
│       ├── JointNormalInverseGamma
│       ├── JointNormalInverseGaussian
│       └── JointGeneralizedHyperbolic
│
└── NormalMixture (abstract)
    ├── VarianceGamma
    ├── NormalInverseGamma
    ├── NormalInverseGaussian
    └── GeneralizedHyperbolic

ExponentialFamily

All distributions with a density of the form

\[p(x \mid \theta) = h(x) \exp\!\bigl(\theta^T t(x) - \psi(\theta)\bigr)\]

subclass ExponentialFamily. Subclasses implement four abstract methods:

Method

Purpose

_log_partition_from_theta(theta)

Log-partition function \(\psi(\theta)\)

natural_params()

Compute \(\theta\) from stored classical parameters

sufficient_statistics(x)

Compute \(t(x)\) for a single observation

log_base_measure(x)

Compute \(\log h(x)\)

Everything else is derived automatically:

  • log_prob(x) = \(\log h(x) + t(x) \cdot \theta - \psi(\theta)\)

  • expectation_params() = \(\nabla\psi(\theta)\) via jax.grad

  • fisher_information() = \(\nabla^2\psi(\theta)\) via jax.hessian

Constructors

# From classical parameters (human-readable)
dist = Gamma(alpha=jnp.array(2.0), beta=jnp.array(1.0))
dist = Gamma.from_classical(alpha=2.0, beta=1.0)

# From natural parameters θ
dist = Gamma.from_natural(theta)

# From expectation parameters η (may involve optimization for GIG)
dist = Gamma.from_expectation(eta)

# MLE: η̂ = mean t(xᵢ), then from_expectation
dist = Gamma.fit_mle(X)

# Warm-start fit from current instance
dist = dist.fit(X)

Distributions

Distribution

Stored Attributes

Notes

Gamma

alpha, beta

Shape, rate

InverseGamma

alpha, beta

Shape, rate

InverseGaussian

mu, lam

Mean, shape

GIG

p, a, b

Generalized Inverse Gaussian

MultivariateNormal

mu, L_Sigma

Mean, Cholesky of covariance

Mixture Structure

The GH family is modelled as a normal variance-mean mixture. The joint distribution \(f(x, y)\) is an exponential family. The marginal distribution \(f(x)\) is not.

JointNormalMixture(ExponentialFamily)     f(x, y)
    ├── JointVarianceGamma                Y ~ Gamma
    ├── JointNormalInverseGamma           Y ~ InverseGamma
    ├── JointNormalInverseGaussian        Y ~ InverseGaussian
    └── JointGeneralizedHyperbolic        Y ~ GIG

NormalMixture(eqx.Module)                f(x) = ∫ f(x,y) dy
    ├── VarianceGamma
    ├── NormalInverseGamma
    ├── NormalInverseGaussian
    └── GeneralizedHyperbolic

NormalMixture owns a JointNormalMixture. The joint provides:

  • conditional_expectations(x) — E[log Y|x], E[1/Y|x], E[Y|x] for the EM E-step

  • _mstep_normal_params(...) — closed-form M-step for μ, γ, L_Σ

Marginal Class

Joint Class

Mixing Distribution

VarianceGamma

JointVarianceGamma

\(Y \sim \text{Gamma}(\alpha, \beta)\)

NormalInverseGamma

JointNormalInverseGamma

\(Y \sim \text{InverseGamma}(\alpha, \beta)\)

NormalInverseGaussian

JointNormalInverseGaussian

\(Y \sim \text{InverseGaussian}(\mu, \lambda)\)

GeneralizedHyperbolic

JointGeneralizedHyperbolic

\(Y \sim \text{GIG}(p, a, b)\)

EM Algorithm

The EM fitters implement the expectation-maximisation algorithm.

from normix.fitting.em import BatchEMFitter, EMResult

fitter = BatchEMFitter(max_iter=200, tol=1e-4)
result = fitter.fit(model, X)

EMResult contains:

  • model — the fitted distribution

  • n_iter — number of iterations

  • converged — whether the algorithm converged

  • elapsed_time — wall-clock seconds

  • param_changes — per-iteration max relative parameter change

  • log_likelihoods — per-iteration log-likelihood (optional)

Available fitters:

Fitter

Description

BatchEMFitter

Standard batch EM; supports lax.scan (JIT) or Python loop (CPU)

OnlineEMFitter

Online EM, one sample at a time, Robbins-Monro averaging

MiniBatchEMFitter

Mini-batch EM with Robbins-Monro averaging