import math import torch import torch.nn as nn from pytorchvideo import transforms as pv_transforms from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler from pytorchvideo.data.encoded_video import EncodedVideo from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord from torchvision import transforms from torchvision.transforms._transforms_video import NormalizeVideo def get_clip_timepoints(clip_sampler, duration): # Read out all clips in this video all_clips_timepoints = [] is_last_clip = False end = 0.0 while not is_last_clip: start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) all_clips_timepoints.append((start, end)) return all_clips_timepoints def crop_boxes(boxes, x_offset, y_offset): """ Perform crop on the bounding boxes given the offsets. Args: boxes (ndarray or None): bounding boxes to perform crop. The dimension is `num boxes` x 4. x_offset (int): cropping offset in the x axis. y_offset (int): cropping offset in the y axis. Returns: cropped_boxes (ndarray or None): the cropped boxes with dimension of `num boxes` x 4. """ cropped_boxes = boxes.copy() cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset return cropped_boxes def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): """ Perform uniform spatial sampling on the images and corresponding boxes. Args: images (tensor): images to perform uniform crop. The dimension is `num frames` x `channel` x `height` x `width`. size (int): size of height and weight to crop the images. spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width is larger than height. Or 0, 1, or 2 for top, center, and bottom crop if height is larger than width. boxes (ndarray or None): optional. Corresponding boxes to images. Dimension is `num boxes` x 4. scale_size (int): optinal. If not None, resize the images to scale_size before performing any crop. Returns: cropped (tensor): images with dimension of `num frames` x `channel` x `size` x `size`. cropped_boxes (ndarray or None): the cropped boxes with dimension of `num boxes` x 4. """ assert spatial_idx in [0, 1, 2] ndim = len(images.shape) if ndim == 3: images = images.unsqueeze(0) height = images.shape[2] width = images.shape[3] if scale_size is not None: if width <= height: width, height = scale_size, int(height / width * scale_size) else: width, height = int(width / height * scale_size), scale_size images = torch.nn.functional.interpolate( images, size=(height, width), mode="bilinear", align_corners=False, ) y_offset = int(math.ceil((height - size) / 2)) x_offset = int(math.ceil((width - size) / 2)) if height > width: if spatial_idx == 0: y_offset = 0 elif spatial_idx == 2: y_offset = height - size else: if spatial_idx == 0: x_offset = 0 elif spatial_idx == 2: x_offset = width - size cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None if ndim == 3: cropped = cropped.squeeze(0) return cropped, cropped_boxes class SpatialCrop(nn.Module): """ Convert the video into 3 smaller clips spatially. Must be used after the temporal crops to get spatial crops, and should be used with -2 in the spatial crop at the slowfast augmentation stage (so full frames are passed in here). Will return a larger list with the 3x spatial crops as well. """ def __init__(self, crop_size: int = 224, num_crops: int = 3): super().__init__() self.crop_size = crop_size if num_crops == 3: self.crops_to_ext = [0, 1, 2] self.flipped_crops_to_ext = [] elif num_crops == 1: self.crops_to_ext = [1] self.flipped_crops_to_ext = [] else: raise NotImplementedError("Nothing else supported yet") def forward(self, videos): """ Args: videos: A list of C, T, H, W videos. Returns: videos: A list with 3x the number of elements. Each video converted to C, T, H', W' by spatial cropping. """ assert isinstance(videos, list), "Must be a list of videos after temporal crops" assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" res = [] for video in videos: for spatial_idx in self.crops_to_ext: res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) if not self.flipped_crops_to_ext: continue flipped_video = transforms.functional.hflip(video) for spatial_idx in self.flipped_crops_to_ext: res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) return res def load_and_transform_video_data( video_file, video_path, clip_duration=2, clips_per_video=5, sample_rate=16000, with_audio=False ): video_transform = transforms.Compose( [ pv_transforms.ShortSideScale(224), NormalizeVideo( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) clip_sampler = ConstantClipsPerVideoSampler( clip_duration=clip_duration, clips_per_video=clips_per_video ) frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) if isinstance(video_file, str): video = EncodedVideo.from_path( video_file, decoder="decord", decode_audio=with_audio, # **{"sample_rate": sample_rate}, ) else: video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate) all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) all_video = [] for clip_timepoints in all_clips_timepoints: # Read the clip, get frames clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) if clip is None: raise ValueError("No clip found") video_clip = frame_sampler(clip["video"]) video_clip = video_clip / 255.0 # since this is float, need 0-1 all_video.append(video_clip) all_video = [video_transform(clip) for clip in all_video] all_video = SpatialCrop(224, num_crops=3)(all_video) all_video = torch.stack(all_video, dim=0) if not with_audio: return all_video else: return all_video, clip['audio'] if __name__ == '__main__': video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4" video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True) import pdb;pdb.set_trace()