Training Code

from torch.utils.data import dataset
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from datasets import load_metric
from transformers import (
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq
)
import evaluate
import os
from datasets import load_dataset
import numpy as np

MAX_LENGTH_INPUT = 512+128
MAX_LENGTH_OUTPUT = 2

from datasets import load_dataset

class Seq2SeqDataset(dataset.Dataset):

    def __init__(self, tokenizer, type_data='train'):
        
        # Set up the datasets
        data_path = "CarperAI/openai_summarize_comparisons"
        if type_data == 'train':
            dataset = load_dataset("CarperAI/openai_summarize_comparisons", split="train")
        else:
            dataset = load_dataset("CarperAI/openai_summarize_comparisons", split="test").select(range(20000))
        self.prompts = []
        self.outputs = []
        inputs = dataset["prompt"]
        choosen = dataset["chosen"]
        rejected = dataset["rejected"]
        for i, (inp, ch, re) in enumerate(zip(inputs, choosen, rejected)):
            choice_first = np.random.choice([ch, re])
            res = "A" if choice_first == ch else "B"
            choice_second = ch if choice_first == re else re
            prompt = f"""POST: {inp}\n\nRESPONSE A: {choice_first}\n\nRESPONSE B: {choice_second}\n\nWhich response is better? RESPONSE"""
            output = f"{res}"
            self.prompts.append(prompt)
            self.outputs.append(output)
        print("Example prompt: ", self.prompts[0])
        print("Example output: ", self.outputs[0])
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        input_text = self.prompts[idx]
        output_text = self.outputs[idx]

        model_input = self.tokenizer(
            input_text,
            max_length=MAX_LENGTH_INPUT, 
            padding='max_length',
            truncation=True
        )
        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(
                output_text,
                max_length=MAX_LENGTH_OUTPUT,
                padding='max_length',
                truncation=True
            )["input_ids"]
            model_input['labels'] = labels
            model_input['labels'] = [-100 if token == self.tokenizer.pad_token_id else token for token in model_input['labels']]
        return model_input
    
import wandb 
wandb.init(name="stanfordnlp/SteamSHP-flan-t5-xl", project="trlx", entity="pvduy")


if __name__=="__main__":
    config = {
        "logging_steps": 100,
        "eval_steps": 100,
        "save_steps": 500,
        "batch_size": 4,
        "batch_size_val": 4,
        "warmup_steps": 100,
        "accum_steps": 2,
        "num_beams": 3,
        "output_dir": "flan-t5-rm",
    }
    
    accuracy_metric = evaluate.load("accuracy")
    def compute_metrics(pred):
        labels_ids = pred.label_ids
        pred_ids = pred.predictions
        pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        labels_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        acc = sum(np.array(labels_str) == np.array(pred_str)) / len(labels_str)
        return {"accuracy": acc}

    training_args = Seq2SeqTrainingArguments(
        output_dir=config["output_dir"],
        do_train=True,
        num_train_epochs=5,
        do_eval=False,
        predict_with_generate=True,
        adam_beta1=0.9,
        adam_beta2=0.999,
        learning_rate=5e-5,
        half_precision_backend=True,
        bf16=True,
        per_device_train_batch_size=config["batch_size"],
        per_device_eval_batch_size=config["batch_size_val"],
        logging_steps=config["logging_steps"],
        evaluation_strategy="epoch",
        warmup_steps=config["warmup_steps"],
        eval_accumulation_steps=1,
        lr_scheduler_type="linear",
        save_strategy="epoch",
        gradient_accumulation_steps=config["accum_steps"],
        deepspeed='configs/ds_configs/ds_config_gpt_2.json',
    )
    
    tokenizer = AutoTokenizer.from_pretrained("stanfordnlp/SteamSHP-flan-t5-xl")
    model = AutoModelForSeq2SeqLM.from_pretrained("stanfordnlp/SteamSHP-flan-t5-xl")
    
    train_dataset = Seq2SeqDataset(tokenizer, type_data='train')
    val_dataset = Seq2SeqDataset(tokenizer, type_data='val')
    print("Train dataset size: ", len(train_dataset))
    print("Val dataset size: ", len(val_dataset))

    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {params}")

    trainer = Seq2SeqTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )

    trainer.train()

Inference Code

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import numpy as np
import torch
from tqdm import tqdm
dataset = load_dataset("CarperAI/openai_summarize_comparisons", split="test")

tokenizer = AutoTokenizer.from_pretrained("flan-t5-rm/checkpoint-4338/")
model = AutoModelForSeq2SeqLM.from_pretrained("flan-t5-rm/checkpoint-4338/")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

df = dataset.to_pandas()
predictions = []
for i, row in tqdm(df.iterrows(), total=len(df)):
    prompt = f"""POST: {row["prompt"]}\n\nRESPONSE A: {row["chosen"]}\n\nRESPONSE B: {row["rejected"]}\n\nWhich response is better? RESPONSE"""
    x = tokenizer([prompt], return_tensors='pt').input_ids.to(device)
    y = model.generate(x, max_new_tokens=1)
    predictions.append(tokenizer.batch_decode(y, skip_special_tokens=True)[0])

print("Accuracy: ", sum(np.array(predictions) == 'A') / len(predictions))