from transformers import GPT2Tokenizer, GPT2LMHeadModel

model_load_path = 'clam004/emerg-intent-consistent-good-gpt2-xl-v2' 
tokenizer_name = 'gpt2-xl'

tokenizer = GPT2Tokenizer.from_pretrained(
    tokenizer_name,
    pad_token='<|endoftext|>',
    padding_side = 'left',
)

model = GPT2LMHeadModel.from_pretrained(
    model_load_path,
    cache_dir = None, 
    pad_token_id=tokenizer.eos_token_id,
)

few_shot_pretext = \
'<|?|>\n[human]: i just slipped fell hit my head now im bleeding.<|emergency?|>True'+\
'<|?|>\n[human]: aluminium foil is bad, not good.<|emergency?|>False'+\
'<|?|>\n[human]: i dont want to live anymore i want to kill myself.<|emergency?|>True'+\
'<|?|>\n[human]: im dying for some coffee and donuts but i have diabetes.<|emergency?|>False'+\
'<|?|>\n[human]: i just got back from the hospital.<|emergency?|>False'+\
'<|?|>\n[human]: im fasting for my blood test tomorrow to see if my meds are working.<|emergency?|>False'+\
'<|?|>\n[human]: '

###############################################################

query_text = "I want to die, I am worthless anyways."

few_shot_prompt = few_shot_pretext + query_text + '<|emergency?|>'

print(repr(few_shot_prompt))
print('-'*50)

prompt_dic = tokenizer(few_shot_prompt,return_tensors="pt")
prompt_ids = prompt_dic.input_ids
prompt_mask = prompt_dic.attention_mask
prompt_len = prompt_ids.shape[1]

prompt_ids = prompt_ids.to(model.device)
prompt_mask = prompt_mask.to(model.device)

output_ids = model.generate(
    prompt_ids,
    attention_mask = prompt_mask,
    max_length = prompt_len + 1,
)

generated_text = tokenizer.batch_decode(output_ids[:,-1])

print(generated_text[0]) #'True'
print('-'*50)

###############################################################

query_text = "I am dying for a cup of coffee."

few_shot_prompt = few_shot_pretext + query_text + '<|emergency?|>'

print(repr(few_shot_prompt))
print('-'*50)

prompt_dic = tokenizer(few_shot_prompt,return_tensors="pt")
prompt_ids = prompt_dic.input_ids
prompt_mask = prompt_dic.attention_mask
prompt_len = prompt_ids.shape[1]

prompt_ids = prompt_ids.to(model.device)
prompt_mask = prompt_mask.to(model.device)

output_ids = model.generate(
    prompt_ids,
    attention_mask = prompt_mask,
    max_length = prompt_len + 1,
)

generated_text = tokenizer.batch_decode(output_ids[:,-1])

print(generated_text[0]) #'False'
print('-'*50)