from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-sst-byt5-small-updated")
model = AutoModelForSeq2SeqLM.from_pretrained("kaiyuy/leandojo-lean4-sst-byt5-small-updated")
state_pair = """before
α : Type u
β : Type u
a : Cardinal
b : Cardinal
c : Cardinal
⊢ a ^< b ≤ c ↔ ∀ (x : Cardinal), x < b → a ^ x ≤ c
after
...
⊢ (∀ (i : ↑(Iio b)), a ^ ↑i ≤ c) ↔ ∀ (x : Cardinal), x < b → a ^ x ≤ c"""
tokenized_state_pair = tokenizer(state_pair, return_tensors="pt")
# Generate a single tactic.
tactic_ids = model.generate(tokenized_state_pair.input_ids, max_length=1024)
tactic = tokenizer.decode(tactic_ids[0], skip_special_tokens=True)
print(tactic, end="\n\n")
# Generate multiple tactics via beam search.
tactic_candidates_ids = model.generate(
tokenized_state_pair.input_ids,
max_length=1024,
num_beams=4,
length_penalty=0.0,
do_sample=False,
num_return_sequences=4,
early_stopping=False,
)
tactic_candidates = tokenizer.batch_decode(
tactic_candidates_ids, skip_special_tokens=True
)
for tac in tactic_candidates:
print(tac)