Amplify Pytorch->Candle Conversion

Conversion of AMPLIFY from Pytorch to Candle
rust
ai
proteins
Author

Zachary Charlop-Powers

Published

November 12, 2024

Intro

In the previous post I left off with the comment that Amplify would be a better first candidate to port to Candle. In this post I am happy to report and initial, but as-of-yet-untested port of the Amplify model from pytorch to Candle. Here are a few of the lessons learned .

# runnable test you can use/modify to start playing
# git clone git@github.com:zachcp/ferritin.git
# cd ferritin
cargo run --example amplify

# Initial Test METVALMETVAL".to_string()
AMPLIFY.forward():  calculating logits
Encoded Logits Dimension: Tensor[dims 1, 14, 27; f32],
indices: [25, 11, 21, 17, 7, 15, 8, 17, 16, 7, 18, 25, 25, 15]
Decoded Values: C E Y K A D G K P A Q C C D

Parts

Weights

The weights were the easiest part as they were already loaded on HuggingFace and could be retrieved and used via the candle-hf-hub crate. Because I was able to load these weights I was also able to print the names and dimensions of each of the Tensors in this file which was very helpful later for making sure I had the right model copied - e.g. it should use and match all of the weights in the Tensor file

// Setup HF API and model info
let model_id = "chandar-lab/AMPLIFY_120M";
let revision = "main";

// Initialize the Hugging Face API client
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
    model_id.to_string(),
    RepoType::Model,
    revision.to_string(),
));

// Load and analyze the safetensors file
let weights_path = repo.get("model.safetensors")?;
let weights = std::fs::read(&weights_path)?;
let tensors = SafeTensors::deserialize(&weights)?;

// Print all tensor names and their metadata
println!("Model tensors:");
tensors.names().iter().for_each(|tensor_name| {
    if let Ok(tensor_info) = tensors.tensor(tensor_name) {
        println!(
            "Tensor: {:<44}  ||  Shape: {:?}",
            tensor_name,
            tensor_info.shape(),
        );
    }
});

Tokenizer

From the perspective of tokenizers, using Amplify was a good choice. The Amplify Hugging Face repo has tokenizer.json file and uses the python huggingface tokenizer library. There is a comparable rust version, tokenizers, that uses the same underlying code and logic. A bit of Claude magic and i had a functional Rust tokenizer, ProteinTokenizer, and was able to retrieve and encode sequences using the Amplify sequence vocab. This vocabulary will allow us to convert from char to int and vice versa.

// examples/amplify/main.rs

// The HF Repo
let model_id = "chandar-lab/AMPLIFY_120M";
let revision = "main";

// Initialize the Hugging Face API client
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
    model_id.to_string(),
    RepoType::Model,
    revision.to_string(),
));

// create and use the tokenizer...
let tokenizer = repo.get("tokenizer.json")?;
let protein_tokenizer = ProteinTokenizer::new(tokenizer)?;
println!("Successfully created the tokenizer!");
let pmatrix = protein_tokenizer.encode(&["METVALMETVAL".to_string()], Some(20), true, false)?;
// pmatrix: Tensor[dims 14; i64]   <-  sequence is now a set of ints.

Model

In addition to the weights and tokenizer as mentioned above, I also had access to the pytorch source code and model info on github. There are only 3 files and I was able to use Claude to get a decent first pass conversion between pytorch and Candle. The major issues I encouter are related to:

  1. Hallucinating/inferring pytorch methods that don’t exist in Candle. This led to some spelunking in the Candle source code and my first PR to the Candle repo.
  2. Differences in idioms - e.g. Candle uses enums like D::Minus1 to specify some dimensions and never accepts -1.
  3. Direct python->rust translations that feel or look clunky. These sometimes can be addressed with a few rounds of Can you write this more idiomatically? or what would this code look like if it used iterators? or otehr prompts to get Claude to rustize the code.

So my initial test was to get a Load function so that I could load the weights into an appropriately shaped Model, as specified by source code and model weights.

