# ------------------------------------------------------------------------ # Modified from Grounded-SAM (https://github.com/IDEA-Research/Grounded-Segment-Anything) # ------------------------------------------------------------------------ import os import sys import random import warnings os.system("export BUILD_WITH_CUDA=True") os.system("python -m pip install -e segment-anything") os.system("python -m pip install -e GroundingDINO") os.system("pip install --upgrade diffusers[torch]") #os.system("pip install opencv-python pycocotools matplotlib") sys.path.insert(0, './GroundingDINO') sys.path.insert(0, './segment-anything') warnings.filterwarnings("ignore") import cv2 from scipy import ndimage import gradio as gr import argparse import numpy as np from PIL import Image from moviepy.editor import * import torch from torch.nn import functional as F import torchvision import networks import utils # Grounding DINO from groundingdino.util.inference import Model # SAM from segment_anything.utils.transforms import ResizeLongestSide # SD from diffusers import StableDiffusionPipeline transform = ResizeLongestSide(1024) # Green Screen PALETTE_back = (51, 255, 146) GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swint_ogc.pth" mam_checkpoint="checkpoints/mam_sam_vitb.pth" output_dir="outputs" device = 'cuda' background_list = os.listdir('assets/backgrounds') #groundingdino_model = None #mam_predictor = None #generator = None # initialize MAM mam_model = networks.get_generator_m2m(seg='sam', m2m='sam_decoder_deep') mam_model.to(device) checkpoint = torch.load(mam_checkpoint, map_location=device) mam_model.load_state_dict(utils.remove_prefix_state_dict(checkpoint['state_dict']), strict=True) mam_model = mam_model.eval() # initialize GroundingDINO grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=device) # initialize StableDiffusionPipeline generator = StableDiffusionPipeline.from_pretrained("checkpoints/stable-diffusion-v1-5", torch_dtype=torch.float16) generator.to(device) def get_frames(video_in): frames = [] #resize the video clip = VideoFileClip(video_in) #check fps if clip.fps > 30: print("vide rate is over 30, resetting to 30") clip_resized = clip.resize(height=512) clip_resized.write_videofile("video_resized.mp4", fps=30) else: print("video rate is OK") clip_resized = clip.resize(height=512) clip_resized.write_videofile("video_resized.mp4", fps=clip.fps) print("video resized to 512 height") # Opens the Video file with CV2 cap= cv2.VideoCapture("video_resized.mp4") fps = cap.get(cv2.CAP_PROP_FPS) print("video fps: " + str(fps)) i=0 while(cap.isOpened()): ret, frame = cap.read() if ret == False: break cv2.imwrite('kang'+str(i)+'.jpg',frame) frames.append('kang'+str(i)+'.jpg') i+=1 cap.release() cv2.destroyAllWindows() print("broke the video into frames") return frames, fps def create_video(frames, fps, type): print("building video result") clip = ImageSequenceClip(frames, fps=fps) clip.write_videofile(f"video_{type}_result.mp4", fps=fps) return f"video_{type}_result.mp4" def run_grounded_sam(input_image, text_prompt, task_type, background_prompt, bg_already): background_type = "generated_by_text" box_threshold = 0.25 text_threshold = 0.25 iou_threshold = 0.5 scribble_mode = "split" guidance_mode = "alpha" #global groundingdino_model, sam_predictor, generator # make dir os.makedirs(output_dir, exist_ok=True) #if mam_predictor is None: # initialize MAM # build model # mam_model = networks.get_generator_m2m(seg='sam', m2m='sam_decoder_deep') # mam_model.to(device) # load checkpoint # checkpoint = torch.load(mam_checkpoint, map_location=device) # mam_model.load_state_dict(utils.remove_prefix_state_dict(checkpoint['state_dict']), strict=True) # inference # mam_model = mam_model.eval() #if groundingdino_model is None: # grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=device) #if generator is None: # generator = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) # generator.to(device) # load image #image_ori = input_image["image"] image_ori = input_image #scribble = input_image["mask"] original_size = image_ori.shape[:2] if task_type == 'text': if text_prompt is None: print('Please input non-empty text prompt') with torch.no_grad(): detections, phrases = grounding_dino_model.predict_with_caption( image=cv2.cvtColor(image_ori, cv2.COLOR_RGB2BGR), caption=text_prompt, box_threshold=box_threshold, text_threshold=text_threshold ) if len(detections.xyxy) > 1: nms_idx = torchvision.ops.nms( torch.from_numpy(detections.xyxy), torch.from_numpy(detections.confidence), iou_threshold, ).numpy().tolist() detections.xyxy = detections.xyxy[nms_idx] detections.confidence = detections.confidence[nms_idx] bbox = detections.xyxy[np.argmax(detections.confidence)] bbox = transform.apply_boxes(bbox, original_size) bbox = torch.as_tensor(bbox, dtype=torch.float).to(device) image = transform.apply_image(image_ori) image = torch.as_tensor(image).to(device) image = image.permute(2, 0, 1).contiguous() pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3,1,1).to(device) pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3,1,1).to(device) image = (image - pixel_mean) / pixel_std h, w = image.shape[-2:] pad_size = image.shape[-2:] padh = 1024 - h padw = 1024 - w image = F.pad(image, (0, padw, 0, padh)) if task_type == 'scribble_point': scribble = scribble.transpose(2, 1, 0)[0] labeled_array, num_features = ndimage.label(scribble >= 255) centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1)) centers = np.array(centers) ### (x,y) centers = transform.apply_coords(centers, original_size) point_coords = torch.from_numpy(centers).to(device) point_coords = point_coords.unsqueeze(0).to(device) point_labels = torch.from_numpy(np.array([1] * len(centers))).unsqueeze(0).to(device) if scribble_mode == 'split': point_coords = point_coords.permute(1, 0, 2) point_labels = point_labels.permute(1, 0) sample = {'image': image.unsqueeze(0), 'point': point_coords, 'label': point_labels, 'ori_shape': original_size, 'pad_shape': pad_size} elif task_type == 'scribble_box': scribble = scribble.transpose(2, 1, 0)[0] labeled_array, num_features = ndimage.label(scribble >= 255) centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1)) centers = np.array(centers) ### (x1, y1, x2, y2) x_min = centers[:, 0].min() x_max = centers[:, 0].max() y_min = centers[:, 1].min() y_max = centers[:, 1].max() bbox = np.array([x_min, y_min, x_max, y_max]) bbox = transform.apply_boxes(bbox, original_size) bbox = torch.as_tensor(bbox, dtype=torch.float).to(device) sample = {'image': image.unsqueeze(0), 'bbox': bbox.unsqueeze(0), 'ori_shape': original_size, 'pad_shape': pad_size} elif task_type == 'text': sample = {'image': image.unsqueeze(0), 'bbox': bbox.unsqueeze(0), 'ori_shape': original_size, 'pad_shape': pad_size} else: print("task_type:{} error!".format(task_type)) with torch.no_grad(): feas, pred, post_mask = mam_model.forward_inference(sample) alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8'] alpha_pred_os8 = alpha_pred_os8[..., : sample['pad_shape'][0], : sample['pad_shape'][1]] alpha_pred_os4 = alpha_pred_os4[..., : sample['pad_shape'][0], : sample['pad_shape'][1]] alpha_pred_os1 = alpha_pred_os1[..., : sample['pad_shape'][0], : sample['pad_shape'][1]] alpha_pred_os8 = F.interpolate(alpha_pred_os8, sample['ori_shape'], mode="bilinear", align_corners=False) alpha_pred_os4 = F.interpolate(alpha_pred_os4, sample['ori_shape'], mode="bilinear", align_corners=False) alpha_pred_os1 = F.interpolate(alpha_pred_os1, sample['ori_shape'], mode="bilinear", align_corners=False) if guidance_mode == 'mask': weight_os8 = utils.get_unknown_tensor_from_mask_oneside(post_mask, rand_width=10, train_mode=False) post_mask[weight_os8>0] = alpha_pred_os8[weight_os8>0] alpha_pred = post_mask.clone().detach() else: weight_os8 = utils.get_unknown_box_from_mask(post_mask) alpha_pred_os8[weight_os8>0] = post_mask[weight_os8>0] alpha_pred = alpha_pred_os8.clone().detach() weight_os4 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=20, train_mode=False) alpha_pred[weight_os4>0] = alpha_pred_os4[weight_os4>0] weight_os1 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=10, train_mode=False) alpha_pred[weight_os1>0] = alpha_pred_os1[weight_os1>0] alpha_pred = alpha_pred[0][0].cpu().numpy() #### draw ### alpha matte alpha_rgb = cv2.cvtColor(np.uint8(alpha_pred*255), cv2.COLOR_GRAY2RGB) ### com img with background global background_img if background_type == 'real_world_sample': background_img_file = os.path.join('assets/backgrounds', random.choice(background_list)) background_img = cv2.imread(background_img_file) background_img = cv2.cvtColor(background_img, cv2.COLOR_BGR2RGB) background_img = cv2.resize(background_img, (image_ori.shape[1], image_ori.shape[0])) com_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.uint8(background_img) com_img = np.uint8(com_img) else: if background_prompt is None: print('Please input non-empty background prompt') else: if bg_already is False: background_img = generator(background_prompt).images[0] background_img = np.array(background_img) background_img = cv2.resize(background_img, (image_ori.shape[1], image_ori.shape[0])) com_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.uint8(background_img) com_img = np.uint8(com_img) ### com img with green screen green_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.array([PALETTE_back], dtype='uint8') green_img = np.uint8(green_img) #return [(com_img, 'composite with background'), (green_img, 'green screen'), (alpha_rgb, 'alpha matte')] return com_img, green_img, alpha_rgb def infer(video_in, trim_value, prompt, background_prompt): print(prompt) break_vid = get_frames(video_in) frames_list= break_vid[0] fps = break_vid[1] n_frame = int(trim_value*fps) if n_frame >= len(frames_list): print("video is shorter than the cut value") n_frame = len(frames_list) with_bg_result_frames = [] with_green_result_frames = [] with_matte_result_frames = [] print("set stop frames to: " + str(n_frame)) bg_already = False for i in frames_list[0:int(n_frame)]: to_numpy_i = Image.open(i).convert("RGB") #need to convert to numpy # Convert the image to a NumPy array image_array = np.array(to_numpy_i) results = run_grounded_sam(image_array, prompt, "text", background_prompt, bg_already) bg_already = True bg_img = Image.fromarray(results[0]) green_img = Image.fromarray(results[1]) matte_img = Image.fromarray(results[2]) # exporting the images bg_img.save(f"bg_result_img-{i}.jpg") with_bg_result_frames.append(f"bg_result_img-{i}.jpg") green_img.save(f"green_result_img-{i}.jpg") with_green_result_frames.append(f"green_result_img-{i}.jpg") matte_img.save(f"matte_result_img-{i}.jpg") with_matte_result_frames.append(f"matte_result_img-{i}.jpg") print("frame " + i + "/" + str(n_frame) + ": done;") vid_bg = create_video(with_bg_result_frames, fps, "bg") vid_green = create_video(with_green_result_frames, fps, "greenscreen") vid_matte = create_video(with_matte_result_frames, fps, "matte") bg_already = False print("finished !") return vid_bg, vid_green, vid_matte if __name__ == "__main__": parser = argparse.ArgumentParser("MAM demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--share", action="store_true", help="share the app") parser.add_argument('--port', type=int, default=7589, help='port to run the server') parser.add_argument('--no-gradio-queue', action="store_true", help='path to the SAM checkpoint') args = parser.parse_args() print(args) block = gr.Blocks() if not args.no_gradio_queue: block = block.queue() with block: gr.Markdown( """ # Matting Anything in Video Demo Welcome to the Matting Anything in Video demo by @fffiloni and upload your video to get started
You may open usage details below to understand how to use this demo. ## Usage
You may upload a video to start, for the moment we only support 1 prompt type to get the alpha matte of the target: **text**: Send text prompt to identify the target instance in the `Text prompt` box. We also only support 1 background type to support image composition with the alpha matte output: **generated_by_text**: Send background text prompt to create a background image with stable diffusion model in the `Background prompt` box.
Duplicate Space for longer sequences, more control and no queue. """) with gr.Row(): with gr.Column(): video_in = gr.Video(source='upload', type="filepath") trim_in = gr.Slider(label="Cut video at (s)", minimun=1, maximum=10, step=1, value=1) #task_type = gr.Dropdown(["scribble_point", "scribble_box", "text"], value="text", label="Prompt type") #task_type = "text" text_prompt = gr.Textbox(label="Text prompt", placeholder="the girl in the middle", info="Describe the subject visible in your video that you want to matte") #background_type = gr.Dropdown(["generated_by_text", "real_world_sample"], value="generated_by_text", label="Background type") background_prompt = gr.Textbox(label="Background prompt", placeholder="downtown area in New York") run_button = gr.Button(label="Run") #with gr.Accordion("Advanced options", open=False): # box_threshold = gr.Slider( # label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05 # ) # text_threshold = gr.Slider( # label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05 # ) # iou_threshold = gr.Slider( # label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05 # ) # scribble_mode = gr.Dropdown( # ["merge", "split"], value="split", label="scribble_mode" # ) # guidance_mode = gr.Dropdown( # ["mask", "alpha"], value="alpha", label="guidance_mode", info="mask guidance is for complex scenes with multiple instances, alpha guidance is for simple scene with single instance" # ) with gr.Column(): #gallery = gr.Gallery( # label="Generated images", show_label=True, elem_id="gallery" #).style(preview=True, grid=3, object_fit="scale-down") vid_bg_out = gr.Video(label="Video with background") with gr.Row(): vid_green_out = gr.Video(label="Video green screen") vid_matte_out = gr.Video(label="Video matte") run_button.click(fn=infer, inputs=[ video_in, trim_in, text_prompt, background_prompt], outputs=[vid_bg_out, vid_green_out, vid_matte_out]) block.queue(max_size=24).launch(debug=args.debug, share=args.share, show_error=True) #block.queue(concurrency_count=100) #block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)