Spaces:
Sleeping
Sleeping
Commit
•
afebd23
1
Parent(s):
44712bf
Upload 3 files
Browse files- app.py +135 -0
- requirements.in +8 -0
- requirements.txt +296 -0
app.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import multiprocessing
|
2 |
+
import random
|
3 |
+
from datasets import load_dataset
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from PIL.Image import Image, ANTIALIAS
|
6 |
+
import gradio as gr
|
7 |
+
from faiss import METRIC_INNER_PRODUCT
|
8 |
+
import requests
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
import backoff
|
12 |
+
from functools import lru_cache
|
13 |
+
|
14 |
+
cpu_count = multiprocessing.cpu_count()
|
15 |
+
|
16 |
+
model = SentenceTransformer("clip-ViT-B-16")
|
17 |
+
|
18 |
+
|
19 |
+
def resize_image(image: Image, size: int = 224) -> Image:
|
20 |
+
"""Resizes an image retaining the aspect ratio."""
|
21 |
+
w, h = image.size
|
22 |
+
if w == h:
|
23 |
+
image = image.resize((size, size), ANTIALIAS)
|
24 |
+
return image
|
25 |
+
if w > h:
|
26 |
+
height_percent = size / float(h)
|
27 |
+
width_size = int(float(w) * float(height_percent))
|
28 |
+
image = image.resize((width_size, size), ANTIALIAS)
|
29 |
+
return image
|
30 |
+
if w < h:
|
31 |
+
width_percent = size / float(w)
|
32 |
+
height_size = int(float(w) * float(width_percent))
|
33 |
+
image = image.resize((size, height_size), ANTIALIAS)
|
34 |
+
return image
|
35 |
+
|
36 |
+
|
37 |
+
dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
|
38 |
+
dataset = dataset.filter(lambda x: x["embedding"] is not None)
|
39 |
+
dataset.add_faiss_index("embedding", metric_type=METRIC_INNER_PRODUCT)
|
40 |
+
|
41 |
+
|
42 |
+
def get_nearest_k_examples(input, k):
|
43 |
+
query = model.encode(input)
|
44 |
+
# faiss_index = dataset.get_index("embedding").faiss_index # TODO maybe add range?
|
45 |
+
# threshold = 0.95
|
46 |
+
# limits, distances, indices = faiss_index.range_search(x=query, thresh=threshold)
|
47 |
+
# images = dataset[indices]
|
48 |
+
_, retrieved_examples = dataset.get_nearest_examples("embedding", query=query, k=k)
|
49 |
+
images = retrieved_examples["image"][:k]
|
50 |
+
last_modified = retrieved_examples["last_modified_date"] # [:k]
|
51 |
+
crawl_date = retrieved_examples["crawl_date"] # [:k]
|
52 |
+
metadata = [
|
53 |
+
f"last_modified {modified}, crawl date:{crawl}"
|
54 |
+
for modified, crawl in zip(last_modified, crawl_date)
|
55 |
+
]
|
56 |
+
return list(zip(images, metadata))
|
57 |
+
|
58 |
+
|
59 |
+
def return_random_sample(k=27):
|
60 |
+
sample = random.sample(range(len(dataset)), k)
|
61 |
+
images = dataset[sample]["image"]
|
62 |
+
return [resize_image(image).convert("RGB") for image in images]
|
63 |
+
|
64 |
+
|
65 |
+
def predict_subset(model_id, token):
|
66 |
+
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
|
67 |
+
headers = {"Authorization": f"Bearer {token}"}
|
68 |
+
|
69 |
+
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
|
70 |
+
def _query(url):
|
71 |
+
r = requests.post(API_URL, headers=headers, data=url)
|
72 |
+
print(r)
|
73 |
+
return r
|
74 |
+
|
75 |
+
@lru_cache(maxsize=1000)
|
76 |
+
def query(url):
|
77 |
+
response = _query(url)
|
78 |
+
try:
|
79 |
+
data = response.json()
|
80 |
+
argmax = data[0]
|
81 |
+
return {"score": argmax["score"], "label": argmax["label"]}
|
82 |
+
except Exception:
|
83 |
+
return {"score": None, "label": None}
|
84 |
+
|
85 |
+
# dataset2 = copy.deepcopy(dataset)
|
86 |
+
# dataset2.drop_index("embedding")
|
87 |
+
dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
|
88 |
+
sample = random.sample(range(len(dataset)), 10)
|
89 |
+
sample = dataset.select(sample)
|
90 |
+
print("predicting...")
|
91 |
+
predictions = []
|
92 |
+
for row in sample:
|
93 |
+
url = row["url"]
|
94 |
+
predictions.append(query(url))
|
95 |
+
gallery = []
|
96 |
+
for url, prediction in zip(sample["url"], predictions):
|
97 |
+
gallery.append((url, f"{prediction['label'], prediction['score']}"))
|
98 |
+
# sample = sample.map(lambda x: query(x['url']))
|
99 |
+
labels = [d["label"] for d in predictions]
|
100 |
+
from toolz import frequencies
|
101 |
+
|
102 |
+
df = pd.DataFrame(
|
103 |
+
{"labels": frequencies(labels).keys(), "freqs": frequencies(labels).values()}
|
104 |
+
)
|
105 |
+
return gallery, df
|
106 |
+
|
107 |
+
|
108 |
+
with gr.Blocks() as demo:
|
109 |
+
with gr.Tab("Random image gallery"):
|
110 |
+
button = gr.Button("Refresh")
|
111 |
+
gallery = gr.Gallery().style(grid=9, height="1400")
|
112 |
+
button.click(return_random_sample, [], [gallery])
|
113 |
+
with gr.Tab("image search"):
|
114 |
+
text = gr.Textbox(label="Search for images")
|
115 |
+
k = gr.Slider(minimum=3, maximum=18, step=1)
|
116 |
+
button = gr.Button("search")
|
117 |
+
gallery = gr.Gallery().style(grid=3)
|
118 |
+
button.click(get_nearest_k_examples, [text, k], [gallery])
|
119 |
+
# with gr.Tab("Export for label studio"):
|
120 |
+
# button = gr.Button("Export")
|
121 |
+
# dataset2 = copy.deepcopy(dataset)
|
122 |
+
# # dataset2 = dataset2.remove_columns('image')
|
123 |
+
# # dataset2 = dataset2.rename_column("url", "image")
|
124 |
+
# csv = dataset2.to_csv("label_studio.csv")
|
125 |
+
# csv_file = gr.File("label_studio.csv")
|
126 |
+
# button.click(dataset.save_to_disk, [], [csv_file])
|
127 |
+
with gr.Tab("predict"):
|
128 |
+
token = gr.Textbox(label="token", type="password")
|
129 |
+
model_id = gr.Textbox(label="model_id")
|
130 |
+
button = gr.Button("predict")
|
131 |
+
plot = gr.BarPlot(x="labels", y="freqs", width=600, height=400, vertical=False)
|
132 |
+
gallery = gr.Gallery()
|
133 |
+
button.click(predict_subset, [model_id, token], [gallery, plot])
|
134 |
+
|
135 |
+
demo.launch(enable_queue=True, debug=True)
|
requirements.in
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
gradio
|
3 |
+
torch
|
4 |
+
# transformers @ git+https://github.com/huggingface/transformers@dde718e7a62bf8caa6623b5635ba02d6cb758c75
|
5 |
+
faiss-cpu
|
6 |
+
fuego
|
7 |
+
sentence_transformers
|
8 |
+
backoff
|
requirements.txt
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# This file is autogenerated by pip-compile with Python 3.9
|
3 |
+
# by the following command:
|
4 |
+
#
|
5 |
+
# pip-compile --resolver=backtracking requirements.in
|
6 |
+
#
|
7 |
+
aiofiles==23.1.0
|
8 |
+
# via gradio
|
9 |
+
aiohttp==3.8.4
|
10 |
+
# via
|
11 |
+
# datasets
|
12 |
+
# fsspec
|
13 |
+
# gradio
|
14 |
+
aiosignal==1.3.1
|
15 |
+
# via aiohttp
|
16 |
+
altair==4.2.2
|
17 |
+
# via gradio
|
18 |
+
anyio==3.6.2
|
19 |
+
# via
|
20 |
+
# httpcore
|
21 |
+
# starlette
|
22 |
+
async-timeout==4.0.2
|
23 |
+
# via aiohttp
|
24 |
+
attrs==22.2.0
|
25 |
+
# via
|
26 |
+
# aiohttp
|
27 |
+
# jsonschema
|
28 |
+
backoff==2.2.1
|
29 |
+
# via -r requirements.in
|
30 |
+
certifi==2022.12.7
|
31 |
+
# via
|
32 |
+
# httpcore
|
33 |
+
# httpx
|
34 |
+
# requests
|
35 |
+
charset-normalizer==3.1.0
|
36 |
+
# via
|
37 |
+
# aiohttp
|
38 |
+
# requests
|
39 |
+
click==8.1.3
|
40 |
+
# via
|
41 |
+
# nltk
|
42 |
+
# uvicorn
|
43 |
+
contourpy==1.0.7
|
44 |
+
# via matplotlib
|
45 |
+
cycler==0.11.0
|
46 |
+
# via matplotlib
|
47 |
+
datasets==2.10.1
|
48 |
+
# via -r requirements.in
|
49 |
+
dill==0.3.6
|
50 |
+
# via
|
51 |
+
# datasets
|
52 |
+
# multiprocess
|
53 |
+
entrypoints==0.4
|
54 |
+
# via altair
|
55 |
+
faiss-cpu==1.7.3
|
56 |
+
# via -r requirements.in
|
57 |
+
fastapi==0.95.0
|
58 |
+
# via gradio
|
59 |
+
ffmpy==0.3.0
|
60 |
+
# via gradio
|
61 |
+
filelock==3.10.0
|
62 |
+
# via
|
63 |
+
# huggingface-hub
|
64 |
+
# torch
|
65 |
+
# transformers
|
66 |
+
fire==0.5.0
|
67 |
+
# via fuego
|
68 |
+
fonttools==4.39.2
|
69 |
+
# via matplotlib
|
70 |
+
frozenlist==1.3.3
|
71 |
+
# via
|
72 |
+
# aiohttp
|
73 |
+
# aiosignal
|
74 |
+
fsspec[http]==2023.3.0
|
75 |
+
# via
|
76 |
+
# datasets
|
77 |
+
# gradio
|
78 |
+
fuego==0.0.8
|
79 |
+
# via -r requirements.in
|
80 |
+
gitdb==4.0.10
|
81 |
+
# via gitpython
|
82 |
+
gitpython==3.1.31
|
83 |
+
# via fuego
|
84 |
+
gradio==3.22.1
|
85 |
+
# via -r requirements.in
|
86 |
+
h11==0.14.0
|
87 |
+
# via
|
88 |
+
# httpcore
|
89 |
+
# uvicorn
|
90 |
+
httpcore==0.16.3
|
91 |
+
# via httpx
|
92 |
+
httpx==0.23.3
|
93 |
+
# via gradio
|
94 |
+
huggingface-hub==0.13.3
|
95 |
+
# via
|
96 |
+
# datasets
|
97 |
+
# fuego
|
98 |
+
# gradio
|
99 |
+
# sentence-transformers
|
100 |
+
# transformers
|
101 |
+
idna==3.4
|
102 |
+
# via
|
103 |
+
# anyio
|
104 |
+
# requests
|
105 |
+
# rfc3986
|
106 |
+
# yarl
|
107 |
+
importlib-resources==5.12.0
|
108 |
+
# via matplotlib
|
109 |
+
jinja2==3.1.2
|
110 |
+
# via
|
111 |
+
# altair
|
112 |
+
# gradio
|
113 |
+
# torch
|
114 |
+
joblib==1.2.0
|
115 |
+
# via
|
116 |
+
# nltk
|
117 |
+
# scikit-learn
|
118 |
+
jsonschema==4.17.3
|
119 |
+
# via altair
|
120 |
+
kiwisolver==1.4.4
|
121 |
+
# via matplotlib
|
122 |
+
linkify-it-py==2.0.0
|
123 |
+
# via markdown-it-py
|
124 |
+
markdown-it-py[linkify]==2.2.0
|
125 |
+
# via
|
126 |
+
# gradio
|
127 |
+
# mdit-py-plugins
|
128 |
+
markupsafe==2.1.2
|
129 |
+
# via
|
130 |
+
# gradio
|
131 |
+
# jinja2
|
132 |
+
matplotlib==3.7.1
|
133 |
+
# via gradio
|
134 |
+
mdit-py-plugins==0.3.3
|
135 |
+
# via gradio
|
136 |
+
mdurl==0.1.2
|
137 |
+
# via markdown-it-py
|
138 |
+
mpmath==1.3.0
|
139 |
+
# via sympy
|
140 |
+
multidict==6.0.4
|
141 |
+
# via
|
142 |
+
# aiohttp
|
143 |
+
# yarl
|
144 |
+
multiprocess==0.70.14
|
145 |
+
# via datasets
|
146 |
+
networkx==3.0
|
147 |
+
# via torch
|
148 |
+
nltk==3.8.1
|
149 |
+
# via sentence-transformers
|
150 |
+
numpy==1.24.2
|
151 |
+
# via
|
152 |
+
# altair
|
153 |
+
# contourpy
|
154 |
+
# datasets
|
155 |
+
# gradio
|
156 |
+
# matplotlib
|
157 |
+
# pandas
|
158 |
+
# pyarrow
|
159 |
+
# scikit-learn
|
160 |
+
# scipy
|
161 |
+
# sentence-transformers
|
162 |
+
# torchvision
|
163 |
+
# transformers
|
164 |
+
orjson==3.8.7
|
165 |
+
# via gradio
|
166 |
+
packaging==23.0
|
167 |
+
# via
|
168 |
+
# datasets
|
169 |
+
# huggingface-hub
|
170 |
+
# matplotlib
|
171 |
+
# transformers
|
172 |
+
pandas==1.5.3
|
173 |
+
# via
|
174 |
+
# altair
|
175 |
+
# datasets
|
176 |
+
# gradio
|
177 |
+
pillow==9.4.0
|
178 |
+
# via
|
179 |
+
# gradio
|
180 |
+
# matplotlib
|
181 |
+
# torchvision
|
182 |
+
pyarrow==11.0.0
|
183 |
+
# via datasets
|
184 |
+
pydantic==1.10.6
|
185 |
+
# via
|
186 |
+
# fastapi
|
187 |
+
# gradio
|
188 |
+
pydub==0.25.1
|
189 |
+
# via gradio
|
190 |
+
pyparsing==3.0.9
|
191 |
+
# via matplotlib
|
192 |
+
pyrsistent==0.19.3
|
193 |
+
# via jsonschema
|
194 |
+
python-dateutil==2.8.2
|
195 |
+
# via
|
196 |
+
# matplotlib
|
197 |
+
# pandas
|
198 |
+
python-multipart==0.0.6
|
199 |
+
# via gradio
|
200 |
+
pytz==2022.7.1
|
201 |
+
# via pandas
|
202 |
+
pyyaml==6.0
|
203 |
+
# via
|
204 |
+
# datasets
|
205 |
+
# gradio
|
206 |
+
# huggingface-hub
|
207 |
+
# transformers
|
208 |
+
regex==2022.10.31
|
209 |
+
# via
|
210 |
+
# nltk
|
211 |
+
# transformers
|
212 |
+
requests==2.28.2
|
213 |
+
# via
|
214 |
+
# datasets
|
215 |
+
# fsspec
|
216 |
+
# gradio
|
217 |
+
# huggingface-hub
|
218 |
+
# responses
|
219 |
+
# torchvision
|
220 |
+
# transformers
|
221 |
+
responses==0.18.0
|
222 |
+
# via datasets
|
223 |
+
rfc3986[idna2008]==1.5.0
|
224 |
+
# via httpx
|
225 |
+
scikit-learn==1.2.2
|
226 |
+
# via sentence-transformers
|
227 |
+
scipy==1.10.1
|
228 |
+
# via
|
229 |
+
# scikit-learn
|
230 |
+
# sentence-transformers
|
231 |
+
sentence-transformers==2.2.2
|
232 |
+
# via -r requirements.in
|
233 |
+
sentencepiece==0.1.97
|
234 |
+
# via sentence-transformers
|
235 |
+
six==1.16.0
|
236 |
+
# via
|
237 |
+
# fire
|
238 |
+
# python-dateutil
|
239 |
+
smmap==5.0.0
|
240 |
+
# via gitdb
|
241 |
+
sniffio==1.3.0
|
242 |
+
# via
|
243 |
+
# anyio
|
244 |
+
# httpcore
|
245 |
+
# httpx
|
246 |
+
starlette==0.26.1
|
247 |
+
# via fastapi
|
248 |
+
sympy==1.11.1
|
249 |
+
# via torch
|
250 |
+
termcolor==2.2.0
|
251 |
+
# via fire
|
252 |
+
threadpoolctl==3.1.0
|
253 |
+
# via scikit-learn
|
254 |
+
tokenizers==0.13.2
|
255 |
+
# via transformers
|
256 |
+
toolz==0.12.0
|
257 |
+
# via altair
|
258 |
+
torch==2.0.0
|
259 |
+
# via
|
260 |
+
# -r requirements.in
|
261 |
+
# sentence-transformers
|
262 |
+
# torchvision
|
263 |
+
torchvision==0.15.1
|
264 |
+
# via sentence-transformers
|
265 |
+
tqdm==4.65.0
|
266 |
+
# via
|
267 |
+
# datasets
|
268 |
+
# huggingface-hub
|
269 |
+
# nltk
|
270 |
+
# sentence-transformers
|
271 |
+
# transformers
|
272 |
+
transformers==4.27.2
|
273 |
+
# via sentence-transformers
|
274 |
+
typing-extensions==4.5.0
|
275 |
+
# via
|
276 |
+
# gradio
|
277 |
+
# huggingface-hub
|
278 |
+
# pydantic
|
279 |
+
# starlette
|
280 |
+
# torch
|
281 |
+
uc-micro-py==1.0.1
|
282 |
+
# via linkify-it-py
|
283 |
+
urllib3==1.26.15
|
284 |
+
# via
|
285 |
+
# requests
|
286 |
+
# responses
|
287 |
+
uvicorn==0.21.1
|
288 |
+
# via gradio
|
289 |
+
websockets==10.4
|
290 |
+
# via gradio
|
291 |
+
xxhash==3.2.0
|
292 |
+
# via datasets
|
293 |
+
yarl==1.8.2
|
294 |
+
# via aiohttp
|
295 |
+
zipp==3.15.0
|
296 |
+
# via importlib-resources
|