Testing Emulator Example Notebook
This notebook shows code to run several performance metrics on a trained ps_emulator object. For demonstration purposes, all figures here were made using the final optimized emulator in PAPER_LINK_HERE.
[1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from tqdm import tqdm
import itertools
from mentat_lss.emulator import ps_emulator
from mentat_lss.utils import load_config_file, calc_avg_loss, delta_chi_squared, get_parameter_ranges
[2]:
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams.update({'font.size': 11})
plt.rcParams["legend.frameon"] = False
Here are some plotting routine functions we will use later
[3]:
def make_heatmap(x, y, z, bins, median=True):
x_new = np.linspace(np.amin(x), np.amax(x), bins+1)
y_new = np.linspace(np.amin(y), np.amax(y), bins+1)
z_new = np.zeros((bins,bins))
for i in range(bins):
for j in range(bins):
if median == True: z_new[j,i] = np.median(z[(x >= x_new[i]) & (x < x_new[i+1]) & (y >= y_new[j]) & (y < y_new[j+1])])
else: z_new[j,i] = np.mean(z[(x >= x_new[i]) & (x < x_new[i+1]) & (y >= y_new[j]) & (y < y_new[j+1])])
return x_new, y_new, z_new
def make_diagonal(x, y, bins, median=True):
x_new = np.linspace(np.amin(x), np.amax(x), bins+1)
y_new = np.zeros(bins)
for i in range(bins):
if median == True: y_new[i] = np.median(y[(x >= x_new[i]) & (x <= x_new[i+1])])
else: y_new[i] = np.mean(y[(x >= x_new[i]) & (x <= x_new[i+1])])
return x_new[:bins], y_new
def plot_heatmap(params, data, label, extents, cmap, log_scale,
names, labels, median=False, save_str=""):
params = params.copy()
fig, axs = plt.subplots(len(names),len(names), figsize=(12,12), sharex="col")
for i in range(len(names)):
for j in range(len(names)):
idx_i = pnames.index(names[i])
idx_j = pnames.index(names[j])
if i < j:
axs[i][j].axis("off")
continue
if i == j:
x, y = make_diagonal(params[:,idx_j], data, 25, median)
axs[i][j].plot(x, y)
else:
X, Y, Z = make_heatmap(params[:,idx_j], params[:,idx_i], data, 25, median)
if log_scale == True: img = axs[i,j].imshow(Z, aspect="auto", extent=(X[0], X[-1], Y[0], Y[-1]), cmap=cmap,
norm=colors.LogNorm(vmin=extents[0], vmax=extents[1]))
else: img = axs[i,j].imshow(Z, aspect="auto", extent=(X[0], X[-1], Y[0], Y[-1]), cmap=cmap, vmin=extents[0], vmax=extents[1])
axs[i,j].set_xlim(X[0] - (X[-1] - X[0]) * 0.05, X[-1] + (X[-1] - X[0]) * 0.05)
axs[i,j].set_ylim(Y[0] - (Y[-1] - Y[0]) * 0.05, Y[-1] + (Y[-1] - Y[0]) * 0.05)
axs[i,j].xaxis.set_ticks_position('both')
axs[i,j].yaxis.set_ticks_position('both')
axs[i,j].tick_params(direction="in")
for item in ([axs[i,j].xaxis.label, axs[i,j].yaxis.label]):
item.set_fontsize(15)
#if i != j: axs[i][j].axhline(params_best[i], linestyle="--", c="black")
#axs[i][j].axvline(params_best[j], linestyle="--", c="black")
if i == len(names) - 1: axs[i][j].set_xlabel(labels[j])
if j == 0 and i != 0: axs[i][j].set_ylabel(labels[i])
#if j == 0 and i != 5: axs[i][j].xaxis.set_ticklabels([])
elif j != 0 and i == len(names)-1: axs[i][j].yaxis.set_ticklabels([])
elif j != 0 and i != len(names)-1:
#axs[i][j].xaxis.set_ticklabels([])
axs[i][j].yaxis.set_ticklabels([])
#axs[5][3].set_xticks([1,2,3,4])
cbar_ax = fig.add_axes([0.96, 0.14, 0.039, 0.7])
cbar = fig.colorbar(img, cax=cbar_ax)
cbar.set_label(label, size=22)
cbar.ax.tick_params(labelsize=18)
plt.subplots_adjust(wspace=0, hspace=0, right=0.95)
if save_str!="": plt.savefig(save_str, dpi=300, bbox_inches='tight')
As a first step, we load in our trained emulator and the testing dataset in the following cells.
[4]:
# load the network
repo_dir = "/Users/JoeyA/Research/SPHEREx/mentat-lss/"
emulator_dir = repo_dir+"emulators/stacked_transformer_2t_2z_hypersphere/"
training_dir = repo_dir+"../../Data/SPHEREx-Data/training_set_eft_2t_2z_hypersphere/"
cosmo_dir = repo_dir + "configs/cosmo_pars/cosmo_pars_2t_2z.yaml"
emulator = ps_emulator(emulator_dir, "eval")
num_networks = emulator.num_zbins * emulator.num_spectra
# used in heatmap plotting routine
config_dict = load_config_file(emulator_dir+"config.yaml")
cosmo_dict = load_config_file(cosmo_dir)
#cosmo_dict = load_config_file(input_dir+config_dict["cosmo_dir"])
pnames, bounds = get_parameter_ranges(cosmo_dict)
[5]:
# load the test dataset
test_data = emulator.load_data("testing", 1., False, training_dir)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=config_dict["batch_size"], shuffle=True)
Loss Function Stats
The following cell displays loss function stats saved during training, and calculates the average loss value for the testing dataset.
[6]:
# make plots showing training / validation loss during training
test_loss = calc_avg_loss(emulator, test_loader, emulator.loss_function)
for (ps, z) in itertools.product(range(emulator.num_spectra), range(emulator.num_zbins)):
training_data = torch.load(emulator_dir+"training_statistics/train_data_"+str(ps)+"_"+str(z)+".dat")
#epochs = range(min(np.where(training_data[0] == 0)[0][0], training_data.shape[1]))
epochs = range(training_data.shape[1])
num_epochs = epochs[-1]+1
train_loss = training_data[0,:num_epochs]
valid_loss = training_data[1,:num_epochs]
learning_rate = training_data[2,:num_epochs]
print("Net [{:d}, {:d}], Total # of epochs = {:0.0f}".format(ps, z, len(train_loss)))
print("Net [{:d}, {:d}], Best training loss = {:0.4f}".format(ps, z, torch.amin(train_loss)))
print("Net [{:d}, {:d}], Best validation loss = {:0.4f}".format(ps, z, torch.amin(valid_loss)))
print("Net [{:d}, {:d}], Average test loss = {:0.4f}\n".format(ps, z, test_loss[ps][z]))
plt.figure()
plt.plot(epochs, train_loss, c="blue", label="training set loss")
plt.plot(epochs, valid_loss, c="red", ls="--", label="validation set loss")
plt.xlabel("epoch")
plt.ylabel("avg loss")
plt.yscale("log")
plt.legend()
Net [0, 0], Total # of epochs = 500
Net [0, 0], Best training loss = 0.0000
Net [0, 0], Best validation loss = 0.0000
Net [0, 0], Average test loss = 0.0068
Net [0, 1], Total # of epochs = 500
Net [0, 1], Best training loss = 0.0000
Net [0, 1], Best validation loss = 0.0000
Net [0, 1], Average test loss = 0.0081
Net [1, 0], Total # of epochs = 500
Net [1, 0], Best training loss = 0.0000
Net [1, 0], Best validation loss = 0.0000
Net [1, 0], Average test loss = 0.0088
Net [1, 1], Total # of epochs = 500
Net [1, 1], Best training loss = 0.0000
Net [1, 1], Best validation loss = 0.0000
Net [1, 1], Average test loss = 0.0108
Net [2, 0], Total # of epochs = 500
Net [2, 0], Best training loss = 0.0000
Net [2, 0], Best validation loss = 0.0000
Net [2, 0], Average test loss = 0.0096
Net [2, 1], Total # of epochs = 500
Net [2, 1], Best training loss = 0.0000
Net [2, 1], Best validation loss = 0.0000
Net [2, 1], Average test loss = 0.0090
Delta chi2 stats
These cells calculate the main performance statistic we use for emulator performance, mainly \(\Delta \chi^2\) given by the folloiwng equation,
Since we have multiple sub-networks, we can calculate this quantity for each individual bin, or the full output.
[7]:
from mentat_lss.utils import is_in_hypersphere
pk_error = []
error_per_bin = np.zeros((emulator.num_spectra, emulator.num_zbins, emulator.num_ells, emulator.num_kbins))
delta_chi2 = np.zeros((emulator.num_spectra, emulator.num_zbins, len(test_data)))
delta_chi2_combined = np.zeros(len(test_data))
save_idx = 0
# calculating delta chi2 in batches is much faster (~100x) than doing so individually
for (i, batch) in enumerate(tqdm(test_loader)):
params = batch[0].to("cpu").detach().numpy()
pk_idx = batch[2].to(torch.int)
#params = test_data[i][0].to("cpu").detach().numpy()
pk_raw = test_data.get_normalized_galaxy_power_spectra(pk_idx)
pk_true = test_data.get_true_galaxy_power_spectra(pk_idx, emulator.ps_fid, emulator.sqrt_eigvals,
emulator.Q, emulator.Q_inv).to("cpu").detach()
pk_pred_raw = emulator.get_power_spectra(params, False, True)
pk_pred = emulator.get_power_spectra(params, False, False)
for j in range(len(pk_idx)):
#loop thru networks
for (ps, z) in itertools.product(range(emulator.num_spectra), range(emulator.num_zbins)):
prediction = pk_pred_raw[j, ps, z].unsqueeze(0)
target = pk_raw[j, ps, z].unsqueeze(0)
chi2 = delta_chi_squared(prediction, target, emulator.invcov_full, True)
delta_chi2[ps, z, save_idx] = chi2.item()
delta_chi2_combined[save_idx] = delta_chi_squared(pk_pred[j], pk_true[j], emulator.invcov_full, False)
pk_error.append((pk_pred[j] - pk_true[j].numpy()) / pk_true[j])
save_idx += 1
pk_error = np.array(pk_error)
100%|██████████| 401/401 [01:21<00:00, 4.94it/s]
[8]:
# Plot chi2 histograms for each individual network
labels=[r"$P_l$(tracer1, tracer1)", r"$P_l$(tracer1, tracer4)", r"$P_l$(tracer4, tracer4)"]
cmap = ["red", "green", "blue"]
style = ["//", "\\"]
fig, axs = plt.subplots(2, 1, figsize=(8,8), sharex=True)
for (ps, z) in itertools.product(range(emulator.num_spectra), range(emulator.num_zbins)):
print("net [{:d}, {:d}], mean chi2 error = {:0.3f}".format(ps, z, np.mean((delta_chi2)[ps, z])))
print("net [{:d}, {:d}], median chi2 error = {:0.3f}".format(ps, z, np.median(delta_chi2[ps, z])))
axs[z].hist(delta_chi2[ps, z], bins=np.geomspace(np.amin(delta_chi2[ps, z]), np.amax(delta_chi2[ps, z]), 50), alpha=0.6,
facecolor=cmap[ps], hatch=style[z], label=labels[ps])
axs[0].legend(loc="upper center", ncol=3, bbox_to_anchor=(0.5, 1.15))
axs[0].set_ylabel("zbin 1")
axs[1].set_ylabel("zbin 2")
plt.xscale("log")
plt.xlabel(r"$\Delta \chi^2$")
plt.subplots_adjust(hspace=0)
#plt.savefig("../plots/delta_chi2_histogram.png", dpi=300, bbox_inches='tight')
net [0, 0], mean chi2 error = 0.031
net [0, 0], median chi2 error = 0.007
net [0, 1], mean chi2 error = 0.042
net [0, 1], median chi2 error = 0.012
net [1, 0], mean chi2 error = 0.049
net [1, 0], median chi2 error = 0.013
net [1, 1], mean chi2 error = 0.071
net [1, 1], median chi2 error = 0.019
net [2, 0], mean chi2 error = 0.060
net [2, 0], median chi2 error = 0.019
net [2, 1], mean chi2 error = 0.051
net [2, 1], median chi2 error = 0.013
[9]:
# plot delta chi2 heatmap
# @Grace these plots are the main ones we're using to quantify network performance
print("mean combined chi2 error = {:0.3f}".format(np.mean(delta_chi2_combined)))
print("median combined chi2 error = {:0.3f}".format(np.median(delta_chi2_combined)))
names = ['h', "ombh2", "omch2", "As", "ns"]
labels = [r"h", r"$\omega_{b}$", r"$\omega_{cdm}$", r"$A_s$", r"$n_s$"]
plt.hist(delta_chi2_combined, bins=np.geomspace(np.amin(delta_chi2_combined), np.amax(delta_chi2_combined), 50), alpha=0.6,
facecolor="blue", hatch="", label=r"$\Delta \chi^2$ from full output")
plt.xscale("log")
plt.xlabel(r"$\Delta \chi^2$")
plt.legend()
#plt.savefig("../plots/chi2-histogram-combined.png", dpi=300)
cmap="inferno"
#extents = [2e6, 1e10]
extents = [5e-2, 1e0]
save_str = "../plots/delta_chi2_heatmap.png"
plot_heatmap(test_data.params.to("cpu").detach().numpy(), delta_chi2_combined, r"$\Delta \chi^2$",
extents, cmap, True, names, labels, median=False, save_str="")
mean combined chi2 error = 0.352
median combined chi2 error = 0.167
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide
ret = ret.dtype.type(ret / rcount)
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide
ret = ret.dtype.type(ret / rcount)
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide
ret = ret.dtype.type(ret / rcount)
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide
ret = ret.dtype.type(ret / rcount)
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide
ret = ret.dtype.type(ret / rcount)
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide
ret = ret.dtype.type(ret / rcount)
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/Users/JoeyA/miniconda3/envs/mentat_lss/lib/python3.11/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide
ret = ret.dtype.type(ret / rcount)
Test individual power spectra
Finally, it can be helpful to take a look at individual outputs to see how the emulator is doing on a single run-through.
[13]:
idx = np.random.randint(len(test_data))
print("Stats for index {:d}".format(idx))
params = test_data[idx][0].to("cpu")
pk_emu = emulator.get_power_spectra(params.detach().numpy(), False, False)
pk_emu_raw = emulator.get_power_spectra(params.detach().numpy(), False, True)[0]
pk_raw = test_data.get_normalized_galaxy_power_spectra(idx)
pk_true = test_data.get_true_galaxy_power_spectra(idx, emulator.ps_fid,
emulator.sqrt_eigvals,
emulator.Q,
emulator.Q_inv).detach().numpy()
print("input parameters:", params)
error = 100 * (pk_emu - pk_true) / pk_true
chi2_full = delta_chi_squared(pk_emu, pk_true, emulator.invcov_full, False)
print("Combined Delta chi2 = {:0.3f}".format(chi2_full))
#Calculate delta chi2 for each network seperately
d_chi2_sum = 0
for (ps, z) in itertools.product(range(emulator.num_spectra), range(emulator.num_zbins)):
d_chi2 = delta_chi_squared(pk_emu_raw[ps, z].unsqueeze(0), pk_raw[ps, z].unsqueeze(0), emulator.invcov_blocks, True)
d_chi2_sum += d_chi2
print("Net [{:d}, {:d}], Delta chi2 = {:0.3f}".format(ps, z, d_chi2))
print("Sum of Delta chi2 = {:0.3f}".format(d_chi2_sum))
print("Average error per bin = {:0.2f}%".format(np.mean(abs(error))))
plt.figure()
plt.title("normalized output (sample idx = "+str(idx)+")")
plt.plot(pk_raw.flatten(), c="black", label="from ps_1loop")
plt.plot(pk_emu_raw.flatten(), c="red", ls="--", label="from emulator")
plt.legend()
plt.figure()
plt.title("normalized output (sample idx = "+str(idx)+")")
plt.plot(abs(pk_raw - pk_emu_raw).flatten(), label="|predict - true|")
plt.yscale("log")
plt.legend()
fig, axs = plt.subplots(2, 1, figsize=(8, 8), sharex=True, gridspec_kw={'height_ratios': [3, 1]})
axs[0].plot(pk_true.flatten(), c="black", label="from ps_1loop")
#axs[0].plot(pk_true[:,0].flatten(), c="green", ls="--")
axs[0].plot(pk_emu.flatten(), c="red", ls="--", label="from emulator")
axs[1].axhline(0, c="black")
axs[1].plot(error.flatten(), c="red")
axs[1].plot(error.flatten(), c="red", ls="--")
axs[1].set_xlabel("k")
axs[0].set_ylabel("P(k)")
axs[1].set_ylabel("error (%)")
axs[0].legend()
fig.subplots_adjust(hspace=0)
Stats for index 186082
input parameters: tensor([ 7.7051e-01, 2.1514e-02, 9.8107e-02, 2.0968e-09, 9.5865e-01,
1.3389e+00, 1.7207e+00, 1.7048e+00, 1.2258e+00, -1.2775e+00,
-1.9652e+00, -8.1640e-01, -1.3647e+00, -2.6730e-01, -7.8944e-01,
-4.9422e-01, -5.5165e-01])
Combined Delta chi2 = 1.011
Net [0, 0], Delta chi2 = 0.012
Net [0, 1], Delta chi2 = 0.018
Net [1, 0], Delta chi2 = 0.105
Net [1, 1], Delta chi2 = 0.035
Net [2, 0], Delta chi2 = 0.328
Net [2, 1], Delta chi2 = 0.019
Sum of Delta chi2 = 0.518
Average error per bin = 0.03%
[11]:
# Sanity Check: make sure two different samples to see if the network outputs different results
idx_1 = np.random.randint(len(test_data))
idx_2 = np.random.randint(len(test_data))
params = test_data[idx_1][0].to("cpu")
pk_emu_1 = emulator.get_power_spectra(params.detach().numpy())
params = test_data[idx_2][0].to("cpu")
pk_emu_2 = emulator.get_power_spectra(params.detach().numpy())
fig, axs = plt.subplots(1, 1, figsize=(7, 6), sharex=True)
axs.plot(pk_emu_1.flatten(), c="red", label="idx = " + str(idx_1))
axs.plot(pk_emu_2.flatten(), c="blue", ls="--", label="idx = " + str(idx_2))
#axs[0].plot(k, data_vector[25:], c="blue", ls="--")
axs.set_ylabel("P(k)")
axs.legend()
fig.subplots_adjust(hspace=0)