p1atdev commited on
Commit
388f879
1 Parent(s): 37bf661

fix: memory efficient

Browse files
Files changed (1) hide show
  1. app.py +27 -7
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from setup import setup
3
- import cv2
 
4
  from PIL import Image
5
  from manga_line_extraction.model import MangaLineExtractor
6
  from anime2sketch.model import Anime2Sketch
@@ -9,24 +10,36 @@ setup()
9
 
10
  print("Setup finished")
11
 
12
- extractor = MangaLineExtractor("./models/erika.pth", "cpu")
13
- to_sketch = Anime2Sketch("./models/netG.pth", "cpu")
14
 
15
- print("Model loaded")
 
 
16
 
17
 
18
  def extract(image):
19
- return extractor.predict(image)
 
 
 
 
20
 
21
 
22
  def convert_to_sketch(image):
23
- return to_sketch.predict(image)
 
 
 
 
24
 
25
 
26
  def start(image):
27
  return [extract(image), convert_to_sketch(Image.fromarray(image).convert("RGB"))]
28
 
29
 
 
 
 
 
30
  def ui():
31
  with gr.Blocks() as blocks:
32
  gr.Markdown(
@@ -43,7 +56,8 @@ def ui():
43
  with gr.Column():
44
  input_img = gr.Image(label="Input", interactive=True)
45
 
46
- extract_btn = gr.Button("Extract", variant="primary")
 
47
 
48
  with gr.Column():
49
  # with gr.Row():
@@ -78,6 +92,12 @@ def ui():
78
  outputs=[extract_output_img, to_sketch_output_img],
79
  )
80
 
 
 
 
 
 
 
81
  return blocks
82
 
83
 
 
1
  import gradio as gr
2
  from setup import setup
3
+ import torch
4
+ import gc
5
  from PIL import Image
6
  from manga_line_extraction.model import MangaLineExtractor
7
  from anime2sketch.model import Anime2Sketch
 
10
 
11
  print("Setup finished")
12
 
 
 
13
 
14
+ def flush():
15
+ gc.collect()
16
+ torch.cuda.empty_cache()
17
 
18
 
19
  def extract(image):
20
+ extractor = MangaLineExtractor("./models/erika.pth", "cpu")
21
+ result = extractor.predict(image)
22
+ del extractor
23
+ flush()
24
+ return result
25
 
26
 
27
  def convert_to_sketch(image):
28
+ to_sketch = Anime2Sketch("./models/netG.pth", "cpu")
29
+ result = to_sketch.predict(image)
30
+ del to_sketch
31
+ flush()
32
+ return result
33
 
34
 
35
  def start(image):
36
  return [extract(image), convert_to_sketch(Image.fromarray(image).convert("RGB"))]
37
 
38
 
39
+ def clear():
40
+ return [None, None]
41
+
42
+
43
  def ui():
44
  with gr.Blocks() as blocks:
45
  gr.Markdown(
 
56
  with gr.Column():
57
  input_img = gr.Image(label="Input", interactive=True)
58
 
59
+ extract_btn = gr.Button("Start", variant="primary")
60
+ clear_btn = gr.Button("Clear", variant="secondary")
61
 
62
  with gr.Column():
63
  # with gr.Row():
 
92
  outputs=[extract_output_img, to_sketch_output_img],
93
  )
94
 
95
+ clear_btn.click(
96
+ fn=clear,
97
+ inputs=[],
98
+ outputs=[extract_output_img, to_sketch_output_img],
99
+ )
100
+
101
  return blocks
102
 
103