from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-sst-byt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("kaiyuy/leandojo-lean4-sst-byt5-small")

state_pair = """before
α : Type?u.60285
β : Type?u.60288
G : Type u_1
inst✝ : Group G
a b c d : G
⊢ b / a = c / a ↔ b = c

after
α : Type?u.60285
β : Type?u.60288
G : Type u_1
inst✝ : Group G
a b c d : G
⊢ b * a⁻¹ = c * a⁻¹ ↔ b = c"""
tokenized_state_pair = tokenizer(state_pair, return_tensors="pt")

# Generate a single tactic.
tactic_ids = model.generate(tokenized_state_pair.input_ids, max_length=1024)
tactic = tokenizer.decode(tactic_ids[0], skip_special_tokens=True)
print(tactic, end="\n\n")

# Generate multiple tactics via beam search.
tactic_candidates_ids = model.generate(
    tokenized_state_pair.input_ids,
    max_length=1024,
    num_beams=4,
    length_penalty=0.0,
    do_sample=False,
    num_return_sequences=4,
    early_stopping=False,
)
tactic_candidates = tokenizer.batch_decode(
    tactic_candidates_ids, skip_special_tokens=True
)
for tac in tactic_candidates:
    print(tac)