This checkpoint is obtained after training FlaxBigBirdForQuestionAnswering (with extra pooler head) on natural_questions dataset on TPU v3-8. This dataset takes around ~100 GB on disk. But thanks to Cloud TPUs and Jax, each epoch took just 4.5 hours. Script for training can be found here: https://github.com/vasudevgupta7/bigbird

Use this model just like any other model from 🤗Transformers

from transformers import FlaxBigBirdForQuestionAnswering, BigBirdTokenizerFast

model_id = "vasudevgupta/flax-bigbird-natural-questions"
model = FlaxBigBirdForQuestionAnswering.from_pretrained(model_id)
tokenizer = BigBirdTokenizerFast.from_pretrained(model_id)

In case you are interested in predicting category (null, long, short, yes, no) as well, use FlaxBigBirdForNaturalQuestions (instead of FlaxBigBirdForQuestionAnswering) from my training script.

Exact Match 55.12

Evaluation script: https://colab.research.google.com/github/vasudevgupta7/bigbird/blob/main/notebooks/evaluate-flax-natural-questions.ipynb