Xlabs-Gradio / app.py
John6666's picture
Upload app.py
c855c9b verified
raw
history blame
No virus
18.8 kB
import spaces
from pathlib import Path
import torch
import gradio as gr
from PIL import Image, ExifTags
import numpy as np
import torch
from torch import Tensor
from einops import rearrange
import uuid
import os
from src.flux.modules.layers import (
SingleStreamBlockProcessor,
DoubleStreamBlockLoraProcessor,
IPDoubleStreamBlockProcessor,
ImageProjModel,
)
from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack
from src.flux.util import (
load_ae,
load_clip,
load_flow_model,
load_t5,
load_controlnet,
load_flow_model_quintized,
Annotator,
get_lora_rank,
load_checkpoint
)
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
class XFluxPipeline:
def __init__(self, model_type, device, offload: bool = False):
self.device = torch.device(device)
self.offload = offload
self.model_type = model_type
self.clip = load_clip(self.device)
self.t5 = load_t5(self.device, max_length=512)
self.ae = load_ae(model_type, device="cpu" if offload else self.device)
if "fp8" in model_type:
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
else:
self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
self.image_encoder_path = "openai/clip-vit-large-patch14"
self.hf_lora_collection = "XLabs-AI/flux-lora-collection"
self.lora_types_to_names = {
"realism": "lora.safetensors",
}
self.controlnet_loaded = False
self.ip_loaded = False
def set_ip(self, local_path: str = None, repo_id = None, name: str = None):
self.model.to(self.device)
# unpack checkpoint
checkpoint = load_checkpoint(local_path, repo_id, name)
prefix = "double_blocks."
blocks = {}
proj = {}
for key, value in checkpoint.items():
if key.startswith(prefix):
blocks[key[len(prefix):].replace('.processor.', '.')] = value
if key.startswith("ip_adapter_proj_model"):
proj[key[len("ip_adapter_proj_model."):]] = value
for key, value in checkpoint.items():
if key.startswith(prefix):
blocks[key[len(prefix):].replace('.processor.', '.')] = value
if key.startswith("ip_adapter_proj_model"):
proj[key[len("ip_adapter_proj_model."):]] = value
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
# setup image embedding projection model
self.improj = ImageProjModel(4096, 768, 4)
self.improj.load_state_dict(proj)
self.improj = self.improj.to(self.device, dtype=torch.bfloat16)
ip_attn_procs = {}
for name, _ in self.model.attn_processors.items():
ip_state_dict = {}
for k in checkpoint.keys():
if name in k:
ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k]
if ip_state_dict:
ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
ip_attn_procs[name].load_state_dict(ip_state_dict)
ip_attn_procs[name].to(self.device, dtype=torch.bfloat16)
else:
ip_attn_procs[name] = self.model.attn_processors[name]
self.model.set_attn_processor(ip_attn_procs)
self.ip_loaded = True
def set_lora(self, local_path: str = None, repo_id: str = None,
name: str = None, lora_weight: int = 0.7):
checkpoint = load_checkpoint(local_path, repo_id, name)
self.update_model_with_lora(checkpoint, lora_weight)
def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
checkpoint = load_checkpoint(
None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
)
self.update_model_with_lora(checkpoint, lora_weight)
def update_model_with_lora(self, checkpoint, lora_weight):
rank = get_lora_rank(checkpoint)
lora_attn_procs = {}
for name, _ in self.model.attn_processors.items():
if name.startswith("single_blocks"):
lora_attn_procs[name] = SingleStreamBlockProcessor()
continue
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
lora_state_dict = {}
for k in checkpoint.keys():
if name in k:
lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
lora_attn_procs[name].load_state_dict(lora_state_dict)
lora_attn_procs[name].to(self.device)
self.model.set_attn_processor(lora_attn_procs)
def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
self.model.to(self.device)
self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16)
checkpoint = load_checkpoint(local_path, repo_id, name)
self.controlnet.load_state_dict(checkpoint, strict=False)
self.annotator = Annotator(control_type, self.device)
self.controlnet_loaded = True
self.control_type = control_type
def get_image_proj(
self,
image_prompt: Tensor,
):
# encode image-prompt embeds
image_prompt = self.clip_image_processor(
images=image_prompt,
return_tensors="pt"
).pixel_values
image_prompt = image_prompt.to(self.image_encoder.device)
image_prompt_embeds = self.image_encoder(
image_prompt
).image_embeds.to(
device=self.device, dtype=torch.bfloat16,
)
# encode image
image_proj = self.improj(image_prompt_embeds)
return image_proj
def __call__(self,
prompt: str,
image_prompt: Image = None,
controlnet_image: Image = None,
width: int = 512,
height: int = 512,
guidance: float = 4,
num_steps: int = 50,
seed: int = 123456789,
true_gs: float = 3,
control_weight: float = 0.9,
ip_scale: float = 1.0,
neg_ip_scale: float = 1.0,
neg_prompt: str = '',
neg_image_prompt: Image = None,
timestep_to_start_cfg: int = 0,
):
width = 16 * (width // 16)
height = 16 * (height // 16)
image_proj = None
neg_image_proj = None
if not (image_prompt is None and neg_image_prompt is None) :
assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input'
if image_prompt is None:
image_prompt = np.zeros((width, height, 3), dtype=np.uint8)
if neg_image_prompt is None:
neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8)
image_proj = self.get_image_proj(image_prompt)
neg_image_proj = self.get_image_proj(neg_image_prompt)
if self.controlnet_loaded:
controlnet_image = self.annotator(controlnet_image, width, height)
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
controlnet_image = controlnet_image.permute(
2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device)
return self.forward(
prompt,
width,
height,
guidance,
num_steps,
seed,
controlnet_image,
timestep_to_start_cfg=timestep_to_start_cfg,
true_gs=true_gs,
control_weight=control_weight,
neg_prompt=neg_prompt,
image_proj=image_proj,
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
)
@torch.inference_mode()
@spaces.GPU()
def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance,
num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,
lora_weight, local_path, lora_local_path, ip_local_path):
if controlnet_image is not None:
controlnet_image = Image.fromarray(controlnet_image)
if ((self.controlnet_loaded and control_type != self.control_type)
or not self.controlnet_loaded):
if local_path is not None:
self.set_controlnet(control_type, local_path=local_path)
else:
self.set_controlnet(control_type, local_path=None,
repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3",
name=f"flux-{control_type}-controlnet-v3.safetensors")
if lora_local_path is not None:
self.set_lora(local_path=lora_local_path, lora_weight=lora_weight)
if image_prompt is not None:
image_prompt = Image.fromarray(image_prompt)
if neg_image_prompt is not None:
neg_image_prompt = Image.fromarray(neg_image_prompt)
if not self.ip_loaded:
if ip_local_path is not None:
self.set_ip(local_path=ip_local_path)
else:
self.set_ip(repo_id="xlabs-ai/flux-ip-adapter",
name="flux-ip-adapter.safetensors")
seed = int(seed)
if seed == -1:
seed = torch.Generator(device="cpu").seed()
img = self(prompt, image_prompt, controlnet_image, width, height, guidance,
num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt,
neg_image_prompt, timestep_to_start_cfg)
filename = f"output/gradio/{uuid.uuid4()}.jpg"
os.makedirs(os.path.dirname(filename), exist_ok=True)
exif_data = Image.Exif()
exif_data[ExifTags.Base.Make] = "XLabs AI"
exif_data[ExifTags.Base.Model] = self.model_type
img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
return img, filename
def forward(
self,
prompt,
width,
height,
guidance,
num_steps,
seed,
controlnet_image = None,
timestep_to_start_cfg = 0,
true_gs = 3.5,
control_weight = 0.9,
neg_prompt="",
image_proj=None,
neg_image_proj=None,
ip_scale=1.0,
neg_ip_scale=1.0,
):
x = get_noise(
1, height, width, device=self.device,
dtype=torch.bfloat16, seed=seed
)
timesteps = get_schedule(
num_steps,
(width // 8) * (height // 8) // (16 * 16),
shift=True,
)
torch.manual_seed(seed)
with torch.no_grad():
if self.offload:
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)
if self.offload:
self.offload_model_to_cpu(self.t5, self.clip)
self.model = self.model.to(self.device)
if self.controlnet_loaded:
x = denoise_controlnet(
self.model,
**inp_cond,
controlnet=self.controlnet,
timesteps=timesteps,
guidance=guidance,
controlnet_cond=controlnet_image,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=neg_inp_cond['txt'],
neg_txt_ids=neg_inp_cond['txt_ids'],
neg_vec=neg_inp_cond['vec'],
true_gs=true_gs,
controlnet_gs=control_weight,
image_proj=image_proj,
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
)
else:
x = denoise(
self.model,
**inp_cond,
timesteps=timesteps,
guidance=guidance,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=neg_inp_cond['txt'],
neg_txt_ids=neg_inp_cond['txt_ids'],
neg_vec=neg_inp_cond['vec'],
true_gs=true_gs,
image_proj=image_proj,
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
)
if self.offload:
self.offload_model_to_cpu(self.model)
self.ae.decoder.to(x.device)
x = unpack(x.float(), height, width)
x = self.ae.decode(x)
self.offload_model_to_cpu(self.ae.decoder)
x1 = x.clamp(-1, 1)
x1 = rearrange(x1[-1], "c h w -> h w c")
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
return output_img
def offload_model_to_cpu(self, *models):
if not self.offload: return
for model in models:
model.cpu()
torch.cuda.empty_cache()
def create_demo(
model_type: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
offload: bool = False,
ckpt_dir: str = "",
):
xflux_pipeline = XFluxPipeline(model_type, device, offload)
checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors"))
with gr.Blocks() as demo:
gr.Markdown("⚠️ Warning: Gradio is not functioning correctly. We are looking for someone to help fix it by submitting a Pull Request.")
gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
with gr.Accordion("Generation Options", open=False):
with gr.Row():
width = gr.Slider(512, 2048, 1024, step=16, label="Width")
height = gr.Slider(512, 2048, 1024, step=16, label="Height")
neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo")
with gr.Row():
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg")
with gr.Row():
guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True)
seed = gr.Textbox(-1, label="Seed (-1 for random)")
with gr.Accordion("ControlNet Options", open=False):
control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type")
control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True)
local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint",
info="Local Path to Controlnet weights (if no, it will be downloaded from HF)"
)
controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True)
with gr.Accordion("LoRA Options", open=False):
lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True)
lora_local_path = gr.Dropdown(
checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights"
)
with gr.Accordion("IP Adapter Options", open=False):
image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True)
ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale")
neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True)
neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale")
ip_local_path = gr.Dropdown(
checkpoints, label="IP Adapter Checkpoint",
info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)"
)
generate_btn = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Generated Image")
download_btn = gr.File(label="Download full-resolution")
@spaces.GPU()
def gradio_generate(*args):
return xflux_pipeline.gradio_generate(*args)
inputs = [prompt, image_prompt, controlnet_image, width, height, guidance,
num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,
lora_weight, local_path, lora_local_path, ip_local_path
]
generate_btn.click(
fn=gradio_generate,
inputs=inputs,
outputs=[output_image, download_btn],
)
return demo
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Flux")
parser.add_argument("--name", type=str, default="flux-dev", help="Model name")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format")
args = parser.parse_args()
demo = create_demo(args.name, args.device, args.offload, args.ckpt_dir)
demo.launch(share=args.share)