EM vs MCECM Algorithm Comparison#
This notebook replicates Table 4 from [Shi2016], comparing the EM and MCECM algorithms for fitting Generalized Hyperbolic (GH) distributions.
Procedure (following the thesis):
Fit a 2D GH distribution to real stock return data via EM → base model with parameters \((\mu, \gamma, \Sigma, p_0, a, b)\).
For each \(p \in \{-10, -9, \ldots, 10\}\), construct a “true” model by replacing \(p_0\) with \(p\) in the base model (all other parameters unchanged).
Generate 5000 i.i.d. multivariate GH samples from each “true” model.
Re-fit each sample (all parameters free) using both EM and MCECM.
Compare log-likelihoods and squared Hellinger distances \(H^2\).
Expected result: Both algorithms converge to essentially the same MLE.
import time
from pathlib import Path
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
jax.config.update("jax_enable_x64", True)
from normix import GeneralizedHyperbolic, squared_hellinger
np.set_printoptions(precision=6, suppress=False)
# myst-nb executes with cwd = docs/tutorials/em/
data_path = Path("../../../data/sp500_returns.csv").resolve()
returns = pd.read_csv(data_path, index_col=0)
cols = ['AAPL', 'MSFT']
X_real = jnp.array(returns[cols].dropna().values, dtype=jnp.float64)
print(f'Data: {X_real.shape[0]} observations, {X_real.shape[1]} stocks')
Data: 2552 observations, 2 stocks
Step 1: Fit base model to real data#
model_base = GeneralizedHyperbolic.default_init(X_real)
result_base = model_base.fit(
X_real, max_iter=200, tol=1e-2,
regularization='det_sigma_one', verbose=1)
model_base = result_base.model
j = model_base.joint
print(f'\nBase model parameters:')
print(f' p = {float(j.p):.4f}, a = {float(j.a):.4f}, b = {float(j.b):.4f}')
print(f' mu = {np.array(j.mu)}')
print(f' gamma = {np.array(j.gamma)}')
============================================================
EM Fitting: GeneralizedHyperbolic
============================================================
Algorithm : EM
Loop : Python loop
E-step : cpu
M-step : cpu / newton
Regularize : det_sigma_one
Tolerance : 1.0e-02
Max iters : 200
Converged after 8 iterations (6.04s), final LL=5.765284
Base model parameters:
p = -0.8874, a = 1720.2603, b = 0.0002
mu = [0.001787 0.001736]
gamma = [-3.77401 -3.956902]
Step 2–5: Sweep \(p\), simulate, re-fit via EM and MCECM#
p_values = list(range(-10, 11))
results = []
t_em_total = 0.0
t_mcecm_total = 0.0
for p_val in p_values:
print(f'\np = {p_val:+3d} ', end='', flush=True)
joint_true = eqx.tree_at(lambda j: j.p, model_base.joint, jnp.float64(p_val))
model_true = GeneralizedHyperbolic(joint_true)
X_synth = model_true.rvs(5000, seed=p_val + 100)
try:
t0 = time.perf_counter()
result_em = model_true.fit(
X_synth, algorithm='em', max_iter=200, tol=1e-2,
regularization='det_sigma_one', verbose=0)
t_em = time.perf_counter() - t0
t_em_total += t_em
model_em = result_em.model
ll_em = float(model_em.marginal_log_likelihood(X_synth))
h2_em = float(squared_hellinger(model_true.joint, model_em.joint))
print(f'[EM: {result_em.n_iter} iters, {t_em:.1f}s] ', end='', flush=True)
except Exception as e:
print(f'[EM FAILED: {e}] ', end='', flush=True)
ll_em, h2_em = np.nan, np.nan
try:
t0 = time.perf_counter()
result_mcecm = model_true.fit(
X_synth, algorithm='mcecm', max_iter=200, tol=1e-4,
regularization='det_sigma_one', verbose=0)
t_mcecm = time.perf_counter() - t0
t_mcecm_total += t_mcecm
model_mcecm = result_mcecm.model
ll_mcecm = float(model_mcecm.marginal_log_likelihood(X_synth))
h2_mcecm = float(squared_hellinger(model_true.joint, model_mcecm.joint))
print(f'[MCECM: {result_mcecm.n_iter} iters, {t_mcecm:.1f}s]', end='', flush=True)
except Exception as e:
print(f'[MCECM FAILED: {e}]', end='', flush=True)
ll_mcecm, h2_mcecm = np.nan, np.nan
results.append({
'p': p_val,
'LL_EM': ll_em, 'H2_EM': h2_em,
'LL_MCECM': ll_mcecm, 'H2_MCECM': h2_mcecm,
})
print(f'\n\nTotal EM time: {t_em_total:.1f}s')
print(f'Total MCECM time: {t_mcecm_total:.1f}s')
p = -10
[EM: 17 iters, 1.4s]
[MCECM: 41 iters, 3.3s]
p = -9
[EM: 17 iters, 0.7s]
[MCECM: 183 iters, 9.6s]
p = -8
[EM: 17 iters, 0.7s]
[MCECM: 81 iters, 4.7s]
p = -7
[EM: 12 iters, 0.5s]
[MCECM: 200 iters, 11.1s]
p = -6
[EM: 13 iters, 0.6s]
[MCECM: 61 iters, 3.7s]
p = -5
[EM: 12 iters, 0.6s]
[MCECM: 25 iters, 1.7s]
p = -4
[EM: 11 iters, 0.4s]
[MCECM: 48 iters, 3.0s]
p = -3
[EM: 12 iters, 0.5s]
[MCECM: 152 iters, 9.6s]
p = -2
[EM: 3 iters, 0.1s]
[MCECM: 19 iters, 1.1s]
p = -1
[EM: 5 iters, 0.2s]
[MCECM: 29 iters, 1.7s]
p = +0
[EM: 4 iters, 0.1s]
[MCECM: 22 iters, 1.2s]
p = +1
[EM: 7 iters, 0.3s]
[MCECM: 30 iters, 1.7s]
p = +2
[EM: 11 iters, 0.4s]
[MCECM: 200 iters, 12.3s]
p = +3
[EM: 11 iters, 0.4s]
[MCECM: 38 iters, 2.7s]
p = +4
[EM: 13 iters, 0.5s]
[MCECM: 200 iters, 12.1s]
p = +5
[EM: 16 iters, 0.8s]
[MCECM: 40 iters, 3.3s]
p = +6
[EM: 15 iters, 0.6s]
[MCECM: 200 iters, 11.6s]
p = +7
[EM: 17 iters, 1.0s]
[MCECM: 45 iters, 4.0s]
p = +8
[EM: 23 iters, 0.9s]
[MCECM: 200 iters, 11.9s]
p = +9
[EM: 17 iters, 0.7s]
[MCECM: 200 iters, 12.1s]
p = +10
[EM: 20 iters, 1.3s]
[MCECM: 51 iters, 4.8s]
Total EM time: 12.8s
Total MCECM time: 127.4s
Table 4: Comparison between EM and MCECM#
df = pd.DataFrame(results)
print('Table 4: Comparison between the EM algorithm and the MCECM algorithm')
print('=' * 75)
print(f'{"p":>4s} | {"EM algorithm":^25s} | {"MCECM algorithm":^25s}')
print(f'{"":>4s} | {"Log-likelihood":>14s} {"H^2":>8s} | {"Log-likelihood":>14s} {"H^2":>8s}')
print('-' * 75)
for _, row in df.iterrows():
p = int(row['p'])
ll_em = f"{row['LL_EM']:.4f}" if not np.isnan(row['LL_EM']) else ' N/A'
h2_em = f"{row['H2_EM']:.4f}" if not np.isnan(row['H2_EM']) else ' N/A'
ll_mc = f"{row['LL_MCECM']:.4f}" if not np.isnan(row['LL_MCECM']) else ' N/A'
h2_mc = f"{row['H2_MCECM']:.4f}" if not np.isnan(row['H2_MCECM']) else ' N/A'
print(f'{p:+4d} | {ll_em:>14s} {h2_em:>8s} | {ll_mc:>14s} {h2_mc:>8s}')
print('=' * 75)
Table 4: Comparison between the EM algorithm and the MCECM algorithm
===========================================================================
p | EM algorithm | MCECM algorithm
| Log-likelihood H^2 | Log-likelihood H^2
---------------------------------------------------------------------------
-10 | 8.5674 0.0003 | 8.5674 0.0003
-9 | 8.4649 0.0005 | 8.4651 0.0048
-8 | 8.3323 0.0005 | 8.3323 0.0009
-7 | 8.1884 0.0005 | 8.1888 0.0110
-6 | 7.9669 0.0016 | 7.9670 0.0026
-5 | 7.7893 0.0005 | 7.7893 0.0005
-4 | 7.5427 0.0010 | 7.5428 0.0014
-3 | 7.1671 0.0003 | 7.1671 0.0014
-2 | 6.6367 0.0002 | 6.6367 0.0003
-1 | 5.8342 0.0003 | 5.8342 0.0005
+0 | 4.7823 0.0003 | 4.7824 0.0005
+1 | 3.8388 0.0011 | 3.8389 0.0018
+2 | 3.2507 0.0007 | 3.2515 0.0134
+3 | 2.8541 0.0004 | 2.8541 0.0005
+4 | 2.5689 0.0010 | 2.5689 0.0015
+5 | 2.2991 0.0004 | 2.2991 0.0004
+6 | 2.0828 0.0018 | 2.0829 0.0030
+7 | 1.9761 0.0007 | 1.9761 0.0008
+8 | 1.8322 0.0005 | 1.8323 0.0019
+9 | 1.7238 0.0002 | 1.7240 0.0067
+10 | 1.6277 0.0008 | 1.6277 0.0008
===========================================================================