1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
//! AMPLIFY is an optimized transformer model focused on optimizing the context of sequence models
//! while maintaining computational efficiency.
//!
//! Key features:
//! - Rotary positional embeddings
//! - RMSNorm for improved training stability
//! - SwiGLU activation function
//! - Specialized architecture optimizations
//! - Memory efficient inference
use super::rotary::{apply_rotary_emb, precompute_freqs_cis};
use candle_core::{Module, Result, Tensor, D};
use candle_nn::{
    embedding, linear, linear_no_bias, ops::softmax, rms_norm, Activation, Dropout, Embedding,
    Linear, RmsNorm, VarBuilder,
};

// Config struct
#[derive(Debug, Clone)]
pub struct AMPLIFYConfig {
    pub hidden_size: usize,
    pub num_hidden_layers: usize,
    pub num_attention_heads: usize,
    pub intermediate_size: usize,
    pub dropout_prob: f64,
    pub embedding_init_range: f64,
    pub decoder_init_range: f64,
    pub rms_norm: bool,
    pub norm_eps: f64,
    pub hidden_act: Activation,
    pub layer_norm_after_embedding: bool,
    pub layer_norm_before_last_layer: bool,
    pub vocab_size: usize,
    pub ffn_bias: bool,
    pub att_bias: bool,
    pub pad_token_id: usize,
    pub max_length: usize,
}

impl Default for AMPLIFYConfig {
    fn default() -> Self {
        Self {
            hidden_size: 960,
            num_hidden_layers: 32,
            num_attention_heads: 15,
            intermediate_size: 3840,
            dropout_prob: 0.0,
            embedding_init_range: 0.02,
            decoder_init_range: 0.02,
            rms_norm: true,
            norm_eps: 1e-5,
            hidden_act: Activation::Swiglu,
            layer_norm_after_embedding: false,
            layer_norm_before_last_layer: true,
            vocab_size: 27,
            ffn_bias: false,
            att_bias: false,
            pad_token_id: 0,
            max_length: 2048,
        }
    }
}

impl AMPLIFYConfig {
    pub fn amp_120m() -> Self {
        Self {
            hidden_size: 640,
            num_hidden_layers: 24,
            num_attention_heads: 10,
            intermediate_size: 2560,
            dropout_prob: 0.0,
            embedding_init_range: 0.02,
            decoder_init_range: 0.02,
            rms_norm: true,
            norm_eps: 1e-5,
            hidden_act: Activation::Swiglu,
            layer_norm_after_embedding: false,
            layer_norm_before_last_layer: true,
            vocab_size: 27,
            ffn_bias: false,
            att_bias: false,
            pad_token_id: 0,
            max_length: 2048,
        }
    }
    pub fn amp_350m() -> Self {
        AMPLIFYConfig::default()
    }
}

/// Amplify EncoderBlock implementation
///
/// example 01: T5: https://github.com/huggingface/candle/blob/e2b6b367fa852ed30ac532f8d77cd8479c7ed092/candle-transformers/src/models/t5.rs#L331
//
/// Example 01: FFN: https://github.com/huggingface/candle/blob/e2b6b367fa852ed30ac532f8d77cd8479c7ed092/candle-transformers/src/models/distilbert.rs#L198
/// Example: https://github.com/huggingface/candle/blob/e2b6b367fa852ed30ac532f8d77cd8479c7ed092/candle-transformers/src/models/glm4.rs#L340
/// SwiGLu Implementation:  https://github.com/facebookresearch/xformers/blob/main/xformers/ops/swiglu_op.py#L462
pub struct EncoderBlock {
    q: Linear,
    k: Linear,
    v: Linear,
    wo: Linear,
    resid_dropout: Dropout,
    w12: Linear,
    w3: Linear,
    ffn_norm: RmsNorm,
    attention_norm: RmsNorm,
    ffn_dropout: Dropout,
    d_head: usize,
    config: AMPLIFYConfig,
}

