neggles commited on
Commit
5b0d6ce
1 Parent(s): 92d4909

update for v3

Browse files
Files changed (4) hide show
  1. README.md +11 -5
  2. app.py +169 -63
  3. data/selected_tags.csv +0 -0
  4. tagger/common.py +56 -4
README.md CHANGED
@@ -9,12 +9,18 @@ app_file: app.py
9
  pinned: false
10
  short_description: A WD Tagger Space for pi-chan to use
11
  preload_from_hub:
12
- - SmilingWolf/wd-v1-4-moat-tagger-v2 model.onnx
13
- - SmilingWolf/wd-v1-4-swinv2-tagger-v2 model.onnx
14
- - SmilingWolf/wd-v1-4-convnext-tagger-v2 model.onnx
15
- - SmilingWolf/wd-v1-4-convnextv2-tagger-v2 model.onnx
16
- - SmilingWolf/wd-v1-4-vit-tagger-v2 model.onnx
 
 
 
17
  models:
 
 
 
18
  - SmilingWolf/wd-v1-4-moat-tagger-v2
19
  - SmilingWolf/wd-v1-4-swinv2-tagger-v2
20
  - SmilingWolf/wd-v1-4-convnext-tagger-v2
 
9
  pinned: false
10
  short_description: A WD Tagger Space for pi-chan to use
11
  preload_from_hub:
12
+ - SmilingWolf/wd-vit-tagger-v3 model.onnx,selected_tags.csv
13
+ - SmilingWolf/wd-swinv2-tagger-v3 model.onnx,selected_tags.csv
14
+ - SmilingWolf/wd-convnext-tagger-v3 model.onnx,selected_tags.csv
15
+ - SmilingWolf/wd-v1-4-moat-tagger-v2 model.onnx,selected_tags.csv
16
+ - SmilingWolf/wd-v1-4-swinv2-tagger-v2 model.onnx,selected_tags.csv
17
+ - SmilingWolf/wd-v1-4-convnext-tagger-v2 model.onnx,selected_tags.csv
18
+ - SmilingWolf/wd-v1-4-convnextv2-tagger-v2 model.onnx,selected_tags.csv
19
+ - SmilingWolf/wd-v1-4-vit-tagger-v2 model.onnx,selected_tags.csv
20
  models:
21
+ - SmilingWolf/wd-vit-tagger-v3
22
+ - SmilingWolf/wd-swinv2-tagger-v3
23
+ - SmilingWolf/wd-convnext-tagger-v3
24
  - SmilingWolf/wd-v1-4-moat-tagger-v2
25
  - SmilingWolf/wd-v1-4-swinv2-tagger-v2
26
  - SmilingWolf/wd-v1-4-convnext-tagger-v2
app.py CHANGED
@@ -7,25 +7,41 @@ import numpy as np
7
  import onnxruntime as rt
8
  from PIL import Image
9
 
10
- from tagger.common import LabelData, load_labels, preprocess_image
11
  from tagger.model import create_session
12
 
 
 
 
 
 
 
13
  HF_TOKEN = getenv("HF_TOKEN", None)
14
- WORK_DIR = Path.cwd().resolve()
15
 
16
  MODEL_VARIANTS: dict[str, str] = {
17
- "MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2",
18
- "SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
19
- "ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2",
20
- "ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
21
- "ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2",
 
 
 
 
 
 
 
22
  }
 
 
 
23
 
 
 
24
  # allowed extensions
25
  IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
26
 
27
- # model input shape
28
- IMAGE_SIZE = 448
29
  example_images = sorted(
30
  [
31
  str(x.relative_to(WORK_DIR))
@@ -33,34 +49,51 @@ example_images = sorted(
33
  if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
34
  ]
35
  )
36
- loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k, _ in MODEL_VARIANTS.items()}
37
 
38
 
39
- def load_model(variant: str) -> rt.InferenceSession:
40
  global loaded_models
41
 
42
  # resolve the repo name
43
- model_repo = MODEL_VARIANTS.get(variant, None)
44
  if model_repo is None:
45
- raise ValueError(f"Unknown model variant: {variant}")
46
 
47
- if loaded_models.get(variant, None) is None:
 
48
  # save model to cache
