API Reference
Base Classes
- class normix.exponential_family.ExponentialFamily[source]
Bases:
ModuleAbstract base class for exponential family distributions.
- Concrete subclasses must implement:
_log_partition_from_theta, natural_params, sufficient_statistics, log_base_measure
- expectation_params(backend='jax')[source]
η = ∇ψ(θ).
- Parameters:
backend ('jax' (default, JIT-able) or 'cpu' (numpy/scipy))
- Return type:
- fisher_information(backend='jax')[source]
I(θ) = ∇²ψ(θ).
- Parameters:
backend ('jax' (default, JIT-able) or 'cpu' (numpy/scipy))
- Return type:
- classmethod from_natural(theta)[source]
Construct from natural parameters θ. Subclasses must override.
- Parameters:
theta (Array)
- Return type:
- classmethod bregman_divergence(theta, eta)[source]
Bregman divergence ψ(θ) − θ·η (conjugate dual).
Minimising over θ yields ∇ψ(θ*) = η, i.e. the natural parameters corresponding to expectation parameters η.
- classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='lbfgs', verbose=0)[source]
Construct from expectation parameters η by solving ∇ψ(θ) = η.
Minimises the Bregman divergence ψ(θ) − θ·η via solve_bregman. Subclasses can override for closed-form inverses.
- Parameters:
- Return type:
- classmethod fit_mle(X, *, theta0=None, maxiter=500, tol=1e-10, verbose=0)[source]
MLE via exponential family identity: η̂ = mean_i t(xᵢ).
Batches over X using jax.vmap, then calls from_expectation(η̂).
- Parameters:
X ((n, ...) array of observations)
theta0 (optional initial natural parameters θ₀ for the η→θ solver)
maxiter (maximum iterations for the η→θ solver)
tol (convergence tolerance for the η→θ solver)
verbose (0 = silent, >= 1 = print solver summary)
- Return type:
- fit(X, *, maxiter=500, tol=1e-10, verbose=0, **kwargs)[source]
Fit using self as initialization (warm start).
Computes η̂ = mean_i t(xᵢ) and solves from_expectation(η̂) using
self.natural_params()as the initial theta0.- Parameters:
X ((n, ...) array of observations)
maxiter (maximum iterations for the η→θ solver)
tol (convergence tolerance for the η→θ solver)
verbose (0 = silent, >= 1 = print solver summary)
- Return type:
Univariate Distributions
Gamma
Gamma distribution as an exponential family.
PDF: p(x|α,β) = β^α/Γ(α) · x^{α-1} · exp(-βx), x > 0
- Exponential family:
h(x) = 1 t(x) = [log x, x] θ = [α-1, -β] (θ₁ > -1, θ₂ < 0) ψ(θ) = log Γ(θ₁+1) − (θ₁+1) log(−θ₂) η = [ψ(α) − log β, α/β] (digamma, mean)
- class normix.distributions.gamma.Gamma(alpha, beta)[source]
Bases:
ExponentialFamilyGamma(α, β) distribution — shape α > 0, rate β > 0.
Inverse Gamma
InverseGamma distribution as an exponential family.
PDF: p(x|α,β) = β^α/Γ(α) · x^{-α-1} · exp(-β/x), x > 0
- Exponential family:
h(x) = 1 t(x) = [-1/x, log x] θ = [β, -(α+1)] (θ₁ > 0, θ₂ < -1) ψ(θ) = log Γ(-θ₂-1) − (-θ₂-1) log(θ₁)
= log Γ(α) − α log β
η = [-α/β, log β − ψ(α)] (E[-1/X], E[log X])
- class normix.distributions.inverse_gamma.InverseGamma(alpha, beta)[source]
Bases:
ExponentialFamilyInverseGamma(α, β) — shape α > 0, rate β > 0.
Inverse Gaussian
Inverse Gaussian (Wald) distribution as an exponential family.
PDF: f(x|μ,λ) = √(λ/(2π)) · x^{-3/2} · exp(−λ(x−μ)²/(2μ²x)), x > 0
- Exponential family:
h(x) = (2π)^{-1/2} · x^{-3/2} t(x) = [x, 1/x] θ = [−λ/(2μ²), −λ/2] (θ₁ < 0, θ₂ < 0) ψ(θ) = ½log(2π) − ½log(−2θ₂) + √((−2θ₁)(−2θ₂))
(the ½log(2π) is absorbed into log h(x) = −½log(2π) − 3/2 log x)
η = [E[X], E[1/X]] = [μ, 1/μ + 1/λ]
- class normix.distributions.inverse_gaussian.InverseGaussian(mu, lam)[source]
Bases:
ExponentialFamilyInverseGaussian(μ, λ) — mean μ > 0, shape λ > 0.
Generalized Inverse Gaussian
Generalized Inverse Gaussian (GIG) distribution as an exponential family.
PDF: f(x|p,a,b) = (a/b)^{p/2} / (2 K_p(√(ab))) · x^{p-1} · exp(-(ax+b/x)/2)
- Exponential family:
h(x) = 1 t(x) = [log x, 1/x, x] θ = [p-1, -b/2, -a/2] (θ₂ ≤ 0, θ₃ ≤ 0) ψ(θ) = log 2 + log K_p(√(ab)) + (p/2) log(b/a)
where p = θ₁+1, a = -2θ₃, b = -2θ₂
η = [E[log X], E[1/X], E[X]]
- Special cases:
b→0, p>0: GIG → Gamma(p, a/2) a→0, p<0: GIG → InverseGamma(-p, b/2) p=-1/2: GIG → InverseGaussian
- η→θ optimization uses η-rescaling to reduce Fisher condition number:
s = √(η₂/η₃), η̃ = (η₁+½log(η₂/η₃), √(η₂η₃), √(η₂η₃)) Solve η̃→θ̃ with symmetric GIG (ã=b̃), then unscale.
Log-Partition Triad Overrides
_log_partition_from_theta : JAX, uses log_kv(backend=’jax’) _grad_log_partition : inherits default jax.grad (custom JVP on log_kv) _hessian_log_partition : analytical 7-Bessel Hessian in θ-space _log_partition_cpu : numpy + log_kv(backend=’cpu’) _grad_log_partition_cpu : analytical Bessel ratios via scipy.kve _hessian_log_partition_cpu : central FD on _log_partition_cpu
- class normix.distributions.generalized_inverse_gaussian.GeneralizedInverseGaussian(p, a, b)[source]
Bases:
ExponentialFamilyGeneralized Inverse Gaussian distribution.
Stored: p (shape, any real), a > 0, b > 0.
- static expectation_params_batch(p, a, b, backend='jax')[source]
Vectorized η for arrays of (p, a, b), each shape (N,). Returns (N, 3) array where columns are [E_log_X, E_inv_X, E_X].
backend=’jax’ : vmap over scalar JAX grad backend=’cpu’ : vectorized scipy.kve (6 C-level array calls)
- classmethod from_natural(theta)[source]
Construct from natural parameters θ. Subclasses must override.
- Parameters:
theta (Array)
- Return type:
- classmethod from_expectation(eta, *, theta0=None, maxiter=500, tol=1e-10, backend='jax', method='newton', verbose=0)[source]
η → θ via η-rescaling + optimization.
Rescaling makes the Fisher matrix symmetric (ã = b̃), reducing condition number by up to \(10^{30}\) for extreme a/b ratios.
- Parameters:
- Return type:
- normix.distributions.generalized_inverse_gaussian.GIG
alias of
GeneralizedInverseGaussian
Multivariate Distributions
Multivariate Normal distribution.
Stored: mu (d,), L_Sigma (d,d) lower-triangular Cholesky of Σ. All linear algebra via L_Sigma — never form Σ⁻¹ explicitly.
PDF: f(x) = (2π)^{-d/2} |Σ|^{-1/2} exp(-½(x-μ)ᵀΣ⁻¹(x-μ))
Exponential family:
In practice we use the Cholesky log_prob formula directly.
- class normix.distributions.normal.MultivariateNormal(mu, L_Sigma)[source]
Bases:
ModuleMultivariate Normal with Cholesky parametrization.
- classmethod from_classical(mu, sigma)[source]
Construct from mean μ and covariance matrix Σ.
- Return type:
Mixture Base Classes
JointNormalMixture
JointNormalMixture — abstract exponential family for normal variance-mean mixtures.
- Joint distribution f(x,y):
X|Y ~ N(μ + γy, Σy) Y ~ subordinator (GIG, Gamma, InverseGamma, InverseGaussian)
- Sufficient statistics:
t(x,y) = [log y, 1/y, y, x, x/y, vec(xxᵀ/y)]
- Natural parameters:
θ₁ = p_sub - 1 - d/2 (GIG p; scalar, depends on subordinator) θ₂ = -(b_sub + ½μᵀΣ⁻¹μ) < 0 θ₃ = -(a_sub + ½γᵀΣ⁻¹γ) < 0 θ₄ = Σ⁻¹γ (d-vector) θ₅ = Σ⁻¹μ (d-vector) θ₆ = -½vec(Σ⁻¹) (d²-vector)
- Log partition:
ψ = ψ_sub(p, a, b) + ½log|Σ| + μᵀΣ⁻¹γ
- Expectation parameters (EM E-step quantities):
η₁ = E[log Y] η₂ = E[1/Y] η₃ = E[Y] η₄ = E[X] = μ + γ E[Y] η₅ = E[X/Y] = μ E[1/Y] + γ η₆ = E[XXᵀ/Y] = Σ + μμᵀ E[1/Y] + γγᵀ E[Y] + μγᵀ + γμᵀ
- EM M-step closed-form (from η):
Let D = 1 - E[1/Y]·E[Y] μ = (E[X] - E[Y]·E[X/Y]) / D γ = (E[X/Y] - E[1/Y]·E[X]) / D Σ = E[XXᵀ/Y] - E[X/Y]μᵀ - μE[X/Y]ᵀ + E[1/Y]μμᵀ - E[Y]γγᵀ
- class normix.mixtures.joint.JointNormalMixture(mu, gamma, L_Sigma)[source]
Bases:
ExponentialFamilyAbstract joint distribution f(x, y) for normal variance-mean mixtures.
Stored: mu (d,), gamma (d,), L_Sigma (d,d lower Cholesky of Σ) Subordinator parameters defined by concrete subclasses.
- log_prob_joint(x, y)[source]
log f(x, y) = log f(x|y) + log f_Y(y).
- log f(x|y) = -d/2 log(2π) - ½ log|Σy| - ½(x-μ-γy)ᵀ(Σy)⁻¹(x-μ-γy)
- = -d/2 log(2π) - ½ log|Σ| - d/2 log y
1/(2y) ‖L⁻¹(x-μ)‖² + γᵀΣ⁻¹(x-μ) - y/2 γᵀΣ⁻¹γ
log f_Y(y) from subordinator.
- conditional_expectations(x)[source]
Compute E[g(Y)|X=x] for EM E-step.
- The posterior Y|X follows a GIG-like distribution with parameters:
p_post = p_eff - d/2 a_post = a_eff + γᵀΣ⁻¹γ b_post = b_eff + (x-μ)ᵀΣ⁻¹(x-μ)
Returns dict with keys: E_log_Y, E_inv_Y, E_Y. These are then used in the M-step.
NormalMixture
NormalMixture — marginal f(x) = ∫ f(x,y) dy.
- Owns a JointNormalMixture. Provides:
log_prob(x) — closed-form marginal log-density e_step(X) — jax.vmap over conditional_expectations m_step(X, expectations) — returns new NormalMixture fit(X, …) — convenience EM fitting with multi-start
- class normix.mixtures.marginal.NormalMixture(joint)[source]
Bases:
ModuleMarginal f(x) = ∫₀^∞ f(x,y) dy.
Not an exponential family. Owns a JointNormalMixture (which is).
- property joint
- e_step(X, backend='jax')[source]
E-step: compute conditional expectations E[g(Y)|X=xᵢ] for each i.
Returns dict of arrays with shape (n, …) for each expectation.
- backend=’jax’ (default): jax.vmap over conditional_expectations.
JIT-able, differentiable.
- backend=’cpu’: quad forms in JAX (vmapped) + GIG Bessel on CPU.
Faster for large N; not JIT-able.
- m_step(X, expectations, **kwargs)[source]
M-step: update model parameters from sufficient statistics.
Returns a NEW NormalMixture with updated parameters. Subclasses must override _m_step_subordinator.
- Parameters:
- Return type:
- regularize_det_sigma_one()[source]
Enforce |Σ| = 1 by rescaling.
Σ → Σ/s, γ → γ/s, subordinator params scaled via _scale_subordinator. s = det(Σ)^{1/d}.
- Return type:
- fit(X, *, verbose=0, max_iter=200, tol=0.001, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton')[source]
Fit using self as initialization. Returns EMResult.
- Parameters:
X ((n, d) data array)
verbose (0 = silent, 1 = summary, 2 = per-iteration table)
max_iter (EM convergence parameters)
tol (EM convergence parameters)
regularization ('det_sigma_one' or 'none')
e_step_backend ('jax' or 'cpu')
m_step_backend ('jax' or 'cpu')
m_step_method ('newton', 'lbfgs', or 'bfgs')
- Return type:
EMResult with .model, .log_likelihoods, .param_changes, etc.
Mixture Distributions
Variance Gamma
Variance Gamma (VG) distribution.
Special case of GH with GIG → Gamma subordinator (b → 0, p > 0). Y ~ Gamma(α, β), i.e. GIG(p = α, a = 2β, b → 0).
Stored: μ, γ, L_Σ (Cholesky of Σ), α (shape), β (rate) of Gamma.
- class normix.distributions.variance_gamma.JointVarianceGamma(mu, gamma, L_Sigma, alpha, beta)[source]
Bases:
JointNormalMixtureJoint f(x,y): X|Y ~ N(μ+γy, Σy), Y ~ Gamma(α, β).
GIG limit: p = α, a = 2β, b → 0.
- natural_params()[source]
θ = [α-1-d/2, -½μᵀΛμ, -(β+½γᵀΛγ), Λγ, Λμ, -½vec(Λ)] (Gamma subordinator: p=α, a=2β, b→0).
- Return type:
- class normix.distributions.variance_gamma.VarianceGamma(joint)[source]
Bases:
NormalMixtureMarginal Variance Gamma distribution f(x).
- Parameters:
joint (JointVarianceGamma)
Normal Inverse Gamma
Normal-Inverse Gamma (NInvG) distribution.
Special case of GH with GIG → InverseGamma subordinator (a → 0, p < 0). Y ~ InverseGamma(α, β), i.e. GIG(p = −α, a → 0, b = 2β).
Stored: μ, γ, L_Σ (Cholesky of Σ), α (shape), β (rate) of InverseGamma.
- class normix.distributions.normal_inverse_gamma.JointNormalInverseGamma(mu, gamma, L_Sigma, alpha, beta)[source]
Bases:
JointNormalMixtureJoint f(x,y): X|Y ~ N(μ+γy, Σy), Y ~ InverseGamma(α, β).
GIG limit: p = −α, a → 0, b = 2β.
- class normix.distributions.normal_inverse_gamma.NormalInverseGamma(joint)[source]
Bases:
NormalMixtureMarginal Normal-Inverse Gamma distribution f(x).
- Parameters:
joint (JointNormalInverseGamma)
Normal Inverse Gaussian
Normal-Inverse Gaussian (NIG) distribution.
Special case of GH with GIG → InverseGaussian subordinator (p = −½). Y ~ InverseGaussian(μ_IG, λ), i.e. GIG(p = −½, a = λ/μ_IG², b = λ).
Stored: μ, γ, L_Σ (Cholesky of Σ), μ_IG (IG mean), λ (IG shape).
- class normix.distributions.normal_inverse_gaussian.JointNormalInverseGaussian(mu, gamma, L_Sigma, mu_ig, lam)[source]
Bases:
JointNormalMixtureJoint f(x,y): X|Y ~ N(μ+γy, Σy), Y ~ InverseGaussian(μ_IG, λ).
Stored: μ_IG (IG mean) and λ (IG shape) directly. GIG params: p = −½, a = λ/μ_IG², b = λ.
- class normix.distributions.normal_inverse_gaussian.NormalInverseGaussian(joint)[source]
Bases:
NormalMixtureMarginal Normal-Inverse Gaussian distribution f(x).
- Parameters:
joint (JointNormalInverseGaussian)
Generalized Hyperbolic
Generalized Hyperbolic (GH) distribution.
Joint: X|Y ~ N(μ+γy, Σy), Y ~ GIG(p, a, b) Marginal: GH(μ, γ, Σ, p, a, b)
- Marginal log-density (closed form via Bessel functions):
Let Q(x) = (x-μ)ᵀΣ⁻¹(x-μ), A = a + γᵀΣ⁻¹γ log f(x) = const
(p - d/2) log(Q(x) + b) / 2 - p/2 log A
log K_{p-d/2}(√(A(Q(x)+b)))
log K_p(√(ab))
(x-μ)ᵀΣ⁻¹γ · … (skewness term)
- Full formula:
- f(x) = C(p,a,b,Σ) · (A/(Q(x)+b))^{(p-d/2)/2}
· K_{p-d/2}(√(A(Q(x)+b))) · exp(γᵀΣ⁻¹(x-μ))
- C = (2π)^{-d/2} |Σ|^{-1/2} (a/b)^{p/2} / (2 K_p(√(ab))) · A^{(d/2-p)/2} · …
(see below for precise formula)
Posterior Y|X = x ~ GIG(p - d/2, a + γᵀΣ⁻¹γ, b + (x-μ)ᵀΣ⁻¹(x-μ)) → conditional expectations computed from GIG.
- class normix.distributions.generalized_hyperbolic.JointGeneralizedHyperbolic(mu, gamma, L_Sigma, p, a, b)[source]
Bases:
JointNormalMixtureJoint f(x,y): X|Y~N(μ+γy, Σy), Y~GIG(p,a,b).
Stored: mu, gamma, L (from JointNormalMixture) + p, a, b (GIG parameters).
- natural_params()[source]
θ = [p-1-d/2, -(b+½μᵀΣ⁻¹μ), -(a+½γᵀΣ⁻¹γ), Σ⁻¹γ, Σ⁻¹μ, -½vec(Σ⁻¹)]
- Return type:
- classmethod from_classical(*, mu, gamma, sigma, p, a, b)[source]
Construct from classical parameters.
- class normix.distributions.generalized_hyperbolic.GeneralizedHyperbolic(joint)[source]
Bases:
NormalMixtureMarginal Generalized Hyperbolic distribution f(x).
- Stores a JointGeneralizedHyperbolic. Provides:
log_prob(x) — closed-form Bessel expression e_step, m_step — for EM fitting fit(X, …) — convenience class method
- Parameters:
joint (JointGeneralizedHyperbolic)
- log_prob(x)[source]
Marginal log f(x).
f(x) ∝ ((Q(x)+b)/A)^{(p-d/2)/2} · K_{p-d/2}(√(A(Q(x)+b))) · exp(γᵀΣ⁻¹(x-μ))
where Q(x) = (x-μ)ᵀΣ⁻¹(x-μ), A = a + γᵀΣ⁻¹γ.
- Full normalizing constant:
- C = (2π)^{-d/2} |Σ|^{-1/2} · (A/b)^{p/2} · A^{-d/4} · b^{d/4}
/ K_p(√(ab)) · (some power)
- classmethod default_init(X)[source]
Warm-start from the best of NIG / VG / NInvG sub-model fits.
Runs 5 EM iterations for each special case, converts the winner to GH parametrisation, and uses it as the starting point. Falls back to moment-based default if all sub-models fail.
- Parameters:
X (Array)
- Return type:
- fit(X, *, verbose=0, max_iter=200, tol=0.001, regularization='det_sigma_one', e_step_backend='cpu', m_step_backend='cpu', m_step_method='newton')[source]
Fit GH distribution using EM.
Defaults to CPU backends and det_sigma_one regularization (GH has scale non-identifiability requiring |Sigma| = 1).
Fitting
EM Fitters
EM fitters for normix distributions.
Model knows math, fitter knows iteration.
- BatchEMFitter — standard batch EM with dual-loop architecture:
lax.scan (JIT-able) or Python for-loop (CPU-compatible)
OnlineEMFitter — online EM, one sample at a time MiniBatchEMFitter — mini-batch EM, Robbins-Monro averaging
- class normix.fitting.em.EMResult(model, log_likelihoods, param_changes, n_iter, converged, elapsed_time)[source]
Bases:
objectResult of an EM fitting procedure.
- Parameters:
- class normix.fitting.em.BatchEMFitter(*, max_iter=200, tol=0.001, verbose=0, regularization='none', e_step_backend='jax', m_step_backend='cpu', m_step_method='newton')[source]
Bases:
objectBatch EM algorithm with dual-loop architecture.
Runs E-step (all data) -> M-step -> regularize until convergence. Convergence is measured by relative parameter change in the normal parameters (mu, gamma, L_Sigma), excluding subordinator (GIG) parameters.
- Loop selection is automatic:
lax.scan when both backends are ‘jax’ and verbose <= 1
Python for-loop otherwise (CPU backends, or verbose >= 2)
- Parameters:
max_iter (int) – Maximum number of EM iterations.
tol (float) – Convergence tolerance on max relative parameter change.
verbose (int) – 0 = silent, 1 = summary, 2 = per-iteration table.
regularization (str) – Regularization strategy (‘det_sigma_one’ or ‘none’).
e_step_backend (str) – ‘jax’ (default) or ‘cpu’.
m_step_backend (str) – ‘jax’ or ‘cpu’ (default, faster for GIG).
m_step_method (str) – ‘newton’ (default), ‘lbfgs’, or ‘bfgs’.
- class normix.fitting.em.OnlineEMFitter(*, tau0=10.0, max_epochs=5, verbose=0, regularization='none')[source]
Bases:
objectOnline EM algorithm (Robbins-Monro stochastic approximation).
Updates running sufficient statistics with step size 1/(tau0 + t). One epoch = one full pass through data in random order.
Solvers
Bregman divergence solvers.
Minimises f(θ) − θ·η over θ, where f is any convex function (e.g. the log-partition ψ for an exponential family). At the minimum ∇f(θ*) = η.
Public API
solve_bregman single starting point solve_bregman_multistart multiple starting points (vmap for JAX Newton;
for-loop for quasi-Newton and CPU)
bregman_objective utility: f(θ) − θ·η
Backends × methods
backend=’jax’, method=’newton’ custom lax.scan Newton, autodiff or analytical Hessian backend=’jax’, method=’lbfgs’ jaxopt LBFGSB (bounds native) or LBFGS (reparam) backend=’jax’, method=’bfgs’ jaxopt BFGS with reparameterization for bounds backend=’cpu’, method=’lbfgs’ scipy L-BFGS-B backend=’cpu’, method=’bfgs’ scipy BFGS backend=’cpu’, method=’newton’ scipy trust-exact with Hessian
Gradient / Hessian sources
- grad_fnθ → ∇f(θ). For backend=’cpu’: pure CPU (numpy) gradient.
For backend=’jax’, method=’newton’: JAX-traceable ∇ψ(θ). If None with backend=’cpu’: hybrid — jax.grad compiled → NumPy callbacks.
- hess_fnθ → ∇²f(θ). Required for method=’newton’.
For backend=’cpu’: pure CPU (numpy) Hessian. For backend=’jax’, method=’newton’: JAX-traceable ∇²ψ(θ). If None with method=’newton’: jax.hessian of the full objective.
Both grad_fn and hess_fn operate in theta-space only. The solver handles all reparameterization internally via the chain rule.
- class normix.fitting.solvers.BregmanResult(theta, fun, grad_norm, num_steps, converged, elapsed_time=0.0)[source]
Bases:
objectResult of a Bregman divergence minimization.
Scalar fields accept both Python and JAX types so the result can live inside a lax.scan carry without concretization errors.
- Parameters:
- normix.fitting.solvers.bregman_objective(theta, eta, f)[source]
f(θ) − θ·η — convex dual whose minimum gives ∇f(θ*) = η.
- normix.fitting.solvers.solve_bregman(f, eta, theta0, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]
Minimise f(θ) − θ·η over θ.
- Parameters:
f (convex function θ → scalar (e.g. log-partition ψ))
eta (target vector (e.g. expectation parameters η))
theta0 (initial guess)
backend ('jax' (JIT-able) or 'cpu' (scipy, not JIT-able))
method ('lbfgs', 'bfgs', or 'newton')
bounds (list of (lo, hi) per dimension; None → unconstrained.) – For backend=’jax’: enforced via reparameterization (except method=’lbfgs’ which uses jaxopt.LBFGSB natively). For backend=’cpu’: passed directly as scipy bounds.
max_steps (iteration budget)
tol (convergence tolerance on ‖∇f(θ) − η‖∞)
grad_fn (θ → ∇f(θ).) – For backend=’cpu’: must accept and return numpy arrays. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with backend=’cpu’: falls back to jax.grad (hybrid mode).
hess_fn (θ → ∇²f(θ).) – Required for method=’newton’. For backend=’cpu’: must accept numpy arrays and return numpy array. For backend=’jax’, method=’newton’: must be JAX-traceable. If None with method=’newton’: jax.hessian of the full objective is used.
verbose (int) – 0 = silent, >= 1 = print summary after solve.
- Return type:
- normix.fitting.solvers.solve_bregman_multistart(f, eta, theta0_batch, *, backend='jax', method='lbfgs', bounds=None, max_steps=500, tol=1e-10, grad_fn=None, hess_fn=None, verbose=0)[source]
Run solve_bregman from multiple starting points; return the best result.
- Parameters:
theta0_batch ((K, dim) jax.Array for backend='jax', method='newton') – (parallel via vmap); list of arrays otherwise (sequential for-loop).
verbose (int) – 0 = silent, >= 1 = print summary.
eta (Array)
backend (str)
method (str)
max_steps (int)
tol (float)
grad_fn (Callable | None)
hess_fn (Callable | None)
- Return type:
Utilities
Bessel Functions
JAX-compatible log modified Bessel function of the second kind.
log_kv(v, z) = log K_v(z), fully pure-JAX with zero scipy callbacks.
- Regime-specific methods, selected via lax.cond (only one branch executes):
Hankel asymptotic (DLMF 10.40.2) — large z
Olver uniform expansion (DLMF 10.41.4) — large v
Small-z leading asymptotic (DLMF 10.30.2) — z → 0
Gauss-Legendre quadrature (Takekawa 2022) — moderate z, v
Using lax.cond (not jnp.where) means only the selected branch executes at runtime. lax.cond requires scalar conditions, so the core scalar function _log_kv_scalar is vmapped over array inputs.
- Custom JVP for full autodiff:
∂/∂z : exact recurrence K’_v = −(K_{v−1}+K_{v+1})/2
∂/∂v : central FD on log_kv itself (ε = BESSEL_EPS_V)
backend=’jax’ (default): pure-JAX, JIT-able, differentiable. backend=’cpu’ : scipy.special.kve, fully vectorized numpy.
Not JIT-able. Fast for EM hot path.
- normix.utils.bessel.log_kv(v, z, backend='jax')[source]
log K_v(z) — log modified Bessel function of the second kind.
- Parameters:
v (scalar or array — order (any real; K_v = K_{-v}))
z (scalar or array — argument (must be > 0))
backend ('jax' (default) or 'cpu') –
- ‘jax’pure-JAX, lax.cond regime selection, custom JVP.
JIT-able, differentiable. Default for log_prob, pdf, etc.
- ’cpu’scipy.special.kve, fully vectorized numpy.
Not JIT-able. Fast for EM hot path.
- Return type:
scalar or array of same broadcast shape as (v, z).
Constants
Shared numerical constants for normix.