davanstrien HF staff commited on
Commit
afebd23
1 Parent(s): 44712bf

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +135 -0
  2. requirements.in +8 -0
  3. requirements.txt +296 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import random
3
+ from datasets import load_dataset
4
+ from sentence_transformers import SentenceTransformer
5
+ from PIL.Image import Image, ANTIALIAS
6
+ import gradio as gr
7
+ from faiss import METRIC_INNER_PRODUCT
8
+ import requests
9
+ import pandas as pd
10
+
11
+ import backoff
12
+ from functools import lru_cache
13
+
14
+ cpu_count = multiprocessing.cpu_count()
15
+
16
+ model = SentenceTransformer("clip-ViT-B-16")
17
+
18
+
19
+ def resize_image(image: Image, size: int = 224) -> Image:
20
+ """Resizes an image retaining the aspect ratio."""
21
+ w, h = image.size
22
+ if w == h:
23
+ image = image.resize((size, size), ANTIALIAS)
24
+ return image
25
+ if w > h:
26
+ height_percent = size / float(h)
27
+ width_size = int(float(w) * float(height_percent))
28
+ image = image.resize((width_size, size), ANTIALIAS)
29
+ return image
30
+ if w < h:
31
+ width_percent = size / float(w)
32
+ height_size = int(float(w) * float(width_percent))
33
+ image = image.resize((size, height_size), ANTIALIAS)
34
+ return image
35
+
36
+
37
+ dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
38
+ dataset = dataset.filter(lambda x: x["embedding"] is not None)
39
+ dataset.add_faiss_index("embedding", metric_type=METRIC_INNER_PRODUCT)
40
+
41
+
42
+ def get_nearest_k_examples(input, k):
43
+ query = model.encode(input)
44
+ # faiss_index = dataset.get_index("embedding").faiss_index # TODO maybe add range?
45
+ # threshold = 0.95
46
+ # limits, distances, indices = faiss_index.range_search(x=query, thresh=threshold)
47
+ # images = dataset[indices]
48
+ _, retrieved_examples = dataset.get_nearest_examples("embedding", query=query, k=k)
49
+ images = retrieved_examples["image"][:k]
50
+ last_modified = retrieved_examples["last_modified_date"] # [:k]
51
+ crawl_date = retrieved_examples["crawl_date"] # [:k]
52
+ metadata = [
53
+ f"last_modified {modified}, crawl date:{crawl}"
54
+ for modified, crawl in zip(last_modified, crawl_date)
55
+ ]
56
+ return list(zip(images, metadata))
57
+
58
+
59
+ def return_random_sample(k=27):
60
+ sample = random.sample(range(len(dataset)), k)
61
+ images = dataset[sample]["image"]
62
+ return [resize_image(image).convert("RGB") for image in images]
63
+
64
+
65
+ def predict_subset(model_id, token):
66
+ API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
67
+ headers = {"Authorization": f"Bearer {token}"}
68
+
69
+ @backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
70
+ def _query(url):
71
+ r = requests.post(API_URL, headers=headers, data=url)
72
+ print(r)
73
+ return r
74
+
75
+ @lru_cache(maxsize=1000)
76
+ def query(url):
77
+ response = _query(url)
78
+ try:
79
+ data = response.json()
80
+ argmax = data[0]
81
+ return {"score": argmax["score"], "label": argmax["label"]}
82
+ except Exception:
83
+ return {"score": None, "label": None}
84
+
85
+ # dataset2 = copy.deepcopy(dataset)
86
+ # dataset2.drop_index("embedding")
87
+ dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
88
+ sample = random.sample(range(len(dataset)), 10)
89
+ sample = dataset.select(sample)
90
+ print("predicting...")
91
+ predictions = []
92
+ for row in sample:
93
+ url = row["url"]
94
+ predictions.append(query(url))
95
+ gallery = []
96
+ for url, prediction in zip(sample["url"], predictions):
97
+ gallery.append((url, f"{prediction['label'], prediction['score']}"))
98
+ # sample = sample.map(lambda x: query(x['url']))
99
+ labels = [d["label"] for d in predictions]
100
+ from toolz import frequencies
101
+
102
+ df = pd.DataFrame(
103
+ {"labels": frequencies(labels).keys(), "freqs": frequencies(labels).values()}
104
+ )
105
+ return gallery, df
106
+
107
+
108
+ with gr.Blocks() as demo:
109
+ with gr.Tab("Random image gallery"):
110
+ button = gr.Button("Refresh")
111
+ gallery = gr.Gallery().style(grid=9, height="1400")
112
+ button.click(return_random_sample, [], [gallery])
113
+ with gr.Tab("image search"):
114
+ text = gr.Textbox(label="Search for images")
115
+ k = gr.Slider(minimum=3, maximum=18, step=1)
116
+ button = gr.Button("search")
117
+ gallery = gr.Gallery().style(grid=3)
118
+ button.click(get_nearest_k_examples, [text, k], [gallery])
119
+ # with gr.Tab("Export for label studio"):
120
+ # button = gr.Button("Export")
121
+ # dataset2 = copy.deepcopy(dataset)
122
+ # # dataset2 = dataset2.remove_columns('image')
123
+ # # dataset2 = dataset2.rename_column("url", "image")
124
+ # csv = dataset2.to_csv("label_studio.csv")
125
+ # csv_file = gr.File("label_studio.csv")
126
+ # button.click(dataset.save_to_disk, [], [csv_file])
127
+ with gr.Tab("predict"):
128
+ token = gr.Textbox(label="token", type="password")
129
+ model_id = gr.Textbox(label="model_id")
130
+ button = gr.Button("predict")
131
+ plot = gr.BarPlot(x="labels", y="freqs", width=600, height=400, vertical=False)
132
+ gallery = gr.Gallery()
133
+ button.click(predict_subset, [model_id, token], [gallery, plot])
134
+
135
+ demo.launch(enable_queue=True, debug=True)
requirements.in ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ gradio
3
+ torch
4
+ # transformers @ git+https://github.com/huggingface/transformers@dde718e7a62bf8caa6623b5635ba02d6cb758c75
5
+ faiss-cpu
6
+ fuego
7
+ sentence_transformers
8
+ backoff
requirements.txt ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.9
3
+ # by the following command:
4
+ #
5
+ # pip-compile --resolver=backtracking requirements.in
6
+ #
7
+ aiofiles==23.1.0
8
+ # via gradio
9
+ aiohttp==3.8.4
10
+ # via
11
+ # datasets
12
+ # fsspec
13
+ # gradio
14
+ aiosignal==1.3.1
15
+ # via aiohttp
16
+ altair==4.2.2
17
+ # via gradio
18
+ anyio==3.6.2
19
+ # via
20
+ # httpcore
21
+ # starlette
22
+ async-timeout==4.0.2
23
+ # via aiohttp
24
+ attrs==22.2.0
25
+ # via
26
+ # aiohttp
27
+ # jsonschema
28
+ backoff==2.2.1
29
+ # via -r requirements.in
30
+ certifi==2022.12.7
31
+ # via
32
+ # httpcore
33
+ # httpx
34
+ # requests
35
+ charset-normalizer==3.1.0
36
+ # via
37
+ # aiohttp
38
+ # requests
39
+ click==8.1.3
40
+ # via
41
+ # nltk
42
+ # uvicorn
43
+ contourpy==1.0.7
44
+ # via matplotlib
45
+ cycler==0.11.0
46
+ # via matplotlib
47
+ datasets==2.10.1
48
+ # via -r requirements.in
49
+ dill==0.3.6
50
+ # via
51
+ # datasets
52
+ # multiprocess
53
+ entrypoints==0.4
54
+ # via altair
55
+ faiss-cpu==1.7.3
56
+ # via -r requirements.in
57
+ fastapi==0.95.0
58
+ # via gradio
59
+ ffmpy==0.3.0
60
+ # via gradio
61
+ filelock==3.10.0
62
+ # via
63
+ # huggingface-hub
64
+ # torch
65
+ # transformers
66
+ fire==0.5.0
67
+ # via fuego
68
+ fonttools==4.39.2
69
+ # via matplotlib
70
+ frozenlist==1.3.3
71
+ # via
72
+ # aiohttp
73
+ # aiosignal
74
+ fsspec[http]==2023.3.0
75
+ # via
76
+ # datasets
77
+ # gradio
78
+ fuego==0.0.8
79
+ # via -r requirements.in
80
+ gitdb==4.0.10
81
+ # via gitpython
82
+ gitpython==3.1.31
83
+ # via fuego
84
+ gradio==3.22.1
85
+ # via -r requirements.in
86
+ h11==0.14.0
87
+ # via
88
+ # httpcore
89
+ # uvicorn
90
+ httpcore==0.16.3
91
+ # via httpx
92
+ httpx==0.23.3
93
+ # via gradio
94
+ huggingface-hub==0.13.3
95
+ # via
96
+ # datasets
97
+ # fuego
98
+ # gradio
99
+ # sentence-transformers
100
+ # transformers
101
+ idna==3.4
102
+ # via
103
+ # anyio
104
+ # requests
105
+ # rfc3986
106
+ # yarl
107
+ importlib-resources==5.12.0
108
+ # via matplotlib
109
+ jinja2==3.1.2
110
+ # via
111
+ # altair
112
+ # gradio
113
+ # torch
114
+ joblib==1.2.0
115
+ # via
116
+ # nltk
117
+ # scikit-learn
118
+ jsonschema==4.17.3
119
+ # via altair
120
+ kiwisolver==1.4.4
121
+ # via matplotlib
122
+ linkify-it-py==2.0.0
123
+ # via markdown-it-py
124
+ markdown-it-py[linkify]==2.2.0
125
+ # via
126
+ # gradio
127
+ # mdit-py-plugins
128
+ markupsafe==2.1.2
129
+ # via
130
+ # gradio
131
+ # jinja2
132
+ matplotlib==3.7.1
133
+ # via gradio
134
+ mdit-py-plugins==0.3.3
135
+ # via gradio
136
+ mdurl==0.1.2
137
+ # via markdown-it-py
138
+ mpmath==1.3.0
139
+ # via sympy
140
+ multidict==6.0.4
141
+ # via
142
+ # aiohttp
143
+ # yarl
144
+ multiprocess==0.70.14
145
+ # via datasets
146
+ networkx==3.0
147
+ # via torch
148
+ nltk==3.8.1
149
+ # via sentence-transformers
150
+ numpy==1.24.2
151
+ # via
152
+ # altair
153
+ # contourpy
154
+ # datasets
155
+ # gradio
156
+ # matplotlib
157
+ # pandas
158
+ # pyarrow
159
+ # scikit-learn
160
+ # scipy
161
+ # sentence-transformers
162
+ # torchvision
163
+ # transformers
164
+ orjson==3.8.7
165
+ # via gradio
166
+ packaging==23.0
167
+ # via
168
+ # datasets
169
+ # huggingface-hub
170
+ # matplotlib
171
+ # transformers
172
+ pandas==1.5.3
173
+ # via
174
+ # altair
175
+ # datasets
176
+ # gradio
177
+ pillow==9.4.0
178
+ # via
179
+ # gradio
180
+ # matplotlib
181
+ # torchvision
182
+ pyarrow==11.0.0
183
+ # via datasets
184
+ pydantic==1.10.6
185
+ # via
186
+ # fastapi
187
+ # gradio
188
+ pydub==0.25.1
189
+ # via gradio
190
+ pyparsing==3.0.9
191
+ # via matplotlib
192
+ pyrsistent==0.19.3
193
+ # via jsonschema
194
+ python-dateutil==2.8.2
195
+ # via
196
+ # matplotlib
197
+ # pandas
198
+ python-multipart==0.0.6
199
+ # via gradio
200
+ pytz==2022.7.1
201
+ # via pandas
202
+ pyyaml==6.0
203
+ # via
204
+ # datasets
205
+ # gradio
206
+ # huggingface-hub
207
+ # transformers
208
+ regex==2022.10.31
209
+ # via
210
+ # nltk
211
+ # transformers
212
+ requests==2.28.2
213
+ # via
214
+ # datasets
215
+ # fsspec
216
+ # gradio
217
+ # huggingface-hub
218
+ # responses
219
+ # torchvision
220
+ # transformers
221
+ responses==0.18.0
222
+ # via datasets
223
+ rfc3986[idna2008]==1.5.0
224
+ # via httpx
225
+ scikit-learn==1.2.2
226
+ # via sentence-transformers
227
+ scipy==1.10.1
228
+ # via
229
+ # scikit-learn
230
+ # sentence-transformers
231
+ sentence-transformers==2.2.2
232
+ # via -r requirements.in
233
+ sentencepiece==0.1.97
234
+ # via sentence-transformers
235
+ six==1.16.0
236
+ # via
237
+ # fire
238
+ # python-dateutil
239
+ smmap==5.0.0
240
+ # via gitdb
241
+ sniffio==1.3.0
242
+ # via
243
+ # anyio
244
+ # httpcore
245
+ # httpx
246
+ starlette==0.26.1
247
+ # via fastapi
248
+ sympy==1.11.1
249
+ # via torch
250
+ termcolor==2.2.0
251
+ # via fire
252
+ threadpoolctl==3.1.0
253
+ # via scikit-learn
254
+ tokenizers==0.13.2
255
+ # via transformers
256
+ toolz==0.12.0
257
+ # via altair
258
+ torch==2.0.0
259
+ # via
260
+ # -r requirements.in
261
+ # sentence-transformers
262
+ # torchvision
263
+ torchvision==0.15.1
264
+ # via sentence-transformers
265
+ tqdm==4.65.0
266
+ # via
267
+ # datasets
268
+ # huggingface-hub
269
+ # nltk
270
+ # sentence-transformers
271
+ # transformers
272
+ transformers==4.27.2
273
+ # via sentence-transformers
274
+ typing-extensions==4.5.0
275
+ # via
276
+ # gradio
277
+ # huggingface-hub
278
+ # pydantic
279
+ # starlette
280
+ # torch
281
+ uc-micro-py==1.0.1
282
+ # via linkify-it-py
283
+ urllib3==1.26.15
284
+ # via
285
+ # requests
286
+ # responses
287
+ uvicorn==0.21.1
288
+ # via gradio
289
+ websockets==10.4
290
+ # via gradio
291
+ xxhash==3.2.0
292
+ # via datasets
293
+ yarl==1.8.2
294
+ # via aiohttp
295
+ zipp==3.15.0
296
+ # via importlib-resources