gdTharusha commited on
Commit
eb9a051
1 Parent(s): d4b4ea7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import gradio as gr
2
- from PIL import Image, ImageFilter
3
  import numpy as np
4
  import io
5
  import tempfile
6
  import vtracer
7
- from skimage import color, filters
8
  import torch
9
  from torchvision import transforms
10
 
11
- # Load a pretrained edge detection model (e.g., HED - Holistically-Nested Edge Detection)
12
  class EdgeDetectionModel:
13
  def __init__(self, model_name='resnet18', device='cpu'):
14
  self.device = device
@@ -25,18 +25,25 @@ class EdgeDetectionModel:
25
  input_tensor = preprocess(image).unsqueeze(0).to(self.device)
26
  with torch.no_grad():
27
  output = self.model(input_tensor)
 
28
  output = output.cpu().numpy()[0].transpose((1, 2, 0))
29
- edges = filters.sobel(color.rgb2gray(output))
 
30
  return edges
31
 
32
  # Load the edge detection model
33
- edge_model = EdgeDetectionModel(device='cpu') # Use 'cuda' if you have a GPU
34
 
35
- def preprocess_image(image, blur_radius, edge_enhance):
36
  """Applies advanced preprocessing steps to the image before tracing."""
37
  if blur_radius > 0:
38
  image = image.filter(ImageFilter.GaussianBlur(blur_radius))
39
 
 
 
 
 
 
40
  if edge_enhance:
41
  edges = edge_model.detect_edges(image)
42
  edges_img = Image.fromarray((edges * 255).astype(np.uint8))
@@ -44,13 +51,13 @@ def preprocess_image(image, blur_radius, edge_enhance):
44
 
45
  return image
46
 
47
- def convert_image(image, blur_radius, edge_enhance, color_mode, hierarchical, mode, filter_speckle,
48
  color_precision, layer_difference, corner_threshold, length_threshold,
49
  max_iterations, splice_threshold, path_precision):
50
  """Converts an image to SVG using vtracer with customizable parameters and AI enhancements."""
51
 
52
  # Preprocess the image
53
- image = preprocess_image(image, blur_radius, edge_enhance)
54
 
55
  # Convert Gradio image to bytes for vtracer compatibility
56
  img_byte_array = io.BytesIO()
@@ -87,13 +94,14 @@ def convert_image(image, blur_radius, edge_enhance, color_mode, hierarchical, mo
87
  iface = gr.Blocks()
88
 
89
  with iface:
90
- gr.Markdown("# AI-Enhanced Image to SVG Vectors")
91
- gr.Markdown("Upload an image and customize the conversion parameters for high-quality vector results. AI-enhanced edge detection and preprocessing ensure superior vectorization.")
92
 
93
  with gr.Row():
94
  image_input = gr.Image(type="pil", label="Upload Image")
95
  blur_radius_input = gr.Slider(minimum=0, maximum=10, value=0, step=0.5, label="Blur Radius (for smoothing)")
96
  edge_enhance_input = gr.Checkbox(value=False, label="AI Edge Enhance")
 
97
 
98
  with gr.Row():
99
  color_mode_input = gr.Radio(choices=["Color", "Binary"], value="Color", label="Color Mode")
@@ -121,7 +129,7 @@ with iface:
121
  convert_button.click(
122
  fn=convert_image,
123
  inputs=[
124
- image_input, blur_radius_input, edge_enhance_input, color_mode_input, hierarchical_input, mode_input,
125
  filter_speckle_input, color_precision_input, layer_difference_input, corner_threshold_input,
126
  length_threshold_input, max_iterations_input, splice_threshold_input, path_precision_input
127
  ],
 
1
  import gradio as gr
2
+ from PIL import Image, ImageFilter, ImageOps
3
  import numpy as np
4
  import io
5
  import tempfile
6
  import vtracer
7
+ from skimage import color, filters, exposure
8
  import torch
9
  from torchvision import transforms
10
 
11
+ # AI-based edge detection using a pre-trained PyTorch model (e.g., ResNet)
12
  class EdgeDetectionModel:
13
  def __init__(self, model_name='resnet18', device='cpu'):
14
  self.device = device
 
25
  input_tensor = preprocess(image).unsqueeze(0).to(self.device)
26
  with torch.no_grad():
27
  output = self.model(input_tensor)
28
+
29
  output = output.cpu().numpy()[0].transpose((1, 2, 0))
30
+ gray_output = color.rgb2gray(output)
31
+ edges = filters.sobel(gray_output)
32
  return edges
33
 
34
  # Load the edge detection model
35
+ edge_model = EdgeDetectionModel(device='cpu')
36
 
37
+ def preprocess_image(image, blur_radius, edge_enhance, contrast_stretch):
38
  """Applies advanced preprocessing steps to the image before tracing."""
39
  if blur_radius > 0:
40
  image = image.filter(ImageFilter.GaussianBlur(blur_radius))
41
 
42
+ if contrast_stretch:
43
+ # Apply contrast stretching
44
+ p2, p98 = np.percentile(np.array(image), (2, 98))
45
+ image = ImageOps.autocontrast(image, cutoff=(p2, p98))
46
+
47
  if edge_enhance:
48
  edges = edge_model.detect_edges(image)
49
  edges_img = Image.fromarray((edges * 255).astype(np.uint8))
 
51
 
52
  return image
53
 
54
+ def convert_image(image, blur_radius, edge_enhance, contrast_stretch, color_mode, hierarchical, mode, filter_speckle,
55
  color_precision, layer_difference, corner_threshold, length_threshold,
56
  max_iterations, splice_threshold, path_precision):
57
  """Converts an image to SVG using vtracer with customizable parameters and AI enhancements."""
58
 
59
  # Preprocess the image
60
+ image = preprocess_image(image, blur_radius, edge_enhance, contrast_stretch)
61
 
62
  # Convert Gradio image to bytes for vtracer compatibility
63
  img_byte_array = io.BytesIO()
 
94
  iface = gr.Blocks()
95
 
96
  with iface:
97
+ gr.Markdown("# Advanced AI-Enhanced Image to SVG Vectors")
98
+ gr.Markdown("Upload an image and customize the conversion parameters for high-quality vector results. AI-enhanced edge detection, contrast stretching, and other preprocessing options ensure superior vectorization.")
99
 
100
  with gr.Row():
101
  image_input = gr.Image(type="pil", label="Upload Image")
102
  blur_radius_input = gr.Slider(minimum=0, maximum=10, value=0, step=0.5, label="Blur Radius (for smoothing)")
103
  edge_enhance_input = gr.Checkbox(value=False, label="AI Edge Enhance")
104
+ contrast_stretch_input = gr.Checkbox(value=False, label="Contrast Stretch")
105
 
106
  with gr.Row():
107
  color_mode_input = gr.Radio(choices=["Color", "Binary"], value="Color", label="Color Mode")
 
129
  convert_button.click(
130
  fn=convert_image,
131
  inputs=[
132
+ image_input, blur_radius_input, edge_enhance_input, contrast_stretch_input, color_mode_input, hierarchical_input, mode_input,
133
  filter_speckle_input, color_precision_input, layer_difference_input, corner_threshold_input,
134
  length_threshold_input, max_iterations_input, splice_threshold_input, path_precision_input
135
  ],