generated_from_trainer

<!-- This model card has been generated automatically according to the information the Trainer had access to. You should probably proofread and complete it, then remove this comment. -->

vit-xray-pneumonia-classification

This model is a fine-tuned version of google/vit-base-patch16-224-in21k on the chest-xray-classification dataset. It achieves the following results on the evaluation set:

Inference example

from transformers import pipeline

classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification")

# image taken from https://www.news-medical.net/health/What-is-Viral-Pneumonia.aspx
classifier("https://d2jx2rerrg6sh3.cloudfront.net/image-handler/ts/20200618040600/ri/650/picture/2020/6/shutterstock_786937069.jpg")

>>>
[{'score': 0.990334689617157, 'label': 'PNEUMONIA'},
 {'score': 0.009665317833423615, 'label': 'NORMAL'}]

Training procedure

Notebook link: here

Training hyperparameters

The following hyperparameters were used during training:

from transformers import EarlyStoppingCallback

training_args = TrainingArguments(
    output_dir="vit-xray-pneumonia-classification",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=15,
    save_total_limit=2,
    warmup_ratio=0.1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=True,
    push_to_hub=True,
    report_to="tensorboard"
)

early_stopping = EarlyStoppingCallback(early_stopping_patience=3)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=processor,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],
)

Training results

Training Loss Epoch Step Validation Loss Accuracy
0.5152 0.99 63 0.2507 0.9245
0.2334 1.99 127 0.1766 0.9382
0.1647 3.0 191 0.1218 0.9588
0.144 4.0 255 0.1222 0.9502
0.1348 4.99 318 0.1293 0.9571
0.1276 5.99 382 0.1000 0.9665
0.1175 7.0 446 0.1177 0.9502
0.109 8.0 510 0.1079 0.9665
0.0914 8.99 573 0.0804 0.9717
0.0872 9.99 637 0.0800 0.9717
0.0804 11.0 701 0.0862 0.9682
0.0935 12.0 765 0.0883 0.9657
0.0686 12.99 828 0.0868 0.9742

Framework versions