datasets: https://github.com/omrikeren/ParaShoot/

metrics: f1 49.612 exact_match 26.439

language: he

pipeline_tag: question-answering

license: unknown

mT5-small-Hebrew-ParaShoot-QA

This repository contains a fine-tuned mT5-small (Multilingual Text-to-Text Transfer Transformer) model on the ParaShoot dataset (github). To enhance its performance, a "domain-specific" fine-tuning approach was employed. Initially, the model was pretrained on a Hebrew dataset to capture Hebrew linguistic nuances. Subsequently, I further fine-tuned the model on the ParaShoot dataset, aiming to improve its proficiency in the Question-Answering task. This model builds upon the original work by imvladikon who initially fine-tuned the mT5-small model for the summarization task.

Model Details

Google's mT5

mT5 is pretrained on the mC4 corpus, covering 101 languages. Note: mT5 was only pre-trained on mC4 excluding any supervised training. Therefore, this model has to be fine-tuned before it is useable on a downstream task.

Related papers:

Paper: mT5: A massively multilingual pre-trained text-to-text transformer Authors: Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel

Paper: Multilingual Sequence-to-Sequence Models for Hebrew NLP Authors: Matan Eyal, Hila Noga, Roee Aharoni, Idan Szpektor, Reut Tsarfaty

Paper: PARASHOOT: A Hebrew Question Answering Dataset Authors: Omri Keren, Omer Levy

This model achieves the following results on the test set:

Note: In the paper Multilingual Sequence-to-Sequence Models for Hebrew NLP the results were F1 - 48.71, EM - 24.52.

How to use the model:

Use the code below to get started with the model.

from transformers import MT5ForConditionalGeneration, AutoTokenizer
MODEL_NAME = "Livyatan/mT5-small-Hebrew-ParaShoot-QA"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME)
def generate_answer(question, context):
input_encoding = tokenizer(
question,
context,
max_length = len(context),
padding="max_length",
truncation="only_second",
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt"
).to(DEVICE)

with torch.no_grad():
generated_ids = model.generate(
input_ids = input_encoding['input_ids'].to(DEVICE),
attention_mask = input_encoding['attention_mask'].to(DEVICE),
max_length=20,
)

preds = [
tokenizer.decode(generated_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for generated_id in generated_ids
]

return "".join(preds)

context = 'סדרת הלווייתנאים כוללת כ-90 מינים, שכולם חיים באוקיינוסים מלבד חמישה מיני דולפינים החיים במים מתוקים. הלווייתנאים החיים מחולקים לשתי תת-סדרות: לווייתני מזיפות (Mysticeti) ולווייתני שיניים (Odontoceti; ובהם גם דולפינים); בעבר התקיימה תת-סדרה נוספת: לווייתנים קדומים (Archaeoceti), שנכחדה. במרבית המקרים לווייתני המזיפות גדולים באופן משמעותי מלווייתני השיניים, הקטנים והמהירים יותר, וכמה מלווייתני המזיפות הם מבעלי החיים הגדולים ביותר בכדור הארץ. לווייתני השיניים מתאפיינים בשיניים חדות, והם ציידים מהירים שניזונים מדגים ומיצורים ימיים אחרים. לעומתם לווייתני המזיפות הם חסרי שיניים ובמקום זאת יש להם מזיפות ארוכות דמויות מסננת, שבעזרתן הם מסננים פלנקטון מהמים.'
question = 'כמה מינים כוללת סדרת הלווייתנאים?'
answer = generate_answer(question, context)
print(answer)
>>> 'כ-90 מינים'