merve HF staff commited on
Commit
426e73b
1 Parent(s): d4d395d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, SamModel, SamProcessor
2
+ import torch
3
+ import numpy as np
4
+
5
+ checkpoint = "google/owlvit-base-patch16"
6
+ detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
7
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
8
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
9
+
10
+ def query(image, texts, threshold):
11
+ texts = texts.split(",")
12
+ print(texts)
13
+ print(image.size)
14
+ predictions = detector(
15
+ image,
16
+ candidate_labels=texts,
17
+ )
18
+ print(predictions)
19
+ result_labels = []
20
+ for pred in predictions:
21
+
22
+ box = pred["box"]
23
+ score = pred["score"]
24
+ label = pred["label"]
25
+ box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
26
+ round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
27
+
28
+ inputs = sam_processor(
29
+ image,
30
+ input_boxes=[[[box]]],
31
+ return_tensors="pt"
32
+ ).to("cuda")
33
+
34
+ with torch.no_grad():
35
+ outputs = sam_model(**inputs)
36
+
37
+ mask = sam_processor.image_processor.post_process_masks(
38
+ outputs.pred_masks.cpu(),
39
+ inputs["original_sizes"].cpu(),
40
+ inputs["reshaped_input_sizes"].cpu()
41
+ )[0][0][0].numpy()
42
+ mask = mask[np.newaxis, ...]
43
+ result_labels.append((mask, label))
44
+ return image, result_labels
45
+
46
+ import gradio as gr
47
+
48
+ description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable."
49
+ demo = gr.Interface(
50
+ query,
51
+ inputs=[gr.Image(type="pil"), "text", gr.Slider(0, 1, value=0.2)],
52
+ outputs="annotatedimage",
53
+ title="OWL 🤝 SAM",
54
+ #description=description,
55
+ examples=[
56
+ ["/content/cats.png", "cat", 0.1],
57
+ ],
58
+ )
59
+ demo.launch(debug=True)