use super::rotary::{apply_rotary_emb, precompute_freqs_cis};
use candle_core::{Module, Result, Tensor, D};
use candle_nn::{
embedding, linear, linear_no_bias, ops::softmax, rms_norm, Activation, Dropout, Embedding,
Linear, RmsNorm, VarBuilder,
};
#[derive(Debug, Clone)]
pub struct AMPLIFYConfig {
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub dropout_prob: f64,
pub embedding_init_range: f64,
pub decoder_init_range: f64,
pub rms_norm: bool,
pub norm_eps: f64,
pub hidden_act: Activation,
pub layer_norm_after_embedding: bool,
pub layer_norm_before_last_layer: bool,
pub vocab_size: usize,
pub ffn_bias: bool,
pub att_bias: bool,
pub pad_token_id: usize,
pub max_length: usize,
}
impl Default for AMPLIFYConfig {
fn default() -> Self {
Self {
hidden_size: 960,
num_hidden_layers: 32,
num_attention_heads: 15,
intermediate_size: 3840,
dropout_prob: 0.0,
embedding_init_range: 0.02,
decoder_init_range: 0.02,
rms_norm: true,
norm_eps: 1e-5,
hidden_act: Activation::Swiglu,
layer_norm_after_embedding: false,
layer_norm_before_last_layer: true,
vocab_size: 27,
ffn_bias: false,
att_bias: false,
pad_token_id: 0,
max_length: 2048,
}
}
}
impl AMPLIFYConfig {
pub fn amp_120m() -> Self {
Self {
hidden_size: 640,
num_hidden_layers: 24,
num_attention_heads: 10,
intermediate_size: 2560,
dropout_prob: 0.0,
embedding_init_range: 0.02,
decoder_init_range: 0.02,
rms_norm: true,
norm_eps: 1e-5,
hidden_act: Activation::Swiglu,
layer_norm_after_embedding: false,
layer_norm_before_last_layer: true,
vocab_size: 27,
ffn_bias: false,
att_bias: false,
pad_token_id: 0,
max_length: 2048,
}
}
pub fn amp_350m() -> Self {
AMPLIFYConfig::default()
}
}
pub struct EncoderBlock {
q: Linear,
k: Linear,
v: Linear,
wo: Linear,
resid_dropout: Dropout,
w12: Linear,
w3: Linear,
ffn_norm: RmsNorm,
attention_norm: RmsNorm,
ffn_dropout: Dropout,
d_head: usize,
config: AMPLIFYConfig,
}
impl EncoderBlock {
pub fn new(config: &LIFYConfig, vb: VarBuilder, layer: i32) -> Result<Self> {
let multiple_of = 8;
let intermediate_size = (config.intermediate_size * 2) / 3;
let intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) / multiple_of);
let vb = vb.pp(layer);
let q = linear(config.hidden_size, config.hidden_size, vb.pp("q"))?;
let k = linear(config.hidden_size, config.hidden_size, vb.pp("k"))?;
let v = linear(config.hidden_size, config.hidden_size, vb.pp("v"))?;
let wo = linear(config.hidden_size, config.hidden_size, vb.pp("wo"))?;
let w12 = linear_no_bias(intermediate_size * 2, config.hidden_size, vb.pp("ffn.w12"))?;
let w3 = linear_no_bias(config.hidden_size, intermediate_size, vb.pp("ffn.w3"))?;
let ffn_norm = rms_norm(config.hidden_size, config.norm_eps, vb.pp("ffn_norm"))?;
let attention_norm =
rms_norm(config.hidden_size, config.norm_eps, vb.pp("attention_norm"))?;
Ok(Self {
q,
k,
v,
wo,
resid_dropout: Dropout::new(config.dropout_prob as f32),
w12,
w3,
attention_norm,
ffn_norm,
ffn_dropout: Dropout::new(config.dropout_prob as f32),
d_head: config.hidden_size / config.num_attention_heads,
config: config.clone(), })
}
pub fn forward(
&self,
x: &Tensor,
pad_mask: Option<&Tensor>,
freqs_cis: &Tensor,
output_attentions: bool,
) -> Result<(Tensor, Option<Tensor>)> {
let normed = self.attention_norm.forward(x)?;
let (attn, contacts) =
self.attention_block(&normed, pad_mask, freqs_cis, output_attentions)?;
let x = x.add(&attn)?;
let normed = self.ffn_norm.forward(&x)?;
let ffn_output = self.ffn_forward(&normed)?;
let ff = self.ffn_dropout.forward(&ffn_output, false)?; let x = x.add(&ff)?;
Ok((x, contacts))
}
fn ffn_forward(&self, x: &Tensor) -> Result<Tensor> {
let dims = x.dims();
let batch_shape = &dims[..dims.len() - 1];
let x_flat = self.flatten_last_dim(&x)?;
let w12_out = self.w12.forward(&x_flat)?;
let chunks = w12_out.chunk(2, 1)?;
let x1 = &chunks[0];
let x2 = &chunks[1];
let hidden = x1.silu()?.mul(x2)?;
let output = self.w3.forward(&hidden)?;
let mut new_shape = batch_shape.to_vec();
new_shape.push(output.dim(1)?);
output.reshape(new_shape)
}
fn flatten_last_dim(&self, x: &Tensor) -> Result<Tensor> {
let dims = x.dims();
let last_dim = dims[dims.len() - 1];
let total_elements = dims.iter().product::<usize>();
let first_dim = total_elements / last_dim;
x.reshape((first_dim, last_dim))
}
fn scaled_dot_product_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attn_mask: Option<&Tensor>,
dropout_p: f64,
is_causal: bool,
) -> Result<Tensor> {
let d_k = key.dim(key.dims().len() - 1)? as f64;
let scaling = 1.0 / d_k.sqrt();
let scores = (query.matmul(&key.transpose(D::Minus2, D::Minus1)?)? * scaling)?;
if let Some(mask) = attn_mask {
let scores = scores.add(mask)?;
}
let attn = candle_nn::ops::softmax(&scores, scores.dims().len() - 1)?;
let attn = if dropout_p > 0.0 {
candle_nn::ops::dropout(&attn, dropout_p as f32)?
} else {
attn
};
attn.matmul(value)
}
fn attention_block(
&self,
x: &Tensor,
pad_mask: Option<&Tensor>,
freqs_cis: &Tensor,
output_attentions: bool,
) -> Result<(Tensor, Option<Tensor>)> {
let (batch_size, seq_len, _) = x.dims3()?;
let xq = self.q.forward(x)?; let xk = self.k.forward(x)?;
let xv = self.v.forward(x)?;
let xq = xq.reshape((
batch_size,
seq_len,
self.config.num_attention_heads,
self.d_head,
))?;
let xk = xk.reshape((
batch_size,
seq_len,
self.config.num_attention_heads,
self.d_head,
))?;
let xv = xv.reshape((
batch_size,
seq_len,
self.config.num_attention_heads,
self.d_head,
))?;
let (xq, xk) = apply_rotary_emb(&xq, &xk, &freqs_cis)?;
let dropout_prob = self.config.dropout_prob;
let pad_mask = if let Some(mask) = pad_mask {
let (batch_size, seq_len) = (x.dim(0)?, x.dim(1)?);
let num_heads = self.config.num_attention_heads;
let mask = mask
.unsqueeze(1)? .unsqueeze(1)? .expand((batch_size, num_heads, seq_len, seq_len))?; Some(mask)
} else {
None
};
let attn = self.scaled_dot_product_attention(
&xq.permute((0, 2, 1, 3))?,
&xk.permute((0, 2, 1, 3))?,
&xv.permute((0, 2, 1, 3))?,
pad_mask.as_ref(),
dropout_prob,
false,
)?;
let attn = attn.permute((0, 2, 1, 3))?;
let _attn = if output_attentions {
let xq_t = xq.permute((0, 2, 1, 3))?;
let xk_t = xk.permute((0, 2, 3, 1))?;
let mut attn_weights = xq_t.matmul(&xk_t)?;
let scale = (xq.dim(D::Minus1)? as f64).sqrt();
attn_weights = (attn_weights / scale)?;
Some(softmax(&attn_weights, D::Minus1)?)
} else {
None
};
let output = attn.reshape((
batch_size,
seq_len,
self.config.num_attention_heads * self.d_head,
))?;
let output01 = self.wo.forward(&output)?;
let output02 = self.resid_dropout.forward(&output01, false)?;
Ok((output02, _attn))
}
pub fn load(vb: VarBuilder, cfg: &LIFYConfig, layer: i32) -> Result<Self> {
let multiple_of = 8;
let intermediate_size = (cfg.intermediate_size * 2) / 3;
let intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) / multiple_of);
let vb = vb.pp(layer); let q = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("q"))?;
let k = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("k"))?;
let v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("v"))?;
let wo = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("wo"))?;
let w12 = linear_no_bias(cfg.hidden_size, intermediate_size * 2, vb.pp("ffn.w12"))?;
let w3 = linear_no_bias(intermediate_size, cfg.hidden_size, vb.pp("ffn.w3"))?;
let ffn_norm = rms_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("ffn_norm"))?;
let attention_norm = rms_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("attention_norm"))?;
Ok(Self {
q,
k,
v,
wo,
resid_dropout: Dropout::new(cfg.dropout_prob as f32),
w12,
w3,
attention_norm,
ffn_norm,
ffn_dropout: Dropout::new(cfg.dropout_prob as f32),
d_head: cfg.hidden_size / cfg.num_attention_heads,
config: cfg.clone(),
})
}
}
pub struct AMPLIFY {
encoder: Embedding,
transformer_encoder: Vec<EncoderBlock>,
layer_norm_2: RmsNorm,
decoder: Linear,
freqs_cis: Tensor,
config: AMPLIFYConfig,
}
impl AMPLIFY {
pub fn new(config: &LIFYConfig, vb: VarBuilder) -> Result<Self> {
unimplemented!()
}
fn process_attention_mask(
&self,
pad_mask: Option<&Tensor>,
num_attention_heads: i64,
) -> Result<Option<Tensor>> {
let Some(mask) = pad_mask else {
return Ok(None);
};
if mask.sum_all()?.to_scalar::<f32>()? == 0.0 {
return Ok(None);
}
let batch_size = mask.dim(0)?;
let seq_length = mask.dim(D::Minus1)?;
let num_heads = num_attention_heads as usize;
let expanded_mask = mask
.unsqueeze(1)? .unsqueeze(1)? .expand((batch_size, num_heads, seq_length, seq_length))?;
Ok(Some(expanded_mask))
}
pub fn forward(
&self,
src: &Tensor,
pad_mask: Option<&Tensor>,
output_hidden_states: bool,
output_attentions: bool,
) -> Result<ModelOutput> {
let mut hidden_states = vec![];
let mut attentions = vec![];
let attention_mask =
self.process_attention_mask(pad_mask, self.transformer_encoder.len() as i64)?;
let freqs_cis = self.freqs_cis.narrow(0, 0, src.dim(1)?)?;
let mut x = self.encoder.forward(src)?;
for layer in self.transformer_encoder.iter() {
let (new_x, attn) =
layer.forward(&x, attention_mask.as_ref(), &freqs_cis, output_attentions)?;
x = new_x;
if output_hidden_states {
hidden_states.push(x.clone());
}
if output_attentions {
if let Some(attn) = attn {
attentions.push(attn);
}
}
}
let logits = if self.config.layer_norm_before_last_layer {
self.decoder.forward(&self.layer_norm_2.forward(&x)?)?
} else {
self.decoder.forward(&x)?
};
Ok(ModelOutput {
logits,
hidden_states: if output_hidden_states {
Some(hidden_states)
} else {
None
},
attentions: if output_attentions {
Some(attentions)
} else {
None
},
})
}
pub fn load(vb: VarBuilder, cfg: &LIFYConfig) -> Result<Self> {
let mut transformer_encoder = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
transformer_encoder.push(EncoderBlock::load(
vb.pp("transformer_encoder"),
cfg,
i as i32,
)?);
}
let encoder = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("encoder"))?;
let layer_norm_2 = rms_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("layer_norm_2"))?;
let decoder = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
let freqs_cis = precompute_freqs_cis(head_dim, cfg.max_length)?;
Ok(Self {
encoder,
transformer_encoder,
layer_norm_2,
decoder,
freqs_cis,
config: cfg.clone(),
})
}
}
#[derive(Debug)]
pub struct ModelOutput {
pub logits: Tensor,
pub hidden_states: Option<Vec<Tensor>>,
pub attentions: Option<Vec<Tensor>>,
}