flash-attention

Hugging Face RoBERTa with Flash Attention 2 🚀

Re-implementation of Hugging Face 🤗 RoBERTa with Flash Attention 2 in PyTorch. Drop-in replacement of Pytorch legacy Self-Attention with Flash Attention 2 for Hugging Face RoBERTa based on the standard implementation.

Installation

You need to install the flash-attn library, which currently is only supported for Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100):

MAX_JOBS=0 pip install flash-attn --no-build-isolation

MLM Demo

The implementation and demo code for MLM is available at: https://github.com/iliaschalkidis/flash-roberta

git clone https://github.com/iliaschalkidis/flash-roberta
cd flash-roberta
sh demo_mlm.sh

You can use any pre-trained RoBERTa model from Hugging Face model hub and evaluate on any corpus. For example, to use roberta-base on SST-2:

python demo_mlm.py --model_class roberta --model_path roberta-base --dataset_name sst2
python demo_mlm.py --model_class flash-roberta --model_path roberta-base --dataset_name sst2

Use with Hugging Face Transformers

from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("kiddothe2b/flash-roberta-base", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

Citation

@misc{flashroberta,
  title={Hugging Face RoBERTa with Flash Attention 2,
  author={Chalkidis, Ilias},
  year={2023},
  url={https://huggingface.co/kiddothe2b/flash-roberta-base},
  howpublished={Hugging Face Hub}
}

@article{dao2023flashattention2,
  title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
  author={Dao, Tri},
  year={2023}
}