Example
gen.py
from transformers import GPTNeoForCausalLM, AutoTokenizer
import torch
import sys
model_name = sys.argv[1]
model = GPTNeoForCausalLM.from_pretrained(model_name).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate(model, text, temperature=0.9, min_length=256, max_length=256, no_grad=True, use_cache=False, do_sample=True, match_mesh_tf=False, **kwargs):
ids = tokenizer(text, return_tensors="pt").input_ids.to("cuda")
if no_grad:
with torch.no_grad():
gen_tokens = model.generate(
ids,
do_sample=do_sample,
min_length=min_length,
max_length=max_length,
temperature=temperature,
use_cache=use_cache,
**kwargs
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
print(gen_text)
python gen.py spamtontalk_gpt_neo_xl_v9
>>> text = """Talk (anything):
Example dialogue"""
>>> generate(model, text, temperature=0.92)