|
import gradio as gr |
|
import os |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from insightface.app import FaceAnalysis |
|
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, StableDiffusionXLPipeline |
|
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus |
|
import argparse |
|
import random |
|
from insightface.utils import face_align |
|
from pyngrok import ngrok |
|
import threading |
|
import time |
|
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlusXL |
|
import hashlib |
|
from datetime import datetime |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--share", action="store_true", help="Enable Gradio share option") |
|
parser.add_argument("--num_images", type=int, default=1, help="Number of images to generate") |
|
parser.add_argument("--cache_limit", type=int, default=1, help="Limit for model cache") |
|
parser.add_argument("--ngrok_token", type=str, default=None, help="ngrok authtoken for tunneling") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
static_model_names = [ |
|
"SG161222/Realistic_Vision_V6.0_B1_noVAE", |
|
"stablediffusionapi/rev-animated-v122-eol", |
|
"Lykon/DreamShaper", |
|
"stablediffusionapi/toonyou", |
|
"stablediffusionapi/real-cartoon-3d", |
|
"KBlueLeaf/kohaku-v2.1", |
|
"nitrosocke/Ghibli-Diffusion", |
|
"Linaqruf/anything-v3.0", |
|
"jinaai/flat-2d-animerge", |
|
"stablediffusionapi/realcartoon3d", |
|
"stablediffusionapi/disney-pixar-cartoon", |
|
"stablediffusionapi/pastel-mix-stylized-anime", |
|
"stablediffusionapi/anything-v5", |
|
"SG161222/Realistic_Vision_V2.0", |
|
"SG161222/Realistic_Vision_V4.0_noVAE", |
|
"SG161222/Realistic_Vision_V5.1_noVAE", |
|
"stablediffusionapi/anime-illust-diffusion-xl", |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
|
|
] |
|
|
|
|
|
model_cache = {} |
|
max_cache_size = args.cache_limit |
|
|
|
embeddings_cache = {} |
|
|
|
def get_image_hash(image): |
|
image_bytes = image.tobytes() |
|
return hashlib.sha256(image_bytes).hexdigest() |
|
|
|
def convert_model(checkpoint_path, output_path, isSDXL): |
|
try: |
|
if isSDXL: |
|
pipe = StableDiffusionXLPipeline.from_single_file(checkpoint_path) |
|
pipe.save_pretrained(output_path) |
|
else: |
|
pipe = StableDiffusionPipeline.from_single_file(checkpoint_path) |
|
pipe.save_pretrained(output_path) |
|
return f"Model converted and saved to {output_path}" |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
def load_model(model_name, isSDXL): |
|
if model_name in model_cache: |
|
return model_cache[model_name] |
|
print(f"loading model {model_name}") |
|
|
|
if len(model_cache) >= max_cache_size: |
|
model_cache.pop(next(iter(model_cache))) |
|
|
|
device = "cuda" |
|
noise_scheduler = DDIMScheduler( |
|
num_train_timesteps=1000, |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
clip_sample=False, |
|
set_alpha_to_one=False, |
|
steps_offset=1, |
|
) |
|
vae_model_path = "stabilityai/sd-vae-ft-mse" |
|
if isSDXL: |
|
vae_model_path = "stabilityai/sdxl-vae" |
|
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16) |
|
|
|
if isSDXL: |
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
vae=vae, |
|
scheduler=noise_scheduler, |
|
add_watermarker=False, |
|
).to(device) |
|
else: |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
scheduler=noise_scheduler, |
|
vae=vae, |
|
feature_extractor=None, |
|
safety_checker=None |
|
).to(device) |
|
|
|
if isSDXL: |
|
image_encoder_path = "h94/IP-Adapter/models/image_encoder" |
|
ip_ckpt = "adapters/ip-adapter-faceid-plusv2_sdxl.bin" |
|
ip_model = IPAdapterFaceIDPlusXL(pipe,image_encoder_path, ip_ckpt, device) |
|
else: |
|
image_encoder_path = "h94/IP-Adapter/models/image_encoder" |
|
ip_ckpt = "adapters/ip-adapter-faceid-plusv2_sd15.bin" |
|
ip_model = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_ckpt, device) |
|
|
|
model_cache[model_name] = ip_model |
|
return ip_model |
|
|
|
|
|
def generate_image(input_image, positive_prompt, negative_prompt, width, height, model_name, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale, custom_model_path, isSDXL,cfg): |
|
saved_images = [] |
|
if custom_model_path: |
|
model_name = custom_model_path |
|
|
|
ip_model = load_model(model_name, isSDXL) |
|
|
|
|
|
input_image = input_image.convert("RGB") |
|
input_image_cv2 = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) |
|
image_hash = get_image_hash(input_image) |
|
|
|
|
|
if image_hash in embeddings_cache: |
|
faceid_embeds, face_image = embeddings_cache[image_hash] |
|
else: |
|
app = FaceAnalysis( |
|
name="buffalo_l", providers=["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
) |
|
app.prepare(ctx_id=0, det_size=(640, 640)) |
|
faces = app.get(input_image_cv2) |
|
if not faces: |
|
raise ValueError("No faces found in the image.") |
|
|
|
faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) |
|
face_image = face_align.norm_crop(input_image_cv2, landmark=faces[0].kps, image_size=224) |
|
|
|
embeddings_cache[image_hash] = (faceid_embeds, face_image) |
|
|
|
for image_index in range(num_images): |
|
if randomize_seed or image_index > 0: |
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
|
|
generated_images = ip_model.generate( |
|
prompt=positive_prompt, |
|
negative_prompt=negative_prompt, |
|
faceid_embeds=faceid_embeds, |
|
face_image=face_image, |
|
num_samples=batch_size, |
|
shortcut=enable_shortcut, |
|
s_scale=s_scale, |
|
width=width, |
|
height=height, |
|
guidance_scale=cfg, |
|
num_inference_steps=num_inference_steps, |
|
seed=seed, |
|
) |
|
|
|
|
|
outputs_dir = "outputs" |
|
if not os.path.exists(outputs_dir): |
|
os.makedirs(outputs_dir) |
|
for i, img in enumerate(generated_images, start=1): |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
image_path = os.path.join(outputs_dir, f"{timestamp}_image_{len(os.listdir(outputs_dir)) + i}.png") |
|
img.save(image_path) |
|
saved_images.append(image_path) |
|
|
|
return saved_images, f"Saved images: {', '.join(saved_images)}", seed |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Developed by SECourses - only distributed on https://www.patreon.com/posts/95759342") |
|
with gr.Row(): |
|
input_image = gr.Image(type="pil") |
|
generate_btn = gr.Button("Generate") |
|
with gr.Row(): |
|
width = gr.Number(value=512, label="Width") |
|
height = gr.Number(value=768, label="Height") |
|
cfg = gr.Number(value=7.5, label="CFG") |
|
with gr.Row(): |
|
num_inference_steps = gr.Number(value=30, label="Number of Inference Steps", step=1, minimum=10, maximum=100) |
|
seed = gr.Number(value=2023, label="Seed") |
|
randomize_seed = gr.Checkbox(value=True, label="Randomize Seed") |
|
with gr.Row(): |
|
num_images = gr.Number(value=args.num_images, label="Number of Images to Generate", step=1, minimum=1) |
|
batch_size = gr.Number(value=1, label="Batch Size", step=1) |
|
with gr.Row(): |
|
isSDXL = gr.Checkbox(value=False, label="Activate SDXL") |
|
enable_shortcut = gr.Checkbox(value=True, label="Enable Shortcut") |
|
s_scale = gr.Number(value=1.0, label="Scale Factor (s_scale)", step=0.1, minimum=0.5, maximum=4.0) |
|
with gr.Row(): |
|
positive_prompt = gr.Textbox(label="Positive Prompt") |
|
negative_prompt = gr.Textbox(label="Negative Prompt") |
|
with gr.Row(): |
|
model_selector = gr.Dropdown(label="Select Model", choices=static_model_names, value=static_model_names[0]) |
|
custom_model_path = gr.Textbox(label="Custom Model Path (Optional)") |
|
|
|
with gr.Column(): |
|
output_gallery = gr.Gallery(label="Generated Images") |
|
output_text = gr.Textbox(label="Output Info") |
|
display_seed = gr.Textbox(label="Used Seed", interactive=False) |
|
|
|
with gr.Row(): |
|
checkpoint_path_input = gr.Textbox(label="Enter Checkpoint File Path .e.g G:\model\model.safetensors", ) |
|
output_path_input = gr.Textbox(label="Enter Output Folder Path, e.g. G:\model\model_diffusers") |
|
convert_btn = gr.Button("Convert Model") |
|
|
|
generate_btn.click( |
|
generate_image, |
|
inputs=[input_image, positive_prompt, negative_prompt, width, height, model_selector, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale, custom_model_path, isSDXL,cfg], |
|
outputs=[output_gallery, output_text, display_seed] |
|
) |
|
|
|
convert_btn.click( |
|
convert_model, |
|
inputs=[checkpoint_path_input, output_path_input, isSDXL], |
|
outputs=[gr.Text(label="Conversion Status")], |
|
) |
|
|
|
|
|
def start_ngrok(): |
|
print("Starting ngrok...") |
|
time.sleep(10) |
|
ngrok.set_auth_token(args.ngrok_token) |
|
public_url = ngrok.connect(port=7860) |
|
print(f"ngrok tunnel started at {public_url}") |
|
|
|
if __name__ == "__main__": |
|
if args.ngrok_token: |
|
|
|
ngrok_thread = threading.Thread(target=start_ngrok, daemon=True) |
|
ngrok_thread.start() |
|
|
|
|
|
demo.launch(share=args.share, inbrowser=True) |
|
|