ViT, Cifar10 (50,000 images)

Inference sample code

from transformers import AutoImageProcessor, ViTForImageClassification
import torch
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("againeureka/vit_cifar10_classification")
model = ViTForImageClassification.from_pretrained("againeureka/vit_cifar10_classification")

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

Input image

Output

cat

Training setup

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="againeureka/vit_cifar10_classification",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True,
)