Source code for mentat_lss.dataset

import torch
import numpy as np

from mentat_lss.utils import load_config_file, normalize_power_spectrum, un_normalize_power_spectrum

[docs] class pk_galaxy_dataset(torch.utils.data.Dataset): """Custom dataset storing large sets of galaxy power spectrum multipoles""" def __init__(self, data_dir:str, type:str, frac=1.): """Initializes the dataset. Args: data_dir (str): location to load the data to load type (str): One of ["training", "validation", "testing"]. Determines which data file to try loading. frac (float, optional): fraction of the full dataset to laod in. Defaults to 1.. """ self._load_data(data_dir, type, frac) def _load_data(self, data_dir:str, type:str, frac:float): """Loads in galaxy power spectrum multipoles from disk. Args: data_dir (str): location to load the data to load type (str): One of ["Training", "Validation", "Testing"]. Determines which data file to try loading. frac (float, optional): fraction of the full dataset to laod in. Defaults to 1.. Raises: KeyError: If type is invalid. """ if type.lower()=="training": file = data_dir+"pk-training.npz" elif type.lower()=="validation": file = data_dir+"pk-validation.npz" elif type.lower()=="testing": file = data_dir+"pk-testing.npz" else: raise KeyError(f"Invalid dataset type! Must be [training, validation, testing], but got {type.lower()}") data = np.load(file) self.params = torch.from_numpy(data["params"]).to(torch.float32) self.galaxy_ps = torch.from_numpy(data["galaxy_ps"]).to(torch.float32) del data header_info = load_config_file(data_dir+"info.yaml") self.cosmo_params = header_info["cosmo_params"] self.bias_params = header_info["nuisance_params"] self.num_spectra = self.galaxy_ps.shape[1] self.num_zbins = self.galaxy_ps.shape[2] self.num_kbins = self.galaxy_ps.shape[3] self.num_ells = self.galaxy_ps.shape[4] if frac != 1.: N_frac = int(self.params.shape[0] * frac) self.params = self.params[0:N_frac] self.galaxy_ps = self.galaxy_ps[0:N_frac] def __len__(self): """Returns the number of samples in the dataset Returns: len (int): number of samples in the dataset """ return self.params.shape[0] def __getitem__(self, idx): """Returns specific items from the dataset Args: idx (int or torch.Tensor): index (or set od indices) to access Returns: params (torch.Tensor): input cosmology and bias parameters corresponding to idx galaxy_ps (torch.Tensor): (normalized) power spectrum multipoles corresponding to idx nw_ps (torch.Tensor): (normalized non-wiggle linear power spectra corresponding to idx. NOTE: in developement idx (int or torch.Tensor): The index of the corresponding data. """ return self.params[idx], self.galaxy_ps[idx], idx
[docs] def to(self, device:torch.device): """send data to the specified device, similar to the corresponding method for Tensors Args: device (torch.device): device to send the data to. """ self.params = self.params.to(device) self.galaxy_ps = self.galaxy_ps.to(device)
[docs] def normalize_data(self, ps_fid:torch.Tensor, sqrt_eigvals:torch.Tensor, Q:torch.Tensor): """Normalizes the reshapes the data Args: ps_fid (torch.Tensor): fiducial power spectrum multipoles in units of (Mpc/h)^3 used for normalization. Should have shape [nps, z, nk*nl] ps_nw_fid (torch.Tensor): NOTE: currently not used. sqrt_eigvals (torch.Tensor): set of sqrt eigenvalues used for normalization. Should have shape [ps, z, nk*nl] Q (torch.Tensor): set of eigenvectors used for normalization. Should have shape [ps, z, nk*nl, nk*nl] """ self.galaxy_ps = normalize_power_spectrum(torch.flatten(self.galaxy_ps, start_dim=3), ps_fid, sqrt_eigvals, Q) self.galaxy_ps = self.galaxy_ps.reshape(-1, self.num_spectra, self.num_zbins, self.num_kbins*self.num_ells)
#self.nw_ps = (self.nw_ps / ps_nw_fid) - 1.
[docs] def get_normalized_galaxy_power_spectra(self, idx): """Returns the normalized power spectrum multipoles corresponding to idx Args: idx (int or torch.Tensor): index (or set of indexes) to access Returns: galaxy_ps[idx] (torch.Tensor): normalized power spectrum to access. """ if isinstance(idx, int): return torch.flatten(self.galaxy_ps[idx], start_dim=2) else: return torch.flatten(self.galaxy_ps[idx], start_dim=3)
[docs] def get_true_galaxy_power_spectra(self, idx, ps_fid:torch.Tensor, sqrt_eigvals:torch.Tensor, Q:torch.Tensor, Q_inv:torch.Tensor): """Returns the galaxy power spectrum multipoles in units of (Mpc/h)^3 corresponding to idx Args: idx (int or torch.Tensor): index (or set of indexes) to access ps_fid (torch.Tensor): fiducial power spectrum used to reverse normalization. Expected shape is [nps*nz, nk*nl] sqrt_eigvals (torch.Tensor): square root eigenvalues of the inverse covariance matrix. Expected shape is [nps*nz, nk*nl] Q (torch.Tensor): eigenvectors of the inverse covariance matrix. Expected shape is [nps*nz, nk*nl, nk*nl] Q_inv (torch.Tensor): inverse eigenvectors of the inverse covariance matrix. Expected shape is [nps*nz, nk*nl, nk*nl] Returns: galaxy_ps[idx] (torch.Tensor): galaxy power spectrum in units of (Mpc/h)^3 to access. has shape [b, nps, nz, nk, nl] or [nps, nz, nk, nl] """ if isinstance(idx, int): flatten_dim = 2 final_shape = (self.num_spectra, self.num_zbins, self.num_kbins, self.num_ells) else: flatten_dim = 3 final_shape = (-1, self.num_spectra, self.num_zbins, self.num_kbins, self.num_ells) ps_true = un_normalize_power_spectrum(torch.flatten(self.galaxy_ps[idx], start_dim=flatten_dim), ps_fid, sqrt_eigvals, Q, Q_inv) ps_true = ps_true.reshape(final_shape) return ps_true