bird-classifier / app.py
jerpint's picture
actually classify images
94ab357
raw
history blame
No virus
1.21 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
np.int = int
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
**{'topN': 6, 'device':'cpu', 'num_classes': 200})
model.eval()
def classify_bird(img):
transform_test = transforms.Compose([
transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
# transforms.RandomHorizontalFlip(), # only if train
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
scaled_img = transform_test(img)
torch_images = scaled_img.unsqueeze(0)
with torch.no_grad():
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
_, predict = torch.max(concat_logits, 1)
pred_id = predict.item()
bird_class = model.bird_classes[pred_id]
print(f"{bird_class=}")
return bird_class
image_component = gr.Image(type="pil", label="Bird Image")
demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs="text")
demo.launch()