Question Generator
This model should be used to generate questions based on a given string.
Out-of-Scope Use
English language support only.
How to Get Started with the Model
Use the code below to get started with the model.
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
def question_parser(question: str) -> str:
return " ".join(question.split(":")[1].split())
def generate_questions_v2(context: str, answer: str, n_questions: int = 1):
model = T5ForConditionalGeneration.from_pretrained(
"pipesanma/chasquilla-question-generator"
)
tokenizer = T5Tokenizer.from_pretrained("pipesanma/chasquilla-question-generator")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
text = "context: " + context + " " + "answer: " + answer + " </s>"
encoding = tokenizer.encode_plus(
text, max_length=512, padding=True, return_tensors="pt"
)
input_ids, attention_mask = encoding["input_ids"].to(device), encoding[
"attention_mask"
].to(device)
model.eval()
beam_outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=72,
early_stopping=True,
num_beams=5,
num_return_sequences=n_questions,
)
questions = []
for beam_output in beam_outputs:
sent = tokenizer.decode(
beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
print(sent)
questions.append(question_parser(sent))
return questions
context = "President Donald Trump said and predicted that some states would reopen this month."
answer = "Donald Trump"
questions = generate_questions_v2(context, answer, 1)
print(questions)
Training Details
Dataset generation
The dataset is "squad" from datasets library.
Check the utils/dataset_gen.py file for the dataset generation.
Training model
Check the utils/t5_train_model.py file for the training process
Model and Tokenizer versions
(v1.0) Model and Tokenizer V1: trained with 1000 rows
(v1.1) Model and Tokenizer V2: trained with 3000 rows
(v1.2) Model and Tokenizer V3: trained with all rows from datasets (78664 rows-train, 9652 rows-validation)