ferritin_amplify/amplify/outputs.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
use candle_core::{Result, Tensor, D};
// Helper structs and enums
#[derive(Debug)]
/// Amplify Model Output
///
/// logits, hidden states, and attentions.
///
/// logits -> distribution of the sequences.
/// attentions -> contact map
pub struct ModelOutput {
pub logits: Tensor,
pub hidden_states: Option<Vec<Tensor>>,
pub attentions: Option<Vec<Tensor>>,
}
impl ModelOutput {
/// "Perform average product correct, used for contact prediction."
/// https://github.com/chandar-lab/AMPLIFY/blob/rc-0.1/examples/utils.py#L83
/// "Perform average product correct, used for contact prediction."
fn apc(&self, x: &Tensor) -> Result<Tensor> {
let a1 = x.sum_keepdim(D::Minus1)?;
let a2 = x.sum_keepdim(D::Minus2)?;
let a12 = x.sum_keepdim((D::Minus1, D::Minus2))?;
let avg = a1.matmul(&a2)?;
// Divide by a12 (equivalent to pytorch's div_)
// println!("IN the APC: avg, a12 {:?}, {:?}", avg, a12);
// let avg = avg.div(&a12)?;
let a12_broadcast = a12.broadcast_as(avg.shape())?;
let avg = avg.div(&a12_broadcast)?;
x.sub(&avg)
}
// From https://github.com/facebookresearch/esm/blob/main/esm/modules.py
// https://github.com/chandar-lab/AMPLIFY/blob/rc-0.1/examples/utils.py#L77
// "Make layer symmetric in final two dimensions, used for contact prediction."
fn symmetrize(&self, x: &Tensor) -> Result<Tensor> {
let x_transpose = x.transpose(D::Minus1, D::Minus2)?;
x.add(&x_transpose)
}
/// Contact maps can be obtained from the self-attentions
pub fn get_contact_map(&self) -> Result<Option<Tensor>> {
let Some(attentions) = &self.attentions else {
return Ok(None);
};
// we need the dimensions to reshape below.
// the attention blocks have the following shape
let (_1, _n_head, _seq_length, seq_length) = attentions.first().unwrap().dims4()?;
let last_dim = seq_length;
let attn_stacked = Tensor::stack(attentions, 0)?;
let total_elements = attn_stacked.dims().iter().product::<usize>();
let first_dim = total_elements / (last_dim * last_dim);
let attn_map_combined2 = attn_stacked.reshape(&[first_dim, last_dim, last_dim])?;
// In PyTorch: attn_map = attn_map[:, 1:-1, 1:-1]
let attn_map_combined2 = attn_map_combined2
.narrow(1, 1, attn_map_combined2.dim(1)? - 2)? // second dim
.narrow(2, 1, attn_map_combined2.dim(2)? - 2)?; // third dim
let symmetric = self.symmetrize(&attn_map_combined2)?;
let normalized = self.apc(&symmetric)?;
let proximity_map = normalized.permute((1, 2, 0))?; // # (residues, residues, map)
Ok(Some(proximity_map))
}
}