impl EncoderBlock {
    pub fn new(config: &AMPLIFYConfig, vb: VarBuilder, layer: i32) -> Result<Self> {
        let multiple_of = 8;
        let intermediate_size = (config.intermediate_size * 2) / 3;
        let intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) / multiple_of);
        let vb = vb.pp(layer);
        let q = linear(config.hidden_size, config.hidden_size, vb.pp("q"))?;
        let k = linear(config.hidden_size, config.hidden_size, vb.pp("k"))?;
        let v = linear(config.hidden_size, config.hidden_size, vb.pp("v"))?;
        let wo = linear(config.hidden_size, config.hidden_size, vb.pp("wo"))?;
        let w12 = linear_no_bias(intermediate_size * 2, config.hidden_size, vb.pp("ffn.w12"))?;
        let w3 = linear_no_bias(config.hidden_size, intermediate_size, vb.pp("ffn.w3"))?;
        let ffn_norm = rms_norm(config.hidden_size, config.norm_eps, vb.pp("ffn_norm"))?;
        let attention_norm =
            rms_norm(config.hidden_size, config.norm_eps, vb.pp("attention_norm"))?;

        Ok(Self {
            q,
            k,
            v,
            wo,
            resid_dropout: Dropout::new(config.dropout_prob as f32),
            w12,
            w3,
            attention_norm,
            ffn_norm,
            ffn_dropout: Dropout::new(config.dropout_prob as f32),
            d_head: config.hidden_size / config.num_attention_heads,
            config: config.clone(), // Todo: remove this clone
        })
    }

    pub fn forward(
        &self,
        x: &Tensor,
        pad_mask: Option<&Tensor>,
        freqs_cis: &Tensor,
        output_attentions: bool,
    ) -> Result<(Tensor, Option<Tensor>)> {
        let normed = self.attention_norm.forward(x)?;
        let (attn, contacts) =
            self.attention_block(&normed, pad_mask, freqs_cis, output_attentions)?;
        let x = x.add(&attn)?;
        let normed = self.ffn_norm.forward(&x)?;
        let ffn_output = self.ffn_forward(&normed)?;
        let ff = self.ffn_dropout.forward(&ffn_output, false)?; // Todo: pass in the Inference/Training bit
        let x = x.add(&ff)?;
        Ok((x, contacts))
    }

    // process the FFN Block using swiglu
    //
    fn ffn_forward(&self, x: &Tensor) -> Result<Tensor> {
        // Swiglu
        //
        // Todo: see if the apply or add can be done di
        // Store original batch dimensions
        let dims = x.dims();
        let batch_shape = &dims[..dims.len() - 1];
        // Reshape input to 2D: (batch_size, input_dim)
        let x_flat = self.flatten_last_dim(&x)?;

        // Apply packed W1W2 linear transformation
        let w12_out = self.w12.forward(&x_flat)?;
        // Split the output into two halves (for SwiGLU activation)
        let chunks = w12_out.chunk(2, 1)?;
        let x1 = &chunks[0];
        let x2 = &chunks[1];

        // Apply SwiGLU: silu(x1) * x2
        let hidden = x1.silu()?.mul(x2)?;
        // Final linear transformation
        let output = self.w3.forward(&hidden)?;
        // Reshape back to original batch dimensions
        let mut new_shape = batch_shape.to_vec();
        new_shape.push(output.dim(1)?);
        output.reshape(new_shape)
    }

    fn flatten_last_dim(&self, x: &Tensor) -> Result<Tensor> {
        let dims = x.dims();
        let last_dim = dims[dims.len() - 1];
        let total_elements = dims.iter().product::<usize>();
        let first_dim = total_elements / last_dim;
        x.reshape((first_dim, last_dim))
    }

    fn scaled_dot_product_attention(
        &self,
        query: &Tensor,
        key: &Tensor,
        value: &Tensor,
        attn_mask: Option<&Tensor>,
        dropout_p: f64,
        is_causal: bool,
    ) -> Result<Tensor> {
        // Calculate attention scores
        let d_k = key.dim(key.dims().len() - 1)? as f64;
        let scaling = 1.0 / d_k.sqrt();
        // (B, H, L, S) = (batch, heads, query_length, key_length)
        let scores = (query.matmul(&key.transpose(D::Minus2, D::Minus1)?)? * scaling)?;

        // Apply mask if provided
        if let Some(mask) = attn_mask {
            let scores = scores.add(mask)?;
        }

        // Apply softmax
        let attn = candle_nn::ops::softmax(&scores, scores.dims().len() - 1)?;
        // Apply dropout if needed
        let attn = if dropout_p > 0.0 {
            candle_nn::ops::dropout(&attn, dropout_p as f32)?
        } else {
            attn
        };
        // Final matrix multiplication with values
        attn.matmul(value)
    }

    fn attention_block(
        &self,
        x: &Tensor,
        pad_mask: Option<&Tensor>,
        freqs_cis: &Tensor,
        output_attentions: bool,
    ) -> Result<(Tensor, Option<Tensor>)> {
        // Query, Key, Value projections
        let (batch_size, seq_len, _) = x.dims3()?;
        let xq = self.q.forward(x)?; // [batch_size, seq_len, hidden_size]
        let xk = self.k.forward(x)?;
        let xv = self.v.forward(x)?;
        // Reshape for rotary embeddings
        let xq = xq.reshape((
            batch_size,
            seq_len,
            self.config.num_attention_heads,
            self.d_head,
        ))?;
        let xk = xk.reshape((
            batch_size,
            seq_len,
            self.config.num_attention_heads,
            self.d_head,
        ))?;
        let xv = xv.reshape((
            batch_size,
            seq_len,
            self.config.num_attention_heads,
            self.d_head,
        ))?;

        let (xq, xk) = apply_rotary_emb(&xq, &xk, &freqs_cis)?;
        let dropout_prob = self.config.dropout_prob;

        // need to handle pad_mask better ....
        //
        let pad_mask = if let Some(mask) = pad_mask {
            let (batch_size, seq_len) = (x.dim(0)?, x.dim(1)?);
            let num_heads = self.config.num_attention_heads;

            // Following PyTorch's implementation:
            // 1. unsqueeze twice to add head dimensions
            // 2. repeat to match attention matrix size
            let mask = mask
                .unsqueeze(1)? // Add first head dimension
                .unsqueeze(1)? // Add second head dimension
                .expand((batch_size, num_heads, seq_len, seq_len))?; // Expand to full attention size
            Some(mask)
        } else {
            None
        };

        let attn = self.scaled_dot_product_attention(
            &xq.permute((0, 2, 1, 3))?,
            &xk.permute((0, 2, 1, 3))?,
            &xv.permute((0, 2, 1, 3))?,
            pad_mask.as_ref(),
            dropout_prob,
            false,
        )?;
        // `[batch, num_heads, seq_len, head_dim]` → `[batch, seq_len, num_heads, head_dim]`
        let attn = attn.permute((0, 2, 1, 3))?;
        let _attn = if output_attentions {
            let xq_t = xq.permute((0, 2, 1, 3))?;
            let xk_t = xk.permute((0, 2, 3, 1))?;
            let mut attn_weights = xq_t.matmul(&xk_t)?;
            let scale = (xq.dim(D::Minus1)? as f64).sqrt();
            attn_weights = (attn_weights / scale)?;
            // attn_weights = attn_weights.add(pad_mask)?;  <- Todo. Revisit
            Some(softmax(&attn_weights, D::Minus1)?)
        } else {
            None
        };

        // Final projection and dropout
        let output = attn.reshape((
            batch_size,
            seq_len,
            self.config.num_attention_heads * self.d_head,
        ))?;
        let output01 = self.wo.forward(&output)?;
        let output02 = self.resid_dropout.forward(&output01, false)?;
        Ok((output02, _attn))
    }

    /// Load Weights from a Model
    pub fn load(vb: VarBuilder, cfg: &AMPLIFYConfig, layer: i32) -> Result<Self> {
        // To keep the number of parameters and the amount of computation constant, we reduce the number of
        // hidden units by a factor of 2/3 (https://arxiv.org/pdf/2002.05202.pdf) and make it a multiple of 8 to
        // avoid RuntimeError due to misaligned operand
        let multiple_of = 8;
        let intermediate_size = (cfg.intermediate_size * 2) / 3;
        let intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) / multiple_of);
        let vb = vb.pp(layer); // handle the layer nubmer here.
        let q = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("q"))?;
        let k = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("k"))?;
        let v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("v"))?;
        let wo = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("wo"))?;
        let w12 = linear_no_bias(cfg.hidden_size, intermediate_size * 2, vb.pp("ffn.w12"))?;
        let w3 = linear_no_bias(intermediate_size, cfg.hidden_size, vb.pp("ffn.w3"))?;
        let ffn_norm = rms_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("ffn_norm"))?;
        let attention_norm = rms_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("attention_norm"))?;

        Ok(Self {
            q,
            k,
            v,
            wo,
            resid_dropout: Dropout::new(cfg.dropout_prob as f32),
            w12,
            w3,
            attention_norm,
            ffn_norm,
            ffn_dropout: Dropout::new(cfg.dropout_prob as f32),
            d_head: cfg.hidden_size / cfg.num_attention_heads,
            config: cfg.clone(),
        })
    }
}

