dataset source
https://www.kaggle.com/datasets/asdasdasasdas/garbage-classification?sort=votes
inference example
from peft import PeftConfig, PeftModel
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
from PIL import Image
import requests
repo_name = f"wtnan2003/vit-base-patch16-224-in21k-finetuned-lora-garbage_classification"
label2id = {
"cardboard":0,
"glass":1,
"metal":2,
"paper":3,
"plastic":4,
"trash":5
}
id2label = {value:key for key, value in label2id.items()}
config = PeftConfig.from_pretrained(repo_name)
model = AutoModelForImageClassification.from_pretrained(
config.base_model_name_or_path,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True,
)
# Load the LoRA model
inference_model = PeftModel.from_pretrained(model, repo_name)
url = "https://www.uky.edu/facilities/sites/www.uky.edu.facilities/files/Cardboard%20Image.png"
# url = "https://th.bing.com/th/id/OIP.BkzhM2nwEy1edmV7WvU4EAHaJ4?pid=ImgDet&rs=1https://i.redd.it/01msg69otvl21.jpg" # glass
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained(repo_name)
encoding = image_processor(image.convert("RGB"), return_tensors="pt")
with torch.no_grad():
outputs = inference_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", inference_model.config.id2label[predicted_class_idx])
#Predicted class: cardboard