Nakhwa commited on
Commit
9887d71
1 Parent(s): 617947e

Update test.py

Browse files
Files changed (1) hide show
  1. test.py +7 -0
test.py CHANGED
@@ -13,6 +13,9 @@ def get_model_and_tokenizer(model_name):
13
  default_model_name = "cahya/bert-base-indonesian-522M"
14
  tokenizer, model = load_model(default_model_name)
15
 
 
 
 
16
  # Prediction function
17
  def predict_hoax(title, content):
18
  if tokenizer is None or model is None:
@@ -23,6 +26,8 @@ def predict_hoax(title, content):
23
 
24
  text = f"{title} [SEP] {content}"
25
  inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256)
 
 
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
  probs = softmax(outputs.logits, dim=1)
@@ -36,6 +41,8 @@ def predict_proba_for_lime(texts):
36
  results = []
37
  for text in texts:
38
  inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256)
 
 
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
  probs = softmax(outputs.logits, dim=1).detach().cpu().numpy()
 
13
  default_model_name = "cahya/bert-base-indonesian-522M"
14
  tokenizer, model = load_model(default_model_name)
15
 
16
+ # Move model to GPU
17
+ model = model.to('cuda')
18
+
19
  # Prediction function
20
  def predict_hoax(title, content):
21
  if tokenizer is None or model is None:
 
26
 
27
  text = f"{title} [SEP] {content}"
28
  inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256)
29
+ inputs = {key: value.to('cuda') for key, value in inputs.items()} # Move inputs to GPU
30
+
31
  with torch.no_grad():
32
  outputs = model(**inputs)
33
  probs = softmax(outputs.logits, dim=1)
 
41
  results = []
42
  for text in texts:
43
  inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256)
44
+ inputs = {key: value.to('cuda') for key, value in inputs.items()} # Move inputs to GPU
45
+
46
  with torch.no_grad():
47
  outputs = model(**inputs)
48
  probs = softmax(outputs.logits, dim=1).detach().cpu().numpy()