import torch
import torch.nn as nn
import numpy as np
import yaml, math, os, copy
import itertools
import logging
from mentat_lss.models import blocks
from mentat_lss.models.stacked_mlp import stacked_mlp
from mentat_lss.models.stacked_transformer import stacked_transformer
from mentat_lss.models.analytic_terms import analytic_eft_model
from mentat_lss.dataset import pk_galaxy_dataset
from mentat_lss.utils import load_config_file, get_parameter_ranges,\
normalize_cosmo_params, un_normalize_power_spectrum, \
delta_chi_squared, mse_loss, hyperbolic_loss, hyperbolic_chi2_loss, \
get_invcov_blocks, get_full_invcov, is_in_hypersphere
[docs]
class ps_emulator():
"""Class defining the neural network emulator."""
def __init__(self, net_dir:str, mode:str="train", device:torch.device=None):
"""Emulator constructor, initializes the network structure and all supporting data.
Args:
net_dir (str): path specifying either the directory or full filepath of the trained emulator to load from.
if a directory, assumes the config file is called "config.yaml"
mode (str): whether the emulator should initialize for training, or to load from a previous training run. One
of either ["train", "eval"]. Detailt "train"
device (torch.device): Device to load the emulator on. If None, will attempt to load on any available
GPU (or mps for macos) device. Default None.
Raises:
KeyError: if mode is not correctly specified
IOError: if no input yaml file was found
"""
if net_dir.endswith(".yaml"): self.config_dict = load_config_file(net_dir)
else: self.config_dict = load_config_file(os.path.join(net_dir,"config.yaml"))
self.logger = logging.getLogger('ps_emulator')
# load dictionary entries into their own class variables
for key in self.config_dict:
setattr(self, key, self.config_dict[key])
self._init_device(device, mode)
self._init_model()
self._init_loss()
if mode == "train":
self.logger.debug("Initializing power spectrum emulator in training mode")
self._init_fiducial_power_spectrum()
self._init_inverse_covariance()
self._diagonalize_covariance()
self._init_input_normalizations()
self.galaxy_ps_model.apply(self._init_weights)
self.galaxy_ps_checkpoint = copy.deepcopy(self.galaxy_ps_model.state_dict())
elif mode == "eval":
self.logger.debug("Initializing power spectrum emulator in evaluation mode")
self.load_trained_model(net_dir)
self._init_analytic_model()
else:
raise KeyError(f"Invalid mode specified! Must be one of ['train', 'eval'] but was {mode}.")
[docs]
def load_trained_model(self, path):
"""loads the pre-trained network from file into the current model, as well as all relavent information needed for normalization.
This function is called by the constructor, but can also be called directly by the user if desired.
Args:
path: The directory+filename of the trained network to load.
"""
self.logger.info(f"loading emulator from {path}")
self.galaxy_ps_model.eval()
self.galaxy_ps_model.load_state_dict(torch.load(os.path.join(path,'network_galaxy.params'),
weights_only=True, map_location=self.device))
input_norm_data = torch.load(os.path.join(path,"input_normalizations.pt"),
map_location=self.device, weights_only=True)
self.input_normalizations = input_norm_data[0] # <- in shape expected by networks
self.required_emu_params = input_norm_data[1]
self.emu_param_bounds = input_norm_data[2]
ps_properties = np.load(os.path.join(path, "ps_properties.npz"))
self.k_emu = ps_properties["k"]
self.ells = ps_properties["ells"]
self.z_eff = ps_properties["z_eff"]
self.ndens = ps_properties["ndens"]
output_norm_data = torch.load(os.path.join(path,"output_normalizations.pt"),
map_location=self.device, weights_only=True)
self.ps_fid = output_norm_data[0]
self.invcov_full = output_norm_data[1]
self.invcov_blocks = output_norm_data[2]
self.sqrt_eigvals = output_norm_data[3]
self.Q = output_norm_data[4]
self.Q_inv = torch.zeros_like(self.Q, device="cpu")
for (ps, z) in itertools.product(range(self.num_spectra), range(self.num_zbins)):
self.Q_inv[ps, z] = torch.linalg.inv(self.Q[ps, z].to("cpu").to(torch.float64)).to(torch.float32)
self.Q_inv = self.Q_inv.to(self.device)
[docs]
def load_data(self, key:str, data_frac = 1.0, return_dataloader=True, data_dir=""):
"""loads and returns the training / validation / test dataset into memory
Args:
key: one of ["training", "validation", "testing"] that specifies what type of data-set to laod
data_frac: fraction of the total data-set to load in. Default 1
return_dataloader: Determines what object type to return the data as. Default True
If true: returns data as a pytorch.utils.data.DataLoader object.
If false: returns data as a pk_galaxy_dataset object.
data_dir: location of the data-set on disk. Default ""
Returns:
data: The desired data-set in either a pk_galaxy_dataset or DataLoader object.
Raises:
KeyError: If key is an incorrect value.
ValueError: If some property of the loaded dataset does not match with the emulator.
"""
if data_dir != "": dir = data_dir
else : dir = self.input_dir+self.training_dir
if not hasattr(self, "k_emu"):
self.logger.info("loading kbins from training set")
ps_properties = np.load(os.path.join(dir, "ps_properties.npz"))
self.k_emu = ps_properties["k"]
self.ells = ps_properties["ells"]
self.z_eff = ps_properties["z_eff"]
self.ndens = ps_properties["ndens"]
if key in ["training", "validation", "testing"]:
data = pk_galaxy_dataset(dir, key, data_frac)
data.to(self.device)
data.normalize_data(self.ps_fid, self.sqrt_eigvals, self.Q)
data_loader = torch.utils.data.DataLoader(data, batch_size=self.config_dict["batch_size"], shuffle=True)
self._check_training_set(data)
if return_dataloader: return data_loader
else: return data
else:
raise KeyError("Invalid value for key! must be one of ['training', 'validation', 'testing']")
[docs]
def get_power_spectra(self, params, extrapolate:bool = False, raw_output:bool = False):
"""Gets the full galaxy power spectrum multipoles (emulated and analytically calculated)
Args:
params: 1D or 2D numpy array, torch Tensor, or dictionary containing a list of cosmology + galaxy bias parameters.
if params is a 2D array, this function generates a batch of power spectra simultaniously
extrapolate (bool): Whether or not to pass through the emulator if the given input parameters are outside the range it was trained on.
Default False
raw_output (bool): Whether or not to return the raw network output without undoing normalization. Default False
Returns:
galaxy_ps (np.array): Emulated galaxy power spectrum multipoles.
If raw_output = False, has shape [nps, nz, nk, nl] or [nb, nps, nz, nk, nl]. Else has shape [nb, nps, nz, nk*nl]
"""
galaxy_ps_emu = self.get_emulated_power_spectrum(params, extrapolate, raw_output)
if len(galaxy_ps_emu.shape) == 4 and raw_output == False:
return galaxy_ps_emu + self.analytic_model.get_analytic_terms(params, self.required_emu_params, self.get_required_analytic_parameters())
else:
return galaxy_ps_emu
[docs]
def get_emulated_power_spectrum(self, params, extrapolate:bool = False, raw_output:bool = False):
"""Gets the power spectra corresponding to the given input params by passing them though the emulator
Args:
params: 1D or 2D numpy array, torch Tensor, or dictionary containing a list of cosmology + galaxy bias parameters.
if params is a 2D array, this function generates a batch of power spectra simultaniously
extrapolate (bool): Whether or not to pass through the emulator if the given input parameters are outside the range it was trained on.
Default False
raw_output: bool specifying whether or not to return the raw network output without undoing normalization. Default False
Returns:
galaxy_ps (np.array): emulated galaxy power spectrum multipoles (P_tree + P_1loop). If given a batch of parameters, has shape [nb, nps, nz, nk, nl].
Otherwise, has shape [nps, nz, nk, nl]. If extrapolate is false and the given input parameters are out of bounds, then this function returns
an array of all zeros.
"""
self.galaxy_ps_model.eval()
with torch.no_grad():
emu_params, skip_emulation = self._check_params(params, extrapolate)
if skip_emulation and not raw_output and len(params.shape) == 1:
return np.zeros((self.num_spectra, self.num_zbins, self.num_kbins, self.num_ells))
elif skip_emulation and not raw_output and len(params.shape) > 1:
return np.zeros((params.shape[0], self.num_spectra, self.num_zbins, self.num_kbins, self.num_ells))
galaxy_ps = self.galaxy_ps_model.forward(emu_params) # <- shape [nb, nps, nz, nk*nl]
if raw_output:
return galaxy_ps
galaxy_ps = un_normalize_power_spectrum(torch.flatten(galaxy_ps, start_dim=3), self.ps_fid, self.sqrt_eigvals, self.Q, self.Q_inv)
if len(params.shape) == 1:
galaxy_ps = galaxy_ps.view(self.num_spectra, self.num_zbins, self.num_kbins, self.num_ells)
else:
galaxy_ps = galaxy_ps.view(-1, self.num_spectra, self.num_zbins, self.num_kbins, self.num_ells)
return galaxy_ps.to("cpu").detach().numpy()
[docs]
def get_required_emu_parameters(self):
"""Returns a list of input parameters needed by the emulator.
Currently, mentat-lss requires input parameters to be in the same order as given by
the return value of this function. For example. If the return list is ['h', 'omch2'], you
should pass in [h, omch2] to get_power_spectra in that order.
Returns:
required_emu_params (list): list of input cosmology + bias parameters required by the emulator.
"""
return self.required_emu_params
[docs]
def get_required_analytic_parameters(self):
"""Returns a list of input parameters used by our analytic eft model, not directly emulated.
NOTE: These parameters are currently hard-coded.
Returns:
required_analytic_params (list): list of input (counterterm + stoch) parameters.
"""
analytic_params = []
if 0 in self.ells: analytic_params.append("counterterm_0")
if 2 in self.ells: analytic_params.append("counterterm_2")
if 4 in self.ells: analytic_params.append("counterterm_4")
analytic_params.extend(["counterterm_fog", "P_shot"])
return analytic_params
[docs]
def check_kbins_are_compatible(self, test_kbins:np.array):
"""Tests whether the passed test_kbins is the same as the emulator k-bins
Args:
test_kbins (np.array): k-array to check
Returns:
is_compatible (bool): Whether or not the given k-bins are compatible
"""
if test_kbins.shape != self.k_emu.shape: return False
else: return np.allclose(test_kbins, self.k_emu)
# -----------------------------------------------------------
# Helper methods: Not meant to be called by the user directly
# -----------------------------------------------------------
def _init_device(self, device, mode):
"""Sets emulator device based on machine configuration"""
self.num_gpus = torch.cuda.device_count()
if mode == "eval": self.device = torch.device("cpu")
elif device != None: self.device = device
elif self.use_gpu == False: self.device = torch.device('cpu')
elif torch.cuda.is_available(): self.device = torch.device('cuda:0')
elif torch.backends.mps.is_available(): self.device = torch.device("mps")
else: self.device = torch.device('cpu')
def _init_model(self):
"""Initializes the networks"""
self.num_spectra = self.num_tracers + math.comb(self.num_tracers, 2)
if self.model_type == "stacked_mlp":
self.galaxy_ps_model = stacked_mlp(self.config_dict).to(self.device)
elif self.model_type == "stacked_transformer":
self.galaxy_ps_model = stacked_transformer(self.config_dict).to(self.device)
else:
raise KeyError(f"Invalid value for model_type: {self.model_type}")
def _init_analytic_model(self):
"""Initializes object for calculating analytic eft terms"""
self.analytic_model = analytic_eft_model(self.num_tracers, self.z_eff, self.ells, self.k_emu, self.ndens)
def _init_input_normalizations(self):
"""Initializes input parameter names and normalization factors
Normalizations are in the shape (low / high bound, net_idx, parameter)
"""
try:
cosmo_dict = load_config_file(os.path.join(self.input_dir,self.cosmo_dir))
param_names, param_bounds = get_parameter_ranges(cosmo_dict)
input_normalizations = torch.Tensor(param_bounds.T).to(self.device)
except IOError:
input_normalizations = torch.vstack((torch.zeros((self.num_cosmo_params + (self.num_tracers*self.num_zbins*self.num_nuisance_params))),
torch.ones((self.num_cosmo_params + (self.num_tracers*self.num_zbins*self.num_nuisance_params))))).to(self.device)
param_names, param_bounds = [], np.empty((self.num_cosmo_params + (self.num_tracers*self.num_zbins*self.num_nuisance_params), 2))
lower_bounds = self.galaxy_ps_model.organize_parameters(input_normalizations[0].unsqueeze(0))
upper_bounds = self.galaxy_ps_model.organize_parameters(input_normalizations[1].unsqueeze(0))
self.input_normalizations = torch.vstack([lower_bounds, upper_bounds])
self.required_emu_params = param_names
self.emu_param_bounds = torch.from_numpy(param_bounds).to(torch.float32).to(self.device)
def _init_fiducial_power_spectrum(self):
"""Loads the fiducial galaxy and non-wiggle power spectrum for use in normalization"""
ps_file = self.input_dir+self.training_dir+"ps_fid.npy"
if os.path.exists(ps_file):
self.ps_fid = torch.from_numpy(np.load(ps_file)).to(torch.float32).to(self.device)[0]
# permute input power spectrum if it's a different shape than expected
if self.ps_fid.shape[3] == self.num_kbins:
self.ps_fid = torch.permute(self.ps_fid, (0, 1, 3, 2))
if self.ps_fid.shape[0] == self.num_zbins:
self.ps_fid = torch.permute(self.ps_fid, (1, 0, 2, 3))
self.ps_fid = self.ps_fid.reshape(self.num_spectra, self.num_zbins, self.num_kbins * self.num_ells)
else:
self.ps_fid = torch.zeros((self.num_spectra, self.num_zbins, self.num_kbins * self.num_ells)).to(self.device)
def _init_inverse_covariance(self):
"""Loads the inverse data covariance matrix for use in certain loss functions and normalizations"""
# TODO: Upgrade to handle different number of k-bins for each zbin
cov_file = self.input_dir+self.training_dir
# Temporarily store with double percision to increase numerical stability\
if os.path.exists(cov_file+"cov.dat"):
cov = torch.load(cov_file+"cov.dat", weights_only=True).to(torch.float64)
elif os.path.exists(cov_file+"cov.npy"):
cov = torch.from_numpy(np.load(cov_file+"cov.npy"))
else:
self.logger.warning("Could not find covariance matrix! Using identity matrix instead...")
cov = torch.eye(self.num_spectra*self.num_ells*self.num_kbins).unsqueeze(0)
cov = cov.repeat(self.num_zbins, 1, 1)
self.invcov_blocks = get_invcov_blocks(cov, self.num_spectra, self.num_zbins, self.num_kbins, self.num_ells)
self.invcov_full = get_full_invcov(cov, self.num_zbins)
def _diagonalize_covariance(self):
"""performs an eigenvalue decomposition of the each diagonal block of the inverse covariance matrix
this function is always performed on cpu in double percision to improve stability"""
self.Q = torch.zeros_like(self.invcov_blocks)
self.Q_inv = torch.zeros_like(self.invcov_blocks)
self.sqrt_eigvals = torch.zeros((self.num_spectra, self.num_zbins, self.num_ells*self.num_kbins))
for (ps, z) in itertools.product(range(self.num_spectra), range(self.num_zbins)):
eig, q = torch.linalg.eigh(self.invcov_blocks[ps, z])
assert torch.all(torch.isnan(q)) == False
assert torch.all(eig > 0), "ERROR! inverse covariance matrix has negative eigenvalues? Is it positive definite?"
self.Q[ps, z] = q.real
self.Q_inv[ps, z] = torch.linalg.inv(q).real
self.sqrt_eigvals[ps, z] = torch.sqrt(eig)
# move data to gpu and convert to single percision
self.invcov_blocks = self.invcov_blocks.to(torch.float32).to(self.device)
self.invcov_full = self.invcov_full.to(torch.float32).to(self.device)
self.Q = self.Q.to(torch.float32).to(self.device)
self.Q_inv = self.Q_inv.to(torch.float32).to(self.device)
self.sqrt_eigvals = self.sqrt_eigvals.to(torch.float32).to(self.device)
def _init_loss(self):
"""Defines the loss function to use"""
if self.loss_type == "chi2":
self.loss_function = delta_chi_squared
elif self.loss_type == "mse":
self.loss_function = mse_loss
elif self.loss_type == "hyperbolic":
self.loss_function = hyperbolic_loss
elif self.loss_type == "hyperbolic_chi2":
self.loss_function = hyperbolic_chi2_loss
else:
raise KeyError("ERROR: Invalid loss function type! Must be one of ['chi2', 'mse', 'hyperbolic', 'hyperbolic_chi2']")
def _init_weights(self, m):
"""Initializes weights using a specific scheme set in the input yaml file
This function is meant to be called by the constructor only.
Current options for initialization schemes are ["normal", "He", "xavier"]
"""
if isinstance(m, nn.Linear):
if self.weight_initialization == "He":
nn.init.kaiming_uniform_(m.weight)
elif self.weight_initialization == "normal":
nn.init.normal_(m.weight, mean=0., std=0.1)
nn.init.zeros_(m.bias)
elif self.weight_initialization == "xavier":
nn.init.xavier_normal_(m.weight)
else: # if scheme is invalid, use normal initialization as a substitute
nn.init.normal_(m.weight, mean=0., std=0.1)
nn.init.zeros_(m.bias)
elif isinstance(m, blocks.linear_with_channels):
m.initialize_params(self.weight_initialization)
def _init_training_stats(self):
"""initializes training data as nested lists with dims [nps, nz]"""
self.train_loss = [[[] for i in range(self.num_zbins)] for j in range(self.num_spectra)]
self.valid_loss = [[[] for i in range(self.num_zbins)] for j in range(self.num_spectra)]
self.train_time = 0.
def _init_optimizer(self):
"""Sets optimization objects, one for each sub-network"""
self.optimizer = [[[] for i in range(self.num_zbins)] for j in range(self.num_spectra)]
self.scheduler = [[[] for i in range(self.num_zbins)] for j in range(self.num_spectra)]
for (ps, z) in itertools.product(range(self.num_spectra), range(self.num_zbins)):
net_idx = (z * self.num_spectra) + ps
if self.optimizer_type == "Adam":
self.optimizer[ps][z] = torch.optim.Adam(self.galaxy_ps_model.networks[net_idx].parameters(),
lr=self.galaxy_ps_learning_rate)
else:
raise KeyError("Error! Invalid optimizer type specified!")
# use an adaptive learning rate
self.scheduler[ps][z] = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer[ps][z],
"min", factor=0.1, patience=15)
def _update_checkpoint(self, net_idx=0, mode="galaxy_ps"):
"""saves current best network to an independent state_dict"""
if mode == "galaxy_ps":
new_checkpoint = self.galaxy_ps_model.state_dict()
for name in new_checkpoint.keys():
if "networks."+str(net_idx) in name:
self.galaxy_ps_checkpoint[name] = new_checkpoint[name]
else:
raise NotImplementedError
self._save_model()
def _save_model(self):
"""saves the current model state and normalization information to file"""
save_dir = os.path.join(self.input_dir, self.save_dir)
training_data_dir = os.path.join(save_dir, "training_statistics")
# HACK for training on multiple GPUS - need to create parent directory first
if not os.path.exists(os.path.dirname(os.path.dirname(save_dir))):
os.mkdir(os.path.dirname(os.path.dirname(save_dir)))
if not os.path.exists(save_dir):
os.mkdir(save_dir)
# training statistics
if not os.path.exists(training_data_dir):
os.mkdir(training_data_dir)
for (ps, z) in itertools.product(range(self.num_spectra), range(self.num_zbins)):
training_data = torch.vstack([torch.Tensor(self.train_loss[ps][z]),
torch.Tensor(self.valid_loss[ps][z]),
torch.Tensor([self.train_time]*len(self.valid_loss[ps][z]))])
torch.save(training_data, os.path.join(training_data_dir, "train_data_"+str(ps)+"_"+str(z)+".dat"))
# configuration data
with open(os.path.join(save_dir, 'config.yaml'), 'w') as outfile:
yaml.dump(dict(self.config_dict), outfile, sort_keys=False, default_flow_style=False)
if hasattr(self, "k_emu"):
np.savez(os.path.join(save_dir, "ps_properties.npz"), k=self.k_emu, ells=self.ells, z_eff=self.z_eff, ndens=self.ndens)
else:
self.logger.warning("power spectrum properties not initialized!")
# data related to input normalization
input_files = [self.input_normalizations, self.required_emu_params, self.emu_param_bounds]
torch.save(input_files, os.path.join(save_dir, "input_normalizations.pt"))
with open(os.path.join(save_dir, "param_names.txt"), "w") as outfile:
yaml.dump(self.get_required_emu_parameters(), outfile, sort_keys=False, default_flow_style=False)
# data related to output normalization
output_files = [self.ps_fid, self.invcov_full, self.invcov_blocks, self.sqrt_eigvals, self.Q]
torch.save(output_files, os.path.join(save_dir, "output_normalizations.pt"))
# Finally, the actual model parameters
torch.save(self.galaxy_ps_checkpoint, os.path.join(save_dir, 'network_galaxy.params'))
def _check_params(self, params, extrapolate=False):
"""checks that input parameters are in the expected format and within the specified boundaries"""
skip_emulation = False
if isinstance(params, torch.Tensor):
params = params.to(self.device)
elif isinstance(params, np.ndarray):
params = torch.from_numpy(params).to(torch.float32).to(self.device)
else:
raise TypeError(f"invalid type for variable params ({type(params)})")
if params.dim() == 1: params = params.unsqueeze(0)
if params.shape[1] > len(self.required_emu_params):
params = params[:, :len(self.required_emu_params)]
org_params = self.galaxy_ps_model.organize_parameters(params)
# TODO: Better handling with batch of parameters
# Right now, this if-statement will trigger if any of the batch of parameters
# are out of bounds
if (self.sampling_type == "hypercube" and \
torch.any(org_params < self.input_normalizations[0]) or \
torch.any(org_params > self.input_normalizations[1])) or \
(self.sampling_type == "hypersphere" and \
not torch.any(is_in_hypersphere(self.emu_param_bounds, params)[0])):
if extrapolate:
self.logger.warning("Input parameters out of bounds! Emulator output will be untrustworthy")
else:
self.logger.info("Input parameters out of bounds! Skipping emulation...")
skip_emulation = True
norm_params = normalize_cosmo_params(org_params, self.input_normalizations)
return norm_params, skip_emulation
def _check_training_set(self, data:pk_galaxy_dataset):
"""checks that loaded-in data for training / validation / testing is compatable with the given network config
Raises:
ValueError: If a given property of the training set does not match with the emulator.
"""
if len(data.cosmo_params) != self.num_cosmo_params:
raise ValueError("num_cosmo_params mismatch with training dataset! {:d} vs {:d}".format(len(data.cosmo_params), self.num_cosmo_params))
if len(data.bias_params) != self.num_nuisance_params*self.num_tracers*self.num_zbins:
raise ValueError("num_nuisance_params mismatch with training dataset! {:d} vs {:d}".format(len(data.bias_params), self.num_nuisance_params*self.num_tracers*self.num_zbins))
if data.num_spectra != self.num_spectra:
raise(ValueError("num_spectra (derived from num_tracers) mismatch with training dataset! {:d} vs {:d}".format(data.num_spectra, self.num_spectra)))
if data.num_zbins != self.num_zbins:
raise(ValueError("num_ells mismatch with training dataset! {:d} vs {:d}".format(data.num_zbins, self.num_zbins)))
if data.num_ells != self.num_ells:
raise(ValueError("num_ells mismatch with training dataset! {:d} vs {:d}".format(data.num_ells, self.num_ells)))
if data.num_kbins != self.num_kbins:
raise(ValueError("num_ells mismatch with training dataset! {:d} vs {:d}".format(data.num_kbins, self.num_kbins)))
# --------------------------------------------------------------------------
# extra helper function (TODO: Find a better place for this)
# --------------------------------------------------------------------------
[docs]
def compile_multiple_device_training_results(save_dir:str, config_dir:str, num_gpus:int):
"""takes networks saved on seperate ranks and combines them to the same format as when training on one device
Args:
save_dir (string): base save directory, where each rank was saved in its own sub-directory
config_dir (string): path+name of the original network config file
num_gpus (int): number of gpus to compile results of
Returns:
full_emulator (ps_emulator): emulator object with all training data combined together.
"""
full_emulator = ps_emulator(config_dir, "train")
full_emulator.galaxy_ps_model.eval()
net_idx = torch.Tensor(list(itertools.product(range(full_emulator.num_spectra), range(full_emulator.num_zbins)))).to(int)
split_indices = net_idx.chunk(num_gpus)
full_emulator.train_loss = torch.zeros((full_emulator.num_spectra, full_emulator.num_zbins, full_emulator.num_epochs))
full_emulator.valid_loss = torch.zeros((full_emulator.num_spectra, full_emulator.num_zbins, full_emulator.num_epochs))
full_emulator.train_time = 0.
for n in range(num_gpus):
sub_dir = "rank_"+str(n)
seperate_network = ps_emulator(os.path.join(save_dir,sub_dir), "eval")
# power spectrum properties used by analytic_terms.py
if n == 0:
ps_properties = np.load(os.path.join(save_dir, sub_dir, "ps_properties.npz"))
full_emulator.k_emu = ps_properties["k"]
full_emulator.ells = ps_properties["ells"]
full_emulator.z_eff = ps_properties["z_eff"]
full_emulator.ndens = ps_properties["ndens"]
# galaxy power spectrum networks
for (ps, z) in split_indices[n]:
net_idx = (z * full_emulator.num_spectra) + ps
full_emulator.galaxy_ps_model.networks[net_idx] = seperate_network.galaxy_ps_model.networks[net_idx]
train_data = torch.load(os.path.join(save_dir,sub_dir,"training_statistics/train_data_"+str(int(ps))+"_"+str(int(z))+".dat"), weights_only=True)
epochs = train_data.shape[1]
full_emulator.train_loss[ps, z, :epochs] = train_data[0,:]
full_emulator.valid_loss[ps, z, :epochs] = train_data[1,:]
full_emulator.train_time = train_data[2,0]
full_emulator.galaxy_ps_checkpoint = copy.deepcopy(full_emulator.galaxy_ps_model.state_dict())
return full_emulator