Stacked Transformer Model
A set of classes defining the full MLP + Transformer network used in Adamo et al (2025)
- class mentat_lss.models.stacked_transformer.single_transformer(config_dict, is_cross_spectra: bool)[source]
Class defining a single independent transformer network
- class mentat_lss.models.stacked_transformer.stacked_transformer(config_dict)[source]
Class defining a stack of single_transformer objects, one for each portion of the power spectrum output
- organize_parameters(input_params)[source]
Organizes input cosmology + bias parameters into a form the rest of the network expects
- Parameters:
input_params – tensor of input parameters with shape [batch, num_cosmo_params*(num_nuisance_params*num_zbins*num_tracers)]
- Returns:
- tensor of input parameters with shape [batch, num_spectra*num_zbins, num_cosmo_params + (2*self.num_nuisance_params)].
The bias parameters are split corresponding to their respective redshift / tracer bin
- Return type:
organized_params