from transformers import TextClassificationPipeline
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification


class TemporalRelationClassificationPipeline(TextClassificationPipeline):
    def check_model_type(self, supported_models):
        pass

pretrained_checkpoint = "guyyanko/split-3-hebrew-trc-alephbert-base-EMP"

model = AutoModelForSequenceClassification.from_pretrained(pretrained_checkpoint, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(pretrained_checkpoint, trust_remote_code=True)
classifier = pipeline(task='text-classification', model=model, tokenizer=tokenizer)

txt = "מחר [א1] אתאמן [/א1] אם [א2] אסיים [/א2] את כל המשימות שלי"
print(classifier(txt))

txt = "אחרי [א1] שאסיים [/א1] את כל המשימות שלי [א2] אלך [/א2] להתאמן בחדר הכושר"
print(classifier(txt))