KennethTM's picture
Create app.py
f94a9ea verified
raw
history blame contribute delete
No virus
6.18 kB
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image, ImageDraw
import cv2
image_size = 224
def normalize_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
image = (image/255.0).astype("float32")
image[:, :, 0] = (image[:, :, 0] - mean[0]) / std[0]
image[:, :, 1] = (image[:, :, 1] - mean[1]) / std[1]
image[:, :, 2] = (image[:, :, 2] - mean[2]) / std[2]
return image
def resize_longest_max_size(image, max_size=224):
height, width = image.shape[:2]
if width > height:
ratio = max_size / width
else:
ratio = max_size / height
new_width = int(width * ratio)
new_height = int(height * ratio)
resized_image = cv2.resize(image, (new_width, new_height), cv2.INTER_LINEAR)
return resized_image
def pad_if_needed(image, target_size):
height, width, _ = image.shape
y0 = abs((height-target_size)//2)
x0 = abs((width-target_size)//2)
background = np.zeros((target_size, target_size, 3), dtype="uint8")
background[y0:(y0+height), x0:(x0+width), :] = image
return(background)
def heatmap2keypoints(heatmap: np.ndarray, image_size: int = 224) -> list:
"Function to convert heatmap to keypoint x, y tensor"
indx = heatmap.reshape(-1, image_size*image_size).argmax(axis=1)
row = indx // image_size
col = indx % image_size
keypoints_array = np.stack((col, row), axis=1)
keypoints_list = keypoints_array.tolist()
return keypoints_list
def centercrop_keypoints(keypoints, crop_height, crop_width, image_height, image_width):
y_diff = (image_height-crop_height)//2
x_diff = (image_width-crop_width)//2
keypoints_crop = [[x-x_diff, y-y_diff] for x, y in keypoints]
return(keypoints_crop)
def resize_keypoints(keypoints, current_height, current_width, target_height, target_width):
keypoints_resize = []
for x, y in keypoints:
x_resize = (x/current_width)*target_width
y_resize = (y/current_height)*target_height
keypoints_resize.append([int(x_resize), int(y_resize)])
return(keypoints_resize)
def draw_keypoints(image, keypoints):
draw = ImageDraw.Draw(image)
w, h = image.size
for keypoint in keypoints:
x, y = keypoint
# Draw a small circle at each keypoint
radius = int(min(w, h) * 0.01)
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red', outline='red')
return image
def point_dist(p0, p1):
x0, y0 = p0
x1, y1 = p1
dist = ((x0-x1)**2 + (y0-y1)**2)**0.5
return dist
def receipt_asp_ratio(keypoints, mode = "mean"):
h0 = point_dist(keypoints[0], keypoints[3])
h1 = point_dist(keypoints[1], keypoints[2])
w0 = point_dist(keypoints[0], keypoints[1])
w1 = point_dist(keypoints[2], keypoints[3])
if mode == "max":
h = max(h0, h1)
w = max(w0, w1)
elif mode == "mean":
h = (h0+h1)/2
w = (w0+w1)/2
else:
return("UNKNOWN MODE")
return w/h
# Load the ONNX model
session = ort.InferenceSession("models/timm-mobilenetv3_small_100.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# Main function to handle the image input, apply preprocessing, run the model, and apply postprocessing
def process_image(input_image):
# Convert PIL image to OpenCV image
image = np.array(input_image.convert("RGB"))
h, w, _ = image.shape
# Preprocess the image
image_resize = resize_longest_max_size(image)
h_small, w_small, _ = image_resize.shape
image_pad = pad_if_needed(image_resize, target_size=image_size)
image_norm = normalize_image(image_pad)
image_array = np.transpose(image_norm, (2, 0, 1))
image_array = np.expand_dims(image_array, axis=0)
# Run model inference
output = session.run([output_name], {input_name: image_array})
output_keypoints = heatmap2keypoints(output[0].squeeze())
crop_keypoints = centercrop_keypoints(output_keypoints, h_small, w_small, image_size, image_size)
large_keypoints = resize_keypoints(crop_keypoints, h_small, w_small, h, w)
# Draw keypoints on the image
image_with_keypoints = draw_keypoints(input_image, large_keypoints)
persp_h = 1024
persp_asp = receipt_asp_ratio(large_keypoints, mode="max")
persp_w = int(persp_asp*persp_h)
origin_points = np.float32([[x, y] for x, y in large_keypoints])
target_points = np.float32([[0, 0], [persp_w-1, 0], [persp_w-1, persp_h-1], [0, persp_h-1]])
persp_matrix = cv2.getPerspectiveTransform(origin_points, target_points)
persp_image = cv2.warpPerspective(image, persp_matrix, (persp_w, persp_h), cv2.INTER_LINEAR)
output_image = Image.fromarray(persp_image)
return image_with_keypoints, output_image
demo_images = [
"demo_images/image_1.jpg",
"demo_images/image_2.jpg",
"demo_images/image_3.jpg",
"demo_images/image_flux_1.png",
"demo_images/image_flux_2.png",
]
# Create Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Document corner detection and perspective correction")
gr.Markdown("Upload an image to detect the corners of a document and correct the perspective.\n\nUses a UNet model to detect corners and OpenCV to correct the perspective.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Image", show_label=True, scale=1)
with gr.Column():
output_image1 = gr.Image(type="pil", label="Image with predicted corners", show_label=True, scale=1)
with gr.Column():
output_image2 = gr.Image(type="pil", label="Image with perspective correction", show_label=True, scale=1)
with gr.Row():
examples = gr.Examples(demo_images, input_image, cache_examples=False, label="Exampled documents (CORD dataset and FLUX.1-schnell generated)")
input_image.change(fn=process_image, inputs=input_image, outputs=[output_image1, output_image2])
gr.Markdown("Created by Kenneth Thorø Martinsen ([email protected])")
iface.launch()