MERC / hf_inference.py
AlexandraDolidze's picture
Upload 2 files
3d3ef8a verified
raw
history blame contribute delete
No virus
12.5 kB
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoProcessor, AutoTokenizer, XCLIPVisionModel, AutoModel, AutoModelForSequenceClassification
import numpy as np
import cv2
import opensmile
class TextClassificationModel:
def __init__(self, model, device):
self.model = model
self.device = device
self.model.to(device)
def __call__(self, input_ids, attn_mask, return_last_hidden_state=False):
self.model.eval()
with torch.no_grad():
input_ids = input_ids.to(self.device)
attn_mask = attn_mask.to(self.device)
output = self.model(input_ids=input_ids, attention_mask=attn_mask,
output_hidden_states=return_last_hidden_state)
logits = output['logits']
pred = torch.argmax(logits, dim=1)
if return_last_hidden_state:
hidden_states = output['hidden_states']
if return_last_hidden_state:
return pred, hidden_states[-1][:, 0, :]
else:
return pred
class XCLIPClassificationModel(nn.Module):
def __init__(self, num_labels):
super(XCLIPClassificationModel, self).__init__()
self.base_model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32")
self.num_labels = num_labels
hidden_size = self.base_model.config.hidden_size
self.fc_norm = nn.LayerNorm(hidden_size)
self.classifier = nn.Linear(hidden_size, self.num_labels)
self.loss_fct = nn.CrossEntropyLoss()
self.pool1 = nn.AdaptiveAvgPool1d(1)
self.pool2 = nn.AdaptiveAvgPool1d(1)
def forward(self, pixel_values, labels=None, return_last_hidden_state=False):
batch_size, num_frames, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(-1, num_channels, height, width)
out = self.base_model(pixel_values)[0] # [48, 50, 768]
out = torch.transpose(out, 1, 2) # [48, 768, 50]
out = self.pool1(out) # [48, 768, 1]
out = torch.transpose(out, 1, 2) # [48, 1, 768]
out = out.squeeze(1) # [48, 768]
hidden_out = out.view(batch_size, num_frames, -1) # [3, 16, 768]
hidden_out = torch.transpose(hidden_out, 1, 2) # [3, 768, 16]
pooled_out = self.pool2(hidden_out) # [3, 768, 1]
pooled_out = torch.transpose(pooled_out, 1, 2) # [3, 1, 768]
pooled_out = pooled_out[:, 0, :] # [3, 768]
logits = self.classifier(pooled_out)
loss = None
if labels is not None:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if return_last_hidden_state:
return {'logits': logits, 'loss': loss, 'last_hidden_state': pooled_out}
else:
return {'logits': logits, 'loss': loss}
class VideoClassificationModel:
def __init__(self, model, device):
self.model = model
self.device = device
self.model.to(device)
def __call__(self, pixel_values, return_last_hidden_state=False):
self.model.eval()
with torch.no_grad():
pixel_values = pixel_values.to(self.device)
output = self.model(pixel_values, return_last_hidden_state=return_last_hidden_state)
logits = output['logits']
pred = torch.argmax(logits, dim=1)
if return_last_hidden_state:
hidden_states = output['last_hidden_state']
if return_last_hidden_state:
return pred, hidden_states
else:
return pred
class ConvNet(nn.Module):
def __init__(self, num_labels, n_input=1, n_channel=32):
super(ConvNet, self).__init__()
self.ln0 = nn.LayerNorm((1, 6191))
self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=3)
self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
self.bn1 = nn.BatchNorm1d(n_channel)
self.bn2 = nn.BatchNorm1d(n_channel)
self.pool1 = nn.MaxPool1d(2)
self.fc1 = nn.Linear(n_channel*3093, 3093)
self.fc2 = nn.Linear(3093, num_labels)
self.flat = nn.Flatten()
self.dropout = nn.Dropout(0.3)
def forward(self, x, return_last_hidden_state=False):
x = self.ln0(x)
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.conv2(x)
x = F.relu(self.bn2(x))
x = self.pool1(x)
x = self.dropout(x)
x = self.flat(x)
hid = F.relu(self.fc1(x))
x = self.fc2(hid)
if not return_last_hidden_state:
return {'logits': F.log_softmax(x, dim=1)}
else:
return {'logits': F.log_softmax(x, dim=1), 'last_hidden_state': hid}
class AudioClassificationModel:
def __init__(self, model, device):
self.model = model
self.device = device
self.model.to(device)
def __call__(self, input_ids, return_last_hidden_state=False):
self.model.eval()
with torch.no_grad():
input_ids = torch.tensor(input_ids, dtype=torch.float).to(self.device)
output = self.model(input_ids, return_last_hidden_state=return_last_hidden_state)
logits = output['logits']
pred = torch.argmax(logits, dim=1)
if return_last_hidden_state:
hidden_state = output['last_hidden_state']
if return_last_hidden_state:
return pred, hidden_state
else:
return pred
class MultimodalClassificationModel(nn.Module):
def __init__(self, text_model, video_model, audio_model, num_labels, input_size, hidden_size=256):
super(MultimodalClassificationModel, self).__init__()
self.text_model = text_model
self.video_model = video_model
self.audio_model = audio_model
self.num_labels = num_labels
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, self.num_labels)
self.relu1 = nn.ReLU()
self.drop1 = nn.Dropout()
self.loss_func = nn.CrossEntropyLoss()
def forward(self, batch, labels=None):
text_pred, text_last_hidden = self.text_model(
batch['text']['input_ids'].squeeze(1),
batch['text']['attention_mask'].squeeze(1),
return_last_hidden_state=True
)
video_pred, video_last_hidden = self.video_model(
batch['video']['pixel_values'].squeeze(1),
return_last_hidden_state=True
)
audio_pred, audio_last_hidden = self.audio_model(
batch['audio'],
return_last_hidden_state=True
)
concat_input = torch.cat((text_last_hidden, video_last_hidden, audio_last_hidden), dim=1)
hidden_state = self.linear1(concat_input)
hidden_state = self.drop1(self.relu1(hidden_state))
logits = self.linear2(hidden_state)
loss = None
if labels is not None:
loss = self.loss_func(logits.view(-1, self.num_labels), labels.view(-1))
return {'logits': logits, 'loss': loss}
class MainModel:
def __init__(self, model, device):
self.model = model
self.device = device
self.model.to(device)
def __call__(self, batch):
self.model.eval()
with torch.no_grad():
output = self.model(batch)
logits = output['logits']
pred = torch.argmax(logits, dim=1)
return pred
def prepare_models(num_labels: int,
text_model_path: str,
video_model_path: str,
audio_model_path: str,
device: str='cpu'):
# TEXT
text_model_name = 'bert-large-uncased'
text_base_model = AutoModelForSequenceClassification.from_pretrained(
text_model_name,
num_labels=num_labels
)
state_dict = torch.load(text_model_path, map_location=torch.device('cpu'))
text_base_model.load_state_dict(state_dict, strict=False)
text_model = TextClassificationModel(text_base_model, device=device)
# VIDEO
video_base_model = XCLIPClassificationModel(num_labels)
state_dict = torch.load(video_model_path, map_location=torch.device('cpu'))
video_base_model.load_state_dict(state_dict, strict=False)
video_model = VideoClassificationModel(video_base_model, device=device)
# AUDIO
audio_base_model = ConvNet(num_labels)
checkpoint = torch.load(audio_model_path, map_location=torch.device('cpu'))
audio_base_model.load_state_dict(checkpoint['model_state_dict'])
audio_model = AudioClassificationModel(audio_base_model, device=device)
return text_model, video_model, audio_model
def sample_frame_indices(seg_len, clip_len=16, frame_sample_rate=4, mode="video"):
# seg_len -- how many frames are received
# clip_len -- how many frames to return
converted_len = int(clip_len * frame_sample_rate)
converted_len = min(converted_len, seg_len-1)
end_idx = np.random.randint(converted_len, seg_len)
start_idx = end_idx - converted_len
if mode == "video":
indices = np.linspace(start_idx, end_idx, num=clip_len)
else:
indices = np.linspace(start_idx, end_idx, num=clip_len*frame_sample_rate)
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
return indices
def get_frames(file_path, clip_len=16,):
cap = cv2.VideoCapture(file_path)
v_len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
indices = sample_frame_indices(v_len)
frames = []
for fn in range(v_len):
success, frame = cap.read()
if success is False:
continue
if (fn in indices):
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
res = cv2.resize(frame[90:-80, 60:-100], dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
frames.append(res)
cap.release()
if len(frames) < clip_len:
add_num = clip_len - len(frames)
frames_to_add = [frames[-1]] * add_num
frames.extend(frames_to_add)
return frames
def prepare_data_input(text: str,
video_path: str):
# VIDEO
video_frames = get_frames(video_path)
video_model_name = "microsoft/xclip-base-patch32"
video_feature_extractor = AutoProcessor.from_pretrained(video_model_name)
video_encoding = video_feature_extractor(videos=video_frames, return_tensors="pt")
# AUDIO
smile = opensmile.Smile(
opensmile.FeatureSet.ComParE_2016,
opensmile.FeatureLevel.Functionals,
sampling_rate=16000,
resample=True,
num_workers=5,
verbose=True,
)
audio_features = smile.process_files([video_path])
redundant_feat = open('files/redundant_feat.txt').read().split(',')
audio_features.drop(columns=redundant_feat, inplace=True)
# TEXT
text_model_name = 'bert-large-uncased'
tokenizer = AutoTokenizer.from_pretrained(text_model_name)
text_encoding = tokenizer(text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='pt')
return {'text': text_encoding, 'video': video_encoding, 'audio': audio_features.values.reshape((1, 1, 6191))}
def infer_multimodal_model(text: str,
video_path: str,
model_pathes: dict):
label2id = {'anger': 0, 'disgust': 1, 'fear': 2, 'joy': 3, 'neutral': 4, 'sadness': 5, 'surprise': 6}
id2label = {v: k for k, v in label2id.items()}
num_labels = 7
text_model, video_model, audio_model = prepare_models(num_labels,
model_pathes['text_model_path'],
model_pathes['video_model_path'],
model_pathes['audio_model_path'],)
multi_model = MultimodalClassificationModel(
text_model,
video_model,
audio_model,
num_labels,
input_size=4885,
hidden_size=512
)
checkpoint = torch.load(model_pathes['multimodal_model_path'], map_location=torch.device('cpu'))
multi_model.load_state_dict(checkpoint)
device = 'cpu'
final_model = MainModel(multi_model, device=device)
batch = prepare_data_input(text, video_path)
label = final_model(batch).detach().cpu().tolist()
return id2label[label[0]]