Stacked Transformer Model

A set of classes defining the full MLP + Transformer network used in Adamo et al (2025)

class mentat_lss.models.stacked_transformer.single_transformer(config_dict, is_cross_spectra: bool)[source]

Class defining a single independent transformer network

forward(input_params: Tensor)[source]

Passes an input tensor through the network

class mentat_lss.models.stacked_transformer.stacked_transformer(config_dict)[source]

Class defining a stack of single_transformer objects, one for each portion of the power spectrum output

forward(input_params, net_idx=None)[source]

Passes an input tensor through the network

organize_parameters(input_params)[source]

Organizes input cosmology + bias parameters into a form the rest of the network expects

Parameters:

input_params – tensor of input parameters with shape [batch, num_cosmo_params*(num_nuisance_params*num_zbins*num_tracers)]

Returns:

tensor of input parameters with shape [batch, num_spectra*num_zbins, num_cosmo_params + (2*self.num_nuisance_params)].

The bias parameters are split corresponding to their respective redshift / tracer bin

Return type:

organized_params