File size: 3,145 Bytes
c5d1577
 
388f879
 
c5d1577
f561c25
c5d1577
ebf9292
c5d1577
 
 
 
 
f561c25
 
 
 
 
 
ebf9292
f561c25
 
 
 
 
 
 
 
 
 
c5d1577
388f879
 
 
c5d1577
 
d02110f
c5d1577
f561c25
388f879
c5d1577
 
d02110f
c5d1577
f561c25
388f879
c5d1577
 
 
 
 
 
388f879
 
 
 
c5d1577
 
 
 
 
 
 
 
 
f561c25
 
 
c5d1577
 
 
 
 
 
 
388f879
 
c5d1577
 
 
 
 
 
 
 
 
 
 
f561c25
c5d1577
 
 
 
 
 
bfaed2e
c5d1577
 
 
 
 
 
 
 
 
 
388f879
 
 
 
 
 
c5d1577
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from setup import setup
import torch
import gc
from PIL import Image
from transformers import AutoModel, AutoImageProcessor
from anime2sketch.model import Anime2Sketch
import spaces

setup()

print("Setup finished")

MLE_MODEL_REPO = "p1atdev/MangaLineExtraction-hf" 

class MangaLineExtractor:
    model = AutoModel.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True)
    processor = AutoImageProcessor.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True)

    @spaces.GPU
    @torch.no_grad()
    def __call__(self, image: Image.Image) -> Image.Image:
        inputs = self.processor(image, return_tensors="pt")
        outputs = self.model(inputs.pixel_values)

        line_image = Image.fromarray(outputs.pixel_values[0].numpy().astype("uint8"), mode="L")
        return line_image

mle_model = MangaLineExtractor()
a2s_model = Anime2Sketch("./models/netG.pth", "cpu")

def flush():
    gc.collect()
    torch.cuda.empty_cache()


@torch.no_grad()
def extract(image):
    result = mle_model(image)
    return result


@torch.no_grad()
def convert_to_sketch(image):
    result = a2s_model.predict(image)
    return result


def start(image):
    return [extract(image), convert_to_sketch(Image.fromarray(image).convert("RGB"))]


def clear():
    return [None, None]


def ui():
    with gr.Blocks() as blocks:
        gr.Markdown(
            """
        # Anime to Sketch 
        Unofficial demo for converting illustrations into sketches. 
        Original repos:
        - [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)
        - [Anime2Sketch](https://github.com/Mukosame/Anime2Sketch)

        Using with 🤗 transformers:
        - [MangaLineExtraction-hf](https://huggingface.co/p1atdev/MangaLineExtraction-hf)
        """
        )

        with gr.Row():
            with gr.Column():
                input_img = gr.Image(label="Input", interactive=True)

                extract_btn = gr.Button("Start", variant="primary")
                clear_btn = gr.Button("Clear", variant="secondary")

            with gr.Column():
                # with gr.Row():
                extract_output_img = gr.Image(
                    label="MangaLineExtraction", interactive=False
                )
                to_sketch_output_img = gr.Image(label="Anime2Sketch", interactive=False)

        gr.Examples(
            fn=start,
            examples=[
                ["./examples/0.jpg"],
                ["./examples/1.jpg"],
                ["./examples/2.jpg"],
            ],
            inputs=[input_img],
            outputs=[extract_output_img, to_sketch_output_img],
            label="Examples",
            # cache_examples=True,
        )

        gr.Markdown("Images are from nijijourney.")

        extract_btn.click(
            fn=start,
            inputs=[input_img],
            outputs=[extract_output_img, to_sketch_output_img],
        )

        clear_btn.click(
            fn=clear,
            inputs=[],
            outputs=[extract_output_img, to_sketch_output_img],
        )

    return blocks


if __name__ == "__main__":
    ui().launch()