geekyutao commited on
Commit
cdd0075
1 Parent(s): 8732e6b

add dilation bar and improve UI

Browse files
Files changed (1) hide show
  1. app.py +98 -67
app.py CHANGED
@@ -1,19 +1,38 @@
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  from pathlib import Path
4
  from matplotlib import pyplot as plt
5
  import torch
6
  import tempfile
7
- import os
8
- from omegaconf import OmegaConf
9
- from sam_segment import predict_masks_with_sam
10
  from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
11
  from utils import load_img_to_array, save_array_to_img, dilate_mask, \
12
  show_mask, show_points
13
  from PIL import Image
 
14
  from segment_anything import SamPredictor, sam_model_registry
15
-
16
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def mkstemp(suffix, dir=None):
18
  fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
19
  os.close(fd)
@@ -21,9 +40,7 @@ def mkstemp(suffix, dir=None):
21
 
22
 
23
  def get_sam_feat(img):
24
- # predictor.set_image(img)
25
  model['sam'].set_image(img)
26
- # self.is_image_set = False
27
  features = model['sam'].features
28
  orig_h = model['sam'].orig_h
29
  orig_w = model['sam'].orig_w
@@ -33,24 +50,18 @@ def get_sam_feat(img):
33
  return features, orig_h, orig_w, input_h, input_w
34
 
35
 
36
- def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
37
  point_coords = [w, h]
38
  point_labels = [1]
39
- dilate_kernel_size = 15
40
 
41
- # model['sam'].is_image_set = False
42
  model['sam'].is_image_set = True
43
  model['sam'].features = features
44
  model['sam'].orig_h = orig_h
45
  model['sam'].orig_w = orig_w
46
  model['sam'].input_h = input_h
47
  model['sam'].input_w = input_w
