Binary classifier to rate how well an input/output pair follows instructions.

Need to format examples like this:

def render_prompts(ex, key):
    system = ex.get("system")
    rendered = ""
    if system and len(system.strip()) > 0:
        rendered = rendered + f"### System:\n{system}\n"

    if 'input' in ex and 'instruction' not in ex:
        ex['instruction'] = ex['input']
        del ex['input']

    if 'instruction' in ex:
        rendered += f"\n### Instruction:\n{ex['instruction']}\n"
    if 'input' in ex and len(ex['input'].strip()) > 0:
        rendered += f"\n### Input:\n{ex['input']}\n"
    rendered += f"\n### Output:\n{ex['output']}"

    return {key: rendered.strip()}

Then inference like this:

with torch.no_grad():
    pred = model(**tokenized)
logits = pred.logits.cpu().detach()
labels = torch.argmax(logits, dim=1)
probs = F.softmax(logits.to(torch.float32), dim=1)[:,-1]

Labels are 0/1, probs are 0-1.