Spaces:
Running
Running
gdTharusha
commited on
Commit
•
eb9a051
1
Parent(s):
d4b4ea7
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
-
from PIL import Image, ImageFilter
|
3 |
import numpy as np
|
4 |
import io
|
5 |
import tempfile
|
6 |
import vtracer
|
7 |
-
from skimage import color, filters
|
8 |
import torch
|
9 |
from torchvision import transforms
|
10 |
|
11 |
-
#
|
12 |
class EdgeDetectionModel:
|
13 |
def __init__(self, model_name='resnet18', device='cpu'):
|
14 |
self.device = device
|
@@ -25,18 +25,25 @@ class EdgeDetectionModel:
|
|
25 |
input_tensor = preprocess(image).unsqueeze(0).to(self.device)
|
26 |
with torch.no_grad():
|
27 |
output = self.model(input_tensor)
|
|
|
28 |
output = output.cpu().numpy()[0].transpose((1, 2, 0))
|
29 |
-
|
|
|
30 |
return edges
|
31 |
|
32 |
# Load the edge detection model
|
33 |
-
edge_model = EdgeDetectionModel(device='cpu')
|
34 |
|
35 |
-
def preprocess_image(image, blur_radius, edge_enhance):
|
36 |
"""Applies advanced preprocessing steps to the image before tracing."""
|
37 |
if blur_radius > 0:
|
38 |
image = image.filter(ImageFilter.GaussianBlur(blur_radius))
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
if edge_enhance:
|
41 |
edges = edge_model.detect_edges(image)
|
42 |
edges_img = Image.fromarray((edges * 255).astype(np.uint8))
|
@@ -44,13 +51,13 @@ def preprocess_image(image, blur_radius, edge_enhance):
|
|
44 |
|
45 |
return image
|
46 |
|
47 |
-
def convert_image(image, blur_radius, edge_enhance, color_mode, hierarchical, mode, filter_speckle,
|
48 |
color_precision, layer_difference, corner_threshold, length_threshold,
|
49 |
max_iterations, splice_threshold, path_precision):
|
50 |
"""Converts an image to SVG using vtracer with customizable parameters and AI enhancements."""
|
51 |
|
52 |
# Preprocess the image
|
53 |
-
image = preprocess_image(image, blur_radius, edge_enhance)
|
54 |
|
55 |
# Convert Gradio image to bytes for vtracer compatibility
|
56 |
img_byte_array = io.BytesIO()
|
@@ -87,13 +94,14 @@ def convert_image(image, blur_radius, edge_enhance, color_mode, hierarchical, mo
|
|
87 |
iface = gr.Blocks()
|
88 |
|
89 |
with iface:
|
90 |
-
gr.Markdown("# AI-Enhanced Image to SVG Vectors")
|
91 |
-
gr.Markdown("Upload an image and customize the conversion parameters for high-quality vector results. AI-enhanced edge detection and preprocessing ensure superior vectorization.")
|
92 |
|
93 |
with gr.Row():
|
94 |
image_input = gr.Image(type="pil", label="Upload Image")
|
95 |
blur_radius_input = gr.Slider(minimum=0, maximum=10, value=0, step=0.5, label="Blur Radius (for smoothing)")
|
96 |
edge_enhance_input = gr.Checkbox(value=False, label="AI Edge Enhance")
|
|
|
97 |
|
98 |
with gr.Row():
|
99 |
color_mode_input = gr.Radio(choices=["Color", "Binary"], value="Color", label="Color Mode")
|
@@ -121,7 +129,7 @@ with iface:
|
|
121 |
convert_button.click(
|
122 |
fn=convert_image,
|
123 |
inputs=[
|
124 |
-
image_input, blur_radius_input, edge_enhance_input, color_mode_input, hierarchical_input, mode_input,
|
125 |
filter_speckle_input, color_precision_input, layer_difference_input, corner_threshold_input,
|
126 |
length_threshold_input, max_iterations_input, splice_threshold_input, path_precision_input
|
127 |
],
|
|
|
1 |
import gradio as gr
|
2 |
+
from PIL import Image, ImageFilter, ImageOps
|
3 |
import numpy as np
|
4 |
import io
|
5 |
import tempfile
|
6 |
import vtracer
|
7 |
+
from skimage import color, filters, exposure
|
8 |
import torch
|
9 |
from torchvision import transforms
|
10 |
|
11 |
+
# AI-based edge detection using a pre-trained PyTorch model (e.g., ResNet)
|
12 |
class EdgeDetectionModel:
|
13 |
def __init__(self, model_name='resnet18', device='cpu'):
|
14 |
self.device = device
|
|
|
25 |
input_tensor = preprocess(image).unsqueeze(0).to(self.device)
|
26 |
with torch.no_grad():
|
27 |
output = self.model(input_tensor)
|
28 |
+
|
29 |
output = output.cpu().numpy()[0].transpose((1, 2, 0))
|
30 |
+
gray_output = color.rgb2gray(output)
|
31 |
+
edges = filters.sobel(gray_output)
|
32 |
return edges
|
33 |
|
34 |
# Load the edge detection model
|
35 |
+
edge_model = EdgeDetectionModel(device='cpu')
|
36 |
|
37 |
+
def preprocess_image(image, blur_radius, edge_enhance, contrast_stretch):
|
38 |
"""Applies advanced preprocessing steps to the image before tracing."""
|
39 |
if blur_radius > 0:
|
40 |
image = image.filter(ImageFilter.GaussianBlur(blur_radius))
|
41 |
|
42 |
+
if contrast_stretch:
|
43 |
+
# Apply contrast stretching
|
44 |
+
p2, p98 = np.percentile(np.array(image), (2, 98))
|
45 |
+
image = ImageOps.autocontrast(image, cutoff=(p2, p98))
|
46 |
+
|
47 |
if edge_enhance:
|
48 |
edges = edge_model.detect_edges(image)
|
49 |
edges_img = Image.fromarray((edges * 255).astype(np.uint8))
|
|
|
51 |
|
52 |
return image
|
53 |
|
54 |
+
def convert_image(image, blur_radius, edge_enhance, contrast_stretch, color_mode, hierarchical, mode, filter_speckle,
|
55 |
color_precision, layer_difference, corner_threshold, length_threshold,
|
56 |
max_iterations, splice_threshold, path_precision):
|
57 |
"""Converts an image to SVG using vtracer with customizable parameters and AI enhancements."""
|
58 |
|
59 |
# Preprocess the image
|
60 |
+
image = preprocess_image(image, blur_radius, edge_enhance, contrast_stretch)
|
61 |
|
62 |
# Convert Gradio image to bytes for vtracer compatibility
|
63 |
img_byte_array = io.BytesIO()
|
|
|
94 |
iface = gr.Blocks()
|
95 |
|
96 |
with iface:
|
97 |
+
gr.Markdown("# Advanced AI-Enhanced Image to SVG Vectors")
|
98 |
+
gr.Markdown("Upload an image and customize the conversion parameters for high-quality vector results. AI-enhanced edge detection, contrast stretching, and other preprocessing options ensure superior vectorization.")
|
99 |
|
100 |
with gr.Row():
|
101 |
image_input = gr.Image(type="pil", label="Upload Image")
|
102 |
blur_radius_input = gr.Slider(minimum=0, maximum=10, value=0, step=0.5, label="Blur Radius (for smoothing)")
|
103 |
edge_enhance_input = gr.Checkbox(value=False, label="AI Edge Enhance")
|
104 |
+
contrast_stretch_input = gr.Checkbox(value=False, label="Contrast Stretch")
|
105 |
|
106 |
with gr.Row():
|
107 |
color_mode_input = gr.Radio(choices=["Color", "Binary"], value="Color", label="Color Mode")
|
|
|
129 |
convert_button.click(
|
130 |
fn=convert_image,
|
131 |
inputs=[
|
132 |
+
image_input, blur_radius_input, edge_enhance_input, contrast_stretch_input, color_mode_input, hierarchical_input, mode_input,
|
133 |
filter_speckle_input, color_precision_input, layer_difference_input, corner_threshold_input,
|
134 |
length_threshold_input, max_iterations_input, splice_threshold_input, path_precision_input
|
135 |
],
|