Initialization and multi-start#
EM converges to a local optimum of the likelihood, so where it starts matters.
normix gives you three tools: default_init for a data-driven starting model,
warm-starting through theta0, and jax.vmap to run many starts in parallel.
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
from normix import NormalInverseGaussian, Gamma
from normix.utils.plotting import set_theme
set_theme()
np.set_printoptions(precision=4, suppress=True)
default_init: a data-driven start#
default_init(X) matches the empirical mean and covariance and picks reasonable
subordinator parameters, giving EM a sensible place to begin. Starting there, the
fit converges quickly:
true = NormalInverseGaussian.from_classical(
mu=jnp.array([0.0, 0.0]),
gamma=jnp.array([0.5, -0.4]),
sigma=jnp.array([[1.0, 0.4], [0.4, 1.0]]),
mu_ig=1.0, lam=1.2)
X = true.rvs(3000, seed=0)
init = NormalInverseGaussian.default_init(X)
ll0 = float(init.marginal_log_likelihood(X))
res = init.fit(X, max_iter=100, tol=1e-4, e_step_backend="cpu")
ll1 = float(res.model.marginal_log_likelihood(X))
print(f"mean log-lik at default_init : {ll0:.4f}")
print(f"mean log-lik after EM : {ll1:.4f} ({res.n_iter} iters)")
mean log-lik at default_init : -2.8339
mean log-lik after EM : -2.7405 (70 iters)
Multi-start for robustness#
A single arbitrary start can land in a worse optimum. Running EM from several
random initializations and keeping the best fit guards against this. Here we
compare random starts against default_init:
emp_cov = jnp.asarray(np.cov(np.asarray(X), rowvar=False))
rng = np.random.default_rng(1)
def random_init(seed):
g = jnp.asarray(rng.normal(scale=0.6, size=2))
return NormalInverseGaussian.from_classical(
mu=jnp.asarray(X.mean(axis=0)), gamma=g, sigma=emp_cov,
mu_ig=float(rng.uniform(0.5, 2.0)), lam=float(rng.uniform(0.5, 2.0)))
scores = []
for s in range(5):
r = random_init(s).fit(X, max_iter=100, tol=1e-4, e_step_backend="cpu")
scores.append(float(r.model.marginal_log_likelihood(X)))
print("random-start mean log-liks:", np.round(scores, 4))
print(f"best random start : {max(scores):.4f}")
print(f"default_init start : {ll1:.4f}")
random-start mean log-liks: [-2.7405 -2.7405 -2.7405 -2.7405 -2.7405]
best random start : -2.7405
default_init start : -2.7405
The best random start matches default_init, while the worst trails it — the
takeaway is to use default_init (it is already a strong start) and to keep
the best of several starts when the likelihood surface is rugged.
Vectorized starts with jax.vmap#
For exponential-family distributions, the \(\eta \mapsto \theta\) solve in
from_expectation is pure JAX, so it vmaps. That lets us fit many datasets —
or many bootstrap resamples — in a single vectorized call instead of a Python
loop:
g_true = Gamma(alpha=jnp.array(2.0), beta=jnp.array(1.5))
datasets = jnp.stack([g_true.rvs(2000, seed=s) for s in range(8)]) # (8, 2000)
# Each dataset's expectation parameters, then a single batched solve.
etas = jax.vmap(lambda d: jax.vmap(Gamma.sufficient_statistics)(d).mean(0))(datasets)
fits = jax.vmap(Gamma.from_expectation)(etas) # batched Gamma pytree
print("fitted alphas:", np.asarray(fits.alpha))
print("fitted betas :", np.asarray(fits.beta))
print("alpha mean ± sd: %.3f ± %.3f" % (float(fits.alpha.mean()), float(fits.alpha.std())))
fitted alphas: [2.1107 1.9357 1.9964 2.0283 1.9852 1.993 2.0164 2.0651]
fitted betas : [1.6125 1.4889 1.4707 1.5019 1.4967 1.515 1.5366 1.5359]
alpha mean ± sd: 2.016 ± 0.050
The result fits is a single Gamma pytree whose leaves carry a leading batch
axis — exactly the shape you want for bootstrap confidence intervals.
Warm-starting with theta0#
fit_mle and from_expectation accept a theta0 to seed the solver. A warm
start near the solution converges in fewer iterations and lands at the same
optimum:
X1 = g_true.rvs(5000, seed=0)
cold = Gamma.fit_mle(X1)
warm = Gamma.fit_mle(X1, theta0=g_true.natural_params())
print("cold start (alpha, beta):", float(cold.alpha), float(cold.beta))
print("warm start (alpha, beta):", float(warm.alpha), float(warm.beta))
cold start (alpha, beta): 2.022789272342743 1.509325547948308
warm start (alpha, beta): 2.022789272342743 1.509325547948308
Takeaways#
default_init(X)is a moment-matched starting model; EM converges quickly from it.For rugged likelihoods, run several starts and keep the highest-likelihood fit.
jax.vmapoverfrom_expectationfits many datasets/resamples at once;theta0warm-starts the solver.
Next, the Divergences between models tutorial measures how close two fitted models are.