Spaces:
Sleeping
Sleeping
File size: 6,177 Bytes
f94a9ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image, ImageDraw
import cv2
image_size = 224
def normalize_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
image = (image/255.0).astype("float32")
image[:, :, 0] = (image[:, :, 0] - mean[0]) / std[0]
image[:, :, 1] = (image[:, :, 1] - mean[1]) / std[1]
image[:, :, 2] = (image[:, :, 2] - mean[2]) / std[2]
return image
def resize_longest_max_size(image, max_size=224):
height, width = image.shape[:2]
if width > height:
ratio = max_size / width
else:
ratio = max_size / height
new_width = int(width * ratio)
new_height = int(height * ratio)
resized_image = cv2.resize(image, (new_width, new_height), cv2.INTER_LINEAR)
return resized_image
def pad_if_needed(image, target_size):
height, width, _ = image.shape
y0 = abs((height-target_size)//2)
x0 = abs((width-target_size)//2)
background = np.zeros((target_size, target_size, 3), dtype="uint8")
background[y0:(y0+height), x0:(x0+width), :] = image
return(background)
def heatmap2keypoints(heatmap: np.ndarray, image_size: int = 224) -> list:
"Function to convert heatmap to keypoint x, y tensor"
indx = heatmap.reshape(-1, image_size*image_size).argmax(axis=1)
row = indx // image_size
col = indx % image_size
keypoints_array = np.stack((col, row), axis=1)
keypoints_list = keypoints_array.tolist()
return keypoints_list
def centercrop_keypoints(keypoints, crop_height, crop_width, image_height, image_width):
y_diff = (image_height-crop_height)//2
x_diff = (image_width-crop_width)//2
keypoints_crop = [[x-x_diff, y-y_diff] for x, y in keypoints]
return(keypoints_crop)
def resize_keypoints(keypoints, current_height, current_width, target_height, target_width):
keypoints_resize = []
for x, y in keypoints:
x_resize = (x/current_width)*target_width
y_resize = (y/current_height)*target_height
keypoints_resize.append([int(x_resize), int(y_resize)])
return(keypoints_resize)
def draw_keypoints(image, keypoints):
draw = ImageDraw.Draw(image)
w, h = image.size
for keypoint in keypoints:
x, y = keypoint
# Draw a small circle at each keypoint
radius = int(min(w, h) * 0.01)
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red', outline='red')
return image
def point_dist(p0, p1):
x0, y0 = p0
x1, y1 = p1
dist = ((x0-x1)**2 + (y0-y1)**2)**0.5
return dist
def receipt_asp_ratio(keypoints, mode = "mean"):
h0 = point_dist(keypoints[0], keypoints[3])
h1 = point_dist(keypoints[1], keypoints[2])
w0 = point_dist(keypoints[0], keypoints[1])
w1 = point_dist(keypoints[2], keypoints[3])
if mode == "max":
h = max(h0, h1)
w = max(w0, w1)
elif mode == "mean":
h = (h0+h1)/2
w = (w0+w1)/2
else:
return("UNKNOWN MODE")
return w/h
# Load the ONNX model
session = ort.InferenceSession("models/timm-mobilenetv3_small_100.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# Main function to handle the image input, apply preprocessing, run the model, and apply postprocessing
def process_image(input_image):
# Convert PIL image to OpenCV image
image = np.array(input_image.convert("RGB"))
h, w, _ = image.shape
# Preprocess the image
image_resize = resize_longest_max_size(image)
h_small, w_small, _ = image_resize.shape
image_pad = pad_if_needed(image_resize, target_size=image_size)
image_norm = normalize_image(image_pad)
image_array = np.transpose(image_norm, (2, 0, 1))
image_array = np.expand_dims(image_array, axis=0)
# Run model inference
output = session.run([output_name], {input_name: image_array})
output_keypoints = heatmap2keypoints(output[0].squeeze())
crop_keypoints = centercrop_keypoints(output_keypoints, h_small, w_small, image_size, image_size)
large_keypoints = resize_keypoints(crop_keypoints, h_small, w_small, h, w)
# Draw keypoints on the image
image_with_keypoints = draw_keypoints(input_image, large_keypoints)
persp_h = 1024
persp_asp = receipt_asp_ratio(large_keypoints, mode="max")
persp_w = int(persp_asp*persp_h)
origin_points = np.float32([[x, y] for x, y in large_keypoints])
target_points = np.float32([[0, 0], [persp_w-1, 0], [persp_w-1, persp_h-1], [0, persp_h-1]])
persp_matrix = cv2.getPerspectiveTransform(origin_points, target_points)
persp_image = cv2.warpPerspective(image, persp_matrix, (persp_w, persp_h), cv2.INTER_LINEAR)
output_image = Image.fromarray(persp_image)
return image_with_keypoints, output_image
demo_images = [
"demo_images/image_1.jpg",
"demo_images/image_2.jpg",
"demo_images/image_3.jpg",
"demo_images/image_flux_1.png",
"demo_images/image_flux_2.png",
]
# Create Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Document corner detection and perspective correction")
gr.Markdown("Upload an image to detect the corners of a document and correct the perspective.\n\nUses a UNet model to detect corners and OpenCV to correct the perspective.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Image", show_label=True, scale=1)
with gr.Column():
output_image1 = gr.Image(type="pil", label="Image with predicted corners", show_label=True, scale=1)
with gr.Column():
output_image2 = gr.Image(type="pil", label="Image with perspective correction", show_label=True, scale=1)
with gr.Row():
examples = gr.Examples(demo_images, input_image, cache_examples=False, label="Exampled documents (CORD dataset and FLUX.1-schnell generated)")
input_image.change(fn=process_image, inputs=input_image, outputs=[output_image1, output_image2])
gr.Markdown("Created by Kenneth Thorø Martinsen ([email protected])")
iface.launch() |