File size: 6,177 Bytes
f94a9ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image, ImageDraw
import cv2

image_size = 224

def normalize_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    image = (image/255.0).astype("float32")
    
    image[:, :, 0] = (image[:, :, 0] - mean[0]) / std[0]
    image[:, :, 1] = (image[:, :, 1] - mean[1]) / std[1]
    image[:, :, 2] = (image[:, :, 2] - mean[2]) / std[2]
    
    return image

def resize_longest_max_size(image, max_size=224):
    
    height, width = image.shape[:2]
    
    if width > height:
        ratio = max_size / width
    else:
        ratio = max_size / height
    
    new_width = int(width * ratio)
    new_height = int(height * ratio)
    
    resized_image = cv2.resize(image, (new_width, new_height), cv2.INTER_LINEAR)
    
    return resized_image

def pad_if_needed(image, target_size):
    height, width, _ = image.shape
    
    y0 = abs((height-target_size)//2)
    x0 = abs((width-target_size)//2)
    
    background = np.zeros((target_size, target_size, 3), dtype="uint8")
    background[y0:(y0+height), x0:(x0+width), :] = image
    
    return(background)

def heatmap2keypoints(heatmap: np.ndarray, image_size: int = 224) -> list:
    "Function to convert heatmap to keypoint x, y tensor"

    indx = heatmap.reshape(-1, image_size*image_size).argmax(axis=1)
    row = indx // image_size
    col = indx % image_size

    keypoints_array = np.stack((col, row), axis=1)
    keypoints_list = keypoints_array.tolist()

    return keypoints_list

def centercrop_keypoints(keypoints, crop_height, crop_width, image_height, image_width):
    y_diff = (image_height-crop_height)//2
    x_diff = (image_width-crop_width)//2
    
    keypoints_crop = [[x-x_diff, y-y_diff] for x, y in keypoints]
    return(keypoints_crop)

def resize_keypoints(keypoints, current_height, current_width, target_height, target_width):
    keypoints_resize = []
    for x, y in keypoints:
        x_resize = (x/current_width)*target_width
        y_resize = (y/current_height)*target_height
        keypoints_resize.append([int(x_resize), int(y_resize)])
    return(keypoints_resize)

def draw_keypoints(image, keypoints):
    draw = ImageDraw.Draw(image)
    w, h = image.size
    for keypoint in keypoints:
        x, y = keypoint
        # Draw a small circle at each keypoint
        radius = int(min(w, h) * 0.01)
        draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red', outline='red')
    return image

def point_dist(p0, p1):
    x0, y0 = p0
    x1, y1 = p1
    
    dist = ((x0-x1)**2 + (y0-y1)**2)**0.5
    
    return dist

def receipt_asp_ratio(keypoints, mode = "mean"):
    
    h0 = point_dist(keypoints[0], keypoints[3])
    h1 = point_dist(keypoints[1], keypoints[2])
    
    w0 = point_dist(keypoints[0], keypoints[1])
    w1 = point_dist(keypoints[2], keypoints[3])

    if mode == "max":
        h = max(h0, h1)
        w = max(w0, w1)
    elif mode == "mean":
        h = (h0+h1)/2
        w = (w0+w1)/2
    else:
        return("UNKNOWN MODE")
    
    return w/h

# Load the ONNX model
session = ort.InferenceSession("models/timm-mobilenetv3_small_100.onnx")

input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# Main function to handle the image input, apply preprocessing, run the model, and apply postprocessing
def process_image(input_image):

    # Convert PIL image to OpenCV image
    image = np.array(input_image.convert("RGB"))

    h, w, _ = image.shape

    # Preprocess the image
    image_resize = resize_longest_max_size(image)

    h_small, w_small, _ = image_resize.shape

    image_pad = pad_if_needed(image_resize, target_size=image_size)

    image_norm = normalize_image(image_pad)

    image_array = np.transpose(image_norm, (2, 0, 1))

    image_array = np.expand_dims(image_array, axis=0)

    # Run model inference
    output = session.run([output_name], {input_name: image_array})

    output_keypoints = heatmap2keypoints(output[0].squeeze())
    crop_keypoints = centercrop_keypoints(output_keypoints, h_small, w_small, image_size, image_size)
    large_keypoints = resize_keypoints(crop_keypoints, h_small, w_small, h, w)

    # Draw keypoints on the image
    image_with_keypoints = draw_keypoints(input_image, large_keypoints)

    persp_h = 1024
    persp_asp = receipt_asp_ratio(large_keypoints, mode="max")
    persp_w = int(persp_asp*persp_h)

    origin_points = np.float32([[x, y] for x, y in large_keypoints])
    target_points = np.float32([[0, 0], [persp_w-1, 0], [persp_w-1, persp_h-1], [0, persp_h-1]])

    persp_matrix = cv2.getPerspectiveTransform(origin_points, target_points)
    persp_image = cv2.warpPerspective(image, persp_matrix, (persp_w, persp_h), cv2.INTER_LINEAR)
    
    output_image = Image.fromarray(persp_image)

    return image_with_keypoints, output_image

demo_images = [
    "demo_images/image_1.jpg",
    "demo_images/image_2.jpg",
    "demo_images/image_3.jpg",
    "demo_images/image_flux_1.png",
    "demo_images/image_flux_2.png",
]

# Create Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Document corner detection and perspective correction")
    gr.Markdown("Upload an image to detect the corners of a document and correct the perspective.\n\nUses a UNet model to detect corners and OpenCV to correct the perspective.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Image", show_label=True, scale=1)
        
        with gr.Column():
            output_image1 = gr.Image(type="pil", label="Image with predicted corners", show_label=True, scale=1)

        with gr.Column():
            output_image2 = gr.Image(type="pil", label="Image with perspective correction", show_label=True, scale=1)

    with gr.Row():
            examples = gr.Examples(demo_images, input_image, cache_examples=False, label="Exampled documents (CORD dataset and FLUX.1-schnell generated)")
    
    input_image.change(fn=process_image, inputs=input_image, outputs=[output_image1, output_image2])

    gr.Markdown("Created by Kenneth Thorø Martinsen ([email protected])")

iface.launch()