from typing import Dict, List import evaluate # type: ignore[import] from ..base_task_metrics import BaseTaskMetrics class CMGMetrics(BaseTaskMetrics): def __init__(self): self.bleu = evaluate.load("sacrebleu") self.chrf = evaluate.load("chrf") self.rouge = evaluate.load("rouge") self.bertscore = evaluate.load("bertscore") self.bertscore_normalized = evaluate.load("bertscore") def add_batch( self, predictions: List[str], references: List[str], *args, **kwargs ) -> None: self.bleu.add_batch( predictions=predictions, references=[[ref] for ref in references] ) self.chrf.add_batch( predictions=predictions, references=[[ref] for ref in references] ) self.rouge.add_batch(predictions=predictions, references=references) self.bertscore.add_batch(predictions=predictions, references=references) self.bertscore_normalized.add_batch( predictions=predictions, references=references ) def compute(self, *args, **kwargs) -> Dict[str, float]: rouge = self.rouge.compute() bertscore = self.bertscore.compute(lang="en") bertscore_normalized = self.bertscore_normalized.compute( lang="en", rescale_with_baseline=True ) return { "bleu": self.bleu.compute(tokenize="13a")["score"], "chrf": self.chrf.compute()["score"], "rouge1": rouge["rouge1"] * 100, "rouge2": rouge["rouge2"] * 100, "rougeL": rouge["rougeL"] * 100, "bertscore": sum(bertscore["f1"]) / len(bertscore["f1"]), "bertscore_normalized": sum(bertscore_normalized["f1"]) / len(bertscore_normalized["f1"]), }