scampion commited on
Commit
f3a36cd
1 Parent(s): 851b190

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +75 -0
handler.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import numpy as np
3
+ import pickle
4
+
5
+ from sklearn.preprocessing import MultiLabelBinarizer
6
+ from transformers import AutoTokenizer
7
+ import torch
8
+
9
+ from eurovoc import EurovocTagger
10
+
11
+ BERT_MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
12
+ MAX_LEN = 512
13
+ TEXT_MAX_LEN = MAX_LEN * 50
14
+ tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
15
+
16
+
17
+ class EndpointHandler:
18
+ mlb = MultiLabelBinarizer()
19
+
20
+ def __init__(self, path=""):
21
+ self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb"))
22
+
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ self.model = EurovocTagger.from_pretrained(path,
25
+ bert_model_name=BERT_MODEL_NAME,
26
+ n_classes=len(self.mlb.classes_),
27
+ map_location=self.device)
28
+ self.model.eval()
29
+ self.model.freeze()
30
+
31
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
32
+ """
33
+ data args:
34
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
35
+ kwargs
36
+ Return:
37
+ A :obj:`list` | `dict`: will be serialized and returned
38
+ """
39
+
40
+ text = data.pop("inputs", data)
41
+ topk = data.pop("topk", 5)
42
+ threshold = data.pop("threshold", 0.16)
43
+ debug = data.pop("debug", False)
44
+ prediction = self.get_prediction(text)
45
+ results = [{"label": label, "score": float(score)} for label, score in
46
+ zip(self.mlb.classes_, prediction[0].tolist())]
47
+ results = sorted(results, key=lambda x: x["score"], reverse=True)
48
+ results = [r for r in results if r["score"] > threshold]
49
+ results = results[:topk]
50
+ if debug:
51
+ return {"results": results, "values": prediction, "input": text}
52
+ else:
53
+ return {"results": results}
54
+
55
+ def get_prediction(self, text):
56
+ # split text into chunks of MAX_LEN and get average prediction for each chunk
57
+ chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)]
58
+ predictions = [self._get_prediction(chunk) for chunk in chunks]
59
+ predictions = np.array(predictions).mean(axis=0)
60
+ return predictions
61
+
62
+ def _get_prediction(self, text):
63
+ item = tokenizer.encode_plus(
64
+ text,
65
+ add_special_tokens=True,
66
+ max_length=MAX_LEN,
67
+ return_token_type_ids=False,
68
+ padding="max_length",
69
+ truncation=True,
70
+ return_attention_mask=True,
71
+ return_tensors='pt')
72
+ _, prediction = self.model(item["input_ids"], item["attention_mask"])
73
+ prediction = prediction.cpu().detach().numpy()
74
+ return prediction
75
+