Remove extra import
Browse files
src/evaluation/commit_message_generation/cmg_metrics.py
CHANGED
@@ -3,25 +3,19 @@ from typing import Dict, List
|
|
3 |
import evaluate # type: ignore[import]
|
4 |
|
5 |
from ..base_task_metrics import BaseTaskMetrics
|
6 |
-
from .b_norm import BNorm
|
7 |
|
8 |
|
9 |
class CMGMetrics(BaseTaskMetrics):
|
10 |
def __init__(self):
|
11 |
-
self.bnorm = BNorm()
|
12 |
self.bleu = evaluate.load("sacrebleu")
|
13 |
self.chrf = evaluate.load("chrf")
|
14 |
self.rouge = evaluate.load("rouge")
|
15 |
self.bertscore = evaluate.load("bertscore")
|
16 |
self.bertscore_normalized = evaluate.load("bertscore")
|
17 |
|
18 |
-
def
|
19 |
-
self.bnorm.reset()
|
20 |
-
|
21 |
-
def update(
|
22 |
self, predictions: List[str], references: List[str], *args, **kwargs
|
23 |
) -> None:
|
24 |
-
self.bnorm.update(predictions=predictions, references=references)
|
25 |
self.bleu.add_batch(
|
26 |
predictions=predictions, references=[[ref] for ref in references]
|
27 |
)
|
@@ -41,7 +35,6 @@ class CMGMetrics(BaseTaskMetrics):
|
|
41 |
lang="en", rescale_with_baseline=True
|
42 |
)
|
43 |
return {
|
44 |
-
"bnorm": self.bnorm.compute(),
|
45 |
"bleu": self.bleu.compute(tokenize="13a")["score"],
|
46 |
"chrf": self.chrf.compute()["score"],
|
47 |
"rouge1": rouge["rouge1"] * 100,
|
|
|
3 |
import evaluate # type: ignore[import]
|
4 |
|
5 |
from ..base_task_metrics import BaseTaskMetrics
|
|
|
6 |
|
7 |
|
8 |
class CMGMetrics(BaseTaskMetrics):
|
9 |
def __init__(self):
|
|
|
10 |
self.bleu = evaluate.load("sacrebleu")
|
11 |
self.chrf = evaluate.load("chrf")
|
12 |
self.rouge = evaluate.load("rouge")
|
13 |
self.bertscore = evaluate.load("bertscore")
|
14 |
self.bertscore_normalized = evaluate.load("bertscore")
|
15 |
|
16 |
+
def add_batch(
|
|
|
|
|
|
|
17 |
self, predictions: List[str], references: List[str], *args, **kwargs
|
18 |
) -> None:
|
|
|
19 |
self.bleu.add_batch(
|
20 |
predictions=predictions, references=[[ref] for ref in references]
|
21 |
)
|
|
|
35 |
lang="en", rescale_with_baseline=True
|
36 |
)
|
37 |
return {
|
|
|
38 |
"bleu": self.bleu.compute(tokenize="13a")["score"],
|
39 |
"chrf": self.chrf.compute()["score"],
|
40 |
"rouge1": rouge["rouge1"] * 100,
|