Quick Start
Installation
# Using uv (recommended)
uv sync
# Or pip
pip install -e .
normix requires Python ≥ 3.12 and uses Float64 precision. The package
automatically sets jax.config.update("jax_enable_x64", True) on import.
Core Dependencies
Package |
Role |
|---|---|
|
Array computation, autodiff, JIT, vmap |
|
Immutable pytree-based modules |
|
CPU Bessel evaluation (EM hot path) |
|
L-BFGS-B for GIG η→θ optimization |
Univariate Distributions
All univariate distributions are exponential families with three parametrizations:
import jax.numpy as jnp
from normix import Gamma
# Create from classical parameters
dist = Gamma(alpha=jnp.array(2.0), beta=jnp.array(1.0))
# Evaluate log-density on a single observation
dist.log_prob(jnp.array(1.5))
# Three parametrizations
theta = dist.natural_params() # natural parameters θ
eta = dist.expectation_params() # expectation parameters η = E[t(X)]
I = dist.fisher_information() # Fisher information I(θ) = ∇²ψ(θ)
# Reconstruct from natural or expectation parameters
dist2 = Gamma.from_natural(theta)
dist3 = Gamma.from_expectation(eta)
# Maximum likelihood estimation
key = jax.random.PRNGKey(0)
samples = dist.rvs(1000, seed=42)
dist_mle = Gamma.fit_mle(samples)
Available univariate distributions:
Class |
Parameters |
Description |
|---|---|---|
|
|
Shape α > 0, rate β > 0 |
|
|
Shape α > 0, rate β > 0 |
|
|
Mean μ > 0, shape λ > 0 |
|
|
Generalized Inverse Gaussian |
Multivariate Normal
from normix import MultivariateNormal
mu = jnp.zeros(3)
L = jnp.eye(3) # Cholesky factor of covariance
dist = MultivariateNormal(mu=mu, L_Sigma=L)
# Log-density (single observation), batch via vmap
x = jnp.ones(3)
dist.log_prob(x)
log_probs = jax.vmap(dist.log_prob)(X)
Normal Variance-Mean Mixtures
The GH distribution family is modelled as a normal variance-mean mixture:
Each mixture has a marginal class (what users interact with) and a joint class (used internally for the EM E-step).
Marginal Class |
Subordinator |
Parameters |
|---|---|---|
|
Gamma |
|
|
InverseGamma |
|
|
InverseGaussian |
|
|
GIG |
|
Fitting with EM
The simplest way to fit a distribution is via the fit convenience method:
import jax
import jax.numpy as jnp
from normix import GeneralizedHyperbolic
# Generate sample data
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (1000, 5))
# Initialize from data moments, then fit
model = GeneralizedHyperbolic.default_init(X)
result = model.fit(X, max_iter=200, tol=1e-4)
# result is an EMResult with diagnostics
print(f"Converged: {result.converged}")
print(f"Iterations: {result.n_iter}")
print(f"Time: {result.elapsed_time:.2f}s")
# The fitted model
fitted = result.model
For more control, use BatchEMFitter directly:
from normix.fitting.em import BatchEMFitter
fitter = BatchEMFitter(
max_iter=200,
tol=1e-4,
e_step_backend='cpu', # CPU Bessel for speed
m_step_backend='cpu', # CPU solver for GIG η→θ
verbose=1, # print summary
)
result = fitter.fit(model, X)
The e_step_backend='cpu' option routes Bessel function evaluations through
scipy.special.kve instead of JAX, yielding a ~15× speedup for large datasets.
See Design Decisions for the rationale behind this hybrid approach.
Bessel Functions
normix provides a JIT-able, differentiable log_kv (log modified Bessel function
of the second kind):
from normix import log_kv
# JAX backend (JIT-able, differentiable)
log_kv(0.5, 2.0)
# CPU backend (fast, for EM hot path)
log_kv(0.5, 2.0, backend='cpu')
The JAX backend uses a 4-regime dispatch (Hankel, Olver, small-z, Gauss-Legendre)
with @jax.custom_jvp for exact derivatives.
Batching with vmap
All core methods operate on single observations. Use jax.vmap for batching:
# Log-density over a batch
log_probs = jax.vmap(dist.log_prob)(X)
# Sufficient statistics over a batch
T = jax.vmap(type(dist).sufficient_statistics)(X)
This keeps the core implementation clean and lets JAX handle vectorization optimally.