49
- loaded_models[variant] = create_session(model_repo, token=HF_TOKEN)
 
 
 
50
 
51
- return loaded_models[variant]
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  def predict(
55
  image: Image.Image,
 
56
  variant: str,
57
- general_threshold: float = 0.35,
58
- character_threshold: float = 0.85,
 
 
59
  ):
60
- # Load model
61
- model: rt.InferenceSession = load_model(variant)
62
  # load labels
63
- labels: LabelData = load_labels()
64
 
65
  # get input size and name
66
  _, h, w, _ = model.get_inputs()[0].shape
@@ -85,13 +118,21 @@ def predict(
85
  rating_labels = dict([probs[i] for i in labels.rating])
86
 
87
  # General labels, pick any where prediction confidence > threshold
 
 
 
 
88
  gen_labels = [probs[i] for i in labels.general]
89
- gen_labels = dict([x for x in gen_labels if x[1] > general_threshold])
90
  gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
91
 
92
  # Character labels, pick any where prediction confidence > threshold
 
 
 
 
93
  char_labels = [probs[i] for i in labels.character]
94
- char_labels = dict([x for x in char_labels if x[1] > character_threshold])
95
  char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
96
 
97
  # Combine general and character labels, sort by confidence
@@ -102,64 +143,129 @@ def predict(
102
  caption = ", ".join(combined_names)
103
  booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
104
 
105
- return image, caption, booru, rating_labels, char_labels, gen_labels
106
 
107
 
108
- with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title="pi-chan's tagger") as demo:
 
 
 
 
 
 
 
 
 
109
  with gr.Row(equal_height=False):
110
- with gr.Column():
111
- img_input = gr.Image(
112
- label="Input",
113
- type="pil",
114
- image_mode="RGB",
115
- sources=["upload", "clipboard"],
116
- )
117
- variant = gr.Radio(choices=list(MODEL_VARIANTS.keys()), label="Model Variant", value="MOAT")
118
- gen_thresh = gr.Slider(0.0, 1.0, value=0.35, label="General Tag Threshold")
119
- char_thresh = gr.Slider(0.0, 1.0, value=0.85, label="Character Tag Threshold")
120
- show_processed = gr.Checkbox(label="Show Preprocessed", value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  with gr.Row():
122
- submit = gr.Button(value="Submit", variant="primary", size="lg")
123
  clear = gr.ClearButton(
124
  components=[],
125
  variant="secondary",
126
  size="lg",
127
  )
128
- with gr.Row():
129
- examples = gr.Examples(
130
- examples=[
131
- [imgpath, var, 0.35, 0.85]
132
- for imgpath in example_images
133
- for var in ["MOAT", "ConvNeXTv2"]
134
- ],
135
- inputs=[img_input, variant, gen_thresh, char_thresh],
136
- )
137
- with gr.Column():
138
- img_output = gr.Image(label="Preprocessed", type="pil", image_mode="RGB", scale=1, visible=False)
139
  with gr.Group():
140
- tags_string = gr.Textbox(
141
- label="Caption", placeholder="Caption will appear here", show_copy_button=True
142
- )
143
- tags_booru = gr.Textbox(
144
- label="Tags", placeholder="Tag string will appear here", show_copy_button=True
145
- )
146
- rating = gr.Label(label="Rating")
147
- character = gr.Label(label="Character")
148
- general = gr.Label(label="General")
 
 
 
 
 
 
 
 
 
149
 
150
  # tell clear button which components to clear
151
- clear.add([img_input, img_output, tags_string, rating, character, general])
 
 
 
 
 
 
 
 
152
 
153
  # show/hide processed image
154
- def on_select_show_processed(evt: gr.SelectData):
155
- return gr.update(visible=evt.selected)
 
 
 
 
 
 
 
 
 
156
 
157
- show_processed.select(on_select_show_processed, inputs=[], outputs=[img_output])
 
158
 
159
  submit.click(
160
  predict,
161
- inputs=[img_input, variant, gen_thresh, char_thresh],
162
- outputs=[img_output, tags_string, tags_booru, rating, character, general],
163
  api_name="predict",
164
  )
165
 
 
7
  import onnxruntime as rt
8
  from PIL import Image
9
 
10
+ from tagger.common import LabelData, load_labels_hf, preprocess_image
11
  from tagger.model import create_session
12
 
13
+ TITLE = "WaifuDiffusion Tagger"
14
+ DESCRIPTION = """
15
+ Tag images with the WaifuDiffusion Tagger models!
16
+
17
+ Primarily used as a backend for a Discord bot.
18
+ """
19
  HF_TOKEN = getenv("HF_TOKEN", None)
 
20
 
21
  MODEL_VARIANTS: dict[str, str] = {
22
+ "v3": {
23
+ "SwinV2": "SmilingWolf/wd-swinv2-tagger-v3",
24
+ "ConvNeXT": "SmilingWolf/wd-convnext-tagger-v3",
25
+ "ViT": "SmilingWolf/wd-vit-tagger-v3",
26
+ },
27
+ "v2": {
28
+ "MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2",
29
+ "SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
30
+ "ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2",
31
+ "ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
32
+ "ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2",
33
+ },
34
  }
35
+ # prepopulate cache keys in model cache
36
+ cache_keys = ["-".join([x, y]) for x in MODEL_VARIANTS.keys() for y in MODEL_VARIANTS[x].keys()]
37
+ loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k in cache_keys}
38
 
39
+ # get the repo root (or the current working directory if running in ipython)
40
+ WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve()
41
  # allowed extensions
42
  IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
43
 
44
+ # get the example images
 
45
  example_images = sorted(
46
  [
47
  str(x.relative_to(WORK_DIR))
 
49
  if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
50
  ]
51
  )
 
52
 
53
 
54
+ def load_model(version: str, variant: str) -> rt.InferenceSession:
55
  global loaded_models
56
 
57
  # resolve the repo name
58
+ model_repo = MODEL_VARIANTS.get(version, {}).get(variant, None)
59
  if model_repo is None:
60
+ raise ValueError(f"Unknown model variant: {version}-{variant}")
61
 
62
+ cache_key = f"{version}-{variant}"
63
+ if loaded_models.get(cache_key, None) is None:
64
  # save model to cache
65
+ loaded_models[cache_key] = create_session(model_repo, token=HF_TOKEN)
66
+
67
+ return loaded_models[cache_key]
68
+
69
 
70
+ def mcut_threshold(probs: np.ndarray) -> float:
71
+ """
72
+ Maximum Cut Thresholding (MCut)
73
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
74
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
75
+ (pp. 172-183).
76
+ """
77
+ probs = probs[probs.argsort()[::-1]]
78
+ diffs = probs[:-1] - probs[1:]
79
+ idx = diffs.argmax()
80
+ thresh = (probs[idx] + probs[idx + 1]) / 2
81
+ return float(thresh)
82
 
83
 
84
  def predict(
85
  image: Image.Image,
86
+ version: str,
87
  variant: str,
88
+ gen_threshold: float = 0.35,
89
+ gen_use_mcut: bool = False,
90
+ char_threshold: float = 0.85,
91
+ char_use_mcut: bool = False,
92
  ):
93
+ # join variant for cache key
94
+ model: rt.InferenceSession = load_model(version, variant)
95
  # load labels
96
+ labels: LabelData = load_labels_hf(MODEL_VARIANTS[version][variant])
97
 
98
  # get input size and name
99
  _, h, w, _ = model.get_inputs()[0].shape
 
118
  rating_labels = dict([probs[i] for i in labels.rating])
119
 
120
  # General labels, pick any where prediction confidence > threshold
121
+ if gen_use_mcut:
122
+ gen_array = np.array([probs[i][1] for i in labels.general])
123
+ gen_threshold = mcut_threshold(gen_array)
124
+
125
  gen_labels = [probs[i] for i in labels.general]
126
+ gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
127
  gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
128
 
129
  # Character labels, pick any where prediction confidence > threshold
130
+ if char_use_mcut:
131
+ char_array = np.array([probs[i][1] for i in labels.character])
132
+ char_threshold = round(mcut_threshold(char_array), 2)
133
+
134
  char_labels = [probs[i] for i in labels.character]
135
+ char_labels = dict([x for x in char_labels if x[1] > char_threshold])
136
  char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
137
 
138
  # Combine general and character labels, sort by confidence
 
143
  caption = ", ".join(combined_names)
144
  booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
145
 
146
+ return image, caption, booru, rating_labels, char_labels, char_threshold, gen_labels, gen_threshold
147
 
148
 
149
+ css = """
150
+ #gen_mcut, #char_mcut {
151
+ padding-top: var(--scale-3);
152
+ }
153
+ #gen_threshold.dimmed, #char_threshold.dimmed {
154
+ filter: brightness(75%);
155
+ }
156
+ """
157
+
158
+ with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo:
159
  with gr.Row(equal_height=False):
160
+ with gr.Column(min_width=720):
161
+ with gr.Group():
162
+ img_input = gr.Image(
163
+ label="Input",
164
+ type="pil",
165
+ image_mode="RGB",
166
+ sources=["upload", "clipboard"],
167
+ )
168
+ show_processed = gr.Checkbox(label="Show Preprocessed Image", value=False)
169
+ with gr.Row():
170
+ version = gr.Radio(
171
+ choices=list(MODEL_VARIANTS.keys()),
172
+ label="Model Version",
173
+ value="v3",
174
+ min_width=160,
175
+ scale=1,
176
+ ) # gen_threshold > div.wrap.hide
177
+ variant = gr.Radio(
178
+ choices=list(MODEL_VARIANTS[version.value].keys()),
179
+ label="Model Variant",
180
+ value="ConvNeXT",
181
+ min_width=560,
182
+ )
183
+ with gr.Group():
184
+ with gr.Row():
185
+ gen_threshold = gr.Slider(
186
+ minimum=0.0,
187
+ maximum=1.0,
188
+ value=0.35,
189
+ step=0.01,
190
+ label="General Tag Threshold",
191
+ scale=5,
192
+ elem_id="gen_threshold",
193
+ )
194
+ gen_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="gen_mcut")
195
+ with gr.Row():
196
+ char_threshold = gr.Slider(
197
+ minimum=0.0,
198
+ maximum=1.0,
199
+ value=0.85,
200
+ step=0.01,
201
+ label="Character Tag Threshold",
202
+ scale=5,
203
+ elem_id="char_threshold",
204
+ )
205
+ char_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="char_mcut")
206
  with gr.Row():
 
