AdamOswald1 commited on
Commit
7ad2d01
1 Parent(s): ba96dd1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +459 -0
app.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import shutil
4
+ import sqlite3
5
+ import subprocess
6
+ import sys
7
+ sys.path.append('src/blip')
8
+ sys.path.append('src/clip')
9
+ import clip
10
+ import hashlib
11
+ import math
12
+ import numpy as np
13
+ import pickle
14
+ import torchvision.transforms as T
15
+ import torchvision.transforms.functional as TF
16
+ import requests
17
+ import wget
18
+ import gradio as grad, random, re
19
+ import torch
20
+ import os
21
+ import utils
22
+ import html
23
+ import re
24
+ import base64
25
+ import subprocess
26
+ import argparse
27
+ import logging
28
+ import streamlit as st
29
+ import pandas as pd
30
+ import datasets
31
+ import yaml
32
+ import textwrap
33
+ import tornado
34
+ import time
35
+ import cv2 as cv
36
+ from torch import autocast
37
+ from diffusers import StableDiffusionPipeline
38
+ from transformers import pipeline, set_seed
39
+ from huggingface_hub import HfApi
40
+ from huggingface_hub import hf_hub_download
41
+ from transformers import CLIPTextModel, CLIPTokenizer
42
+ from diffusers import AutoencoderKL, UNet2DConditionModel
43
+ from diffusers import StableDiffusionImg2ImgPipeline
44
+ from PIL import Image
45
+ from datasets import load_dataset
46
+ from share_btn import community_icon_html, loading_icon_html, share_js
47
+ from io import BytesIO
48
+ from models.blip import blip_decoder
49
+ from torch import nn
50
+ from torch.nn import functional as F
51
+ from tqdm import tqdm
52
+ from pathlib import Path
53
+ from flask import Flask, request, jsonify, g
54
+ from flask_expects_json import expects_json
55
+ from flask_cors import CORS
56
+ from huggingface_hub import Repository
57
+ from flask_apscheduler import APScheduler
58
+ from jsonschema import ValidationError
59
+ from os import mkdir
60
+ from os.path import isdir
61
+ from colorthief import ColorThief
62
+ from data_measurements.dataset_statistics import DatasetStatisticsCacheClass as dmt_cls
63
+ from utils import dataset_utils
64
+ from utils import streamlit_utils as st_utils
65
+ from dataclasses import asdict
66
+ from .transfer import transfer_color
67
+ from .utils import convert_bytes_to_pil
68
+ from diffusers import DiffusionPipeline
69
+ from huggingface_hub.inference_api import InferenceApi
70
+ from huggingface_hub import login
71
+ #from torch import autocast
72
+ #from diffusers import StableDiffusionPipeline
73
+ #from io import BytesIO
74
+ #import base64
75
+ #import torch
76
+
77
+ is_colab = utils.is_google_colab()
78
+
79
+ class Model:
80
+ def __init__(self, name, path, prefix):
81
+ self.name = name
82
+ self.path = path
83
+ self.prefix = prefix
84
+ self.pipe_t2i = None
85
+ self.pipe_i2i = None
86
+
87
+ models = [
88
+ Model("Custom model", "", ""),
89
+ Model("Arcane", "nitrosocke/Arcane-Diffusion", "arcane style"),
90
+ Model("Archer", "nitrosocke/archer-diffusion", "archer style"),
91
+ Model("Elden Ring", "nitrosocke/elden-ring-diffusion", "elden ring style"),
92
+ Model("Spider-Verse", "nitrosocke/spider-verse-diffusion", "spiderverse style"),
93
+ Model("Modern Disney", "nitrosocke/modern-disney-diffusion", "modern disney style"),
94
+ Model("Classic Disney", "nitrosocke/classic-anim-diffusion", "classic disney style"),
95
+ Model("Waifu", "hakurei/waifu-diffusion", ""),
96
+ Model("Pokémon", "lambdalabs/sd-pokemon-diffusers", "pokemon style"),
97
+ Model("Pokémon", "svjack/Stable-Diffusion-Pokemon-en", "pokemon style"),
98
+ Model("Pony Diffusion", "AstraliteHeart/pony-diffusion", "pony style"),
99
+ Model("Robo Diffusion", "nousr/robo-diffusion", "robo style"),
100
+ Model("Cyberpunk Anime", "DGSpitzer/Cyberpunk-Anime-Diffusion, flax/Cyberpunk-Anime-Diffusion", "cyberpunk style"),
101
+ Model("Cyberpunk Anime", "DGSpitzer/Cyberpunk-Anime-Diffusion", "cyberpunk style"),
102
+ Model("Cyberpunk Anime", "flax/Cyberpunk-Anime-Diffusion", "cyberpunk style"),
103
+ Model("Cyberware", "Eppinette/Cyberware", "cyberware"),
104
+ Model("Tron Legacy", "dallinmackay/Tron-Legacy-diffusion", "trnlgcy"),
105
+ Model("Waifu", "flax/waifu-diffusion", ""),
106
+ Model("Dark Souls", "Guizmus/DarkSoulsDiffusion", "dark souls style"),
107
+ Model("Waifu", "technillogue/waifu-diffusion", ""),
108
+ Model("Ouroborus", "Eppinette/Ouroboros", "m_ouroboros style"),
109
+ Model("Ouroborus alt", "Eppinette/Ouroboros", "m_ouroboros"),
110
+ Model("Waifu", "Eppinette/Mona", "Mona"),
111
+ Model("Waifu", "Eppinette/Mona", "Mona Woman"),
112
+ Model("Waifu", "Eppinette/Mona", "Mona Genshin"),
113
+ Model("Genshin", "Eppinette/Mona", "Mona"),
114
+ Model("Genshin", "Eppinette/Mona", "Mona Woman"),
115
+ Model("Genshin", "Eppinette/Mona", "Mona Genshin"),
116
+ Model("Space Machine", "rabidgremlin/sd-db-epic-space-machine", "EpicSpaceMachine"),
117
+ Model("Spacecraft", "rabidgremlin/sd-db-epic-space-machine", "EpicSpaceMachine"),
118
+ Model("TARDIS", "Guizmus/Tardisfusion", "Classic Tardis style"),
119
+ Model("TARDIS", "Guizmus/Tardisfusion", "Modern Tardis style"),
120
+ Model("TARDIS", "Guizmus/Tardisfusion", "Tardis Box style"),
121
+ Model("Spacecraft", "Guizmus/Tardisfusion", "Classic Tardis style"),
122
+ Model("Spacecraft", "Guizmus/Tardisfusion", "Modern Tardis style"),
123
+ Model("Spacecraft", "Guizmus/Tardisfusion", "Tardis Box style"),
124
+ Model("CLIP", "EleutherAI/clip-guided-diffusion", "CLIP"),
125
+ Model("Face Swap", "felixrosberg/face-swap", "faceswap"),
126
+ Model("Face Swap", "felixrosberg/face-swap", "faceswap with"),
127
+ Model("Face Swap", "felixrosberg/face-swap", "faceswapped"),
128
+ Model("Face Swap", "felixrosberg/face-swap", "faceswapped with"),
129
+ Model("Face Swap", "felixrosberg/face-swap", "face on"),
130
+ Model("Waifu", "Fampai/lumine_genshin_impact", "lumine_genshin"),
131
+ Model("Waifu", "Fampai/lumine_genshin_impact", "lumine"),
132
+ Model("Waifu", "Fampai/lumine_genshin_impact", "Lumine Genshin"),
133
+ Model("Waifu", "Fampai/lumine_genshin_impact", "Lumine_genshin"),
134
+ Model("Waifu", "Fampai/lumine_genshin_impact", "Lumine_Genshin"),
135
+ Model("Waifu", "Fampai/lumine_genshin_impact", "Lumine"),
136
+ Model("Genshin", "Fampai/lumine_genshin_impact", "Lumine_genshin"),
137
+ Model("Genshin", "Fampai/lumine_genshin_impact", "Lumine_Genshin"),
138
+ Model("Genshin", "Fampai/lumine_genshin_impact", "Lumine"),
139
+ Model("Genshin", "Fampai/lumine_genshin_impact", "Lumine Genshin"),
140
+ Model("Genshin", "Fampai/lumine_genshin_impact", "lumine"),
141
+ Model("Genshin", "sd-concepts-library/ganyu-genshin-impact", "Ganyu"),
142
+ Model("Genshin", "sd-concepts-library/ganyu-genshin-impact", "Ganyu Woman"),
143
+ Model("Genshin", "sd-concepts-library/ganyu-genshin-impact", "Ganyu Genshin"),
144
+ Model("Waifu", "sd-concepts-library/ganyu-genshin-impact", "Ganyu"),
145
+ Model("Waifu", "sd-concepts-library/ganyu-genshin-impact", "Ganyu Woman"),
146
+ Model("Waifu", "sd-concepts-library/ganyu-genshin-impact", "Ganyu Genshin"),
147
+ Model("Waifu", "Fampai/raiden_genshin_impact", "raiden_ei"),
148
+ Model("Waifu", "Fampai/raiden_genshin_impact", "Raiden Ei"),
149
+ Model("Waifu", "Fampai/raiden_genshin_impact", "Ei Genshin"),
150
+ Model("Waifu", "Fampai/raiden_genshin_impact", "Raiden Genshin"),
151
+ Model("Waifu", "Fampai/raiden_genshin_impact", "Raiden_Genshin"),
152
+ Model("Waifu", "Fampai/raiden_genshin_impact", "Ei_Genshin"),
153
+ Model("Waifu", "Fampai/raiden_genshin_impact", "Raiden"),
154
+ Model("Waifu", "Fampai/raiden_genshin_impact", "Ei"),
155
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Raiden Ei"),
156
+ Model("Genshin", "Fampai/raiden_genshin_impact", "raiden_ei"),
157
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Raiden"),
158
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Raiden Genshin"),
159
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Ei Genshin"),
160
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Raiden_Genshin"),
161
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Ei_Genshin"),
162
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Ei"),
163
+ Model("Waifu", "Fampai/hutao_genshin_impact", "hutao_genshin"),
164
+ Model("Waifu", "Fampai/hutao_genshin_impact", "HuTao_Genshin"),
165
+ Model("Waifu", "Fampai/hutao_genshin_impact", "HuTao Genshin"),
166
+ Model("Waifu", "Fampai/hutao_genshin_impact", "HuTao"),
167
+ Model("Waifu", "Fampai/hutao_genshin_impact", "hutao_genshin"),
168
+ Model("Genshin", "Fampai/hutao_genshin_impact", "hutao_genshin"),
169
+ Model("Genshin", "Fampai/hutao_genshin_impact", "HuTao_Genshin"),
170
+ Model("Genshin", "Fampai/hutao_genshin_impact", "HuTao Genshin"),
171
+ Model("Genshin", "Fampai/hutao_genshin_impact", "HuTao"),
172
+ Model("Genshin", "Fampai/lumine_genshin_impact, Eppinette/Mona, sd-concepts-library/ganyu-genshin-impact, Fampai/raiden_genshin_impact, Fampai/hutao_genshin_impact", "Female"),
173
+ Model("Genshin", "Fampai/lumine_genshin_impact, Eppinette/Mona, sd-concepts-library/ganyu-genshin-impact, Fampai/raiden_genshin_impact, Fampai/hutao_genshin_impact", "female"),
174
+ Model("Genshin", "Fampai/lumine_genshin_impact, Eppinette/Mona, sd-concepts-library/ganyu-genshin-impact, Fampai/raiden_genshin_impact, Fampai/hutao_genshin_impact", "Woman"),
175
+ Model("Genshin", "Fampai/lumine_genshin_impact, Eppinette/Mona, sd-concepts-library/ganyu-genshin-impact, Fampai/raiden_genshin_impact, Fampai/hutao_genshin_impact", "woman"),
176
+ Model("Genshin", "Fampai/lumine_genshin_impact, Eppinette/Mona, sd-concepts-library/ganyu-genshin-impact, Fampai/raiden_genshin_impact, Fampai/hutao_genshin_impact", "Girl"),
177
+ Model("Genshin", "Fampai/lumine_genshin_impact, Eppinette/Mona, sd-concepts-library/ganyu-genshin-impact, Fampai/raiden_genshin_impact, Fampai/hutao_genshin_impact", "girl"),
178
+ Model("Genshin", "Fampai/lumine_genshin_impact", "Female"),
179
+ Model("Genshin", "Fampai/lumine_genshin_impact", "female"),
180
+ Model("Genshin", "Fampai/lumine_genshin_impact", "Woman"),
181
+ Model("Genshin", "Fampai/lumine_genshin_impact", "woman"),
182
+ Model("Genshin", "Fampai/lumine_genshin_impact", "Girl"),
183
+ Model("Genshin", "Fampai/lumine_genshin_impact", "girl"),
184
+ Model("Genshin", "Eppinette/Mona", "Female"),
185
+ Model("Genshin", "Eppinette/Mona", "female"),
186
+ Model("Genshin", "Eppinette/Mona", "Woman"),
187
+ Model("Genshin", "Eppinette/Mona", "woman"),
188
+ Model("Genshin", "Eppinette/Mona", "Girl"),
189
+ Model("Genshin", "Eppinette/Mona", "girl"),
190
+ Model("Genshin", "sd-concepts-library/ganyu-genshin-impact", "Female"),
191
+ Model("Genshin", "sd-concepts-library/ganyu-genshin-impact", "female"),
192
+ Model("Genshin", "sd-concepts-library/ganyu-genshin-impact", "Woman"),
193
+ Model("Genshin", "sd-concepts-library/ganyu-genshin-impact", "woman"),
194
+ Model("Genshin", "sd-concepts-library/ganyu-genshin-impact", "Girl"),
195
+ Model("Genshin", "sd-concepts-library/ganyu-genshin-impact", "girl"),
196
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Female"),
197
+ Model("Genshin", "Fampai/raiden_genshin_impact", "female"),
198
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Woman"),
199
+ Model("Genshin", "Fampai/raiden_genshin_impact", "woman"),
200
+ Model("Genshin", "Fampai/raiden_genshin_impact", "Girl"),
201
+ Model("Genshin", "Fampai/raiden_genshin_impact", "girl"),
202
+ Model("Genshin", "Fampai/hutao_genshin_impact", "Female"),
203
+ Model("Genshin", "Fampai/hutao_genshin_impact", "female"),
204
+ Model("Genshin", "Fampai/hutao_genshin_impact", "Woman"),
205
+ Model("Genshin", "Fampai/hutao_genshin_impact", "woman"),
206
+ Model("Genshin", "Fampai/hutao_genshin_impact", "Girl"),
207
+ Model("Genshin", "Fampai/hutao_genshin_impact", "girl"),
208
+ Model("Waifu", "crumb/genshin-stable-inversion, yuiqena/GenshinImpact, Fampai/lumine_genshin_impact, Eppinette/Mona, sd-concepts-library/ganyu-genshin-impact, Fampai/raiden_genshin_impact, Fampai/hutao_genshin_impact", "Genshin"),
209
+ Model("Waifu", "crumb/genshin-stable-inversion, yuiqena/GenshinImpact, Fampai/lumine_genshin_impact, Eppinette/Mona, sd-concepts-library/ganyu-genshin-impact, Fampai/raiden_genshin_impact, Fampai/hutao_genshin_impact", "Genshin Impact"),
210
+ Model("Genshin", "crumb/genshin-stable-inversion, yuiqena/GenshinImpact, Fampai/lumine_genshin_impact, Eppinette/Mona, sd-concepts-library/ganyu-genshin-impact, Fampai/raiden_genshin_impact, Fampai/hutao_genshin_impact", ""),
211
+ Model("Waifu", "crumb/genshin-stable-inversion", "Genshin"),
212
+ Model("Waifu", "crumb/genshin-stable-inversion", "Genshin Impact"),
213
+ Model("Genshin", "crumb/genshin-stable-inversion", ""),
214
+ Model("Waifu", "yuiqena/GenshinImpact", "Genshin"),
215
+ Model("Waifu", "yuiqena/GenshinImpact", "Genshin Impact"),
216
+ Model("Genshin", "yuiqena/GenshinImpact", ""),
217
+ Model("Waifu", "hakurei/waifu-diffusion, flax/waifu-diffusion, technillogue/waifu-diffusion, Guizmus/AnimeChanStyle, katakana/2D-Mix", ""),
218
+ Model("Pokémon", "lambdalabs/sd-pokemon-diffusers, svjack/Stable-Diffusion-Pokemon-en", "pokemon style"),
219
+ Model("Pokémon", "lambdalabs/sd-pokemon-diffusers, svjack/Stable-Diffusion-Pokemon-en", ""),
220
+ Model("Test", "AdamoOswald1/Idk", ""),
221
+ Model("Anime", "Guizmus/AnimeChanStyle", "AnimeChan Style"),
222
+ Model("Genshin", "Guizmus/AnimeChanStyle", "AnimeChan Style"),
223
+ Model("Waifu", "Guizmus/AnimeChanStyle", "AnimeChan Style"),
224
+ Model("Waifu", "Guizmus/AnimeChanStyle", "Genshin"),
225
+ Model("Waifu", "Guizmus/AnimeChanStyle", "Genshin Impact"),
226
+ Model("Genshin", "Guizmus/AnimeChanStyle", ""),
227
+ Model("Anime", "Guizmus/AnimeChanStyle", ""),
228
+ Model("Waifu", "Guizmus/AnimeChanStyle", ""),
229
+ Model("Anime", "Guizmus/AnimeChanStyle, katakana/2D-Mix", ""),
230
+ Model("Anime", "katakana/2D-Mix", "2D-Mix"),
231
+ Model("Genshin", "katakana/2D-Mix", "2D-Mix"),
232
+ Model("Waifu", "katakana/2D-Mix", "2D-Mix"),
233
+ Model("Waifu", "katakana/2D-Mix", "Genshin"),
234
+ Model("Waifu", "katakana/2D-Mix", "Genshin Impact"),
235
+ Model("Genshin", "katakana/2D-Mix", ""),
236
+ Model("Anime", "katakana/2D-Mix", ""),
237
+ Model("Waifu", "katakana/2D-Mix", ""),
238
+ Model("Beeple", "riccardogiorato/beeple-diffusion", "beeple style "),
239
+ Model("Avatar", "riccardogiorato/avatar-diffusion", "avatartwow style "),
240
+ Model("Poolsuite", "prompthero/poolsuite", "poolsuite style ")
241
+ ]
242
+ # Model("Beksinski", "s3nh/beksinski-style-stable-diffusion", "beksinski style "),
243
+ # Model("Guohua", "Langboat/Guohua-Diffusion", "guohua style ")
244
+
245
+ scheduler = DPMSolverMultistepScheduler(
246
+ beta_start=0.00085,
247
+ beta_end=0.012,
248
+ beta_schedule="scaled_linear",
249
+ num_train_timesteps=1000,
250
+ trained_betas=None,
251
+ predict_epsilon=True,
252
+ thresholding=False,
253
+ algorithm_type="dpmsolver++",
254
+ solver_type="midpoint",
255
+ lower_order_final=True,
256
+ )
257
+
258
+ custom_model = None
259
+ if is_colab:
260
+ models.insert(0, Model("Custom model", "", ""))
261
+ custom_model = models[0]
262
+
263
+ last_mode = "txt2img"
264
+ current_model = models[1] if is_colab else models[0]
265
+ current_model_path = current_model.path
266
+
267
+ if is_colab:
268
+ pipe = StableDiffusionPipeline.from_pretrained(current_model.path, torch_dtype=torch.float16, scheduler=scheduler)
269
+
270
+ else: # download all models
271
+ vae = AutoencoderKL.from_pretrained(current_model.path, subfolder="vae", torch_dtype=torch.float16)
272
+ for model in models:
273
+ try:
274
+ unet = UNet2DConditionModel.from_pretrained(model.path, subfolder="unet", torch_dtype=torch.float16)
275
+ model.pipe_t2i = StableDiffusionPipeline.from_pretrained(model.path, unet=unet, vae=vae, torch_dtype=torch.float16, scheduler=scheduler)
276
+ model.pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(model.path, unet=unet, vae=vae, torch_dtype=torch.float16, scheduler=scheduler)
277
+ except:
278
+ models.remove(model)
279
+ pipe = models[0].pipe_t2i
280
+
281
+ if torch.cuda.is_available():
282
+ pipe = pipe.to("cuda")
283
+
284
+ device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
285
+
286
+ def custom_model_changed(path):
287
+ models[0].path = path
288
+ global current_model
289
+ current_model = models[0]
290
+
291
+ def inference(model_name, prompt, guidance, steps, width=512, height=512, seed=0, img=None, strength=0.5, neg_prompt=""):
292
+
293
+ global current_model
294
+ for model in models:
295
+ if model.name == model_name:
296
+ current_model = model
297
+ model_path = current_model.path
298
+
299
+ generator = torch.Generator('cuda').manual_seed(seed) if seed != 0 else None
300
+
301
+ if img is not None:
302
+ return img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height, generator)
303
+ else:
304
+ return txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator)
305
+
306
+ def txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator=None):
307
+
308
+ global last_mode
309
+ global pipe
310
+ global current_model_path
311
+ if model_path != current_model_path or last_mode != "txt2img":
312
+ current_model_path = model_path
313
+
314
+ if is_colab or current_model == custom_model:
315
+ pipe = StableDiffusionPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16, scheduler=scheduler)
316
+ else:
317
+ pipe.to("cpu")
318
+ pipe = current_model.pipe_t2i
319
+
320
+ if torch.cuda.is_available():
321
+ pipe = pipe.to("cuda")
322
+ last_mode = "txt2img"
323
+
324
+ prompt = current_model.prefix + prompt
325
+ result = pipe(
326
+ prompt,
327
+ negative_prompt = neg_prompt,
328
+ # num_images_per_prompt=n_images,
329
+ num_inference_steps = int(steps),
330
+ guidance_scale = guidance,
331
+ width = width,
332
+ height = height,
333
+ generator = generator)
334
+
335
+ return replace_nsfw_images(result)
336
+
337
+ def img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height, generator=None):
338
+
339
+ global last_mode
340
+ global pipe
341
+ global current_model_path
342
+ if model_path != current_model_path or last_mode != "img2img":
343
+ current_model_path = model_path
344
+
345
+ if is_colab or current_model == custom_model:
346
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16, scheduler=scheduler)
347
+ else:
348
+ pipe.to("cpu")
349
+ pipe = current_model.pipe_i2i
350
+
351
+ if torch.cuda.is_available():
352
+ pipe = pipe.to("cuda")
353
+ last_mode = "img2img"
354
+
355
+ prompt = current_model.prefix + prompt
356
+ ratio = min(height / img.height, width / img.width)
357
+ img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
358
+ result = pipe(
359
+ prompt,
360
+ negative_prompt = neg_prompt,
361
+ # num_images_per_prompt=n_images,
362
+ init_image = img,
363
+ num_inference_steps = int(steps),
364
+ strength = strength,
365
+ guidance_scale = guidance,
366
+ width = width,
367
+ height = height,
368
+ generator = generator)
369
+
370
+ return replace_nsfw_images(result)
371
+
372
+ def replace_nsfw_images(results):
373
+ for i in range(len(results.images)):
374
+ if results.nsfw_content_detected[i]:
375
+ results.images[i] = Image.open("nsfw.png")
376
+ return results.images[0]
377
+
378
+ css = """.finetuned-diffusion-div div{display:inline-flex;align-items:center;gap:.8rem;font-size:1.75rem}.finetuned-diffusion-div div h1{font-weight:900;margin-bottom:7px}.finetuned-diffusion-div p{margin-bottom:10px;font-size:94%}a{text-decoration:underline}.tabs{margin-top:0;margin-bottom:0}#gallery{min-height:20rem}
379
+ """
380
+ with gr.Blocks(css=css) as demo:
381
+ gr.HTML(
382
+ f"""
383
+ <div class="finetuned-diffusion-div">
384
+ <div>
385
+ <h1>Playground Diffusion</h1>
386
+ </div>
387
+ <p>
388
+ Demo for multiple fine-tuned Stable Diffusion models, trained on different styles: <br>
389
+ <a href="https://huggingface.co/riccardogiorato/avatar-diffusion">Avatar</a>,<br/>
390
+ <a href="https://huggingface.co/riccardogiorato/beeple-diffusion">Beeple</a>,<br/>
391
+ <a href="https://huggingface.co/s3nh/beksinski-style-stable-diffusion">Beksinski</a>,<br/>
392
+ Diffusers 🧨 SD model hosted on HuggingFace 🤗.
393
+ Running on <b>{device}</b>{(" in a <b>Google Colab</b>." if is_colab else "")}
394
+ </p>
395
+ </div>
396
+ """
397
+ )
398
+ with gr.Row():
399
+
400
+ with gr.Column(scale=55):
401
+ with gr.Group():
402
+ model_name = gr.Dropdown(label="Model", choices=[m.name for m in models], value=current_model.name)
403
+ with gr.Box(visible=False) as custom_model_group:
404
+ custom_model_path = gr.Textbox(label="Custom model path", placeholder="Path to model, e.g. nitrosocke/Arcane-Diffusion", interactive=True)
405
+ gr.HTML("<div><font size='2'>Custom models have to be downloaded first, so give it some time.</font></div>")
406
+
407
+ with gr.Row():
408
+ prompt = gr.Textbox(label="Prompt", show_label=False, max_lines=2,placeholder="Enter prompt. Style applied automatically").style(container=False)
409
+ generate = gr.Button(value="Generate").style(rounded=(False, True, True, False))
410
+
411
+
412
+ image_out = gr.Image(height=512)
413
+ # gallery = gr.Gallery(
414
+ # label="Generated images", show_label=False, elem_id="gallery"
415
+ # ).style(grid=[1], height="auto")
416
+
417
+ with gr.Column(scale=45):
418
+ with gr.Tab("Options"):
419
+ with gr.Group():
420
+ neg_prompt = gr.Textbox(label="Negative prompt", placeholder="What to exclude from the image")
421
+
422
+ # n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
423
+
424
+ with gr.Row():
425
+ guidance = gr.Slider(label="Guidance scale", value=7.5, maximum=15)
426
+ steps = gr.Slider(label="Steps", value=25, minimum=2, maximum=75, step=1)
427
+
428
+ with gr.Row():
429
+ width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
430
+ height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)
431
+
432
+ seed = gr.Slider(0, 2147483647, label='Seed (0 = random)', value=0, step=1)
433
+
434
+ with gr.Tab("Image to image"):
435
+ with gr.Group():
436
+ image = gr.Image(label="Image", height=256, tool="editor", type="pil")
437
+ strength = gr.Slider(label="Transformation strength", minimum=0, maximum=1, step=0.01, value=0.5)
438
+
439
+ if is_colab:
440
+ model_name.change(lambda x: gr.update(visible = x == models[0].name), inputs=model_name, outputs=custom_model_group)
441
+ custom_model_path.change(custom_model_changed, inputs=custom_model_path, outputs=None)
442
+ # n_images.change(lambda n: gr.Gallery().style(grid=[2 if n > 1 else 1], height="auto"), inputs=n_images, outputs=gallery)
443
+
444
+ inputs = [model_name, prompt, guidance, steps, width, height, seed, image, strength, neg_prompt]
445
+ prompt.submit(inference, inputs=inputs, outputs=image_out)
446
+ generate.click(inference, inputs=inputs, outputs=image_out)
447
+
448
+ ex = gr.Examples([
449
+ [models[1].name, "Neon techno-magic robot with spear pierces an ancient beast, hyperrealism, no blur, 4k resolution, ultra detailed", 7.5, 50],
450
+ [models[1].name, "halfturn portrait of a big crystal face of a beautiful abstract ancient Egyptian elderly shaman woman, made of iridescent golden crystals, half - turn, bottom view, ominous, intricate, studio, art by anthony macbain and greg rutkowski and alphonse mucha, concept art, 4k, sharp focus", 7.5, 25],
451
+ ], [model_name, prompt, guidance, steps, seed], image_out, inference, cache_examples=False)
452
+
453
+ gr.HTML("""
454
+ <p>Models by <a href="https://huggingface.co/riccardogiorato">@riccardogiorato</a><br></p>
455
+ """)
456
+
457
+ if not is_colab:
458
+ demo.queue(concurrency_count=1)
459
+ demo.launch(debug=is_colab, share=is_colab)