ferritin_amplify/amplify/
config.rsuse candle_nn::Activation;
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct AMPLIFYConfig {
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub dropout_prob: f64,
pub embedding_init_range: f64,
pub decoder_init_range: f64,
pub rms_norm: bool,
pub norm_eps: f64,
pub hidden_act: Activation,
pub layer_norm_after_embedding: bool,
pub layer_norm_before_last_layer: bool,
pub vocab_size: usize,
pub ffn_bias: bool,
pub att_bias: bool,
pub pad_token_id: usize,
pub max_length: usize,
}
impl Default for AMPLIFYConfig {
fn default() -> Self {
AMPLIFYConfig::amp_120m()
}
}
impl AMPLIFYConfig {
pub fn amp_120m() -> Self {
Self {
hidden_size: 640,
num_hidden_layers: 24,
num_attention_heads: 10,
intermediate_size: 2560,
dropout_prob: 0.0,
embedding_init_range: 0.02,
decoder_init_range: 0.02,
rms_norm: true,
norm_eps: 1e-5,
hidden_act: Activation::Swiglu,
layer_norm_after_embedding: false,
layer_norm_before_last_layer: true,
vocab_size: 27,
ffn_bias: false,
att_bias: false,
pad_token_id: 0,
max_length: 2048,
}
}
pub fn amp_350m() -> Self {
Self {
hidden_size: 960,
num_hidden_layers: 32,
num_attention_heads: 15,
intermediate_size: 3840,
dropout_prob: 0.0,
embedding_init_range: 0.02,
decoder_init_range: 0.02,
rms_norm: true,
norm_eps: 1e-5,
hidden_act: Activation::Swiglu,
layer_norm_after_embedding: false,
layer_norm_before_last_layer: true,
vocab_size: 27,
ffn_bias: false,
att_bias: false,
pad_token_id: 0,
max_length: 2048,
}
}
}