import os from multiprocessing import Pool import numpy as np import random from PIL import Image import re import cv2 import glob from natsort import natsorted class MultiProcessImageSaver(object): def __init__(self, n_workers=1): self.pool = Pool(n_workers) def __call__(self, images, output_files, resizes=None): if resizes is None: resizes = [None for _ in range(len(images))] return self.pool.imap( self.save_image, zip(images, output_files, resizes), ) def close(self): self.pool.close() self.pool.join() @staticmethod def save_image(args): image, filename, resize = args image = Image.fromarray(image) if resize is not None: image = image.resize(tuple(resize)) image.save(filename) def list_dir_with_full_path(path): return [os.path.join(path, f) for f in os.listdir(path)] def find_all_files_in_dir(path): files = [] for root, _, files in os.walk(path): for file in files: files.append(os.path.join(root, file)) return files def is_image(path): return ( path.endswith('.jpg') or path.endswith('.png') or path.endswith('.jpeg') or path.endswith('.JPG') or path.endswith('.PNG') or path.endswith('.JPEG') ) def is_video(path): return ( path.endswith('.mp4') or path.endswith('.avi') or path.endswith('.MP4') or path.endswith('.AVI') or path.endswith('.webm') or path.endswith('.WEBM') or path.endswith('.mkv') or path.endswith('.MVK') ) def random_square_crop(img, random_generator=None): # If no random generator is provided, use numpy's default if random_generator is None: random_generator = np.random.default_rng() # Get the width and height of the image width, height = img.size # Determine the shorter side min_size = min(width, height) # Randomly determine the starting x and y coordinates for the crop if width > height: left = random_generator.integers(0, width - min_size) upper = 0 else: left = 0 upper = random_generator.integers(0, height - min_size) # Calculate the ending x and y coordinates for the crop right = left + min_size lower = upper + min_size # Crop the image return img.crop((left, upper, right, lower)) def read_image_to_tensor(path, center_crop=1.0): pil_im = Image.open(path).convert('RGB') if center_crop < 1.0: width, height = pil_im.size pil_im = pil_im.crop(( int((1 - center_crop) * height / 2), int((1 + center_crop) * height / 2), int((1 - center_crop) * width / 2), int((1 + center_crop) * width / 2), )) input_img = pil_im.resize((256, 256)) input_img = np.array(input_img) / 255.0 input_img = input_img.astype(np.float32) return input_img def match_mulitple_path(root_dir, regex): videos = [] for root, _, files in os.walk(root_dir): for file in files: videos.append(os.path.join(root, file)) videos = [v for v in videos if not v.split('/')[-1].startswith('.')] grouped_path = {} for r in regex: r = re.compile(r) for v in videos: matched = r.findall(v) if len(matched) > 0: groups = matched[0] if groups not in grouped_path: grouped_path[groups] = [] grouped_path[groups].append(v) grouped_path = { k: tuple(v) for k, v in grouped_path.items() if len(v) == len(regex) } return list(grouped_path.values()) def randomly_subsample_frame_indices(length, n_frames, max_stride=30, random_start=True): assert length >= n_frames max_stride = min( (length - 1) // (n_frames - 1), max_stride ) stride = np.random.randint(1, max_stride + 1) if random_start: start = np.random.randint(0, length - (n_frames - 1) * stride) else: start = 0 return np.arange(n_frames) * stride + start def read_frames_from_dir(dir_path, n_frames, stride, random_start=True, center_crop=1.0): files = [os.path.join(dir_path, x) for x in os.listdir(dir_path)] files = natsorted([x for x in files if is_image(x)]) total_frames = len(files) if total_frames < n_frames: return None max_stride = (total_frames - 1) // (n_frames - 1) stride = min(max_stride, stride) if random_start: start = np.random.randint(0, total_frames - (n_frames - 1) * stride) else: start = 0 frame_indices = np.arange(n_frames) * stride + start frames = [] for frame_index in sorted(frame_indices): # Check if the frame_index is valid frames.append(read_image_to_tensor(files[frame_index], center_crop=center_crop)) if len(frames) < n_frames: return None frames = np.stack(frames, axis=0) return frames def read_frames_from_video(video_path, n_frames, stride, random_start=True, center_crop=1.0): frames = [] cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames < n_frames: cap.release() return None max_stride = (total_frames - 1) // (n_frames - 1) stride = min(max_stride, stride) if random_start: start = np.random.randint(0, total_frames - (n_frames - 1) * stride) else: start = 0 frame_indices = np.arange(n_frames) * stride + start for frame_index in sorted(frame_indices): # Check if the frame_index is valid if 0 <= frame_index < total_frames: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ret, frame = cap.read() if ret: if center_crop < 1.0: height, width, _ = frame.shape frame = frame[ int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2), int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2), : ] frame = cv2.resize(frame, (256, 256)) frames.append(frame) else: print(f"Frame index {frame_index} is out of bounds. Skipping...") cap.release() if len(frames) < n_frames: return None frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 # From BGR to RGB return np.stack( [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 ) def read_all_frames_from_video(video_path, center_crop=1.0): frames = [] cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) for frame_index in range(total_frames): # Check if the frame_index is valid if 0 <= frame_index < total_frames: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ret, frame = cap.read() if ret: if center_crop < 1.0: height, width, _ = frame.shape frame = frame[ int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2), int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2), : ] frames.append(cv2.resize(frame, (256, 256))) else: print(f"Frame index {frame_index} is out of bounds. Skipping...") cap.release() if len(frames) == 0: return None frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 # From BGR to RGB return np.stack( [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 ) def read_max_span_frames_from_video(video_path, n_frames): frames = [] cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames < n_frames: cap.release() return None stride = (total_frames - 1) // (n_frames - 1) frame_indices = np.arange(n_frames) * stride frames = [] for frame_index in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ret, frame = cap.read() if ret: frames.append(cv2.resize(frame, (256, 256))) cap.release() if len(frames) < n_frames: return None frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 # From BGR to RGB return np.stack( [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 )