Source code for bonni.model.mlp
from dataclasses import dataclass
from flax import linen as nn
import jax
from flax.linen import initializers
from bonni.model.utils import ActivationType, InitType, get_activation_fn, get_init_fn
class MLPLayer(nn.Module):
out_channels: int
activation_type: ActivationType
norm_groups: int | None
dropout_prob: float = 0.0
bias_init: InitType = InitType.zeros
skip_if_possible: bool = False
def setup(self):
bias_init_fn = get_init_fn(self.bias_init)
self.fc1 = nn.Dense(
features=self.out_channels,
bias_init=bias_init_fn,
kernel_init=initializers.he_normal(),
)
self.norm = None
if self.norm_groups is not None:
self.norm = nn.GroupNorm(
num_groups=self.norm_groups,
epsilon=1e-5, # Default epsilon value
)
self.activation = get_activation_fn(self.activation_type)
self.dropout = None
if self.dropout_prob > 0:
self.dropout = nn.Dropout(
rate=self.dropout_prob,
)
def __call__(
self,
x: jax.Array,
deterministic: bool = False,
) -> jax.Array:
# Apply first linear transformation
y = self.fc1(x)
# group norm
if self.norm is not None:
y = self.norm(y)
# Apply activation
y = self.activation(y)
if self.skip_if_possible and x.shape == y.shape:
y = x + y
# Dropout
if self.dropout is not None:
y = self.dropout(y, deterministic=deterministic)
return y
[docs]
@dataclass(frozen=True, kw_only=True)
class MLPModelConfig:
"""
Configuration object for a Multi-Layer Perceptron (MLP) model.
This dataclass defines the structural and hyperparameter settings for an MLP,
including layer dimensions, normalization, dropout, and activation strategies.
It is frozen (immutable) and requires keyword arguments for initialization.
Attributes:
num_layer (int): The total number of linear layers in the MLP.
out_channels (int): The dimensionality of the output features.
hidden_channels (int | None): The dimensionality of the hidden layers.
If None, this is typically inferred from the input or output channels
depending on the implementation. Defaults to None.
norm_groups (int | None): The number of groups to use for Group Normalization
in the hidden layers. If None, normalization is skipped. Defaults to None.
last_norm_groups (int | None): The number of groups for Group Normalization
applied to the final layer. If None, no normalization is applied to the
output. Defaults to None.
dropout_prob (float): The dropout probability applied after hidden layers.
Must be between 0.0 and 1.0. Defaults to 0.0.
last_dropout_prob (float): The dropout probability applied after the final layer.
Defaults to 0.0.
activation_type (ActivationType): The activation function used after hidden layers.
Defaults to ActivationType.gelu.
different_last_activation (ActivationType | None): The activation function
used after the final layer. If set to `ActivationType.identity`, the output
is linear. If None, the model typically uses the same activation as
`activation_type`. Defaults to ActivationType.identity.
bias_init (InitType): The initialization strategy for the layer biases
(e.g., zeros, uniform). Defaults to InitType.zeros.
skip_if_possible (bool): If True, adds residual connections (skip connections)
around layers where the input and output dimensions are identical.
Defaults to True.
"""
num_layer: int
out_channels: int
hidden_channels: int | None = None
norm_groups: int | None = None
last_norm_groups: int | None = None
dropout_prob: float = 0.0
last_dropout_prob: float = 0.0
activation_type: ActivationType = ActivationType.gelu
different_last_activation: ActivationType | None = (
ActivationType.identity
) # if none, use same activation
bias_init: InitType = InitType.zeros
skip_if_possible: bool = True
class MLP(nn.Module):
cfg: MLPModelConfig
def setup(self):
if self.cfg.num_layer > 1:
assert self.cfg.hidden_channels is not None, "need hidden dim with >1 layer"
layers = []
for idx in range(self.cfg.num_layer):
# select activation
cur_activ = (
self.cfg.activation_type
if idx != self.cfg.num_layer - 1
or self.cfg.different_last_activation is None
else self.cfg.different_last_activation
)
# select out channels
cur_out_channels = self.cfg.out_channels
if idx != self.cfg.num_layer - 1:
assert self.cfg.hidden_channels is not None
cur_out_channels = self.cfg.hidden_channels
# select dropout prob, norm_groups
dropout_prob = (
self.cfg.dropout_prob
if idx != self.cfg.num_layer - 1
else self.cfg.last_dropout_prob
)
norm_groups = (
self.cfg.norm_groups
if idx != self.cfg.num_layer - 1
else self.cfg.last_norm_groups
)
cur_skip_if_possible = (
self.cfg.skip_if_possible if idx != self.cfg.num_layer - 1 else False
)
# build current layer
cur_layer = MLPLayer(
out_channels=cur_out_channels,
activation_type=cur_activ,
norm_groups=norm_groups,
dropout_prob=dropout_prob,
bias_init=self.cfg.bias_init,
skip_if_possible=cur_skip_if_possible,
)
layers.append(cur_layer)
self.layers = layers
def __call__(
self,
x: jax.Array,
deterministic: bool = False,
) -> jax.Array:
for layer in self.layers:
x = layer(x, deterministic=deterministic)
return x