import os import glob from functools import partial from tqdm import tqdm, trange from multiprocessing import Pool from PIL import Image import cv2 import mlxu from natsort import natsorted import numpy as np import einops import torch from vqlm_demo.inference import MultiProcessInferenceModel from vqlm_demo.utils import ( is_video, random_square_crop, read_frames_from_dir, read_frames_from_video ) FLAGS, _ = mlxu.define_flags_with_default( checkpoint='', input_files='', frame_input=False, read_file_list='', center_crop=1.0, n_context_frames=15, n_target_frames=1, n_workers=8, stride=8, batch_size=2, torch_devices='', shuffle=False, random_start=True, max_examples=0, ) class VideoDataset(torch.utils.data.Dataset): def __init__(self, videos, frame_input=False, n_context_frames=15, n_target_frames=1, stride=1): self.videos = videos self.frame_input = frame_input self.n_context_frames = n_context_frames self.n_target_frames = n_target_frames self.stride = stride def __getitem__(self, index): if self.frame_input: frames = read_frames_from_dir( self.videos[index], self.n_context_frames + self.n_target_frames, self.stride, center_crop=FLAGS.center_crop, random_start=FLAGS.random_start, ) else: frames = read_frames_from_video( self.videos[index], self.n_context_frames + self.n_target_frames, self.stride, center_crop=FLAGS.center_crop, random_start=FLAGS.random_start, ) if frames is None: return self[np.random.randint(0, len(self))] return frames[:self.n_context_frames], frames[self.n_context_frames:] def __len__(self): return len(self.videos) def main(_): assert FLAGS.checkpoint != '' assert FLAGS.read_file_list != '' or FLAGS.input_files != '' model = MultiProcessInferenceModel( checkpoint=FLAGS.checkpoint, torch_devices=FLAGS.torch_devices, perplexity_batch_size=FLAGS.batch_size, ) if FLAGS.read_file_list != '': with open(FLAGS.read_file_list, 'r') as f: videos = [x.strip() for x in f.readlines()] else: videos = glob.glob(FLAGS.input_files) if FLAGS.frame_input: videos = [x for x in videos if os.path.isdir(x)] else: videos = [x for x in videos if is_video(x)] if FLAGS.shuffle: np.random.shuffle(videos) if FLAGS.max_examples > 0: videos = videos[:FLAGS.max_examples] dataset = VideoDataset( videos, frame_input=FLAGS.frame_input, n_context_frames=FLAGS.n_context_frames, n_target_frames=FLAGS.n_target_frames, stride=FLAGS.stride ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=FLAGS.batch_size * model.n_processes * 4, shuffle=False, num_workers=FLAGS.n_workers, prefetch_factor=4, drop_last=True, ) perplexities = [] for batch_context_frames, batch_taret_frames in tqdm(dataloader, ncols=0): batch_context_frames = batch_context_frames.numpy() batch_taret_frames = batch_taret_frames.numpy() perplexity = model.compute_perplexity( batch_context_frames, batch_taret_frames ) perplexities.append(perplexity) perplexities = np.concatenate(perplexities, axis=0) print(f'Perplexity: {np.mean(perplexities)}') if __name__ == '__main__': mlxu.run(main)