207
  clear = gr.ClearButton(
208
  components=[],
209
  variant="secondary",
210
  size="lg",
211
  )
212
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
213
+
214
+ with gr.Column(min_width=720):
215
+ img_output = gr.Image(
216
+ label="Preprocessed Image", type="pil", image_mode="RGB", scale=1, visible=False
217
+ )
 
 
 
 
 
218
  with gr.Group():
219
+ caption = gr.Textbox(label="Caption", show_copy_button=True)
220
+ tags = gr.Textbox(label="Tags", show_copy_button=True)
221
+ with gr.Group():
222
+ rating = gr.Label(label="Rating")
223
+ with gr.Group():
224
+ char_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False)
225
+ character = gr.Label(label="Character")
226
+ with gr.Group():
227
+ gen_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False)
228
+ general = gr.Label(label="General")
229
+
230
+ with gr.Row():
231
+ examples = [[imgpath, 0.35, mc, 0.85, mc] for mc in [False, True] for imgpath in example_images]
232
+
233
+ examples = gr.Examples(
234
+ examples=examples,
235
+ inputs=[img_input, gen_threshold, gen_mcut, char_threshold, char_mcut],
236
+ )
237
 
238
  # tell clear button which components to clear
239
+ clear.add([img_input, img_output, caption, rating, character, general])
240
+
241
+ def on_select_variant(evt: gr.SelectData, variant: str):
242
+ if evt.selected:
243
+ choices = list(MODEL_VARIANTS[variant])
244
+ return gr.update(choices=choices, value=choices[0])
245
+ return gr.update()
246
+
247
+ version.select(on_select_variant, inputs=[version], outputs=[variant])
248
 
