Quickstart

Contents

Quickstart#

Fit a Generalized Hyperbolic distribution to data and evaluate it — the whole loop in a dozen lines.

import jax
jax.config.update("jax_enable_x64", True)   # always enable float64 first
import jax.numpy as jnp

from normix import GeneralizedHyperbolic

# Some 3-D data (your returns, measurements, ...)
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (2000, 3))

# Initialize from data moments, then fit by EM
model = GeneralizedHyperbolic.default_init(X)
result = model.fit(X, max_iter=100, tol=1e-3)
fitted = result.model

print(f"converged: {result.converged} in {result.n_iter} iters")
converged: True in 75 iters

The fit comes back inside an EMResult; result.model is the trained distribution. Everything you would expect is one call away:

# Log-density on a batch (the core acts on one observation; vmap to batch)
log_p = jax.vmap(fitted.log_prob)(X)

print("mean:\n", fitted.mean())
print("covariance:\n", fitted.cov())
print("mean log-likelihood:", float(fitted.marginal_log_likelihood(X)))

# Draw fresh samples from the fitted model
samples = fitted.rvs(5, seed=1)
print("samples shape:", samples.shape)
mean:
 [-0.04316198 -0.0334625   0.01275726]
covariance:
 [[ 1.01905896 -0.02488201 -0.00869068]
 [-0.02488201  0.98313119  0.00842108]
 [-0.00869068  0.00842108  1.01237155]]
mean log-likelihood: -4.253829846135175
samples shape: (5, 3)

That is the entire workflow: default_initfit → use the model. Swap GeneralizedHyperbolic for any other family (VarianceGamma, NormalInverseGaussian, …) and the code is identical.

Where to next#