This is only a test.

chinese_dataset = load_dataset("amazon_reviews_multi", "zh")
english_dataset = load_dataset("amazon_reviews_multi", "en")
print(english_dataset)


def filter_books(example):
    return (
        example["product_category"] == "book"
        or example["product_category"] == "digital_ebook_purchase"
    )

chinese_books = chinese_dataset.filter(filter_books)
english_books = english_dataset.filter(filter_books)


from datasets import concatenate_datasets, DatasetDict

books_dataset = DatasetDict()

for split in english_books.keys():
    books_dataset[split] = concatenate_datasets(
        [english_books[split], chinese_books[split]]
    )
    books_dataset[split] = books_dataset[split].shuffle(seed=42)


books_dataset = books_dataset.filter(lambda x: len(x["review_title"].split()) > 2)



hub_model_id = "sdinger/mt5-finetuned-amazon-en-zh"
summarizer = pipeline("summarization", model=hub_model_id)



def print_summary(idx):
    review = books_dataset["test"][idx]["review_body"]
    title = books_dataset["test"][idx]["review_title"]
    summary = summarizer(books_dataset["test"][idx]["review_body"])[0]["summary_text"]
    print(f"'>>> Review: {review}'")
    print(f"\n'>>> Title: {title}'")
    print(f"\n'>>> Summary: {summary}'")
print_summary(0)

model_checkpoint = "google/mt5-small"
GPU:3090 24G
epoch:8