249
  # show/hide processed image
250
+ def on_change_show(val: gr.Checkbox):
251
+ return gr.update(visible=val)
252
+
253
+ show_processed.select(on_change_show, inputs=[show_processed], outputs=[img_output])
254
+
255
+ # handle mcut thresholding (auto-calculate threshold from probs, disable slider)
256
+ def on_change_mcut(val: gr.Checkbox):
257
+ return (
258
+ gr.update(interactive=not val, elem_classes=["dimmed"] if val else []),
259
+ gr.update(visible=val),
260
+ )
261
 
262
+ gen_mcut.change(on_change_mcut, inputs=[gen_mcut], outputs=[gen_threshold, gen_mcut_out])
263
+ char_mcut.change(on_change_mcut, inputs=[char_mcut], outputs=[char_threshold, char_mcut_out])
264
 
265
  submit.click(
266
  predict,
267
+ inputs=[img_input, version, variant, gen_threshold, gen_mcut, char_threshold, char_mcut],
268
+ outputs=[img_output, caption, tags, rating, character, char_threshold, general, gen_threshold],
269
  api_name="predict",
270
  )
271
 
data/selected_tags.csv DELETED
The diff for this file is too large to render. See raw diff
 
tagger/common.py CHANGED
@@ -3,10 +3,12 @@ from dataclasses import asdict, dataclass
3
  from functools import lru_cache
