ferritin_plms/esm/layers/transformer_stack.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
use crate::esm::layers::blocks::UnifiedTransformerBlock;
use crate::esm::models::esmc::ESMCConfig;
// use crate::esm::utils::structure::affine3d::Affine3D;
use candle_core::Result;
use candle_nn::{self as nn, LayerNorm, LayerNormConfig};
pub struct TransformerStack {
/*
A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
which can either be geometric attention or standard multi-head attention.
Args:
d_model (i64): The dimensionality of the input and output feature vectors.
n_heads (i64): The number of attention heads.
v_heads (Option<i64>): The number of voting heads.
n_layers (i64): The number of transformer blocks in the stack.
n_layers_geom (i64, optional): The number of transformer blocks that use geometric attention.
scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
Only applies in the geometric attention blocks, which is conditioned on the structure
*/
blocks: Vec<UnifiedTransformerBlock>,
norm: LayerNorm,
}
impl TransformerStack {
pub fn load(vb: nn::VarBuilder, config: &ESMCConfig) -> Result<Self> {
let ESMCConfig {
d_model, n_layers, ..
} = config;
let mut blocks = Vec::with_capacity(*n_layers);
for i in 0..*n_layers {
blocks.push(UnifiedTransformerBlock::load(
vb.pp(format!("blocks.{}", i)),
&config,
i,
)?);
}
// let ln_conf = LayerNormConfig::from(1e-5);
let ln_conf = LayerNormConfig {
eps: 1e-5,
remove_mean: true,
affine: false,
};
let norm = nn::layer_norm(*d_model, ln_conf, vb.pp("norm"))?;
Ok(Self { blocks, norm })
}
// pub fn new(
// d_model: i64,
// n_heads: i64,
// v_heads: Option<i64>,
// n_layers: i64,
// n_layers_geom: i64,
// scale_residue: bool,
// mask_and_zero_frameless: bool,
// bias: bool,
// qk_layernorm: bool,
// ffn_type: &str,
// expansion_ratio: f64,
// ) -> Result<Self> {
// let mut blocks = Vec::with_capacity(n_layers as usize);
// for i in 0..n_layers {
// blocks.push(UnifiedTransformerBlock::new(
// d_model,
// n_heads,
// v_heads,
// i < n_layers_geom,
// if scale_residue {
// (n_layers as f64 / 36.0).sqrt()
// } else {
// 1.0
// },
// expansion_ratio,
// mask_and_zero_frameless,
// bias,
// qk_layernorm,
// ffn_type,
// )?);
// }
// let norm = nn::LayerNorm::new(d_model, 1e-5, false)?;
// Ok(Self { blocks, norm })
// }
// pub fn forward(
// &self,
// x: &Tensor,
// sequence_id: Option<&Tensor>,
// affine: Option<&Affine3D>,
// affine_mask: Option<&Tensor>,
// chain_id: Option<&Tensor>,
// ) -> Result<(Tensor, Tensor)> {
// let mut x = x.clone();
// let chain_id = if chain_id.is_none() {
// let batch_dims = x.shape().split_last().unwrap().1;
// Tensor::ones(batch_dims, (x.device(), DType::I64))?
// } else {
// chain_id.unwrap().clone()
// };
// for block in self.blocks.iter() {
// x = block.forward(&x, sequence_id, affine, affine_mask, &chain_id)?;
// }
// let normalized = self.norm.forward(&x)?;
// Ok((normalized, x))
// }
}