AMPLIFY(
  (encoder): Embedding(27, 960, padding_idx=0)
  (transformer_encoder): ModuleList(
    (0-31): 32 x EncoderBlock(
      (q): Linear(in_features=960, out_features=960, bias=False)
      (k): Linear(in_features=960, out_features=960, bias=False)
      (v): Linear(in_features=960, out_features=960, bias=False)
      (wo): Linear(in_features=960, out_features=960, bias=False)
      (resid_dropout): Dropout(p=0, inplace=False)
      (ffn): SwiGLU(
        (w12): Linear(in_features=960, out_features=5120, bias=False)
        (w3): Linear(in_features=2560, out_features=960, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
      (ffn_dropout): Dropout(p=0, inplace=False)
    )
  )
  (layer_norm_2): RMSNorm()
  (decoder): Linear(in_features=960, out_features=27, bias=True)
)

Encoding

The last major mechanical task once the Model was up and running and loaded was to test if I could run a new, encoded sequence through it and get results. This was the most time-consuming bits as it is where all of the internal connections have to be correctly mapped/mimicked. I won’t get into all of the bits but did run into a bit of trouble with

  • Rotational Embeddings. I didn’t know what those were or how they work. And they do some funky things with dimensions. I ended up using Claude heavily to iterate on possible translations but I am not sure I have it right and this is the most likley place I’ve failed to replicate the architecture if my tests against the pytorch model don’t pass.
  • Xformers.The Amplify team used xformers in a few places for optimizations. In the feed-forward part of the network, in the SWIGLU activation step, and in a memory efficient cross product block. I tried to translate the intention of the model as best I could to plain Candle. I am SURE that there more optimal ways to implement these.

My goal was to get a working model that is as close to a literal translation as possible and then add tests so I can make sure that any architectural tweaks are not altering the fundamental scientific outputs.

Claude Sonnet.

Once again amazing. I think the updated Sonnet 3.5 got a big logic update. I am going to paste a bit of one of my interactions with sonnet to give a taste of how it helped me reason about a technical field that I am relatively new to. In this bit I have given Claude my Rust code and asked it to trace the path of an encoded sequence through the model while keeping track of the shape of the tensors at each step. This is quite a remarkable response:

Claude’s Response:

Let’s trace the path of a sequence through the AMPLIFY model, following the dimensions at each step. I’ll use a concrete example to make it clearer:

Let’s say we input a protein sequence of length 8 with a batch size of 1:

  1. Initial Input

    Input sequence: "METVAL..." (8 tokens)
    Shape: [1, 8]  // [batch_size, seq_length]
  2. Embedding Layer (self.encoder)

    Converts tokens to embeddings
    Input:  [1, 8]
    Output: [1, 8, 640]  // [batch, seq_len, hidden_size]
  3. Transformer Encoder Blocks (repeated 24 times) Each block contains:

    1. Self-Attention
    • Input normalization (RMSNorm)

      Input:  [1, 8, 640]
      Output: [1, 8, 640]
    • Query/Key/Value projections

      Input:  [1, 8, 640]
      Q,K,V each: [1, 8, 640]
    • Reshape for attention heads

      Reshape to: [1, 8, 10, 64]
      // [batch, seq_len, n_heads, head_dim]
      // where n_heads=10, head_dim=640/10=64
    • Rotary Position Embeddings

      freqs_cis: [8, 32, 2]  // [seq_len, head_dim/2, 2]
      Applied to Q,K maintaining shape: [1, 8, 10, 64]
    • Attention Computation

      Q * K^T -> [1, 10, 8, 8]  // [batch, n_heads, seq_len, seq_len]
      After softmax -> [1, 10, 8, 8]
      After V multiply -> [1, 8, 10, 64]
    • Project back

      Reshape: [1, 8, 640]
      Linear: [1, 8, 640]
    1. Feed Forward Network
    • Input normalization

      Input:  [1, 8, 640]
      Output: [1, 8, 640]
    • FFN computation (with SwiGLU)

      W12: [1, 8, 2560]  // expanded internal dimension
      After SwiGLU: [1, 8, 1280]
      Final: [1, 8, 640]
  4. Final Layer Norm

    Input:  [1, 8, 640]
    Output: [1, 8, 640]
  5. Decoder (Final Linear Layer)

    Input:  [1, 8, 640]
    Output: [1, 8, 27]  // [batch, seq_len, vocab_size]

The key transformations happen in the attention mechanism where: 1. The sequence is first projected into Q,K,V 2. These are split into multiple attention heads 3. Position information is added via rotary embeddings 4. Attention is computed across the sequence 5. The results are merged back and processed through a feed-forward network

The rotary embeddings are particularly tricky because they need to: 1. Handle complex number rotations (hence the pairs of values) 2. Be broadcastable across the batch and head dimensions 3. Align properly with the sequence length and head dimensions

This path helps explain why the shapes need to align precisely at each step, and why the rotary embedding shape mismatch is causing issues. Would you like me to elaborate on any particular part of this flow?

Todo

  • implement a set of tests based on the pytorch model.
  • a style and perf pass
  • WASM!