Quickstart#
Fit a Generalized Hyperbolic distribution to data and evaluate it — the whole loop in a dozen lines.
import jax
jax.config.update("jax_enable_x64", True) # always enable float64 first
import jax.numpy as jnp
from normix import GeneralizedHyperbolic
# Some 3-D data (your returns, measurements, ...)
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (2000, 3))
# Initialize from data moments, then fit by EM
model = GeneralizedHyperbolic.default_init(X)
result = model.fit(X, max_iter=100, tol=1e-3)
fitted = result.model
print(f"converged: {result.converged} in {result.n_iter} iters")
converged: True in 75 iters
The fit comes back inside an EMResult; result.model is the trained
distribution. Everything you would expect is one call away:
# Log-density on a batch (the core acts on one observation; vmap to batch)
log_p = jax.vmap(fitted.log_prob)(X)
print("mean:\n", fitted.mean())
print("covariance:\n", fitted.cov())
print("mean log-likelihood:", float(fitted.marginal_log_likelihood(X)))
# Draw fresh samples from the fitted model
samples = fitted.rvs(5, seed=1)
print("samples shape:", samples.shape)
mean:
[-0.04316198 -0.0334625 0.01275726]
covariance:
[[ 1.01905896 -0.02488201 -0.00869068]
[-0.02488201 0.98313119 0.00842108]
[-0.00869068 0.00842108 1.01237155]]
mean log-likelihood: -4.253829846135175
samples shape: (5, 3)
That is the entire workflow: default_init → fit → use the model. Swap
GeneralizedHyperbolic for any other family (VarianceGamma,
NormalInverseGaussian, …) and the code is identical.
Where to next#
Your first model, step by step — the same workflow, explained step by step.
Distributions — choosing the right distribution.
The exponential family — the structure underneath it all.