Source code for bonni.model.utils
from enum import Enum
from typing import Callable
import jax
import jax.numpy as jnp
from flax import linen as nn
class SkipConnectionType(Enum):
linear = "linear"
identity = "identity"
def get_skip_connection(
in_channels: int,
out_channels: int,
skip_type: SkipConnectionType,
):
"""Create a skip connection module."""
if skip_type == SkipConnectionType.linear:
return _LinearSkipConnection(in_channels, out_channels)
if skip_type == SkipConnectionType.identity:
if in_channels == out_channels:
return lambda x: x
else:
return _LinearSkipConnection(in_channels, out_channels)
raise ValueError(f"Unsupported skip connection type: {skip_type}")
class _LinearSkipConnection(nn.Module):
"""A linear skip connection implemented as a flax.linen Module."""
in_features: int
out_features: int
def setup(self):
self.linear = nn.Dense(
features=self.out_features,
use_bias=False,
kernel_init=nn.initializers.lecun_normal(),
)
def __call__(self, x):
return self.linear(x)
[docs]
class ActivationType(Enum):
"""
Enumeration of supported activation functions for neural network layers.
These values are used to configure the non-linearity applied after linear
transformations in the model configuration.
Attributes:
identity: Applies no activation (f(x) = x). typically used for the final
output layer to produce unbounded linear predictions.
gelu: Gaussian Error Linear Unit. A smooth approximation of ReLU often
used in Transformer architectures and modern MLPs.
relu: Rectified Linear Unit (f(x) = max(0, x)). A standard non-linear
activation that introduces sparsity.
leaky_relu: Leaky Rectified Linear Unit. Similar to ReLU but allows a
small, non-zero gradient when the unit is not active.
sigmoid: Sigmoid function. Squashes values to the range [0, 1], often
used for binary classification probabilities.
tanh: Hyperbolic Tangent. Squashes values to the range [-1, 1].
"""
identity = "identity"
gelu = "gelu"
relu = "relu"
leaky_relu = "leaky_relu"
sigmoid = "sigmoid"
tanh = "tanh"
def get_activation_fn(
activation_type: ActivationType,
) -> Callable[[jax.Array], jax.Array]:
if activation_type == ActivationType.identity:
return lambda x: x
if activation_type == ActivationType.gelu:
return jax.nn.gelu
if activation_type == ActivationType.relu:
return jax.nn.relu
if activation_type == ActivationType.leaky_relu:
return jax.nn.leaky_relu
if activation_type == ActivationType.sigmoid:
return jax.nn.sigmoid
if activation_type == ActivationType.tanh:
return jax.nn.tanh
raise ValueError(f"Invalid activation type: {activation_type}")
[docs]
class InitType(Enum):
"""
Enumeration of initialization strategies for model parameters.
These values define how weights or biases are initialized before training begins.
Used primarily for `bias_init` in the model configuration.
Attributes:
zeros: Initializes parameters to exactly 0. This is the standard practice
for bias terms in most neural network layers.
ones: Initializes parameters to exactly 1.
uniform: Initializes parameters with values drawn from a uniform distribution.
The range is typically determined by the specific layer implementation.
normal: Initializes parameters with values drawn from a normal (Gaussian)
distribution.
"""
zeros = "zeros"
ones = "ones"
uniform = "uniform"
normal = "normal"
def get_init_fn(
init_type: InitType,
):
if init_type == InitType.zeros:
return nn.initializers.zeros
if init_type == InitType.ones:
return nn.initializers.ones
if init_type == InitType.uniform:
scale = 0.01
def init(key, shape, dtype=jnp.float64, out_sharding=None) -> jax.Array:
return jax.random.uniform(
key, shape, dtype=dtype, out_sharding=out_sharding
) * jnp.array(scale, dtype)
return jax.tree_util.Partial(init)
if init_type == InitType.normal:
scale = 0.1
def init(key, shape, dtype=jnp.float64, out_sharding=None) -> jax.Array:
return jax.random.normal(
key, shape, dtype=dtype, out_sharding=out_sharding
) * jnp.array(scale, dtype)
return jax.tree_util.Partial(init)
raise ValueError(f"Invalid init type: {init_type}")