/// The AMPLIFY model
///
/// - [GH PythonModel](https://github.com/chandar-lab/AMPLIFY/blob/rc-0.1/src/amplify/model/amplify.py)
/// - [paper](https://www.biorxiv.org/content/10.1101/2024.09.23.614603v1)
/// - [HF](https://huggingface.co/chandar-lab/AMPLIFY_120M)
///
pub struct AMPLIFY {
    encoder: Embedding,
    transformer_encoder: Vec<EncoderBlock>,
    layer_norm_2: RmsNorm,
    decoder: Linear,
    freqs_cis: Tensor,
    config: AMPLIFYConfig,
}

impl AMPLIFY {
    pub fn new(config: &AMPLIFYConfig, vb: VarBuilder) -> Result<Self> {
        unimplemented!()
    }

    fn process_attention_mask(
        &self,
        pad_mask: Option<&Tensor>,
        num_attention_heads: i64,
    ) -> Result<Option<Tensor>> {
        let Some(mask) = pad_mask else {
            return Ok(None);
        };
        if mask.sum_all()?.to_scalar::<f32>()? == 0.0 {
            return Ok(None);
        }
        let batch_size = mask.dim(0)?;
        let seq_length = mask.dim(D::Minus1)?;
        let num_heads = num_attention_heads as usize;
        let expanded_mask = mask
            .unsqueeze(1)? // Add head dimension
            .unsqueeze(1)? // Add query dimension
            .expand((batch_size, num_heads, seq_length, seq_length))?;
        Ok(Some(expanded_mask))
    }

