Network Blocks

This page describe various classes defining neural net layers used by the emulator.

class mentat_lss.models.blocks.activation_function(d: int)[source]

Custom nonlinear activation function

forward(X: Tensor)[source]

Passes through the activation function

Parameters:

X (torch.Tensor) – Input to the function. Should be shape (batch_size, d)

Returns:

Output of the function. Has shape (batch_size, d)

Return type:

X (torch.Tensor)

class mentat_lss.models.blocks.block_addnorm(shape, dropout_prob=0.0)[source]
forward(X, Y)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class mentat_lss.models.blocks.block_resnet(input_dim: int, output_dim: int, num_layers: int, skip_connection: bool = True)[source]
forward(X: Tensor)[source]

Passes through the block

Parameters:

X (torch.Tensor) – input to the block. Should have shape (batch_size, input_dim)

Returns:

output of the block. Has shape (batch_size, output_dim)

Return type:

X (torch.Tensor)

class mentat_lss.models.blocks.block_transformer_encoder(embedding_dim: int, split_dim: int, dropout_prob=0.0)[source]

Custom transformer encoder class

forward(X: Tensor)[source]

Passes through the transformer block

Parameters:

X (torch.Tensor) – Input to the block. Should have shape (batch_size, embedding_dim)

Returns:

Output of the block. Has shape (batch_size, embedding_dim)

Return type:

X (torch.Tensor)

class mentat_lss.models.blocks.linear_with_channels(input_dim: int, output_dim: int, num_channels: int)[source]

Class for independent MLP layers passed-through in parallel

forward(X: Tensor)[source]

passes through the layer

Parameters:

X (torch.Tensor) – Input to the layer. Should have shape (batch_size, num_channels, input_dim)

Returns:

(torch.Tensor): output of the layer. Has shape (batch_size, num_channels, output_dim)

Return type:

X

initialize_params(weight_initialization)[source]

function for initializing layer weights, since pytorch struggles to do so automatically

class mentat_lss.models.blocks.multi_headed_attention(hidden_dim, num_heads=2, dropout_prob=0.0)[source]
forward(queries, keys, values)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

transpose_output(X)[source]

Reverse the operation of transpose_qkv.

transpose_qkv(X)[source]

Transposition for parallel computation of multiple attention heads.