API Reference
Base Classes
- class normix.exponential_family.ExponentialFamily[source]
Bases:
ModuleAbstract base class for exponential family distributions.
- Concrete subclasses must implement:
_log_partition_from_theta, natural_params, sufficient_statistics, log_base_measure
- abstractmethod static sufficient_statistics(x)[source]
\(t(x)\) for a single unbatched observation.
- abstractmethod static log_base_measure(x)[source]
\(\log h(x)\) for a single unbatched observation.
- log_prob(x)[source]
\(\log p(x\mid\theta) = \log h(x) + \theta^\top t(x) - \psi(\theta)\), single observation.
- squared_hellinger(other)[source]
Squared Hellinger distance \(H^2(p, q)\).
Default uses the general exponential-family formula via \(\psi\). Subclasses may override for numerically improved variants.
- Parameters:
other (ExponentialFamily)
- Return type:
- kl_divergence(other)[source]
KL divergence \(D_{\mathrm{KL}}(\mathrm{self} \| \mathrm{other})\).
Default uses the Bregman-divergence formula via \(\psi\) and \(\nabla\psi\). Subclasses may override.
- Parameters:
other (ExponentialFamily)
- Return type:
- classmethod from_natural(theta)[source]
Construct from natural parameters θ. Subclasses must override.
- Parameters:
theta (Array)
- Return type:
- classmethod bregman_divergence(theta, eta)[source]
Bregman divergence \(\psi(\theta) - \theta\cdot\eta\) (conjugate dual).
Minimising over \(\theta\) yields \(\nabla\psi(\theta^*) = \eta\), i.e. the natural parameters corresponding to expectation parameters \(\eta\).
- classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='lbfgs', verbose=0)[source]
Construct from expectation parameters \(\eta\) by solving \(\nabla\psi(\theta) = \eta\).
Minimises the Bregman divergence \(\psi(\theta) - \theta\cdot\eta\) via
solve_bregman. Subclasses can override for closed-form inverses.
- classmethod fit_mle(X, *, theta0=None, maxiter=500, tol=1e-10, verbose=0)[source]
MLE via exponential family identity: \(\hat\eta = \frac{1}{n}\sum_i t(x_i)\).
Batches over
Xusingjax.vmap, then callsfrom_expectation(\(\hat\eta\)).- Parameters:
X (jax.Array) –
(n, ...)array of observations.theta0 (jax.Array, optional) – Initial natural parameters \(\theta_0\) for the \(\eta\to\theta\) solver.
maxiter (int) – Maximum iterations for the \(\eta\to\theta\) solver.
tol (float) – Convergence tolerance for the \(\eta\to\theta\) solver.
verbose (int) – 0 = silent, >= 1 = print solver summary.
- Return type:
- fit(X, *, maxiter=500, tol=1e-10, verbose=0, **kwargs)[source]
Fit using self as initialization (warm start).
Computes \(\hat\eta = \frac{1}{n}\sum_i t(x_i)\) and solves
from_expectation(\(\hat\eta\)) usingself.natural_params()as the initial \(\theta_0\).- Parameters:
- Return type:
- classmethod default_init(X)[source]
Moment-based initialisation from data.
Computes \(\hat\eta = \frac{1}{n}\sum_i t(x_i)\) and inverts to get an initial model. For distributions with closed-form
from_expectation(Gamma, InverseGamma, InverseGaussian), this gives the MLE directly.- Parameters:
X (Array)
- Return type:
Univariate Distributions
Gamma
Gamma distribution as an exponential family.
Exponential family structure:
- class normix.distributions.gamma.Gamma(alpha, beta)[source]
Bases:
ExponentialFamilyGamma(\(\alpha\), \(\beta\)) distribution — shape \(\alpha > 0\), rate \(\beta > 0\).
- classmethod from_natural(theta)[source]
Construct from natural parameters θ. Subclasses must override.
- to_gig(*, boundary_eps=0.0)[source]
Exact embedding into the GIG family.
Gamma(\(\alpha\), \(\beta\)) is the \(b \to 0\) limit of GIG(\(p = \alpha,\; a = 2\beta,\; b\)). With
boundary_eps = 0the embedding storesb = 0exactly; pass a small positive value to stay in the strict interior of GIG’s domain (matters only for downstreamexpectation_paramscalls on the lifted GIG).- Parameters:
boundary_eps (float)
- classmethod from_expectation(eta, *, theta0=None, maxiter=100, tol=1e-12, backend='jax', **kwargs)[source]
Closed-form \(\eta \to \theta\) via Newton on \(\psi(\alpha) - \log\alpha = \eta_1 - \log\eta_2\).
\(\eta = [E[\log X],\; E[X]]\) → \(\alpha\) from digamma inversion, \(\beta = \alpha / \eta_2\).
Inverse Gamma
InverseGamma distribution as an exponential family.
Exponential family structure:
- class normix.distributions.inverse_gamma.InverseGamma(alpha, beta)[source]
Bases:
ExponentialFamilyInverseGamma(\(\alpha\), \(\beta\)) — shape \(\alpha > 0\), rate \(\beta > 0\).
- rvs(n, seed=42)[source]
Sample n observations from \(\mathrm{InvGamma}(\alpha, \beta)\) via JAX PRNG.
Uses the relation: if \(X \sim \mathrm{Gamma}(\alpha, 1/\beta)\) then \(1/X \sim \mathrm{InvGamma}(\alpha, \beta)\).
- classmethod from_natural(theta)[source]
Construct from natural parameters θ. Subclasses must override.
- Parameters:
theta (Array)
- Return type:
- to_gig(*, boundary_eps=0.0)[source]
Exact embedding into the GIG family.
InverseGamma(\(\alpha\), \(\beta\)) is the \(a \to 0\) limit of GIG(\(p = -\alpha,\; a,\; b = 2\beta\)). With
boundary_eps = 0the embedding storesa = 0exactly; pass a small positive value to stay in the strict interior of GIG’s domain.- Parameters:
boundary_eps (float)
Inverse Gaussian
Inverse Gaussian (Wald) distribution as an exponential family.
Exponential family structure:
- class normix.distributions.inverse_gaussian.InverseGaussian(mu, lam)[source]
Bases:
ExponentialFamilyInverseGaussian(\(\mu\), \(\lambda\)) — mean \(\mu > 0\), shape \(\lambda > 0\).
- cdf(x)[source]
CDF of the Inverse Gaussian distribution (log-space stable).
\[F(x) = \Phi(t_1) + \exp\!\bigl(2\lambda/\mu + \log\Phi(-t_2)\bigr)\]where \(t_1 = \sqrt{\lambda/x}\,(x/\mu - 1)\) and \(t_2 = \sqrt{\lambda/x}\,(x/\mu + 1)\). The second term uses
log_ndtrto avoid overflow when \(\lambda/\mu\) is large.
- rvs(n, seed=42)[source]
Sample n observations from \(\mathrm{InvGaussian}(\mu, \lambda)\) via JAX PRNG.
Uses the algorithm from Michael, Schucany & Haas (1976):
\(\nu \sim \mathcal{N}(0,1)\), \(y = \nu^2\)
\(x = \mu + \frac{\mu^2 y}{2\lambda} - \frac{\mu}{2\lambda}\sqrt{4\mu\lambda y + \mu^2 y^2}\)
\(z \sim \mathrm{Uniform}(0,1)\); return \(x\) if \(z \le \mu/(\mu+x)\), else \(\mu^2/x\)
Uses
jnp.wherefor vectorized branching over the full sample array.
- classmethod from_natural(theta)[source]
Construct from natural parameters θ. Subclasses must override.
- Parameters:
theta (Array)
- Return type:
- to_gig()[source]
Exact embedding into the GIG family.
InverseGaussian(\(\mu\), \(\lambda\)) is GIG with \(p = -1/2,\; a = \lambda/\mu^2,\; b = \lambda\). No boundary approximation: the embedding lands strictly inside GIG’s domain.
Generalized Inverse Gaussian
Generalized Inverse Gaussian (GIG) distribution as an exponential family.
Exponential family structure:
Special cases:
\(b \to 0,\; p > 0\): GIG → \(\mathrm{Gamma}(p,\; a/2)\)
\(a \to 0,\; p < 0\): GIG → \(\mathrm{InvGamma}(-p,\; b/2)\)
\(p = -1/2\): GIG → InverseGaussian
η→θ rescaling (reduces Fisher condition number):
Solve \(\tilde{\eta} \to \tilde{\theta}\) with symmetric GIG (\(\tilde{a} = \tilde{b}\)), then unscale.
Log-Partition Triad Overrides:
_log_partition_from_theta: JAX, useslog_kv(backend='jax')_grad_log_partition: analytical Bessel ratios (5 \(K_\nu\) calls)_hessian_log_partition: analytical 7-Bessel Hessian in \(\theta\)-space_log_partition_cpu: numpy +log_kv(backend='cpu')_grad_log_partition_cpu: analytical Bessel ratios viascipy.kve_hessian_log_partition_cpu: central FD on_log_partition_cpu
- class normix.distributions.generalized_inverse_gaussian.GeneralizedInverseGaussian(p, a, b)[source]
Bases:
ExponentialFamilyGeneralized Inverse Gaussian distribution.
Stored: \(p\) (shape, any real), \(a > 0\), \(b > 0\).
- static expectation_params_batch(p, a, b, backend='jax')[source]
Vectorized η for arrays of (p, a, b), each shape (N,). Returns (N, 3) array where columns are [E_log_X, E_inv_X, E_X].
backend=’jax’ : vmap over scalar JAX grad backend=’cpu’ : vectorized scipy.kve (6 C-level array calls)
- var()[source]
\(\mathrm{Var}[X] = \partial^2\psi/\partial\theta_3^2\) = Fisher information [2,2].
- Return type:
- rvs(n, seed=42, method='devroye')[source]
Sample n observations from \(\mathrm{GIG}(p, a, b)\).
- Parameters:
n (int) – Sample size.
seed (int) – Integer seed for JAX PRNG (or
scipyrandom_statefor'scipy').method (str) –
Sampling algorithm:
'devroye'(default) — Transformed density rejection (TDR) on \(\log(x)\), pure JAX, no Bessel functions.'pinv'— Numerical inverse CDF (CPU table build + JAX sampling), no Bessel. Best for large n with fixed parameters.'scipy'—scipy.stats.geninvgauss(CPU, original fallback).
- Return type:
- to_gamma()[source]
KL projection onto the
Gammafamily.Minimises \(D_{\mathrm{KL}}(\mathrm{GIG}\,\|\,q)\) over \(q \in \mathrm{Gamma}\) by matching the Gamma sufficient statistics under the source GIG: \(E_{q^*}[\log X] = \eta_1,\; E_{q^*}[X] = \eta_3\). Solved via
Gamma.from_expectation(). Seedocs/tech_notes/distribution_conversions.mdfor the derivation.
- to_inverse_gamma()[source]
KL projection onto the
InverseGammafamily.Matches \(E[-1/X] = -\eta_2,\; E[\log X] = \eta_1\).
- to_inverse_gaussian()[source]
KL projection onto the
InverseGaussianfamily.Matches \(E[X] = \eta_3,\; E[1/X] = \eta_2\). The closed form \(\lambda = 1/(\eta_2 - 1/\eta_3)\) is well-defined whenever \(\eta_2 > 1/\eta_3\) (Jensen, true for every non-degenerate GIG);
InverseGaussian.from_expectation()clamps near degenerate inputs.
- classmethod from_natural(theta)[source]
Construct from natural parameters θ. Subclasses must override.
- Parameters:
theta (Array)
- Return type:
- classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='newton', verbose=0)[source]
\(\eta \to \theta\) via \(\eta\)-rescaling + optimization.
Rescaling makes the Fisher matrix symmetric (\(\tilde{a} = \tilde{b}\)), reducing condition number by up to \(10^{30}\) for extreme \(a/b\) ratios.
- Parameters:
theta0 (jax.Array, optional) – Warm-start point (required for JAX solvers; if
None, uses multi-start CPU solver with Gamma/InvGamma/InvGauss seeds).backend (str) –
'jax'(default, JIT-able) or'cpu'(scipy, more robust).method (str) –
'newton','lbfgs', or'bfgs'.eta (Array)
maxiter (int)
tol (float)
verbose (int)
- Return type:
- normix.distributions.generalized_inverse_gaussian.GIG
alias of
GeneralizedInverseGaussian
Multivariate Distributions
Multivariate Normal distribution.
Stored: mu (d,), L_Sigma (d×d) lower-triangular Cholesky of \(\Sigma\).
All linear algebra via L_Sigma — never form \(\Sigma^{-1}\) explicitly.
Exponential family structure
where \(\operatorname{vec}\) uses row-major order (numpy.ndarray.ravel()).
All parametrization conversions are analytical (closed-form):
classical \(\leftrightarrow\) natural:
natural_params/from_naturalnatural \(\to\) expectation:
_grad_log_partition(analytical override)expectation \(\to\) classical:
from_expectation
No Bregman solver is ever invoked. fit_mle computes
\(\hat\eta = n^{-1}\sum_i t(x_i)\) and calls from_expectation
(closed-form). log_prob overrides the inherited EF formula with a
direct Cholesky computation for efficiency.
- class normix.distributions.normal.MultivariateNormal(mu, L_Sigma)[source]
Bases:
ExponentialFamilyMultivariate Normal distribution as an exponential family.
- Parameters:
Notes
Natural parameters: \(\theta = [\Sigma^{-1}\mu,\; -\tfrac{1}{2}\operatorname{vec}(\Sigma^{-1})]\). Sufficient statistics: \(t(x) = [x,\; \operatorname{vec}(xx^\top)]\). Log-partition: \(\psi(\theta) = \tfrac{1}{2}\mu^\top\Sigma^{-1}\mu - \tfrac{1}{2}\log|\Sigma^{-1}| + \tfrac{d}{2}\log(2\pi)\). Log base measure: \(\log h(x) = 0\).
- natural_params()[source]
\(\theta = [\Sigma^{-1}\mu,\; -\tfrac{1}{2}\operatorname{vec}(\Sigma^{-1})]\)
- Return type:
- classmethod from_classical(mu, sigma)[source]
Construct from mean μ and covariance matrix Σ.
- Return type:
- classmethod from_expectation(eta, **_kwargs)[source]
Closed-form inversion \(\eta \to \theta \to (\mu, L_\Sigma)\).
\(\eta = [E[X],\; \operatorname{vec}(E[XX^\top])]\), so
\[\mu = \eta_1, \qquad \Sigma = \operatorname{reshape}(\eta_2, d, d) - \mu\mu^\top\]- Parameters:
eta (jax.Array) – Expectation parameters of shape
(d + d²,).**_kwargs – Ignored (accepts
backend,theta0, etc. for API compatibility).
- Return type:
- classmethod from_natural(theta)[source]
Construct from natural parameters \(\theta = [\theta_1, \theta_2]\).
Recovers \(\Lambda = -2\,\mathrm{reshape}(\theta_2, d, d)\), then \(\mu = \Lambda^{-1}\theta_1\) and \(L_\Sigma = \mathrm{chol}(\Lambda^{-1})\).
- Parameters:
theta (Array)
- Return type:
- log_prob(x)[source]
\(\log f(x) = -\tfrac{d}{2}\log(2\pi) - \tfrac{1}{2}\log|\Sigma| - \tfrac{1}{2}\|L_\Sigma^{-1}(x-\mu)\|^2\).
Overrides the inherited EF formula for numerical efficiency (Cholesky-direct).
Mixture Base Classes
JointNormalMixture
JointNormalMixture — abstract exponential family for normal variance-mean mixtures.
Joint distribution \(f(x, y)\):
Sufficient statistics:
Natural parameters:
For a GIG subordinator, \(\theta_2,\theta_3\) combine \(-b/2,-a/2\) with the normal quadratic forms so that \(\theta^{\top} t\) matches \(-(a y + b/y)/2\) from \(f_Y\) plus the \(y\)-dependent terms from \(f_{X\mid Y}\).
Log-partition:
Expectation parameters (EM E-step quantities):
EM M-step closed-form (let \(D = 1 - E[1/Y] \cdot E[Y]\)):
- class normix.mixtures.joint.JointNormalMixture(mu, gamma, L_Sigma)[source]
Bases:
ExponentialFamilyAbstract joint distribution \(f(x, y)\) for normal variance-mean mixtures.
Stored:
mu(d,),gamma(d,),L_Sigma(d×d lower Cholesky of \(\Sigma\)). Subordinator parameters defined by concrete subclasses.- log_det_sigma()[source]
\(\log|\Sigma| = 2\sum_i \log L_{ii}\), via Cholesky diagonal.
- Return type:
- log_prob_joint(x, y)[source]
\(\log f(x, y) = \log f(x\mid y) + \log f_Y(y)\).
\[\log f(x\mid y) = -\tfrac{d}{2}\log(2\pi) - \tfrac{1}{2}\log|\Sigma| - \tfrac{d}{2}\log y - \frac{1}{2y}\|L^{-1}(x-\mu)\|^2 + \gamma^\top\Sigma^{-1}(x-\mu) - \tfrac{y}{2}\gamma^\top\Sigma^{-1}\gamma\]\(\log f_Y(y)\) from subordinator.
- conditional_expectations(x)[source]
Compute \(E[g(Y)\mid X=x]\) for EM E-step.
The posterior \(Y\mid X\) follows a GIG-like distribution:
\[p_{\mathrm{post}} = p_{\mathrm{eff}} - d/2, \quad a_{\mathrm{post}} = a_{\mathrm{eff}} + \gamma^\top\Sigma^{-1}\gamma, \quad b_{\mathrm{post}} = b_{\mathrm{eff}} + (x-\mu)^\top\Sigma^{-1}(x-\mu)\]Returns dict with keys:
E_log_Y,E_inv_Y,E_Y.
- static sufficient_statistics(xy)[source]
\(t(x,y) = [\log y,\; 1/y,\; y,\; x,\; x/y,\; \mathrm{vec}(xx^\top/y)]\).
Input: flat vector \([x_1,\ldots,x_d, y]\).
- classmethod from_expectation(eta, **kwargs)[source]
Construct from expectation parameters \(\eta\).
Two input forms are supported:
NormalMixtureEta— the natural η pytree of the joint normal-variance-mean mixture; uses the closed-form M-step (\(\mu, \gamma, \Sigma\) analytical; subordinator via_subordinator_from_eta()).flat
jax.Array— the generic Bregman solver inherited fromExponentialFamily.
The pytree path is the canonical η→θ map for these distributions: it is exact (no Bregman iterations on the normal block) and uses the subordinator’s own
from_expectation(closed-form for Gamma / InverseGamma / InverseGaussian; numerical for GIG).- Parameters:
eta (NormalMixtureEta or jax.Array) – Expectation parameters.
**kwargs – For the pytree path: forwarded to
_subordinator_from_eta()(e.g.backend,method,maxiter,theta0for warm-starting GIG). For the flat-array path: forwarded to the parent solver.
- Return type:
NormalMixture
Marginal mixture base classes.
MarginalMixture is the abstract interface fitters and the
divergences module depend on. NormalMixture is the
full-covariance implementation, owning a
JointNormalMixture. A factor-analysis
implementation lives in the sibling class FactorNormalMixture
(see docs/design/mixtures.md § 6).
NormalMixture provides:
log_prob(x)— closed-form marginal log-densitye_step(X)—jax.vmap()over conditional expectationsm_step(X, expectations)— returns newNormalMixturefit(X, ...)— convenience EM fitting with multi-start
- class normix.mixtures.marginal.MarginalMixture[source]
Bases:
ModuleAbstract interface for marginal mixtures used by the EM fitters.
Concrete subclasses pick the storage of the Gaussian dispersion (full Cholesky in
NormalMixture, low-rank-plus-diagonal inFactorNormalMixture) and the type of the EM expectation pytree (NormalMixtureEtavs.FactorMixtureStats).The fitter depends only on this contract; it does not know which storage form a model uses.
- abstractmethod e_step(X, *, backend='jax')[source]
E-step: aggregated expectation parameters for the batch.
- abstractmethod m_step(eta, **kwargs)[source]
Full M-step: updates all parameters; returns a new model.
- Parameters:
eta (Any)
- Return type:
- abstractmethod m_step_normal(eta)[source]
M-step for normal parameters only (MCECM cycle 1).
- Parameters:
eta (Any)
- Return type:
- abstractmethod m_step_subordinator(eta, **kwargs)[source]
M-step for subordinator parameters only (MCECM cycle 2).
- Parameters:
eta (Any)
- Return type:
- abstractmethod compute_eta_from_model()[source]
Reconstruct the expectation pytree from the model’s own parameters.
- Return type:
- abstractmethod em_convergence_params()[source]
Pytree whose leaf-wise change measures EM convergence.
Subordinator parameters are intentionally excluded (their solver has its own tolerance, and including them inflates iteration counts). For full-covariance models this is
(mu, gamma, L_Sigma); for factor-analysis models it is(mu, gamma, F F^T + D)to sidestep the rotational gauge.- Return type:
- class normix.mixtures.marginal.NormalMixture(joint)[source]
Bases:
MarginalMixtureMarginal \(f(x) = \int_0^\infty f(x,y)\,dy\) for a normal variance-mean mixture.
Not an exponential family. Owns a
JointNormalMixture(which is). The classical parameters \((\mu, \gamma, \Sigma, \text{subordinator})\) are forwarded as read-only properties; usereplace()to obtain a new model with updated parameters (modules are immutable).- property joint: JointNormalMixture
- sigma()[source]
Dispersion \(\Sigma = L_\Sigma L_\Sigma^\top\) (forwarded from the joint).
Distinct from
cov(), which returns the marginal covariance \(E[Y]\,\Sigma + \mathrm{Var}[Y]\,\gamma\gamma^\top\).- Return type:
- cov()[source]
\(\mathrm{Cov}[X] = E[Y]\,\Sigma + \mathrm{Var}[Y]\,\gamma\gamma^\top\).
- Return type:
- squared_hellinger(other)[source]
Squared Hellinger distance via joint distributions (upper bound on marginal).
- Parameters:
other (NormalMixture)
- Return type:
- kl_divergence(other)[source]
KL divergence via joint distributions.
- Parameters:
other (NormalMixture)
- Return type:
- compute_eta_from_model()[source]
Reconstruct \(\eta\) from the model’s own parameters.
Uses the marginal expectations of the joint sufficient statistics:
\[\eta_4 = \mu + \gamma\,E[Y], \quad \eta_5 = \mu\,E[1/Y] + \gamma, \quad \eta_6 = \Sigma + \mu\mu^\top E[1/Y] + \gamma\gamma^\top E[Y] + \mu\gamma^\top + \gamma\mu^\top\]- Return type:
NormalMixtureEta
- e_step(X, backend='jax')[source]
Full E-step: subordinator conditionals + batch aggregation.
Returns a
NormalMixtureEtawith the six aggregated expectation parameters.
- classmethod from_expectation(eta, **kwargs)[source]
Construct from expectation parameters \(\eta\).
Wraps
JointNormalMixture.from_expectation(), which performs the exact closed-form M-step on the normal block and the subordinator’sfrom_expectationon the subordinator block.This is the canonical η→model map: any prior or shrinkage target \(\eta_0\) can be inspected as a concrete model via
cls.from_expectation(eta_0).sigma()etc.- Parameters:
eta (NormalMixtureEta) – Aggregated expectation parameters.
**kwargs – Forwarded to
JointNormalMixture.from_expectation()(e.g.backend,method,maxiter,theta0).
- Return type:
- m_step(eta, **kwargs)[source]
Full M-step: update normal params + subordinator from \(\eta\).
Equivalent to
type(self).from_expectation(eta, **kwargs);selfis only used to dispatch on the subclass. Subclasses with iterative subordinator solvers (e.g.GeneralizedHyperbolic) may override to inject a warm-start \(\theta_0\).- Parameters:
eta (NormalMixtureEta)
- Return type:
- m_step_normal(eta)[source]
M-step for normal parameters only (MCECM Cycle 1).
Updates \(\mu, \gamma, \Sigma\); subordinator unchanged.
- Parameters:
eta (NormalMixtureEta)
- Return type:
- m_step_subordinator(eta, **kwargs)[source]
M-step for the subordinator only (MCECM Cycle 2).
Reads the subordinator-relevant fields of
eta; normal parameters are read fromself._jointand copied unchanged. Subclasses with iterative solvers may override to add warm-start or sanity-check fallbacks.- Parameters:
eta (NormalMixtureEta)
- Return type:
- replace(**updates)[source]
Return a new model with selected parameters replaced.
Accepts any subset of:
normal parameters:
mu,gamma,L_Sigma;dispersion alias
sigma— converted toL_Sigmavia Cholesky (mutually exclusive withL_Sigma);subordinator parameters declared by
_subordinator_keys()(e.g.alpha, betafor VG / NInvG,mu_ig, lamfor NIG,p, a, bfor GH).
The actual storage lives in
joint; this method does an immutable update viaequinox.tree_at().Examples
>>> vg2 = vg.replace(mu=new_mu) # change μ >>> vg3 = vg.replace(alpha=2.5, beta=0.5) # change subordinator >>> vg4 = vg.replace(sigma=sigma2 * jnp.eye(d)) # set Σ via covariance
- Return type:
- regularize_det_sigma(target_log_det=0.0)[source]
Rescale to enforce \(\log|\Sigma| = \mathrm{target\_log\_det}\).
Picks \(s = \exp((\log|\Sigma| - \tau)/d)\) and applies
_rescale(). The defaulttarget_log_det = 0recovers the \(|\Sigma| = 1\) convention; passing the log-determinant of an initial reference Σ implements thedet_sigma_xfamily.- Parameters:
target_log_det (float)
- Return type:
- regularize_det_sigma_one()[source]
Enforce \(|\Sigma| = 1\). Alias for
regularize_det_sigma()withtarget_log_det = 0.- Return type:
- regularize_a_eq_b()[source]
Rescale subordinator so that \(a = b = \sqrt{ab}\) for GIG-parameterised families.
Default implementation is a no-op; override in subclasses with both \(a, b > 0\) (currently GH and NIG; VG and NInvG have a degenerate
a=0orb=0and the default no-op is the right behaviour).- Return type:
- em_convergence_params()[source]
Pytree whose leaf-wise change measures EM convergence.
Returns
(mu, gamma, L_Sigma). Subordinator parameters (p, a, b) are excluded — their solver has its own tolerance and including them inflates iteration counts.
Mixture Distributions
Variance Gamma
Variance Gamma (VG) distribution.
Special case of GH with GIG → Gamma subordinator (\(b \to 0\), \(p > 0\)). \(Y \sim \mathrm{Gamma}(\alpha, \beta)\), i.e. GIG(\(p = \alpha\), \(a = 2\beta\), \(b \to 0\)).
Stored: \(\mu\), \(\gamma\), \(L_\Sigma\) (Cholesky of \(\Sigma\)), \(\alpha\) (shape), \(\beta\) (rate) of Gamma.
- class normix.distributions.variance_gamma.JointVarianceGamma(mu, gamma, L_Sigma, alpha, beta)[source]
Bases:
JointNormalMixtureJoint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{Gamma}(\alpha, \beta)\).
GIG limit: \(p = \alpha\), \(a = 2\beta\), \(b \to 0\).
- natural_params()[source]
\(\theta = [\alpha-1-d/2,\; -\tfrac{1}{2}\mu^\top\Lambda\mu,\; -(\beta+\tfrac{1}{2}\gamma^\top\Lambda\gamma),\; \Lambda\gamma,\; \Lambda\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Lambda)]\)
(Gamma subordinator: \(p=\alpha\), \(a=2\beta\), \(b\to 0\)).
- Return type:
- class normix.distributions.variance_gamma.VarianceGamma(joint)[source]
Bases:
NormalMixtureMarginal Variance Gamma distribution f(x).
- Parameters:
joint (JointVarianceGamma)
- log_prob(x)[source]
Marginal VG log-density (own formula, no GH delegation).
\[f(x) \propto \left(\frac{q}{2c}\right)^{\nu/2} K_\nu\!\left(\sqrt{2qc}\right) \exp(\gamma^\top\Sigma^{-1}(x-\mu))\]where \(\nu = \alpha - d/2\), \(c = \beta + \tfrac{1}{2}\gamma^\top\Lambda\gamma\), \(q = (x-\mu)^\top\Lambda(x-\mu)\).
- fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]
Fit VG using EM or MCECM. Defaults to CPU E-step (faster than JAX vmap for the degenerate-GIG posterior arising from the Gamma subordinator).
Normal Inverse Gamma
Normal-Inverse Gamma (NInvG) distribution.
Special case of GH with GIG → InverseGamma subordinator (\(a \to 0\), \(p < 0\)). \(Y \sim \mathrm{InvGamma}(\alpha, \beta)\), i.e. GIG(\(p = -\alpha\), \(a \to 0\), \(b = 2\beta\)).
Stored: \(\mu\), \(\gamma\), \(L_\Sigma\) (Cholesky of \(\Sigma\)), \(\alpha\) (shape), \(\beta\) (rate) of InverseGamma.
- class normix.distributions.normal_inverse_gamma.JointNormalInverseGamma(mu, gamma, L_Sigma, alpha, beta)[source]
Bases:
JointNormalMixtureJoint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{InvGamma}(\alpha, \beta)\).
GIG limit: \(p = -\alpha\), \(a \to 0\), \(b = 2\beta\).
- natural_params()[source]
\(\theta = [-(\alpha+1)-d/2,\; -(\beta+\tfrac{1}{2}\mu^\top\Lambda\mu),\; -\tfrac{1}{2}\gamma^\top\Lambda\gamma,\; \Lambda\gamma,\; \Lambda\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Lambda)]\)
(InverseGamma subordinator: \(p=-\alpha\), \(a\to 0\), \(b=2\beta\)).
- Return type:
- class normix.distributions.normal_inverse_gamma.NormalInverseGamma(joint)[source]
Bases:
NormalMixtureMarginal Normal-Inverse Gamma distribution f(x).
- Parameters:
joint (JointNormalInverseGamma)
- log_prob(x)[source]
Marginal NInvG log-density (own formula, no GH delegation).
GIG params: \(p=-\alpha\), \(a=\gamma^\top\Lambda\gamma\), \(b=2\beta+Q(x)\). The normalising integral is \(2(b/a)^{p/2} K_p(\sqrt{ab})\).
- fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]
Fit NInvG using EM or MCECM. Defaults to CPU E-step (faster than JAX vmap for the degenerate-GIG posterior arising from the InverseGamma subordinator).
Normal Inverse Gaussian
Normal-Inverse Gaussian (NIG) distribution.
Special case of GH with GIG → InverseGaussian subordinator (\(p = -1/2\)). \(Y \sim \mathrm{InvGaussian}(\mu_{IG}, \lambda)\), i.e. GIG(\(p = -1/2\), \(a = \lambda/\mu_{IG}^2\), \(b = \lambda\)).
Stored: \(\mu\), \(\gamma\), \(L_\Sigma\) (Cholesky of \(\Sigma\)), \(\mu_{IG}\) (IG mean), \(\lambda\) (IG shape).
- class normix.distributions.normal_inverse_gaussian.JointNormalInverseGaussian(mu, gamma, L_Sigma, mu_ig, lam)[source]
Bases:
JointNormalMixtureJoint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{InvGaussian}(\mu_{IG}, \lambda)\).
Stored: \(\mu_{IG}\) (IG mean) and \(\lambda\) (IG shape) directly. GIG params: \(p = -1/2\), \(a = \lambda/\mu_{IG}^2\), \(b = \lambda\).
- natural_params()[source]
\(\theta = [-3/2-d/2,\; -(\lambda/2+\tfrac{1}{2}\mu^\top\Lambda\mu),\; -(\lambda/(2\mu_{IG}^2)+\tfrac{1}{2}\gamma^\top\Lambda\gamma),\; \Lambda\gamma,\; \Lambda\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Lambda)]\)
where \(p=-1/2\), \(a=\lambda/\mu_{IG}^2\), \(b=\lambda\), aligned with GIG natural parameters on \([\log y,\,1/y,\,y]\).
- Return type:
- class normix.distributions.normal_inverse_gaussian.NormalInverseGaussian(joint)[source]
Bases:
NormalMixtureMarginal Normal-Inverse Gaussian distribution f(x).
- Parameters:
joint (JointNormalInverseGaussian)
- log_prob(x)[source]
Marginal NIG log-density.
Uses \(K_{-1/2}(z) = \sqrt{\pi/(2z)}\,e^{-z}\) for the normalisation, leaving only one \(\log K_\nu\) call at order \(\nu = -1/2 - d/2\).
Generalized Hyperbolic
Generalized Hyperbolic (GH) distribution.
Joint: \(X \mid Y \sim \mathcal{N}(\mu + \gamma y, \Sigma y)\), \(Y \sim \mathrm{GIG}(p, a, b)\).
Marginal: \(\mathrm{GH}(\mu, \gamma, \Sigma, p, a, b)\).
Marginal log-density (closed form via Bessel functions):
Let \(Q(x) = (x-\mu)^\top \Sigma^{-1}(x-\mu)\), \(A = a + \gamma^\top \Sigma^{-1} \gamma\).
Posterior: \(Y \mid X = x \sim \mathrm{GIG}(p - d/2,\; a + \gamma^\top\Sigma^{-1}\gamma,\; b + (x-\mu)^\top\Sigma^{-1}(x-\mu))\).
- class normix.distributions.generalized_hyperbolic.JointGeneralizedHyperbolic(mu, gamma, L_Sigma, p, a, b)[source]
Bases:
JointNormalMixtureJoint \(f(x,y)\): \(X\mid Y \sim \mathcal{N}(\mu+\gamma y, \Sigma y)\), \(Y \sim \mathrm{GIG}(p, a, b)\).
Stored:
mu,gamma,L_Sigma(fromJointNormalMixture) +p,a,b(GIG parameters).- natural_params()[source]
\(\theta = [p-1-d/2,\; -(b/2+\tfrac{1}{2}\mu^\top\Sigma^{-1}\mu),\; -(a/2+\tfrac{1}{2}\gamma^\top\Sigma^{-1}\gamma),\; \Sigma^{-1}\gamma,\; \Sigma^{-1}\mu,\; -\tfrac{1}{2}\mathrm{vec}(\Sigma^{-1})]\)
The scalar coefficients on sufficient statistics \(1/y\) and \(y\) match the GIG convention \(\theta_{\mathrm{GIG}} = [p-1,\,-b/2,\,-a/2]\) on \(t_Y = [\log y,\,1/y,\,y]\).
- Return type:
- classmethod from_classical(*, mu, gamma, sigma, p, a, b)[source]
Construct from classical parameters.
- classmethod from_natural(theta)[source]
Recover classical parameters from \(\theta\).
\(p = \theta_1 + 1 + d/2\), \(b = -2\theta_2 - 2\mu_{\mathrm{quad}}\), \(a = -2\theta_3 - 2\gamma_{\mathrm{quad}}\).
- Parameters:
theta (Array)
- Return type:
- to_joint_variance_gamma()[source]
KL projection onto
JointVarianceGamma(Gamma subordinator).- Return type:
- class normix.distributions.generalized_hyperbolic.GeneralizedHyperbolic(joint)[source]
Bases:
NormalMixtureMarginal Generalized Hyperbolic distribution \(f(x)\).
Stores a
JointGeneralizedHyperbolic. Provides:log_prob(x)— closed-form Bessel expressione_step,m_step— for EM fittingfit(X, ...)— convenience fitting method
- Parameters:
joint (JointGeneralizedHyperbolic)
- log_prob(x)[source]
Marginal \(\log f(x)\).
\[f(x) \propto \left(\frac{A}{Q(x)+b}\right)^{(d/2-p)/2} K_{p-d/2}\!\left(\sqrt{A(Q(x)+b)}\right) \exp\!\left(\gamma^\top\Sigma^{-1}(x-\mu)\right)\]where \(Q(x) = (x-\mu)^\top\Sigma^{-1}(x-\mu)\), \(A = a + \gamma^\top\Sigma^{-1}\gamma\).
- m_step(eta, **kwargs)[source]
Full M-step with warm-started + sanity-checked subordinator solve.
Overrides the base
cls.from_expectation(eta)because the GIG solver benefits substantially from warm-starting at the current \(\theta\) and from a fall-back when the solver wanders into unsane regions; both are managed bym_step_subordinator().- Return type:
- m_step_subordinator(eta, **kwargs)[source]
M-step for the subordinator only (MCECM Cycle 2).
Reads the subordinator-relevant fields of
eta; normal parameters are read fromself._jointand copied unchanged. Subclasses with iterative solvers may override to add warm-start or sanity-check fallbacks.- Return type:
- regularize_a_eq_b()[source]
Rescale so \(a = b = \sqrt{ab}\) (orbit invariant).
Picks \(s = \sqrt{a/b}\). Idempotent: applying twice leaves the model unchanged.
- Return type:
- to_normal_inverse_gaussian()[source]
KL projection onto the
NormalInverseGaussianfamily.- Return type:
- classmethod default_init(X)[source]
Warm-start from the best of NIG / VG / NInvG sub-model fits.
Runs 5 EM iterations (JAX backend, Newton method) for each special case, converts each to GH parametrisation, and selects the candidate with the highest marginal log-likelihood. A moment-based fallback (\(p=1, a=1, b=1\)) is included as a fourth candidate.
Fully JAX-native: no try/except, no Python branching on data values.
- Parameters:
X (Array)
- Return type:
- fit(X, *, algorithm='em', verbose=0, max_iter=200, tol=0.001, regularization='det_sigma_one', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]
Fit GH distribution using EM or MCECM.
Defaults to CPU backends and det_sigma_one regularization (GH has scale non-identifiability requiring |Sigma| = 1).
Fitting
EM Fitters
EM fitters for normix distributions.
Model knows math, fitter knows iteration.
- BatchEMFitter — standard batch EM with dual-loop architecture:
lax.scan (JIT-able) or Python for-loop (CPU-compatible)
IncrementalEMFitter — online / mini-batch EM with pluggable eta update rules
- class normix.fitting.em.EMResult(model, log_likelihoods, param_changes, n_iter, converged, elapsed_time)[source]
Bases:
objectResult of an EM fitting procedure.
- Parameters:
- class normix.fitting.em.BatchEMFitter(*, algorithm='em', max_iter=200, tol=0.001, verbose=0, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton', eta_update=None)[source]
Bases:
objectBatch EM / MCECM algorithm with dual-loop architecture.
EM (default): E-step → M-step (all params) → regularize.
MCECM: E-step → M-step (normal params only) → regularize → E-step → M-step (subordinator only).
Convergence is measured by relative parameter change in the normal parameters (mu, gamma, L_Sigma), excluding subordinator (GIG) parameters.
- Loop selection is automatic:
lax.scan when both backends are ‘jax’, verbose <= 1, algorithm=’em’, and no eta_update rule
Python for-loop otherwise
- Parameters:
algorithm (str) – ‘em’ (default) or ‘mcecm’.
max_iter (int) – Maximum number of iterations.
tol (float) – Convergence tolerance on max relative parameter change.
verbose (int) – 0 = silent, 1 = summary, 2 = per-iteration table.
regularization (str) –
Strategy applied after each M-step:
'none'— no regularization.'det_sigma_one'— enforce \(|\Sigma| = 1\) (the original GH convention; equivalent toregularize_det_sigma(target_log_det=0)).'det_sigma_x'— enforce \(|\Sigma| = |\Sigma_0|\) where \(\Sigma_0\) is the dispersion of the initial model passed tofit(). Useful when comparing GH / FactorGH parameters against VG / NIG / NInvG, which leave \(|\Sigma|\) at the empirical scale.'a_eq_b'— rescale the GIG subordinator so \(a = b = \sqrt{ab}\). Trivial no-op for VG / NInvG / MultivariateNormal.
e_step_backend (str) – ‘jax’ (default) or ‘cpu’.
m_step_backend (str) – ‘jax’ or ‘cpu’ (default, faster for GIG).
m_step_method (str) – ‘newton’ (default), ‘lbfgs’, or ‘bfgs’.
eta_update (EtaUpdateRule or None) – Optional eta combination rule (e.g.
Shrinkage(IdentityUpdate(), eta0, tau)). When set, the E-step output is transformed before the M-step.
- class normix.fitting.em.IncrementalEMFitter(*, eta_update=None, batch_size=256, max_steps=200, inner_iter=1, verbose=0, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton')[source]
Bases:
objectIncremental EM with pluggable eta update rules.
Replaces
OnlineEMFitterandMiniBatchEMFitter. Processes data in random mini-batches, applies anEtaUpdateRuleto combine the running \(\eta\) with each batch estimate, then M-steps on the combined \(\eta\).- Parameters:
eta_update (EtaUpdateRule) – How to combine running η with each batch estimate.
batch_size (int) – Observations per batch.
max_steps (int) – Number of batches to process (total budget).
inner_iter (int) – 1 = online (default); >1 = fine-tuning on each batch.
verbose (int) –
0 = silent; 1 = periodic summary.
Scan path (JAX backends only):
verbosemust be0so diagnostics do not rely on Python side effects each step.regularization (str) – Same options as
BatchEMFitter—'none'|'det_sigma_one'|'det_sigma_x'|'a_eq_b'.e_step_backend (str) – Passed through to
e_step/m_step.m_step_backend (str) – Passed through to
e_step/m_step.m_step_method (str) – Passed through to
e_step/m_step.
Solvers
Bregman divergence solvers.
Minimises f(θ) − θ·η over θ, where f is any convex function (e.g. the log-partition ψ for an exponential family). At the minimum ∇f(θ*) = η.
Public API
solve_bregman single starting point solve_bregman_multistart multiple starting points (vmap for JAX Newton;
for-loop for quasi-Newton and CPU)
bregman_objective utility: f(θ) − θ·η
make_jit_newton_solver build a stable @jax.jit Newton solve specialised
to a fixed
(f, grad_fn, hess_fn, bounds)— repeated calls with matching shapes/dtypes hit the XLA cache, avoiding the per-call re-tracing thatsolve_bregmanincurs from fresh closures.
Backends × methods
backend=’jax’, method=’newton’ custom lax.scan Newton, autodiff or analytical Hessian backend=’jax’, method=’lbfgs’ jaxopt LBFGSB (bounds native) or LBFGS (reparam) backend=’jax’, method=’bfgs’ jaxopt BFGS with reparameterization for bounds backend=’cpu’, method=’lbfgs’ scipy L-BFGS-B backend=’cpu’, method=’bfgs’ scipy BFGS backend=’cpu’, method=’newton’ scipy trust-exact with Hessian
Gradient / Hessian sources
- grad_fnθ → ∇f(θ). For backend=’cpu’: pure CPU (numpy) gradient.
For backend=’jax’, method=’newton’: JAX-traceable ∇ψ(θ). If None with backend=’cpu’: hybrid — jax.grad compiled → NumPy callbacks.
- hess_fnθ → ∇²f(θ). Required for method=’newton’.
For backend=’cpu’: pure CPU (numpy) Hessian. For backend=’jax’, method=’newton’: JAX-traceable ∇²ψ(θ). If None with method=’newton’: jax.hessian of the full objective.
Both grad_fn and hess_fn operate in theta-space only. The solver handles all reparameterization internally via the chain rule.
- class normix.fitting.solvers.BregmanResult(theta, fun, grad_norm, num_steps, converged, elapsed_time=0.0)[source]
Bases:
objectResult of a Bregman divergence minimization.
Scalar fields accept both Python and JAX types so the result can live inside a lax.scan carry without concretization errors.
- Parameters:
- normix.fitting.solvers.bregman_objective(theta, eta, f)[source]
f(θ) − θ·η — convex dual whose minimum gives ∇f(θ*) = η.
- normix.fitting.solvers.solve_bregman(f, eta, theta0, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]
Minimise f(θ) − θ·η over θ.
- Parameters:
f (convex function θ → scalar (e.g. log-partition ψ))
eta (target vector (e.g. expectation parameters η))
theta0 (initial guess)
backend ('jax' (JIT-able) or 'cpu' (scipy, not JIT-able))
method ('lbfgs', 'bfgs', or 'newton')
bounds (
(lower, upper)pair of JAX arrays, each shape(d,);) –None→ unconstrained. For backend=’jax’: enforced via reparameterization. For backend=’cpu’: converted to scipy format internally.max_steps (iteration budget)
tol (convergence tolerance on ‖∇f(θ) − η‖∞)
grad_fn (θ → ∇f(θ).) – For backend=’cpu’: must accept and return numpy arrays. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with backend=’cpu’: falls back to jax.grad (hybrid mode).
hess_fn (θ → ∇²f(θ).) – Required for method=’newton’. For backend=’cpu’: must accept numpy arrays and return numpy array. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with method=’newton’: jax.hessian of the full objective is used.
verbose (int) – 0 = silent, >= 1 = print summary after solve.
- Return type:
- normix.fitting.solvers.solve_bregman_multistart(f, eta, theta0_batch, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]
Run solve_bregman from multiple starting points; return the best result.
- Parameters:
theta0_batch ((K, dim) jax.Array for backend='jax', method='newton') – (parallel via vmap); list of arrays otherwise (sequential for-loop).
verbose (int) – 0 = silent, >= 1 = print summary.
eta (Array)
backend (str)
method (str)
max_steps (int)
tol (float)
grad_fn (Callable | None)
hess_fn (Callable | None)
- Return type:
- normix.fitting.solvers.make_jit_newton_solver(f, grad_fn, hess_fn, bounds=None)[source]
Build a
@jax.jit-decorated Newton solver specialised to one problem.The returned callable has signature
solve(eta, theta0, max_steps=20, tol=1e-10) -> (theta_opt, fun, grad_norm, converged)wheremax_stepsis a static argument (required bylax.scan).All distribution-level inputs (
f,grad_fn,hess_fn,bounds) are baked into the closure at construction time. Repeated calls with the same array shapes and dtypes therefore reuse the compiled XLA executable.Use this in EM hot paths where
solve_bregmanwould otherwise build a fresh Python closure on every call and force JAX to re-trace the same Newton kernel on each iteration.- Parameters:
- Returns:
Jit-compiled Newton solver. Returns a 4-tuple of JAX arrays
(theta, fun, grad_norm, converged); wrap inBregmanResultexternally if needed.- Return type:
Callable
Utilities
Bessel Functions
JAX-compatible log modified Bessel function of the second kind.
log_kv(v, z) = log K_v(z), fully pure-JAX with zero scipy callbacks.
- Regime-specific methods, selected via lax.cond (only one branch executes):
Hankel asymptotic (DLMF 10.40.2) — large z
Olver uniform expansion (DLMF 10.41.4) — large v
Small-z leading asymptotic (DLMF 10.30.2) — z → 0
Gauss-Legendre quadrature (Takekawa 2022) — moderate z, v
Using lax.cond (not jnp.where) means only the selected branch executes at runtime. lax.cond requires scalar conditions, so the core scalar function _log_kv_scalar is vmapped over array inputs.
- Custom JVP for full autodiff:
∂/∂z : exact recurrence K’_v = −(K_{v−1}+K_{v+1})/2
∂/∂v : central FD on log_kv itself (ε = BESSEL_EPS_V)
backend=’jax’ (default): pure-JAX, JIT-able, differentiable. backend=’cpu’ : scipy.special.kve, fully vectorized numpy.
Not JIT-able. Fast for EM hot path.
- normix.utils.bessel.log_kv(v, z, backend='jax')[source]
\(\log K_v(z)\) — log modified Bessel function of the second kind.
- Parameters:
v (scalar or array) – Order (any real; \(K_v = K_{-v}\)).
z (scalar or array) – Argument (must be > 0).
backend (str, optional) –
'jax'(default) or'cpu'.'jax': pure-JAX,lax.condregime selection, custom JVP. JIT-able, differentiable. Default forlog_prob,pdf, etc.'cpu':scipy.special.kve, fully vectorised NumPy. Not JIT-able. Fast for EM hot path.
- Returns:
Same broadcast shape as
(v, z).- Return type:
Examples
Evaluate at a single point (JAX backend, JIT-able):
>>> import jax.numpy as jnp >>> from normix import log_kv >>> float(log_kv(v=0.5, z=1.0)) -0.112...
CPU backend (uses scipy, faster for EM hot-paths):
>>> float(log_kv(v=0.5, z=1.0, backend='cpu')) -0.112...
Symmetry \(K_v(z) = K_{-v}(z)\):
>>> abs(float(log_kv(0.5, 2.0)) - float(log_kv(-0.5, 2.0))) < 1e-10 True
Differentiable via JAX:
>>> import jax >>> dlogkv_dz = jax.grad(lambda z: log_kv(0.5, z))(jnp.array(1.0)) >>> float(dlogkv_dz) < 0 # K_v decreases with z True
Constants
Shared numerical constants for normix.