RRoundTable commited on
Commit
26be1cc
1 Parent(s): 948d643

Add app for dinov2 pca

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import gradio as gr
5
+ import glob
6
+ from typing import List
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms as T
9
+ from sklearn.decomposition import PCA
10
+ import sklearn
11
+ import numpy as np
12
+
13
+
14
+ # Constants
15
+ patch_h = 40
16
+ patch_w = 40
17
+
18
+ # Use GPU if available
19
+ if torch.cuda.is_available():
20
+ device = torch.device("cuda")
21
+ else:
22
+ device = torch.device("cpu")
23
+
24
+ # DINOV2
25
+ model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
26
+
27
+ # Trasnforms
28
+ transform = T.Compose([
29
+ T.Resize((patch_h * 14, patch_w * 14)),
30
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
31
+ ])
32
+
33
+ # Empty Tenosr
34
+ imgs_tensor = torch.zeros(4, 3, patch_h * 14, patch_w * 14)
35
+
36
+
37
+ # PCA
38
+ pca = PCA(n_components=3)
39
+
40
+ def query_image(img1, img2, img3, img4) -> List[np.ndarray]:
41
+
42
+ # Transform
43
+ imgs = [img1, img2, img3, img4]
44
+ for i, img in enumerate(imgs):
45
+ img = np.transpose(img, (2, 0, 1))
46
+ imgs_tensor[i] = transform(torch.Tensor(img))
47
+
48
+ # Get feature from patches
49
+ with torch.no_grad():
50
+ features_dict = model.forward_features(imgs_tensor)
51
+ features = features_dict['x_prenorm'][:, 1:]
52
+
53
+ features = features.reshape(4 * patch_h * patch_w, -1)
54
+ # PCA Feature
55
+ pca.fit(features)
56
+ pca_features = pca.transform(features)
57
+ pca_feature = sklearn.preprocessing.minmax_scale(pca_features)
58
+
59
+ # Foreground/Background
60
+ pca_features_bg = pca_features[:, 0] < 0
61
+ pca_features_fg = ~pca_features_bg
62
+
63
+ # PCA with only foreground
64
+ pca.fit(features[pca_features_fg])
65
+ pca_features_rem = pca.transform(features[pca_features_fg])
66
+
67
+ # Min Max Normalization
68
+ for i in range(3):
69
+ pca_features_rem[:, i] = (pca_features_rem[:, i] - pca_features_rem[:, i].min()) / (pca_features_rem[:, i].max() - pca_features_rem[:, i].min())
70
+
71
+ pca_features_rgb = np.zeros((4 * patch_h * patch_w, 3))
72
+ pca_features_rgb[pca_features_bg] = 0
73
+ pca_features_rgb[pca_features_fg] = pca_features_rem
74
+ pca_features_rgb = pca_features_rgb.reshape(4, patch_h, patch_w, 3)
75
+
76
+ return [pca_features_rgb[i] for i in range(4)]
77
+
78
+ description = """
79
+ DINOV2 PCA
80
+ """
81
+ demo = gr.Interface(
82
+ query_image,
83
+ inputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image()],
84
+ outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image()],
85
+ title="DINOV2 PCA",
86
+ description=description,
87
+ examples=[],
88
+ )
89
+ demo.launch()