File size: 5,399 Bytes
f9aa991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47bc811
d3260f5
f9aa991
47bc811
f9aa991
846b367
 
 
 
 
 
 
 
 
 
 
f9aa991
 
846b367
05693ff
 
 
 
 
 
 
846b367
05693ff
d3260f5
f9aa991
 
 
 
 
 
 
846b367
afc1021
f9aa991
 
 
 
 
 
 
 
 
 
846b367
 
f9aa991
846b367
f9aa991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import spaces
import argparse
import os
import time
from os import path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path

import gradio as gr
import torch
from diffusers import FluxPipeline

torch.backends.cuda.matmul.allow_tf32 = True

class timer:
    def __init__(self, method_name="timed process"):
        self.method = method_name
    def __enter__(self):
        self.start = time.time()
        print(f"{self.method} starts")
    def __exit__(self, exc_type, exc_val, exc_tb):
        end = time.time()
        print(f"{self.method} took {str(round(end - self.start, 2))}s")

if not path.exists(cache_path):
    os.makedirs(cache_path, exist_ok=True)

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("RED-AIGC/TDD", "TDD-FLUX.1-dev-lora-beta.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to("cuda")

css = """
h1 {
    text-align: center;
    display:block;
}
.gradio-container {
  max-width: 70.5rem !important;
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(
        """
        # FLUX.1-dev(beta) distilled by ✨Target-Driven Distillation✨
        
        Compared to Hyper-FLUX, the beta version of TDD has its parameters reduced by half(600MB), resulting in more realistic details. 
        
        Due to limitations in machine resources, there are still many imperfections in the beta version. The official version is still being optimized and is expected to be released after the National Day holiday. 
        
        Besides, TDD is also available for distilling video generation models. This space presents TDD-distilled [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
        
        [**Project Page**](https://redaigc.github.io/TDD/) **|** [**Paper**](https://arxiv.org/abs/2409.01347) **|** [**Code**](https://github.com/RedAIGC/Target-Driven-Distillation) **|** [**Model**](https://huggingface.co/RED-AIGC/TDD) **|** [🤗 **TDD-SDXL Demo**](https://huggingface.co/spaces/RED-AIGC/TDD) **|** [🤗 **TDD-SVD Demo**](https://huggingface.co/spaces/RED-AIGC/SVD-TDD)
        
        The codes of this space are built on [Hyper-FLUX](https://huggingface.co/spaces/ByteDance/Hyper-FLUX-8Steps-LoRA) and we acknowledge their contribution.
        """
    )

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Group():
                prompt = gr.Textbox(
                    label="Prompt",
                    value="portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
                    lines=3
                )
                
                with gr.Accordion("Advanced Settings", open=False):
                    with gr.Group():
                        with gr.Row():
                            height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
                            width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
                        
                        with gr.Row():
                            steps = gr.Slider(label="Inference Steps", minimum=4, maximum=10, step=1, value=8)
                            scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=3.5, step=0.1, value=2.0)
                        
                        seed = gr.Number(label="Seed", value=3413, precision=0)
                
                generate_btn = gr.Button("Generate Image", variant="primary", scale=1)

        with gr.Column(scale=4):
            output = gr.Image(label="Your Generated Image")
    
    gr.Markdown(
        """
        <div style="max-width: 650px; margin: 2rem auto; padding: 1rem; border-radius: 10px; background-color: #f0f0f0;">
            <h2 style="font-size: 1.5rem; margin-bottom: 1rem;">How to Use</h2>
            <ol style="padding-left: 1.5rem;">
                <li>Enter a detailed description of the image you want to create.</li>
                <li>Adjust advanced settings if desired (tap to expand).</li>
                <li>Tap "Generate Image" and wait for your creation!</li>
            </ol>
            <p style="margin-top: 1rem; font-style: italic;">Tip: Be specific in your description for best results!</p>
        </div>
        """
    )

    @spaces.GPU
    def process_image(height, width, steps, scales, prompt, seed):
        global pipe
        with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
            return pipe(
                prompt=[prompt],
                generator=torch.Generator().manual_seed(int(seed)),
                num_inference_steps=int(steps),
                guidance_scale=float(scales),
                height=int(height),
                width=int(width),
                max_sequence_length=256
            ).images[0]

    generate_btn.click(
        process_image,
        inputs=[height, width, steps, scales, prompt, seed],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()