4
  from os import PathLike
5
  from pathlib import Path
6
- from typing import Any
7
 
8
  import numpy as np
9
  import pandas as pd
 
 
10
  from PIL import Image
11
 
12
 
@@ -36,10 +38,36 @@ class ImageLabels(DictJsonMixin):
36
 
37
 
38
  @lru_cache(maxsize=5)
39
- def load_labels(csv_path: PathLike = "data/selected_tags.csv") -> LabelData:
40
- csv_path = Path(csv_path).resolve()
 
41
  if not csv_path.is_file():
42
- raise FileNotFoundError("No selected_tags.csv found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
45
  tag_data = LabelData(
@@ -101,3 +129,27 @@ def preprocess_image(
101
  image.thumbnail(size_px, Image.BICUBIC)
102
 
103
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from functools import lru_cache
4
  from os import PathLike
5
  from pathlib import Path
6
+ from typing import Any, Optional
7
 
8
  import numpy as np
9
  import pandas as pd
10
+ from huggingface_hub import hf_hub_download
11
+ from huggingface_hub.utils import HfHubHTTPError
12
  from PIL import Image
13
 
14
 
 
38
 
39
 
40
  @lru_cache(maxsize=5)
41
+ def load_labels(version: str = "v3", data_dir: PathLike = "./data") -> LabelData:
42
+ data_dir = Path(data_dir).resolve()
43
+ csv_path = data_dir.joinpath(f"selected_tags_{version}.csv")
44
  if not csv_path.is_file():
45
+ raise FileNotFoundError(f"{csv_path.name} not found in {data_dir}")
46
+
47
+ df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
48
+ tag_data = LabelData(
49
+ names=df["name"].tolist(),
50
+ rating=list(np.where(df["category"] == 9)[0]),
51
+ general=list(np.where(df["category"] == 0)[0]),
52
+ character=list(np.where(df["category"] == 4)[0]),
53
+ )
54
+
55
+ return tag_data
56
+
57
+
58
+ @lru_cache(maxsize=5)
59
+ def load_labels_hf(
60
+ repo_id: str,
61
+ revision: Optional[str] = None,
62
+ token: Optional[str] = None,
63
+ ) -> LabelData:
64
+ try:
65
+ csv_path = hf_hub_download(
66
+ repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
67
+ )
68
+ csv_path = Path(csv_path).resolve()
69
+ except HfHubHTTPError as e:
70
+ raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
71
 
72
  df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
73
  tag_data = LabelData(
 
129
  image.thumbnail(size_px, Image.BICUBIC)
130
 
131
  return image
132
+
133
+
134
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
135
+ kaomojis = [
136
+ "0_0",
137
+ "(o)_(o)",
138
+ "+_+",
139
+ "+_-",
140
+ "._.",
141
+ "<o>_<o>",
142
+ "<|>_<|>",
143
+ "=_=",
144
+ ">_<",
145
+ "3_3",
146
+ "6_9",
147
+ ">_o",
148
+ "@_@",
149
+ "^_^",
150
+ "o_o",
151
+ "u_u",
152
+ "x_x",
153
+ "|_|",
154
+ "||_||",
155
+ ]