import gradio as gr from PIL import Image, ImageFilter, ImageOps import numpy as np import io import tempfile import vtracer from skimage import color, filters, exposure import torch from torchvision import transforms # AI-based edge detection using a pre-trained PyTorch model (e.g., ResNet) class EdgeDetectionModel: def __init__(self, model_name='resnet18', device='cpu'): self.device = device self.model = torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=True).to(self.device) self.model.eval() def detect_edges(self, image): preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_tensor = preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): output = self.model(input_tensor) output = output.cpu().numpy()[0].transpose((1, 2, 0)) gray_output = color.rgb2gray(output) edges = filters.sobel(gray_output) return edges # Load the edge detection model edge_model = EdgeDetectionModel(device='cpu') def preprocess_image(image, blur_radius, edge_enhance, contrast_stretch): """Applies advanced preprocessing steps to the image before tracing.""" if blur_radius > 0: image = image.filter(ImageFilter.GaussianBlur(blur_radius)) if contrast_stretch: # Apply contrast stretching p2, p98 = np.percentile(np.array(image), (2, 98)) image = ImageOps.autocontrast(image, cutoff=(p2, p98)) if edge_enhance: edges = edge_model.detect_edges(image) edges_img = Image.fromarray((edges * 255).astype(np.uint8)) image = Image.blend(image.convert('RGB'), edges_img.convert('RGB'), alpha=0.5) return image def convert_image(image, blur_radius, edge_enhance, contrast_stretch, color_mode, hierarchical, mode, filter_speckle, color_precision, layer_difference, corner_threshold, length_threshold, max_iterations, splice_threshold, path_precision): """Converts an image to SVG using vtracer with customizable parameters and AI enhancements.""" # Preprocess the image image = preprocess_image(image, blur_radius, edge_enhance, contrast_stretch) # Convert Gradio image to bytes for vtracer compatibility img_byte_array = io.BytesIO() image.save(img_byte_array, format='PNG') img_bytes = img_byte_array.getvalue() # Perform the conversion with advanced settings svg_str = vtracer.convert_raw_image_to_svg( img_bytes, img_format='png', colormode=color_mode.lower(), hierarchical=hierarchical.lower(), mode=mode.lower(), filter_speckle=int(filter_speckle), color_precision=int(color_precision), layer_difference=int(layer_difference), corner_threshold=int(corner_threshold), length_threshold=float(length_threshold), max_iterations=int(max_iterations), splice_threshold=int(splice_threshold), path_precision=int(path_precision) ) # Save the SVG string to a temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.svg') temp_file.write(svg_str.encode('utf-8')) temp_file.close() # Display the SVG in the Gradio interface and provide the download link svg_html = f'{svg_str}' return gr.HTML(svg_html), temp_file.name # Gradio interface iface = gr.Blocks() with iface: gr.Markdown("# Advanced AI-Enhanced Image to SVG Vectors") gr.Markdown("Upload an image and customize the conversion parameters for high-quality vector results. AI-enhanced edge detection, contrast stretching, and other preprocessing options ensure superior vectorization.") with gr.Row(): image_input = gr.Image(type="pil", label="Upload Image") blur_radius_input = gr.Slider(minimum=0, maximum=10, value=0, step=0.5, label="Blur Radius (for smoothing)") edge_enhance_input = gr.Checkbox(value=False, label="AI Edge Enhance") contrast_stretch_input = gr.Checkbox(value=False, label="Contrast Stretch") with gr.Row(): color_mode_input = gr.Radio(choices=["Color", "Binary"], value="Color", label="Color Mode") hierarchical_input = gr.Radio(choices=["Stacked", "Cutout"], value="Stacked", label="Hierarchical") mode_input = gr.Radio(choices=["Spline", "Polygon", "None"], value="Spline", label="Mode") with gr.Row(): filter_speckle_input = gr.Slider(minimum=1, maximum=100, value=4, step=1, label="Filter Speckle") color_precision_input = gr.Slider(minimum=1, maximum=100, value=6, step=1, label="Color Precision") layer_difference_input = gr.Slider(minimum=1, maximum=100, value=16, step=1, label="Layer Difference") with gr.Row(): corner_threshold_input = gr.Slider(minimum=1, maximum=100, value=60, step=1, label="Corner Threshold") length_threshold_input = gr.Slider(minimum=1, maximum=100, value=4.0, step=0.5, label="Length Threshold") max_iterations_input = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Max Iterations") with gr.Row(): splice_threshold_input = gr.Slider(minimum=1, maximum=100, value=45, step=1, label="Splice Threshold") path_precision_input = gr.Slider(minimum=1, maximum=100, value=8, step=1, label="Path Precision") convert_button = gr.Button("Convert Image to SVG") svg_output = gr.HTML(label="SVG Output") download_output = gr.File(label="Download SVG") convert_button.click( fn=convert_image, inputs=[ image_input, blur_radius_input, edge_enhance_input, contrast_stretch_input, color_mode_input, hierarchical_input, mode_input, filter_speckle_input, color_precision_input, layer_difference_input, corner_threshold_input, length_threshold_input, max_iterations_input, splice_threshold_input, path_precision_input ], outputs=[svg_output, download_output] ) iface.launch()