Architecture
This page describes the package structure and class hierarchy of normix.
Package Layout
normix/
├── __init__.py # Public API, enables float64
├── exponential_family.py # ExponentialFamily(eqx.Module) base class
├── distributions/
│ ├── gamma.py # Gamma(α, β)
│ ├── inverse_gamma.py # InverseGamma(α, β)
│ ├── inverse_gaussian.py # InverseGaussian(μ, λ)
│ ├── generalized_inverse_gaussian.py # GIG(p, a, b)
│ ├── normal.py # MultivariateNormal(μ, L_Σ)
│ ├── variance_gamma.py # VarianceGamma / JointVarianceGamma
│ ├── normal_inverse_gamma.py # NormalInverseGamma / JointNormalInverseGamma
│ ├── normal_inverse_gaussian.py # NormalInverseGaussian / JointNormalInverseGaussian
│ └── generalized_hyperbolic.py # GeneralizedHyperbolic / JointGeneralizedHyperbolic
├── mixtures/
│ ├── joint.py # JointNormalMixture(ExponentialFamily)
│ └── marginal.py # NormalMixture (owns a JointNormalMixture)
├── fitting/
│ ├── em.py # EMResult; Batch / Online / MiniBatch EM fitters
│ └── solvers.py # Bregman divergence solvers (η→θ)
└── utils/
├── bessel.py # log_kv with custom JVP
├── constants.py # Shared numerical constants
├── plotting.py # Notebook plotting helpers
└── validation.py # EM validation helpers
Class Hierarchy
eqx.Module
├── ExponentialFamily (abstract)
│ ├── Gamma
│ ├── InverseGamma
│ ├── InverseGaussian
│ ├── GeneralizedInverseGaussian (alias: GIG)
│ ├── MultivariateNormal
│ └── JointNormalMixture (abstract)
│ ├── JointVarianceGamma
│ ├── JointNormalInverseGamma
│ ├── JointNormalInverseGaussian
│ └── JointGeneralizedHyperbolic
│
└── NormalMixture (abstract)
├── VarianceGamma
├── NormalInverseGamma
├── NormalInverseGaussian
└── GeneralizedHyperbolic
ExponentialFamily
All distributions with a density of the form
subclass ExponentialFamily. Subclasses implement four abstract methods:
Method |
Purpose |
|---|---|
|
Log-partition function \(\psi(\theta)\) |
|
Compute \(\theta\) from stored classical parameters |
|
Compute \(t(x)\) for a single observation |
|
Compute \(\log h(x)\) |
Everything else is derived automatically:
log_prob(x)= \(\log h(x) + t(x) \cdot \theta - \psi(\theta)\)expectation_params()= \(\nabla\psi(\theta)\) viajax.gradfisher_information()= \(\nabla^2\psi(\theta)\) viajax.hessian
Constructors
# From classical parameters (human-readable)
dist = Gamma(alpha=jnp.array(2.0), beta=jnp.array(1.0))
dist = Gamma.from_classical(alpha=2.0, beta=1.0)
# From natural parameters θ
dist = Gamma.from_natural(theta)
# From expectation parameters η (may involve optimization for GIG)
dist = Gamma.from_expectation(eta)
# MLE: η̂ = mean t(xᵢ), then from_expectation
dist = Gamma.fit_mle(X)
# Warm-start fit from current instance
dist = dist.fit(X)
Distributions
Distribution |
Stored Attributes |
Notes |
|---|---|---|
|
|
Shape, rate |
|
|
Shape, rate |
|
|
Mean, shape |
|
|
Generalized Inverse Gaussian |
|
|
Mean, Cholesky of covariance |
Mixture Structure
The GH family is modelled as a normal variance-mean mixture. The joint distribution \(f(x, y)\) is an exponential family. The marginal distribution \(f(x)\) is not.
JointNormalMixture(ExponentialFamily) f(x, y)
├── JointVarianceGamma Y ~ Gamma
├── JointNormalInverseGamma Y ~ InverseGamma
├── JointNormalInverseGaussian Y ~ InverseGaussian
└── JointGeneralizedHyperbolic Y ~ GIG
NormalMixture(eqx.Module) f(x) = ∫ f(x,y) dy
├── VarianceGamma
├── NormalInverseGamma
├── NormalInverseGaussian
└── GeneralizedHyperbolic
NormalMixture owns a JointNormalMixture. The joint provides:
conditional_expectations(x)— E[log Y|x], E[1/Y|x], E[Y|x] for the EM E-step_mstep_normal_params(...)— closed-form M-step for μ, γ, L_Σ
Marginal Class |
Joint Class |
Mixing Distribution |
|---|---|---|
|
|
\(Y \sim \text{Gamma}(\alpha, \beta)\) |
|
|
\(Y \sim \text{InverseGamma}(\alpha, \beta)\) |
|
|
\(Y \sim \text{InverseGaussian}(\mu, \lambda)\) |
|
|
\(Y \sim \text{GIG}(p, a, b)\) |
EM Algorithm
The EM fitters implement the expectation-maximisation algorithm.
from normix.fitting.em import BatchEMFitter, EMResult
fitter = BatchEMFitter(max_iter=200, tol=1e-4)
result = fitter.fit(model, X)
EMResult contains:
model— the fitted distributionn_iter— number of iterationsconverged— whether the algorithm convergedelapsed_time— wall-clock secondsparam_changes— per-iteration max relative parameter changelog_likelihoods— per-iteration log-likelihood (optional)
Available fitters:
Fitter |
Description |
|---|---|
|
Standard batch EM; supports |
|
Online EM, one sample at a time, Robbins-Monro averaging |
|
Mini-batch EM with Robbins-Monro averaging |