import spacy.cli import errant class Gramformer: def __init__(self, models=1, use_gpu=False): from transformers import AutoTokenizer from transformers import AutoModelForSeq2SeqLM # Ensure the SpaCy model 'en_core_web_sm' is downloaded spacy.cli.download("en_core_web_sm") # Load the correct SpaCy model for errant self.annotator = errant.load('en_core_web_sm') if use_gpu: device = "cuda:0" else: device = "cpu" batch_size = 1 self.device = device correction_model_tag = "prithivida/grammar_error_correcter_v1" self.model_loaded = False if models == 1: self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False) self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False) self.correction_model = self.correction_model.to(device) self.model_loaded = True print("[Gramformer] Grammar error correct/highlight model loaded..") elif models == 2: # TODO: Implement this part print("TO BE IMPLEMENTED!!!") def correct(self, input_sentence, max_candidates=1): if self.model_loaded: correction_prefix = "gec: " input_sentence = correction_prefix + input_sentence input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt') input_ids = input_ids.to(self.device) preds = self.correction_model.generate( input_ids, do_sample=True, max_length=128, num_beams=7, early_stopping=True, num_return_sequences=max_candidates ) corrected = set() for pred in preds: corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip()) return corrected else: print("Model is not loaded") return None def highlight(self, orig, cor): edits = self._get_edits(orig, cor) orig_tokens = orig.split() ignore_indexes = [] for edit in edits: edit_type = edit[0] edit_str_start = edit[1] edit_spos = edit[2] edit_epos = edit[3] edit_str_end = edit[4] # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion for i in range(edit_spos + 1, edit_epos): ignore_indexes.append(i) if edit_str_start == "": if edit_spos - 1 >= 0: new_edit_str = orig_tokens[edit_spos - 1] edit_spos -= 1 else: new_edit_str = orig_tokens[edit_spos + 1] edit_spos += 1 if edit_type == "PUNCT": st = f"{new_edit_str}" else: st = f"{new_edit_str}" orig_tokens[edit_spos] = st elif edit_str_end == "": st = f"{edit_str_start}" orig_tokens[edit_spos] = st else: st = f"{edit_str_start}" orig_tokens[edit_spos] = st for i in sorted(ignore_indexes, reverse=True): del orig_tokens[i] return " ".join(orig_tokens) def detect(self, input_sentence): # TO BE IMPLEMENTED pass def _get_edits(self, orig, cor): orig = self.annotator.parse(orig) cor = self.annotator.parse(cor) alignment = self.annotator.align(orig, cor) edits = self.annotator.merge(alignment) if len(edits) == 0: return [] edit_annotations = [] for e in edits: e = self.annotator.classify(e) edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end)) if len(edit_annotations) > 0: return edit_annotations else: return [] def get_edits(self, orig, cor): return self._get_edits(orig, cor)