ferritin_plms/esm/layers/
regression_head.rsuse crate::esm::models::esmc::ESMCConfig;
use candle_core::Tensor;
use candle_nn::{self as nn, LayerNormConfig, Module, Sequential, VarBuilder};
pub struct RegressionHead {
model: Sequential,
}
impl RegressionHead {
pub fn load(vb: VarBuilder, config: &ESMCConfig) -> candle_core::Result<Self> {
let ESMCConfig {
d_model,
regression_head_output_dim,
regression_head_hidden_dim,
..
} = config;
let linear1 = nn::linear(*d_model, *regression_head_hidden_dim, vb.pp("0"))?;
let gelu = candle_nn::Activation::Gelu;
let ln_conf = LayerNormConfig::from(1e-5);
let norm = nn::layer_norm(*regression_head_hidden_dim, ln_conf, vb.pp("2"))?;
let linear2 = nn::linear(
*regression_head_hidden_dim,
*regression_head_output_dim,
vb.pp("3"),
)?;
let model = nn::seq().add(linear1).add(gelu).add(norm).add(linear2);
Ok(Self { model })
}
}
impl Module for RegressionHead {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
self.model.forward(x)
}
}