from pathlib import Path import torch import gradio as gr from src.flux.xflux_pipeline import XFluxPipeline def create_demo( model_type: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False, ckpt_dir: str = "", ): xflux_pipeline = XFluxPipeline(model_type, device, offload) checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) with gr.Blocks() as demo: gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="handsome woman in the city") with gr.Accordion("Generation Options", open=False): with gr.Row(): width = gr.Slider(512, 2048, 1024, step=16, label="Width") height = gr.Slider(512, 2048, 1024, step=16, label="Height") neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo") with gr.Row(): num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg") with gr.Row(): guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True) true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True) seed = gr.Textbox(-1, label="Seed (-1 for random)") with gr.Accordion("ControlNet Options", open=False): control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type") control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True) local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint", info="Local Path to Controlnet weights (if no, it will be downloaded from HF)" ) controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True) with gr.Accordion("LoRA Options", open=False): lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True) lora_local_path = gr.Dropdown( checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights" ) with gr.Accordion("IP Adapter Options", open=False): image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True) ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale") neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True) neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale") ip_local_path = gr.Dropdown( checkpoints, label="IP Adapter Checkpoint", info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)" ) generate_btn = gr.Button("Generate") with gr.Column(): output_image = gr.Image(label="Generated Image") download_btn = gr.File(label="Download full-resolution") inputs = [prompt, image_prompt, controlnet_image, width, height, guidance, num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, lora_weight, local_path, lora_local_path, ip_local_path ] generate_btn.click( fn=xflux_pipeline.gradio_generate, inputs=inputs, outputs=[output_image, download_btn], ) return demo if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Flux") parser.add_argument("--name", type=str, default="flux-dev", help="Model name") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") parser.add_argument("--share", action="store_true", help="Create a public link to your demo") parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format") args = parser.parse_args() demo = create_demo(args.name, args.device, args.offload, args.ckpt_dir) demo.launch(share=args.share)