import torch from torch import nn import torch.nn.functional as F from transformers import AutoProcessor, AutoTokenizer, XCLIPVisionModel, AutoModel, AutoModelForSequenceClassification import numpy as np import cv2 import opensmile class TextClassificationModel: def __init__(self, model, device): self.model = model self.device = device self.model.to(device) def __call__(self, input_ids, attn_mask, return_last_hidden_state=False): self.model.eval() with torch.no_grad(): input_ids = input_ids.to(self.device) attn_mask = attn_mask.to(self.device) output = self.model(input_ids=input_ids, attention_mask=attn_mask, output_hidden_states=return_last_hidden_state) logits = output['logits'] pred = torch.argmax(logits, dim=1) if return_last_hidden_state: hidden_states = output['hidden_states'] if return_last_hidden_state: return pred, hidden_states[-1][:, 0, :] else: return pred class XCLIPClassificationModel(nn.Module): def __init__(self, num_labels): super(XCLIPClassificationModel, self).__init__() self.base_model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32") self.num_labels = num_labels hidden_size = self.base_model.config.hidden_size self.fc_norm = nn.LayerNorm(hidden_size) self.classifier = nn.Linear(hidden_size, self.num_labels) self.loss_fct = nn.CrossEntropyLoss() self.pool1 = nn.AdaptiveAvgPool1d(1) self.pool2 = nn.AdaptiveAvgPool1d(1) def forward(self, pixel_values, labels=None, return_last_hidden_state=False): batch_size, num_frames, num_channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(-1, num_channels, height, width) out = self.base_model(pixel_values)[0] # [48, 50, 768] out = torch.transpose(out, 1, 2) # [48, 768, 50] out = self.pool1(out) # [48, 768, 1] out = torch.transpose(out, 1, 2) # [48, 1, 768] out = out.squeeze(1) # [48, 768] hidden_out = out.view(batch_size, num_frames, -1) # [3, 16, 768] hidden_out = torch.transpose(hidden_out, 1, 2) # [3, 768, 16] pooled_out = self.pool2(hidden_out) # [3, 768, 1] pooled_out = torch.transpose(pooled_out, 1, 2) # [3, 1, 768] pooled_out = pooled_out[:, 0, :] # [3, 768] logits = self.classifier(pooled_out) loss = None if labels is not None: loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if return_last_hidden_state: return {'logits': logits, 'loss': loss, 'last_hidden_state': pooled_out} else: return {'logits': logits, 'loss': loss} class VideoClassificationModel: def __init__(self, model, device): self.model = model self.device = device self.model.to(device) def __call__(self, pixel_values, return_last_hidden_state=False): self.model.eval() with torch.no_grad(): pixel_values = pixel_values.to(self.device) output = self.model(pixel_values, return_last_hidden_state=return_last_hidden_state) logits = output['logits'] pred = torch.argmax(logits, dim=1) if return_last_hidden_state: hidden_states = output['last_hidden_state'] if return_last_hidden_state: return pred, hidden_states else: return pred class ConvNet(nn.Module): def __init__(self, num_labels, n_input=1, n_channel=32): super(ConvNet, self).__init__() self.ln0 = nn.LayerNorm((1, 6191)) self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=3) self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3) self.bn1 = nn.BatchNorm1d(n_channel) self.bn2 = nn.BatchNorm1d(n_channel) self.pool1 = nn.MaxPool1d(2) self.fc1 = nn.Linear(n_channel*3093, 3093) self.fc2 = nn.Linear(3093, num_labels) self.flat = nn.Flatten() self.dropout = nn.Dropout(0.3) def forward(self, x, return_last_hidden_state=False): x = self.ln0(x) x = self.conv1(x) x = F.relu(self.bn1(x)) x = self.conv2(x) x = F.relu(self.bn2(x)) x = self.pool1(x) x = self.dropout(x) x = self.flat(x) hid = F.relu(self.fc1(x)) x = self.fc2(hid) if not return_last_hidden_state: return {'logits': F.log_softmax(x, dim=1)} else: return {'logits': F.log_softmax(x, dim=1), 'last_hidden_state': hid} class AudioClassificationModel: def __init__(self, model, device): self.model = model self.device = device self.model.to(device) def __call__(self, input_ids, return_last_hidden_state=False): self.model.eval() with torch.no_grad(): input_ids = torch.tensor(input_ids, dtype=torch.float).to(self.device) output = self.model(input_ids, return_last_hidden_state=return_last_hidden_state) logits = output['logits'] pred = torch.argmax(logits, dim=1) if return_last_hidden_state: hidden_state = output['last_hidden_state'] if return_last_hidden_state: return pred, hidden_state else: return pred class MultimodalClassificationModel(nn.Module): def __init__(self, text_model, video_model, audio_model, num_labels, input_size, hidden_size=256): super(MultimodalClassificationModel, self).__init__() self.text_model = text_model self.video_model = video_model self.audio_model = audio_model self.num_labels = num_labels self.linear1 = nn.Linear(input_size, hidden_size) self.linear2 = nn.Linear(hidden_size, self.num_labels) self.relu1 = nn.ReLU() self.drop1 = nn.Dropout() self.loss_func = nn.CrossEntropyLoss() def forward(self, batch, labels=None): text_pred, text_last_hidden = self.text_model( batch['text']['input_ids'].squeeze(1), batch['text']['attention_mask'].squeeze(1), return_last_hidden_state=True ) video_pred, video_last_hidden = self.video_model( batch['video']['pixel_values'].squeeze(1), return_last_hidden_state=True ) audio_pred, audio_last_hidden = self.audio_model( batch['audio'], return_last_hidden_state=True ) concat_input = torch.cat((text_last_hidden, video_last_hidden, audio_last_hidden), dim=1) hidden_state = self.linear1(concat_input) hidden_state = self.drop1(self.relu1(hidden_state)) logits = self.linear2(hidden_state) loss = None if labels is not None: loss = self.loss_func(logits.view(-1, self.num_labels), labels.view(-1)) return {'logits': logits, 'loss': loss} class MainModel: def __init__(self, model, device): self.model = model self.device = device self.model.to(device) def __call__(self, batch): self.model.eval() with torch.no_grad(): output = self.model(batch) logits = output['logits'] pred = torch.argmax(logits, dim=1) return pred def prepare_models(num_labels: int, text_model_path: str, video_model_path: str, audio_model_path: str, device: str='cuda'): # TEXT text_model_name = 'bert-large-uncased' text_base_model = AutoModelForSequenceClassification.from_pretrained( text_model_name, num_labels=num_labels ) state_dict = torch.load(text_model_path) text_base_model.load_state_dict(state_dict, strict=False) text_model = TextClassificationModel(text_base_model, device=device) # VIDEO video_base_model = XCLIPClassificationModel(num_labels) state_dict = torch.load(video_model_path) video_base_model.load_state_dict(state_dict, strict=False) video_model = VideoClassificationModel(video_base_model, device=device) # AUDIO audio_base_model = ConvNet(num_labels) checkpoint = torch.load(audio_model_path) audio_base_model.load_state_dict(checkpoint['model_state_dict']) audio_model = AudioClassificationModel(audio_base_model, device=device) return text_model, video_model, audio_model def sample_frame_indices(seg_len, clip_len=16, frame_sample_rate=4, mode="video"): # seg_len -- how many frames are received # clip_len -- how many frames to return converted_len = int(clip_len * frame_sample_rate) converted_len = min(converted_len, seg_len-1) end_idx = np.random.randint(converted_len, seg_len) start_idx = end_idx - converted_len if mode == "video": indices = np.linspace(start_idx, end_idx, num=clip_len) else: indices = np.linspace(start_idx, end_idx, num=clip_len*frame_sample_rate) indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) return indices def get_frames(file_path, clip_len=16,): cap = cv2.VideoCapture(file_path) v_len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) indices = sample_frame_indices(v_len) frames = [] for fn in range(v_len): success, frame = cap.read() if success is False: continue if (fn in indices): frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) res = cv2.resize(frame[90:-80, 60:-100], dsize=(224, 224), interpolation=cv2.INTER_CUBIC) frames.append(res) cap.release() if len(frames) < clip_len: add_num = clip_len - len(frames) frames_to_add = [frames[-1]] * add_num frames.extend(frames_to_add) return frames def prepare_data_input(text: str, video_path: str): # VIDEO video_frames = get_frames(video_path) video_model_name = "microsoft/xclip-base-patch32" video_feature_extractor = AutoProcessor.from_pretrained(video_model_name) video_encoding = video_feature_extractor(videos=video_frames, return_tensors="pt") # AUDIO smile = opensmile.Smile( opensmile.FeatureSet.ComParE_2016, opensmile.FeatureLevel.Functionals, sampling_rate=16000, resample=True, num_workers=5, verbose=True, ) audio_features = smile.process_files([video_path]) redundant_feat = open('redundant_feat.txt').read().split(',') audio_features.drop(columns=redundant_feat, inplace=True) # TEXT text_model_name = 'bert-large-uncased' tokenizer = AutoTokenizer.from_pretrained(text_model_name) text_encoding = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt') return {'text': text_encoding, 'video': video_encoding, 'audio': audio_features.values.reshape((1, 1, 6191))} def infer_multimodal_model(text: str, video_path: str, model_pathes: dict): label2id = {'anger': 0, 'disgust': 1, 'fear': 2, 'joy': 3, 'neutral': 4, 'sadness': 5, 'surprise': 6} id2label = {v: k for k, v in label2id.items()} num_labels = 7 text_model, video_model, audio_model = prepare_models(num_labels, model_pathes['text_model_path'], model_pathes['video_model_path'], model_pathes['audio_model_path'],) multi_model = MultimodalClassificationModel( text_model, video_model, audio_model, num_labels, input_size=4885, hidden_size=512 ) checkpoint = torch.load(model_pathes['multimodal_model_path']) multi_model.load_state_dict(checkpoint) device = 'cuda' final_model = MainModel(multi_model, device=device) batch = prepare_data_input(text, video_path) label = final_model(batch).detach().cpu().tolist() return id2label[label[0]]