black-agenta commited on
Commit
a7bb749
1 Parent(s): 9f8b2da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, File, Form, UploadFile
2
- from fastapi.responses import StreamingResponse
3
  import torch
4
  from transformers import GroundingDinoForObjectDetection, AutoProcessor
5
  from PIL import Image, ImageDraw
@@ -10,13 +10,13 @@ import threading
10
 
11
  app = FastAPI()
12
 
13
- # MySQL database connection setting
14
  username = 'ukrqsqxg_kacafix'
15
  password = 'kdm#{k&4@&Y+'
16
  host = 'www.kacafix.com'
17
  database = 'ukrqsqxg_millbox_storage'
18
 
19
- # Create a connection to the MySQL databased
20
  cnx = mysql.connector.connect(
21
  user=username,
22
  password=password,
@@ -32,10 +32,6 @@ model = GroundingDinoForObjectDetection.from_pretrained('IDEA-Research/grounding
32
  processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
33
  model.to(device)
34
 
35
- # Cache the pre-trained model and processor
36
- model_cache = model
37
- processor_cache = processor
38
-
39
  lock = threading.Lock()
40
 
41
  @app.post("/predict")
@@ -52,13 +48,12 @@ async def predict(
52
  labels = [label if label.endswith(".") else label + "." for label in labels]
53
  labels = " ".join(labels)
54
 
55
- # Use cached model and processor
56
  with lock:
57
- inputs = processor_cache(images=image, text=labels, return_tensors="pt").to(device)
58
  with torch.no_grad():
59
- outputs = model_cache(**inputs)
60
 
61
- result = processor_cache.post_process_grounded_object_detection(
62
  outputs,
63
  inputs.input_ids,
64
  box_threshold=box_threshold,
@@ -81,7 +76,8 @@ async def predict(
81
  # Save the data in the MySQL database asynchronously
82
  asyncio.create_task(save_to_database(labels, box_threshold, text_threshold, output_image_io.getvalue()))
83
 
84
- return JSONResponse(content={"message": "complete"})
 
85
 
86
  async def save_to_database(labels, box_threshold, text_threshold, output_image):
87
  query = ("INSERT INTO model_requests (labels, box_threshold, text_threshold, output_image) "
 
1
  from fastapi import FastAPI, File, Form, UploadFile
2
+ from fastapi.responses import JSONResponse
3
  import torch
4
  from transformers import GroundingDinoForObjectDetection, AutoProcessor
5
  from PIL import Image, ImageDraw
 
10
 
11
  app = FastAPI()
12
 
13
+ # MySQL database connection settings
14
  username = 'ukrqsqxg_kacafix'
15
  password = 'kdm#{k&4@&Y+'
16
  host = 'www.kacafix.com'
17
  database = 'ukrqsqxg_millbox_storage'
18
 
19
+ # Create a connection to the MySQL database
20
  cnx = mysql.connector.connect(
21
  user=username,
22
  password=password,
 
32
  processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
33
  model.to(device)
34
 
 
 
 
 
35
  lock = threading.Lock()
36
 
37
  @app.post("/predict")
 
48
  labels = [label if label.endswith(".") else label + "." for label in labels]
49
  labels = " ".join(labels)
50
 
 
51
  with lock:
52
+ inputs = processor(images=image, text=labels, return_tensors="pt").to(device)
53
  with torch.no_grad():
54
+ outputs = model(**inputs)
55
 
56
+ result = processor.post_process_grounded_object_detection(
57
  outputs,
58
  inputs.input_ids,
59
  box_threshold=box_threshold,
 
76
  # Save the data in the MySQL database asynchronously
77
  asyncio.create_task(save_to_database(labels, box_threshold, text_threshold, output_image_io.getvalue()))
78
 
79
+ # Return a success message instead of the image
80
+ return JSONResponse(content={"message": "Scan Completed Successfully"})
81
 
82
  async def save_to_database(labels, box_threshold, text_threshold, output_image):
83
  query = ("INSERT INTO model_requests (labels, box_threshold, text_threshold, output_image) "