AlexandraDolidze commited on
Commit
f2c3ec9
1 Parent(s): 327ed58

Inference and app.py added

Browse files
Files changed (2) hide show
  1. app.py +18 -0
  2. hf_inference.py +313 -0
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import hf_inference
3
+
4
+ # пути до моделей словарь
5
+ models_dict = ['text_model_path' : text_model_path,
6
+ 'video_model_path' : video_model_path,
7
+ 'audio_model_path': audio_model_path]
8
+
9
+ st.title("Multimodal ERC project")
10
+
11
+ uploaded_file = st.file_uploader("Choose a video")
12
+ input_text = st.text_area("Please, write transcript", '''That's obligatory.''')
13
+
14
+ if uploaded_file is not None & input_text != '''That's obligatory.''':
15
+ output_emotion = infer_multimodal_model(input_text, uploaded_file, models_dict)
16
+ # закидываю видео и текст в инференс
17
+ # получаю аутпут эмоции
18
+ st.write(f"We think that's {output_emotion}")
hf_inference.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoProcessor, AutoTokenizer, XCLIPVisionModel, AutoModel, AutoModelForSequenceClassification
5
+
6
+ import numpy as np
7
+ import cv2
8
+ import opensmile
9
+
10
+
11
+ class TextClassificationModel:
12
+ def __init__(self, model, device):
13
+ self.model = model
14
+ self.device = device
15
+ self.model.to(device)
16
+
17
+ def __call__(self, input_ids, attn_mask, return_last_hidden_state=False):
18
+ self.model.eval()
19
+ with torch.no_grad():
20
+ input_ids = input_ids.to(self.device)
21
+ attn_mask = attn_mask.to(self.device)
22
+ output = self.model(input_ids=input_ids, attention_mask=attn_mask,
23
+ output_hidden_states=return_last_hidden_state)
24
+ logits = output['logits']
25
+ pred = torch.argmax(logits, dim=1)
26
+ if return_last_hidden_state:
27
+ hidden_states = output['hidden_states']
28
+ if return_last_hidden_state:
29
+ return pred, hidden_states[-1][:, 0, :]
30
+ else:
31
+ return pred
32
+
33
+
34
+ class XCLIPClassificationModel(nn.Module):
35
+ def __init__(self, num_labels):
36
+ super(XCLIPClassificationModel, self).__init__()
37
+ self.base_model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32")
38
+ self.num_labels = num_labels
39
+ hidden_size = self.base_model.config.hidden_size
40
+ self.fc_norm = nn.LayerNorm(hidden_size)
41
+ self.classifier = nn.Linear(hidden_size, self.num_labels)
42
+ self.loss_fct = nn.CrossEntropyLoss()
43
+ self.pool1 = nn.AdaptiveAvgPool1d(1)
44
+ self.pool2 = nn.AdaptiveAvgPool1d(1)
45
+
46
+ def forward(self, pixel_values, labels=None, return_last_hidden_state=False):
47
+ batch_size, num_frames, num_channels, height, width = pixel_values.shape
48
+ pixel_values = pixel_values.reshape(-1, num_channels, height, width)
49
+ out = self.base_model(pixel_values)[0] # [48, 50, 768]
50
+ out = torch.transpose(out, 1, 2) # [48, 768, 50]
51
+ out = self.pool1(out) # [48, 768, 1]
52
+ out = torch.transpose(out, 1, 2) # [48, 1, 768]
53
+ out = out.squeeze(1) # [48, 768]
54
+ hidden_out = out.view(batch_size, num_frames, -1) # [3, 16, 768]
55
+ hidden_out = torch.transpose(hidden_out, 1, 2) # [3, 768, 16]
56
+ pooled_out = self.pool2(hidden_out) # [3, 768, 1]
57
+ pooled_out = torch.transpose(pooled_out, 1, 2) # [3, 1, 768]
58
+ pooled_out = pooled_out[:, 0, :] # [3, 768]
59
+ logits = self.classifier(pooled_out)
60
+ loss = None
61
+ if labels is not None:
62
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
63
+
64
+ if return_last_hidden_state:
65
+ return {'logits': logits, 'loss': loss, 'last_hidden_state': pooled_out}
66
+ else:
67
+ return {'logits': logits, 'loss': loss}
68
+
69
+
70
+ class VideoClassificationModel:
71
+ def __init__(self, model, device):
72
+ self.model = model
73
+ self.device = device
74
+ self.model.to(device)
75
+
76
+ def __call__(self, pixel_values, return_last_hidden_state=False):
77
+ self.model.eval()
78
+ with torch.no_grad():
79
+ pixel_values = pixel_values.to(self.device)
80
+ output = self.model(pixel_values, return_last_hidden_state=return_last_hidden_state)
81
+ logits = output['logits']
82
+ pred = torch.argmax(logits, dim=1)
83
+ if return_last_hidden_state:
84
+ hidden_states = output['last_hidden_state']
85
+ if return_last_hidden_state:
86
+ return pred, hidden_states
87
+ else:
88
+ return pred
89
+
90
+
91
+ class ConvNet(nn.Module):
92
+ def __init__(self, num_labels, n_input=1, n_channel=32):
93
+ super(ConvNet, self).__init__()
94
+ self.ln0 = nn.LayerNorm((1, 6191))
95
+ self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=3)
96
+ self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
97
+ self.bn1 = nn.BatchNorm1d(n_channel)
98
+ self.bn2 = nn.BatchNorm1d(n_channel)
99
+ self.pool1 = nn.MaxPool1d(2)
100
+ self.fc1 = nn.Linear(n_channel*3093, 3093)
101
+ self.fc2 = nn.Linear(3093, num_labels)
102
+ self.flat = nn.Flatten()
103
+ self.dropout = nn.Dropout(0.3)
104
+
105
+ def forward(self, x, return_last_hidden_state=False):
106
+ x = self.ln0(x)
107
+ x = self.conv1(x)
108
+ x = F.relu(self.bn1(x))
109
+ x = self.conv2(x)
110
+ x = F.relu(self.bn2(x))
111
+ x = self.pool1(x)
112
+ x = self.dropout(x)
113
+ x = self.flat(x)
114
+ hid = F.relu(self.fc1(x))
115
+ x = self.fc2(hid)
116
+ if not return_last_hidden_state:
117
+ return {'logits': F.log_softmax(x, dim=1)}
118
+ else:
119
+ return {'logits': F.log_softmax(x, dim=1), 'last_hidden_state': hid}
120
+
121
+
122
+ class AudioClassificationModel:
123
+ def __init__(self, model, device):
124
+ self.model = model
125
+ self.device = device
126
+ self.model.to(device)
127
+
128
+ def __call__(self, input_ids, return_last_hidden_state=False):
129
+ self.model.eval()
130
+ with torch.no_grad():
131
+ input_ids = torch.tensor(input_ids, dtype=torch.float).to(self.device)
132
+ output = self.model(input_ids, return_last_hidden_state=return_last_hidden_state)
133
+ logits = output['logits']
134
+ pred = torch.argmax(logits, dim=1)
135
+ if return_last_hidden_state:
136
+ hidden_state = output['last_hidden_state']
137
+ if return_last_hidden_state:
138
+ return pred, hidden_state
139
+ else:
140
+ return pred
141
+
142
+
143
+ class MultimodalClassificationModel(nn.Module):
144
+ def __init__(self, text_model, video_model, audio_model, num_labels, input_size, hidden_size=256):
145
+ super(MultimodalClassificationModel, self).__init__()
146
+ self.text_model = text_model
147
+ self.video_model = video_model
148
+ self.audio_model = audio_model
149
+ self.num_labels = num_labels
150
+ self.linear1 = nn.Linear(input_size, hidden_size)
151
+ self.linear2 = nn.Linear(hidden_size, self.num_labels)
152
+ self.relu1 = nn.ReLU()
153
+ self.drop1 = nn.Dropout()
154
+ self.loss_func = nn.CrossEntropyLoss()
155
+
156
+ def forward(self, batch, labels=None):
157
+ text_pred, text_last_hidden = self.text_model(
158
+ batch['text']['input_ids'].squeeze(1),
159
+ batch['text']['attention_mask'].squeeze(1),
160
+ return_last_hidden_state=True
161
+ )
162
+ video_pred, video_last_hidden = self.video_model(
163
+ batch['video']['pixel_values'].squeeze(1),
164
+ return_last_hidden_state=True
165
+ )
166
+ audio_pred, audio_last_hidden = self.audio_model(
167
+ batch['audio'],
168
+ return_last_hidden_state=True
169
+ )
170
+ concat_input = torch.cat((text_last_hidden, video_last_hidden, audio_last_hidden), dim=1)
171
+ hidden_state = self.linear1(concat_input)
172
+ hidden_state = self.drop1(self.relu1(hidden_state))
173
+ logits = self.linear2(hidden_state)
174
+ loss = None
175
+ if labels is not None:
176
+ loss = self.loss_func(logits.view(-1, self.num_labels), labels.view(-1))
177
+ return {'logits': logits, 'loss': loss}
178
+
179
+
180
+ class MainModel:
181
+ def __init__(self, model, device):
182
+ self.model = model
183
+ self.device = device
184
+ self.model.to(device)
185
+
186
+ def __call__(self, batch):
187
+ self.model.eval()
188
+ with torch.no_grad():
189
+ output = self.model(batch)
190
+ logits = output['logits']
191
+ pred = torch.argmax(logits, dim=1)
192
+ return pred
193
+
194
+ def prepare_models(num_labels: int,
195
+ text_model_path: str,
196
+ video_model_path: str,
197
+ audio_model_path: str,
198
+ device: str='cuda'):
199
+ # TEXT
200
+ text_model_name = 'bert-large-uncased'
201
+ text_base_model = AutoModelForSequenceClassification.from_pretrained(
202
+ text_model_name,
203
+ num_labels=num_labels
204
+ )
205
+ state_dict = torch.load(text_model_path)
206
+ text_base_model.load_state_dict(state_dict, strict=False)
207
+ text_model = TextClassificationModel(text_base_model, device=device)
208
+
209
+ # VIDEO
210
+ video_base_model = XCLIPClassificationModel(num_labels)
211
+ state_dict = torch.load(video_model_path)
212
+ video_base_model.load_state_dict(state_dict, strict=False)
213
+ video_model = VideoClassificationModel(video_base_model, device=device)
214
+
215
+ # AUDIO
216
+ audio_base_model = ConvNet(num_labels)
217
+ checkpoint = torch.load(audio_model_path)
218
+ audio_base_model.load_state_dict(checkpoint['model_state_dict'])
219
+ audio_model = AudioClassificationModel(audio_base_model, device=device)
220
+
221
+ return text_model, video_model, audio_model
222
+
223
+ def sample_frame_indices(seg_len, clip_len=16, frame_sample_rate=4, mode="video"):
224
+ # seg_len -- how many frames are received
225
+ # clip_len -- how many frames to return
226
+ converted_len = int(clip_len * frame_sample_rate)
227
+ converted_len = min(converted_len, seg_len-1)
228
+ end_idx = np.random.randint(converted_len, seg_len)
229
+ start_idx = end_idx - converted_len
230
+ if mode == "video":
231
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
232
+ else:
233
+ indices = np.linspace(start_idx, end_idx, num=clip_len*frame_sample_rate)
234
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
235
+ return indices
236
+
237
+ def get_frames(file_path, clip_len=16,):
238
+ cap = cv2.VideoCapture(file_path)
239
+ v_len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
240
+ indices = sample_frame_indices(v_len)
241
+
242
+ frames = []
243
+ for fn in range(v_len):
244
+ success, frame = cap.read()
245
+ if success is False:
246
+ continue
247
+ if (fn in indices):
248
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
249
+ res = cv2.resize(frame[90:-80, 60:-100], dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
250
+ frames.append(res)
251
+ cap.release()
252
+
253
+ if len(frames) < clip_len:
254
+ add_num = clip_len - len(frames)
255
+ frames_to_add = [frames[-1]] * add_num
256
+ frames.extend(frames_to_add)
257
+
258
+ return frames
259
+
260
+ def prepare_data_input(text: str,
261
+ video_path: str):
262
+ # VIDEO
263
+ video_frames = get_frames(video_path)
264
+ video_model_name = "microsoft/xclip-base-patch32"
265
+ video_feature_extractor = AutoProcessor.from_pretrained(video_model_name)
266
+ video_encoding = video_feature_extractor(videos=video_frames, return_tensors="pt")
267
+ # AUDIO
268
+ smile = opensmile.Smile(
269
+ opensmile.FeatureSet.ComParE_2016,
270
+ opensmile.FeatureLevel.Functionals,
271
+ sampling_rate=16000,
272
+ resample=True,
273
+ num_workers=5,
274
+ verbose=True,
275
+ )
276
+ audio_features = smile.process_files([video_path])
277
+ redundant_feat = open('redundant_feat.txt').read().split(',')
278
+ audio_features.drop(columns=redundant_feat, inplace=True)
279
+ # TEXT
280
+ text_model_name = 'bert-large-uncased'
281
+ tokenizer = AutoTokenizer.from_pretrained(text_model_name)
282
+ text_encoding = tokenizer(text,
283
+ padding='max_length',
284
+ truncation=True,
285
+ max_length=128,
286
+ return_tensors='pt')
287
+ return {'text': text_encoding, 'video': video_encoding, 'audio': audio_features.values.reshape((1, 1, 6191))}
288
+
289
+ def infer_multimodal_model(text: str,
290
+ video_path: str,
291
+ model_pathes: dict):
292
+ label2id = {'anger': 0, 'disgust': 1, 'fear': 2, 'joy': 3, 'neutral': 4, 'sadness': 5, 'surprise': 6}
293
+ id2label = {v: k for k, v in label2id.items()}
294
+ num_labels = 7
295
+ text_model, video_model, audio_model = prepare_models(num_labels,
296
+ model_pathes['text_model_path'],
297
+ model_pathes['video_model_path'],
298
+ model_pathes['audio_model_path'],)
299
+ multi_model = MultimodalClassificationModel(
300
+ text_model,
301
+ video_model,
302
+ audio_model,
303
+ num_labels,
304
+ input_size=4885,
305
+ hidden_size=512
306
+ )
307
+ checkpoint = torch.load(model_pathes['multimodal_model_path'])
308
+ multi_model.load_state_dict(checkpoint)
309
+ device = 'cuda'
310
+ final_model = MainModel(multi_model, device=device)
311
+ batch = prepare_data_input(text, video_path)
312
+ label = final_model(batch).detach().cpu().tolist()
313
+ return id2label[label[0]]