Amplify Pytorch->Candle Conversion Part 3

Conversion of AMPLIFY from Pytorch to Candle
rust
ai
proteins
Author

Zachary Charlop-Powers

Published

November 15, 2024

Now that we’ve got the model together, I want to start shaping the API. We’d like to make it easy to download/invoke the model, and to do stuff with the outputs. In this code, I’ve taken a piece of the tests as of Nov15 which show:

Note

Internally I am using candle_hf_hub which will download and cache the model weights to ~/.cache/huggingface/hub. See cache docs

Core API in palce

use ferritin_featurizers::Amplify;

fn test_amplify_full_model() -> Result<(), Box<dyn std::error::Error>> {

    // Load the Model adn the Tokenizer
    let (tokenizer, amplify) = AMPLIFY::load_from_huggingface()?;

    // Test the outputs of the Encoding from the Amplify Test Suite
    let AMPLIFY_TEST_SEQ = "MSVVGIDLGFQSCYVAVARAGGIETIANEYSDRCTPACISFGPKNR";

    // encode the sequence
    let pmatrix = tokenizer.encode(
        &[AMPLIFY_TEST_SEQ.to_string()], None, true, false)?;
    let pmatrix = pmatrix.unsqueeze(0)?; // [batch, length]

    // Run the sequence through the model....
    let encoded = amplify.forward(&pmatrix, None, true, true)?;

    // There are <VOCABSIZE=27> logits per char in sequence
    //  ARGMAX will take the highest value of these.
    let predictions = &encoded.logits.argmax(D::Minus1)?;
    let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
    let decoded = tokenizer.decode(indices.as_slice(), true)?;

    // woo hoo! if this passes we have roundtripped.
    // we strongly expect the model to recover the inputs.
    assert_eq!(final_seq, decoded.replace(" ", ""));

    // What if we want the contact map of the residues?
    // for that we will need to  retrieve them from the attentions.
    if let Some(norm) = &encoded_long.get_contact_map()? {
        contact_mapdims =  <seqlen, seqlen, 240>
    }
    ...
}

Lots more work to do, of course, but it looks like this is shaping up!