48
- # model['sam'].image_embedding = image_embedding
49
- # model['sam'].original_size = original_size
50
- # model['sam'].input_size = input_size
51
- # model['sam'].is_image_set = True
52
-
53
- model['sam'].set_image(img)
54
  masks, _, _ = model['sam'].predict(
55
  point_coords=np.array([point_coords]),
56
  point_labels=np.array(point_labels),
@@ -77,6 +88,7 @@ def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
77
  show_points(plt.gca(), [point_coords], point_labels,
78
  size=(width*0.04)**2)
79
  show_mask(plt.gca(), mask, random_color=False)
 
80
  plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
81
  figs.append(fig)
82
  plt.close()
@@ -84,8 +96,7 @@ def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
84
 
85
 
86
  def get_inpainted_img(img, mask0, mask1, mask2):
87
- lama_config = "third_party/lama/configs/prediction/default.yaml"
88
- # lama_ckpt = "pretrained_models/big-lama"
89
  device = "cuda" if torch.cuda.is_available() else "cpu"
90
  out = []
91
  for mask in [mask0, mask1, mask2]:
@@ -97,25 +108,27 @@ def get_inpainted_img(img, mask0, mask1, mask2):
97
  return out
98
 
99
 
100
- ## build models
 
 
 
 
101
  model = {}
102
  # build the sam model
103
  model_type="vit_h"
104
- ckpt_p="pretrained_models/sam_vit_h_4b8939.pth"
105
  model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
106
  device = "cuda" if torch.cuda.is_available() else "cpu"
107
  model_sam.to(device=device)
108
- # predictor = SamPredictor(model_sam)
109
  model['sam'] = SamPredictor(model_sam)
110
 
111
  # build the lama model
112
- lama_config = "third_party/lama/configs/prediction/default.yaml"
113
- lama_ckpt = "pretrained_models/big-lama"
114
  device = "cuda" if torch.cuda.is_available() else "cpu"
115
- # model_lama = build_lama_model(lama_config, lama_ckpt, device=device)
116
  model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
117
 
118
-
119
  with gr.Blocks() as demo:
120
  features = gr.State(None)
121
  orig_h = gr.State(None)
@@ -123,36 +136,59 @@ with gr.Blocks() as demo:
123
  input_h = gr.State(None)
124
  input_w = gr.State(None)
125
 
126
- with gr.Row():
127
- img = gr.Image(label="Image")
128
- img_pointed = gr.Plot(label='Pointed Image')
129
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
130
  with gr.Row():
131
  w = gr.Number(label="Point Coordinate W")
132
  h = gr.Number(label="Point Coordinate H")
133
- # sam_feat = gr.Button("Prepare for Segmentation")
134
- sam_mask = gr.Button("Predict Mask Using SAM")
135
- lama = gr.Button("Inpaint Image Using LaMA")
136
- # clear_button_image = gr.Button(value="Clear Image", interactive=True)
137
 
138
  # todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
139
- with gr.Row():
140
- mask_0 = gr.outputs.Image(type="numpy", label="Segmentation Mask 0")
141
- mask_1 = gr.outputs.Image(type="numpy", label="Segmentation Mask 1")
142
- mask_2 = gr.outputs.Image(type="numpy", label="Segmentation Mask 2")
143
-
144
- with gr.Row():
145
- img_with_mask_0 = gr.Plot(label="Image with Segmentation Mask 0")
146
- img_with_mask_1 = gr.Plot(label="Image with Segmentation Mask 1")
147
- img_with_mask_2 = gr.Plot(label="Image with Segmentation Mask 2")
148
-
149
- with gr.Row():
150
- img_rm_with_mask_0 = gr.outputs.Image(
151
- type="numpy", label="Image Removed with Segmentation Mask 0")
152
- img_rm_with_mask_1 = gr.outputs.Image(
153
- type="numpy", label="Image Removed with Segmentation Mask 1")
154
- img_rm_with_mask_2 = gr.outputs.Image(
155
- type="numpy", label="Image Removed with Segmentation Mask 2")
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  def get_select_coords(img, evt: gr.SelectData):
158
  dpi = plt.rcParams['figure.dpi']
@@ -160,22 +196,17 @@ with gr.Blocks() as demo:
160
  fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
161
  plt.imshow(img)
162
  plt.axis('off')
 
163
  show_points(plt.gca(), [[evt.index[0], evt.index[1]]], [1],
164
  size=(width*0.04)**2)
165
  return evt.index[0], evt.index[1], fig
166
 
167
  img.select(get_select_coords, [img], [w, h, img_pointed])
168
- # sam_feat.click(
169
- # get_sam_feat,
170
- # [img],
171
- # []
172
- # )
173
- # img.change(get_sam_feat, [img], [])
174
  img.upload(get_sam_feat, [img], [features, orig_h, orig_w, input_h, input_w])
175
 
176
  sam_mask.click(
177
  get_masked_img,
178
- [img, w, h, features, orig_h, orig_w, input_h, input_w],
179
  [img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
180
  )
181
 
@@ -185,16 +216,16 @@ with gr.Blocks() as demo:
185
  [img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
186
  )
187
 
188
- # clear_button_image.click(
189
- # lambda: ([], [], [], []),
190
- # [],
191
- # [img, img_pointed, w, h],
192
- # queue=False,
193
- # show_progress=False
194
- # )
 
 
195
 
196
  if __name__ == "__main__":
197
- # demo.queue(concurrency_count=4, max_size=25)
198
- # demo.launch(max_threads=8)
199
- demo.launch()
200
 
 
1
+ import os
2
+ import sys
3
+ # sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
4
+ # os.chdir("../")
5
  import gradio as gr
6
  import numpy as np
7
  from pathlib import Path
8
  from matplotlib import pyplot as plt
9
  import torch
10
  import tempfile
 
 
 
11
  from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
12
  from utils import load_img_to_array, save_array_to_img, dilate_mask, \
13
  show_mask, show_points
14
  from PIL import Image
15
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything"))
16
  from segment_anything import SamPredictor, sam_model_registry
17
+ import argparse
18
+
19
+ def setup_args(parser):
20
+ parser.add_argument(
21
+ "--lama_config", type=str,
22
+ default="./third_party/lama/configs/prediction/default.yaml",
23
+ help="The path to the config file of lama model. "
24
+ "Default: the config of big-lama",
25
+ )
26
+ parser.add_argument(
27
+ "--lama_ckpt", type=str,
28
+ default="pretrained_models/big-lama",
29
+ help="The path to the lama checkpoint.",
30
+ )
31
+ parser.add_argument(
32
+ "--sam_ckpt", type=str,
33
+ default="./pretrained_models/sam_vit_h_4b8939.pth",
34
+ help="The path to the SAM checkpoint to use for mask generation.",
35
+ )
36
  def mkstemp(suffix, dir=None):
37
  fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
38
  os.close(fd)
 
40
 
41
 
42
  def get_sam_feat(img):
 
43
  model['sam'].set_image(img)
 
44
  features = model['sam'].features
45
  orig_h = model['sam'].orig_h
46
  orig_w = model['sam'].orig_w
 
50
  return features, orig_h, orig_w, input_h, input_w
51
 
52
 
53
+ def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size):
54
  point_coords = [w, h]
55
  point_labels = [1]
 
56
 
 
57
  model['sam'].is_image_set = True
58
  model['sam'].features = features
59
  model['sam'].orig_h = orig_h
60
  model['sam'].orig_w = orig_w
61
  model['sam'].input_h = input_h
62
  model['sam'].input_w = input_w
63
+
64
+ # model['sam'].set_image(img) # todo : update here for accelerating
 
 
 
 
65
  masks, _, _ = model['sam'].predict(
66
  point_coords=np.array([point_coords]),
67
  point_labels=np.array(point_labels),
 
88
  show_points(plt.gca(), [point_coords], point_labels,
89
  size=(width*0.04)**2)
90
  show_mask(plt.gca(), mask, random_color=False)
91
+ plt.tight_layout()
92
  plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
93
  figs.append(fig)
94
  plt.close()
 
96
 
97
 
98
  def get_inpainted_img(img, mask0, mask1, mask2):
99
+ lama_config = args.lama_config
 
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
  out = []
102
  for mask in [mask0, mask1, mask2]:
 
108
  return out
109
 
110
 
111
+ # get args
112
+ parser = argparse.ArgumentParser()
113
+ setup_args(parser)
114
+ args = parser.parse_args(sys.argv[1:])
115
+ # build models
116
  model = {}
117
  # build the sam model
118
  model_type="vit_h"
119
+ ckpt_p=args.sam_ckpt
120
  model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
121
  device = "cuda" if torch.cuda.is_available() else "cpu"
122
  model_sam.to(device=device)
 
123
  model['sam'] = SamPredictor(model_sam)
124
 
125
  # build the lama model
126
+ lama_config = args.lama_config
127
+ lama_ckpt = args.lama_ckpt
128
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
129
  model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
130
 
131
+ button_size = (100,50)
132
  with gr.Blocks() as demo:
133
  features = gr.State(None)
134
  orig_h = gr.State(None)
 
136
  input_h = gr.State(None)
137
  input_w = gr.State(None)
138
 
139
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
140
+ with gr.Column(variant="panel"):
141
+ with gr.Row():
142
+ gr.Markdown("## Input Image")
143
+ with gr.Row():
144
+ img = gr.Image(label="Input Image").style(height="200px")
145
+ with gr.Column(variant="panel"):
146
+ with gr.Row():
147
+ gr.Markdown("## Pointed Image")
148
+ with gr.Row():
149
+ img_pointed = gr.Plot(label='Pointed Image')
150
+ with gr.Column(variant="panel"):
151
+ with gr.Row():
152
+ gr.Markdown("## Control Panel")
153
  with gr.Row():
154
  w = gr.Number(label="Point Coordinate W")
155
  h = gr.Number(label="Point Coordinate H")
156
+ dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=100, step=1, value=15)
157
+ sam_mask = gr.Button("Predict Mask", variant="primary").style(full_width=True, size="sm")
158
+ lama = gr.Button("Inpaint Image", variant="primary").style(full_width=True, size="sm")
159
+ clear_button_image = gr.Button(value="Reset", label="Reset", variant="secondary").style(full_width=True, size="sm")
160
 
161
  # todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
162
+ with gr.Row(variant="panel"):
163
+ with gr.Column():
164
+ with gr.Row():
165
+ gr.Markdown("## Segmentation Mask")
166
+ with gr.Row():
167
+ mask_0 = gr.outputs.Image(type="numpy", label="Segmentation Mask 0").style(height="200px")
168
+ mask_1 = gr.outputs.Image(type="numpy", label="Segmentation Mask 1").style(height="200px")
169
+ mask_2 = gr.outputs.Image(type="numpy", label="Segmentation Mask 2").style(height="200px")
170
+
171
+ with gr.Row(variant="panel"):
172
+ with gr.Column():
173
+ with gr.Row():
174
+ gr.Markdown("## Image with Mask")
175
+ with gr.Row():
176
+ img_with_mask_0 = gr.Plot(label="Image with Segmentation Mask 0")
177
+ img_with_mask_1 = gr.Plot(label="Image with Segmentation Mask 1")
178
+ img_with_mask_2 = gr.Plot(label="Image with Segmentation Mask 2")
179
+
180
+ with gr.Row(variant="panel"):
181
+ with gr.Column():
182
+ with gr.Row():
183
+ gr.Markdown("## Image Removed with Mask")
184
+ with gr.Row():
185
+ img_rm_with_mask_0 = gr.outputs.Image(
186
+ type="numpy", label="Image Removed with Segmentation Mask 0").style(height="200px")
187
+ img_rm_with_mask_1 = gr.outputs.Image(
188
+ type="numpy", label="Image Removed with Segmentation Mask 1").style(height="200px")
189
+ img_rm_with_mask_2 = gr.outputs.Image(
190
+ type="numpy", label="Image Removed with Segmentation Mask 2").style(height="200px")
191
+
192
 
193
  def get_select_coords(img, evt: gr.SelectData):
194
  dpi = plt.rcParams['figure.dpi']
 
196
  fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
197
  plt.imshow(img)
198
  plt.axis('off')
199
+ plt.tight_layout()
200
  show_points(plt.gca(), [[evt.index[0], evt.index[1]]], [1],
201
  size=(width*0.04)**2)
202
  return evt.index[0], evt.index[1], fig
203
 
204
  img.select(get_select_coords, [img], [w, h, img_pointed])
 
 
 
 
 
 
205
  img.upload(get_sam_feat, [img], [features, orig_h, orig_w, input_h, input_w])
206
 
207
  sam_mask.click(
208
  get_masked_img,
209
+ [img, w, h, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size],
210
  [img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
211
  )
212
 
 
216
  [img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
217
  )
218
 
219
+
220
+ def reset(*args):
221
+ return [None for _ in args]
222
+
223
+ clear_button_image.click(
224
+ reset,
225
+ [img, features, img_pointed, w, h, mask_0, mask_1, mask_2, img_with_mask_0, img_with_mask_1, img_with_mask_2, img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2],
226
+ [img, features, img_pointed, w, h, mask_0, mask_1, mask_2, img_with_mask_0, img_with_mask_1, img_with_mask_2, img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
227
+ )
228
 
229
  if __name__ == "__main__":
230
+ demo.launch(share=True)
 
 
231