transformers biology esm esm2 protein protein language model

ESM-2 RNA Binding Site LoRA

This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of the esm2_t6_8M_UR50D model for the (binary) token classification task of predicting RNA binding sites of proteins. You can also find a version of this model that was fine-tuned without LoRA here.

Training procedure

This is a Low Rank Adaptation (LoRA) of esm2_t6_8M_UR50D, trained on 166 protein sequences in the RNA binding sites dataset using a 80/20 train/test split. This model was trained with class weighting due to the imbalanced nature of the RNA binding site dataset (fewer binding sites than non-binding sites). You can train your own version using this notebook! You just need the RNA binding_sites.xml file found here. You may also need to run some pip install statements at the beginning of the script. If you are running in colab run:

!pip install transformers[torch] datasets peft -q
!pip install accelerate -U -q

Try to improve upon these metrics by adjusting the hyperparameters:

{'eval_loss': 0.49476009607315063,
'eval_precision': 0.14372964169381108,
'eval_recall': 0.7526652452025586,
'eval_f1': 0.24136752136752138,
'eval_auc': 0.7710141129858947,
'epoch': 15.0}

A similar model can also be trained using the Github with a training script and conda env YAML, which can be found here. This version uses wandb sweeps for hyperparameter search. However, it does not use class weighting.

Framework versions

Using the Model

To use the model, try running the following pip install statements:

!pip install transformers peft -q

then try tunning:

from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t6_8M_weighted_lora_rna_binding"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))