import sys sys.path.insert(0, './code') from datamodules.transformations import UnNest from models.interpretation import ImageInterpretationNet from transformers import ViTFeatureExtractor, ViTForImageClassification from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image import gradio as gr import numpy as np import torch import seaborn as sns import matplotlib.pyplot as plt # Load Vision Transformer hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10" hf_model_imagenet = "google/vit-base-patch16-224" vit = ViTForImageClassification.from_pretrained(hf_model) vit_imagenet = ViTForImageClassification.from_pretrained(hf_model_imagenet) vit.eval() vit_imagenet.eval() # Load Feature Extractor feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt") feature_extractor_imagenet = ViTFeatureExtractor.from_pretrained(hf_model_imagenet, return_tensors="pt") feature_extractor = UnNest(feature_extractor) feature_extractor_imagenet = UnNest(feature_extractor_imagenet) # Load Vision DiffMask diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt') diffmask.set_vision_transformer(vit) diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt') diffmask_imagenet.set_vision_transformer(vit_imagenet) diffmask.eval() diffmask_imagenet.eval() # Define mask plotting functions def draw_mask(image, mask): return draw_mask_on_image(image, smoothen(mask))\ .permute(1, 2, 0)\ .clip(0, 1)\ .numpy() def draw_heatmap(image, mask): return draw_heatmap_on_image(image, smoothen(mask))\ .permute(1, 2, 0)\ .clip(0, 1)\ .numpy() # Define callable method for the demo def get_mask(image, model_name: str): if image is None: return None, None if model_name == 'DiffMask-CiFAR-10': diffmask_model = diffmask elif model_name == 'DiffMask-ImageNet': diffmask_model = diffmask_imagenet image = torch.from_numpy(image).permute(2, 0, 1).float() / 255 dm_image = feature_extractor(image).unsqueeze(0) dm_out = diffmask_model.get_mask(dm_image) mask = dm_out["mask"][0].detach() logits = dm_out["logits"][0].detach().softmax(dim=-1) logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1) # fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10)) # sns.displot(logits_orig.cpu().numpy().flatten(), kind="kde", label="Original", ax=ax) top5logits_orig = logits_orig.topk(5, dim=-1) idx = top5logits_orig.indices # keep the top 5 classes from the indices of the top 5 logits top5logits_orig = top5logits_orig.values top5logits = logits[idx] pred = dm_out["pred_class"][0].detach() pred = diffmask_model.model.config.id2label[pred.item()] masked_img = draw_mask(image, mask) heatmap = draw_heatmap(image, mask) orig_probs = {diffmask_model.model.config.id2label[i]: top5logits_orig[i].item() for i in range(5)} pred_probs = {diffmask_model.model.config.id2label[i]: top5logits[i].item() for i in range(5)} return np.hstack((masked_img, heatmap)), pred, orig_probs, pred_probs # Launch demo interface gr.Interface( get_mask, inputs=[gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"), gr.inputs.Dropdown(["DiffMask-CiFAR-10", "DiffMask-ImageNet"])], outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction"), gr.Label(label="Original Probabilities"), gr.Label(label="Predicted Probabilities")], title="Vision DiffMask Demo", live=True, ).launch()