Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -42,22 +42,25 @@ def draw_heatmap(image, mask):
|
|
42 |
# Define callable method for the demo
|
43 |
def get_mask(image):
|
44 |
if image is None:
|
45 |
-
return None
|
46 |
|
47 |
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
|
48 |
dm_image = feature_extractor(image).unsqueeze(0)
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
masked_img = draw_mask(image, mask)
|
52 |
heatmap = draw_heatmap(image, mask)
|
53 |
-
return np.hstack((masked_img, heatmap))
|
54 |
|
55 |
|
56 |
# Launch demo interface
|
57 |
gr.Interface(
|
58 |
get_mask,
|
59 |
inputs=gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
|
60 |
-
outputs=[gr.outputs.Image(label="Output")],
|
61 |
title="Vision DiffMask Demo",
|
62 |
live=True,
|
63 |
).launch()
|
|
|
42 |
# Define callable method for the demo
|
43 |
def get_mask(image):
|
44 |
if image is None:
|
45 |
+
return None, None
|
46 |
|
47 |
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
|
48 |
dm_image = feature_extractor(image).unsqueeze(0)
|
49 |
+
dm_out = diffmask.get_mask(dm_image)
|
50 |
+
mask = dm_out["mask"][0].detach()
|
51 |
+
pred = dm_out["pred_class"][0].detach()
|
52 |
+
pred = diffmask.model.config.id2label[pred.item()]
|
53 |
|
54 |
masked_img = draw_mask(image, mask)
|
55 |
heatmap = draw_heatmap(image, mask)
|
56 |
+
return np.hstack((masked_img, heatmap)), pred
|
57 |
|
58 |
|
59 |
# Launch demo interface
|
60 |
gr.Interface(
|
61 |
get_mask,
|
62 |
inputs=gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
|
63 |
+
outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction")],
|
64 |
title="Vision DiffMask Demo",
|
65 |
live=True,
|
66 |
).launch()
|