ViT, Cifar10 (50,000 images)
Inference sample code
from transformers import AutoImageProcessor, ViTForImageClassification
import torch
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("againeureka/vit_cifar10_classification")
model = ViTForImageClassification.from_pretrained("againeureka/vit_cifar10_classification")
inputs = image_processor(image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
Input image
- http://images.cocodataset.org/val2017/000000039769.jpg
Output
cat
Training setup
- dataset : cifar10 (50,000 images)
- base ViT model : 'google/vit-base-patch16-224-in21k'
- training arguments
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="againeureka/vit_cifar10_classification",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
load_best_model_at_end=True,
)