    pub fn forward(
        &self,
        src: &Tensor,
        pad_mask: Option<&Tensor>,
        output_hidden_states: bool,
        output_attentions: bool,
    ) -> Result<ModelOutput> {
        let mut hidden_states = vec![];
        let mut attentions = vec![];
        // Process attention mask if provided
        let attention_mask =
            self.process_attention_mask(pad_mask, self.transformer_encoder.len() as i64)?;
        let freqs_cis = self.freqs_cis.narrow(0, 0, src.dim(1)?)?;
        // Embedding layer
        let mut x = self.encoder.forward(src)?;
        // Transform through encoder blocks
        // println!("AMPLIFY.forward():  running through the transformer");
        for layer in self.transformer_encoder.iter() {
            let (new_x, attn) =
                layer.forward(&x, attention_mask.as_ref(), &freqs_cis, output_attentions)?;
            x = new_x;
            if output_hidden_states {
                hidden_states.push(x.clone());
            }
            if output_attentions {
                if let Some(attn) = attn {
                    attentions.push(attn);
                }
            }
        }

        // Final layer norm and decoder
        // println!("AMPLIFY.forward():  calculating logits");
        let logits = if self.config.layer_norm_before_last_layer {
            self.decoder.forward(&self.layer_norm_2.forward(&x)?)?
        } else {
            self.decoder.forward(&x)?
        };

        Ok(ModelOutput {
            logits,
            hidden_states: if output_hidden_states {
                Some(hidden_states)
            } else {
                None
            },
            attentions: if output_attentions {
                Some(attentions)
            } else {
                None
            },
        })
    }

    pub fn load(vb: VarBuilder, cfg: &AMPLIFYConfig) -> Result<Self> {
        // process the transformer section
        let mut transformer_encoder = Vec::with_capacity(cfg.num_hidden_layers);
        for i in 0..cfg.num_hidden_layers {
            transformer_encoder.push(EncoderBlock::load(
                vb.pp("transformer_encoder"),
                cfg,
                i as i32,
            )?);
        }
        let encoder = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("encoder"))?;
        let layer_norm_2 = rms_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("layer_norm_2"))?;
        let decoder = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
        let freqs_cis = precompute_freqs_cis(head_dim, cfg.max_length)?;

        Ok(Self {
            encoder,
            transformer_encoder,
            layer_norm_2,
            decoder,
            freqs_cis,
            config: cfg.clone(),
        })
    }
}

// Helper structs and enums
#[derive(Debug)]
pub struct ModelOutput {
    pub logits: Tensor,
    pub hidden_states: Option<Vec<Tensor>>,
    pub attentions: Option<Vec<Tensor>>,
}