gdTharusha commited on
Commit
cbbdd98
1 Parent(s): eb9a051

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -50
app.py CHANGED
@@ -1,63 +1,34 @@
1
  import gradio as gr
2
- from PIL import Image, ImageFilter, ImageOps
3
  import numpy as np
4
  import io
5
  import tempfile
6
  import vtracer
7
- from skimage import color, filters, exposure
8
- import torch
9
- from torchvision import transforms
10
-
11
- # AI-based edge detection using a pre-trained PyTorch model (e.g., ResNet)
12
- class EdgeDetectionModel:
13
- def __init__(self, model_name='resnet18', device='cpu'):
14
- self.device = device
15
- self.model = torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=True).to(self.device)
16
- self.model.eval()
17
-
18
- def detect_edges(self, image):
19
- preprocess = transforms.Compose([
20
- transforms.Resize(256),
21
- transforms.CenterCrop(224),
22
- transforms.ToTensor(),
23
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
24
- ])
25
- input_tensor = preprocess(image).unsqueeze(0).to(self.device)
26
- with torch.no_grad():
27
- output = self.model(input_tensor)
28
-
29
- output = output.cpu().numpy()[0].transpose((1, 2, 0))
30
- gray_output = color.rgb2gray(output)
31
- edges = filters.sobel(gray_output)
32
- return edges
33
-
34
- # Load the edge detection model
35
- edge_model = EdgeDetectionModel(device='cpu')
36
-
37
- def preprocess_image(image, blur_radius, edge_enhance, contrast_stretch):
38
  """Applies advanced preprocessing steps to the image before tracing."""
39
  if blur_radius > 0:
40
  image = image.filter(ImageFilter.GaussianBlur(blur_radius))
41
-
42
- if contrast_stretch:
43
- # Apply contrast stretching
44
- p2, p98 = np.percentile(np.array(image), (2, 98))
45
- image = ImageOps.autocontrast(image, cutoff=(p2, p98))
46
-
47
  if edge_enhance:
48
- edges = edge_model.detect_edges(image)
 
 
49
  edges_img = Image.fromarray((edges * 255).astype(np.uint8))
50
  image = Image.blend(image.convert('RGB'), edges_img.convert('RGB'), alpha=0.5)
51
-
52
  return image
53
 
54
- def convert_image(image, blur_radius, edge_enhance, contrast_stretch, color_mode, hierarchical, mode, filter_speckle,
55
- color_precision, layer_difference, corner_threshold, length_threshold,
56
  max_iterations, splice_threshold, path_precision):
57
- """Converts an image to SVG using vtracer with customizable parameters and AI enhancements."""
58
 
59
  # Preprocess the image
60
- image = preprocess_image(image, blur_radius, edge_enhance, contrast_stretch)
61
 
62
  # Convert Gradio image to bytes for vtracer compatibility
63
  img_byte_array = io.BytesIO()
@@ -94,14 +65,14 @@ def convert_image(image, blur_radius, edge_enhance, contrast_stretch, color_mode
94
  iface = gr.Blocks()
95
 
96
  with iface:
97
- gr.Markdown("# Advanced AI-Enhanced Image to SVG Vectors")
98
- gr.Markdown("Upload an image and customize the conversion parameters for high-quality vector results. AI-enhanced edge detection, contrast stretching, and other preprocessing options ensure superior vectorization.")
99
 
100
  with gr.Row():
101
  image_input = gr.Image(type="pil", label="Upload Image")
102
  blur_radius_input = gr.Slider(minimum=0, maximum=10, value=0, step=0.5, label="Blur Radius (for smoothing)")
103
  edge_enhance_input = gr.Checkbox(value=False, label="AI Edge Enhance")
104
- contrast_stretch_input = gr.Checkbox(value=False, label="Contrast Stretch")
105
 
106
  with gr.Row():
107
  color_mode_input = gr.Radio(choices=["Color", "Binary"], value="Color", label="Color Mode")
@@ -129,9 +100,10 @@ with iface:
129
  convert_button.click(
130
  fn=convert_image,
131
  inputs=[
132
- image_input, blur_radius_input, edge_enhance_input, contrast_stretch_input, color_mode_input, hierarchical_input, mode_input,
133
- filter_speckle_input, color_precision_input, layer_difference_input, corner_threshold_input,
134
- length_threshold_input, max_iterations_input, splice_threshold_input, path_precision_input
 
135
  ],
136
  outputs=[svg_output, download_output]
137
  )
 
1
  import gradio as gr
2
+ from PIL import Image, ImageFilter
3
  import numpy as np
4
  import io
5
  import tempfile
6
  import vtracer
7
+ from skimage import color, filters
8
+ from skimage.feature import canny
9
+ from skimage.transform import resize
10
+
11
+ def preprocess_image(image, blur_radius, edge_enhance, edge_threshold):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """Applies advanced preprocessing steps to the image before tracing."""
13
  if blur_radius > 0:
14
  image = image.filter(ImageFilter.GaussianBlur(blur_radius))
15
+
 
 
 
 
 
16
  if edge_enhance:
17
+ # Convert image to grayscale and apply edge detection
18
+ gray_image = np.array(image.convert('L'))
19
+ edges = canny(gray_image, sigma=edge_threshold)
20
  edges_img = Image.fromarray((edges * 255).astype(np.uint8))
21
  image = Image.blend(image.convert('RGB'), edges_img.convert('RGB'), alpha=0.5)
22
+
23
  return image
24
 
25
+ def convert_image(image, blur_radius, edge_enhance, edge_threshold, color_mode, hierarchical, mode,
26
+ filter_speckle, color_precision, layer_difference, corner_threshold, length_threshold,
27
  max_iterations, splice_threshold, path_precision):
28
+ """Converts an image to SVG using vtracer with customizable parameters."""
29
 
30
  # Preprocess the image
31
+ image = preprocess_image(image, blur_radius, edge_enhance, edge_threshold)
32
 
33
  # Convert Gradio image to bytes for vtracer compatibility
34
  img_byte_array = io.BytesIO()
 
65
  iface = gr.Blocks()
66
 
67
  with iface:
68
+ gr.Markdown("# CPU-Optimized AI-Enhanced Image to SVG Vectors")
69
+ gr.Markdown("Upload an image and customize the conversion parameters for high-quality vector results. AI-enhanced edge detection and preprocessing ensure superior vectorization.")
70
 
71
  with gr.Row():
72
  image_input = gr.Image(type="pil", label="Upload Image")
73
  blur_radius_input = gr.Slider(minimum=0, maximum=10, value=0, step=0.5, label="Blur Radius (for smoothing)")
74
  edge_enhance_input = gr.Checkbox(value=False, label="AI Edge Enhance")
75
+ edge_threshold_input = gr.Slider(minimum=0.1, maximum=3.0, value=1.0, step=0.1, label="Edge Detection Threshold")
76
 
77
  with gr.Row():
78
  color_mode_input = gr.Radio(choices=["Color", "Binary"], value="Color", label="Color Mode")
 
100
  convert_button.click(
101
  fn=convert_image,
102
  inputs=[
103
+ image_input, blur_radius_input, edge_enhance_input, edge_threshold_input, color_mode_input,
104
+ hierarchical_input, mode_input, filter_speckle_input, color_precision_input, layer_difference_input,
105
+ corner_threshold_input, length_threshold_input, max_iterations_input, splice_threshold_input,
106
+ path_precision_input
107
  ],
108
  outputs=[svg_output, download_output]
109
  )