“[shujaatalishariati]” commited on
Commit
116d721
1 Parent(s): 6eddba2

gramoformer

Browse files
Files changed (9) hide show
  1. LICENSE +21 -0
  2. __init__.py +0 -1
  3. app.py +23 -123
  4. gramformer.py +0 -126
  5. gramformer/__init__.py +1 -0
  6. gramformer/demo.py +30 -0
  7. gramformer/gramformer.py +128 -0
  8. requirements.txt +4 -9
  9. setup.py +20 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Prithivida
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
__init__.py DELETED
@@ -1 +0,0 @@
1
- from .gramformer import Gramformer
 
 
app.py CHANGED
@@ -1,126 +1,26 @@
1
- import os
2
  import gradio as gr
3
- from transformers import pipeline
4
- import spacy
5
- import subprocess
6
- import nltk
7
- from nltk.corpus import wordnet
8
- import torch
9
  from gramformer import Gramformer
10
 
11
- # Initialize the English text classification pipeline for AI detection
12
- pipeline_en = pipeline(task="text-classification", model="Hello-SimpleAI/chatgpt-detector-roberta")
13
-
14
- # Initialize Gramformer
15
- gf = Gramformer(models=1, use_gpu=False) # 1 = corrector
16
-
17
- # Function to predict the label and score for English text (AI Detection)
18
- def predict_en(text):
19
- res = pipeline_en(text)[0]
20
- return res['label'], res['score']
21
-
22
- # Ensure necessary NLTK data is downloaded for Humanifier
23
- nltk.download('wordnet')
24
- nltk.download('omw-1.4')
25
-
26
- # Ensure the SpaCy model is installed for Humanifier
27
- try:
28
- nlp = spacy.load("en_core_web_sm")
29
- except OSError:
30
- subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
31
- nlp = spacy.load("en_core_web_sm")
32
-
33
- # Function to get synonyms using NLTK WordNet (Humanifier)
34
- def get_synonyms_nltk(word, pos):
35
- synsets = wordnet.synsets(word, pos=pos)
36
- if synsets:
37
- lemmas = synsets[0].lemmas()
38
- return [lemma.name() for lemma in lemmas]
39
- return []
40
-
41
- # Function to capitalize the first letter of sentences and proper nouns (Humanifier)
42
- def capitalize_sentences_and_nouns(text):
43
- doc = nlp(text)
44
- corrected_text = []
45
-
46
- for sent in doc.sents:
47
- sentence = []
48
- for token in sent:
49
- if token.i == sent.start: # First word of the sentence
50
- sentence.append(token.text.capitalize())
51
- elif token.pos_ == "PROPN": # Proper noun
52
- sentence.append(token.text.capitalize())
53
- else:
54
- sentence.append(token.text)
55
- corrected_text.append(' '.join(sentence))
56
-
57
- return ' '.join(corrected_text)
58
-
59
- # Paraphrasing function using SpaCy and NLTK (Humanifier)
60
- def paraphrase_with_spacy_nltk(text):
61
- doc = nlp(text)
62
- paraphrased_words = []
63
-
64
- for token in doc:
65
- # Map SpaCy POS tags to WordNet POS tags
66
- pos = None
67
- if token.pos_ in {"NOUN"}:
68
- pos = wordnet.NOUN
69
- elif token.pos_ in {"VERB"}:
70
- pos = wordnet.VERB
71
- elif token.pos_ in {"ADJ"}:
72
- pos = wordnet.ADJ
73
- elif token.pos_ in {"ADV"}:
74
- pos = wordnet.ADV
75
-
76
- synonyms = get_synonyms_nltk(token.text.lower(), pos) if pos else []
77
-
78
- # Replace with a synonym only if it makes sense
79
- if synonyms and token.pos_ in {"NOUN", "VERB", "ADJ", "ADV"} and synonyms[0] != token.text.lower():
80
- paraphrased_words.append(synonyms[0])
81
- else:
82
- paraphrased_words.append(token.text)
83
-
84
- # Join the words back into a sentence
85
- paraphrased_sentence = ' '.join(paraphrased_words)
86
-
87
- # Capitalize sentences and proper nouns
88
- corrected_text = capitalize_sentences_and_nouns(paraphrased_sentence)
89
-
90
- return corrected_text
91
-
92
- # Combined function: Paraphrase -> Capitalization -> Grammar Correction (Humanifier)
93
- def paraphrase_correct_and_grammar(text):
94
- # Step 1: Paraphrase the text
95
- paraphrased_text = paraphrase_with_spacy_nltk(text)
96
-
97
- # Step 2: Capitalize sentences and proper nouns
98
- capitalized_text = capitalize_sentences_and_nouns(paraphrased_text)
99
-
100
- # Step 3: Grammar correction using Gramformer
101
- corrected_sentences = gf.correct(capitalized_text, max_candidates=1)
102
- final_text = next(iter(corrected_sentences)) if corrected_sentences else capitalized_text
103
-
104
- return final_text
105
-
106
- # Gradio app setup with two tabs
107
- with gr.Blocks() as demo:
108
- with gr.Tab("AI Detection"):
109
- t1 = gr.Textbox(lines=5, label='Text')
110
- button1 = gr.Button("🤖 Predict!")
111
- label1 = gr.Textbox(lines=1, label='Predicted Label 🎃')
112
- score1 = gr.Textbox(lines=1, label='Prob')
113
-
114
- # Connect the prediction function to the button
115
- button1.click(predict_en, inputs=[t1], outputs=[label1, score1], api_name='predict_en')
116
-
117
- with gr.Tab("Humanifier"):
118
- text_input = gr.Textbox(lines=5, label="Input Text")
119
- paraphrase_button = gr.Button("Paraphrase, Correct & Grammar Check")
120
- output_text = gr.Textbox(label="Processed Text")
121
-
122
- # Connect the paraphrasing and grammar correction function to the button
123
- paraphrase_button.click(paraphrase_correct_and_grammar, inputs=text_input, outputs=output_text)
124
-
125
- # Launch the app with both functionalities
126
- demo.launch()
 
 
1
  import gradio as gr
 
 
 
 
 
 
2
  from gramformer import Gramformer
3
 
4
+ # Initialize the Gramformer model (using default settings for now)
5
+ gf = Gramformer(models=1, use_gpu=False)
6
+
7
+ def correct_grammar(text):
8
+ # Correct the input text using Gramformer
9
+ corrected_sentences = gf.correct(text)
10
+ return " ".join(corrected_sentences)
11
+
12
+ # Gradio Interface
13
+ def main():
14
+ interface = gr.Interface(
15
+ fn=correct_grammar,
16
+ inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
17
+ outputs="text",
18
+ title="Grammar Correction App",
19
+ description="This app corrects grammar using the Gramformer model. Enter a sentence to correct its grammar.",
20
+ )
21
+
22
+ # Launch the Gradio interface
23
+ interface.launch()
24
+
25
+ if __name__ == "__main__":
26
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gramformer.py DELETED
@@ -1,126 +0,0 @@
1
- import spacy.cli
2
- import errant
3
-
4
- class Gramformer:
5
-
6
- def __init__(self, models=1, use_gpu=False):
7
- from transformers import AutoTokenizer
8
- from transformers import AutoModelForSeq2SeqLM
9
-
10
- # Ensure the SpaCy model 'en_core_web_sm' is downloaded
11
- spacy.cli.download("en_core_web_sm")
12
-
13
- # Load the correct SpaCy model for errant
14
- self.annotator = errant.load('en_core_web_sm')
15
-
16
- if use_gpu:
17
- device = "cuda:0"
18
- else:
19
- device = "cpu"
20
-
21
- batch_size = 1
22
- self.device = device
23
- correction_model_tag = "prithivida/grammar_error_correcter_v1"
24
- self.model_loaded = False
25
-
26
- if models == 1:
27
- self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False)
28
- self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False)
29
- self.correction_model = self.correction_model.to(device)
30
- self.model_loaded = True
31
- print("[Gramformer] Grammar error correct/highlight model loaded..")
32
- elif models == 2:
33
- # TODO: Implement this part
34
- print("TO BE IMPLEMENTED!!!")
35
-
36
- def correct(self, input_sentence, max_candidates=1):
37
- if self.model_loaded:
38
- correction_prefix = "gec: "
39
- input_sentence = correction_prefix + input_sentence
40
- input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
41
- input_ids = input_ids.to(self.device)
42
-
43
- preds = self.correction_model.generate(
44
- input_ids,
45
- do_sample=True,
46
- max_length=128,
47
- num_beams=7,
48
- early_stopping=True,
49
- num_return_sequences=max_candidates
50
- )
51
-
52
- corrected = set()
53
- for pred in preds:
54
- corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
55
-
56
- return corrected
57
- else:
58
- print("Model is not loaded")
59
- return None
60
-
61
- def highlight(self, orig, cor):
62
- edits = self._get_edits(orig, cor)
63
- orig_tokens = orig.split()
64
-
65
- ignore_indexes = []
66
-
67
- for edit in edits:
68
- edit_type = edit[0]
69
- edit_str_start = edit[1]
70
- edit_spos = edit[2]
71
- edit_epos = edit[3]
72
- edit_str_end = edit[4]
73
-
74
- # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
75
- for i in range(edit_spos + 1, edit_epos):
76
- ignore_indexes.append(i)
77
-
78
- if edit_str_start == "":
79
- if edit_spos - 1 >= 0:
80
- new_edit_str = orig_tokens[edit_spos - 1]
81
- edit_spos -= 1
82
- else:
83
- new_edit_str = orig_tokens[edit_spos + 1]
84
- edit_spos += 1
85
- if edit_type == "PUNCT":
86
- st = f"<a type='{edit_type}' edit='{edit_str_end}'>{new_edit_str}</a>"
87
- else:
88
- st = f"<a type='{edit_type}' edit='{new_edit_str} {edit_str_end}'>{new_edit_str}</a>"
89
- orig_tokens[edit_spos] = st
90
- elif edit_str_end == "":
91
- st = f"<d type='{edit_type}' edit=''>{edit_str_start}</d>"
92
- orig_tokens[edit_spos] = st
93
- else:
94
- st = f"<c type='{edit_type}' edit='{edit_str_end}'>{edit_str_start}</c>"
95
- orig_tokens[edit_spos] = st
96
-
97
- for i in sorted(ignore_indexes, reverse=True):
98
- del orig_tokens[i]
99
-
100
- return " ".join(orig_tokens)
101
-
102
- def detect(self, input_sentence):
103
- # TO BE IMPLEMENTED
104
- pass
105
-
106
- def _get_edits(self, orig, cor):
107
- orig = self.annotator.parse(orig)
108
- cor = self.annotator.parse(cor)
109
- alignment = self.annotator.align(orig, cor)
110
- edits = self.annotator.merge(alignment)
111
-
112
- if len(edits) == 0:
113
- return []
114
-
115
- edit_annotations = []
116
- for e in edits:
117
- e = self.annotator.classify(e)
118
- edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
119
-
120
- if len(edit_annotations) > 0:
121
- return edit_annotations
122
- else:
123
- return []
124
-
125
- def get_edits(self, orig, cor):
126
- return self._get_edits(orig, cor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gramformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from gramformer.gramformer import Gramformer
gramformer/demo.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gramformer import Gramformer
2
+ import torch
3
+
4
+ def set_seed(seed):
5
+ torch.manual_seed(seed)
6
+ if torch.cuda.is_available():
7
+ torch.cuda.manual_seed_all(seed)
8
+
9
+ set_seed(1212)
10
+
11
+
12
+ gf = Gramformer(models = 1, use_gpu=False) # 1=corrector, 2=detector
13
+
14
+ influent_sentences = [
15
+ "Matt like fish",
16
+ "the collection of letters was original used by the ancient Romans",
17
+ "We enjoys horror movies",
18
+ "Anna and Mike is going skiing",
19
+ "I walk to the store and I bought milk",
20
+ "We all eat the fish and then made dessert",
21
+ "I will eat fish for dinner and drank milk",
22
+ "what be the reason for everyone leave the company",
23
+ ]
24
+
25
+ for influent_sentence in influent_sentences:
26
+ corrected_sentences = gf.correct(influent_sentence, max_candidates=1)
27
+ print("[Input] ", influent_sentence)
28
+ for corrected_sentence in corrected_sentences:
29
+ print("[Correction] ",corrected_sentence)
30
+ print("-" *100)
gramformer/gramformer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Gramformer:
2
+
3
+ def __init__(self, models=1, use_gpu=False):
4
+ from transformers import AutoTokenizer
5
+ from transformers import AutoModelForSeq2SeqLM
6
+ #from lm_scorer.models.auto import AutoLMScorer as LMScorer
7
+ import errant
8
+ self.annotator = errant.load('en')
9
+
10
+ if use_gpu:
11
+ device= "cuda:0"
12
+ else:
13
+ device = "cpu"
14
+ batch_size = 1
15
+ #self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
16
+ self.device = device
17
+ correction_model_tag = "prithivida/grammar_error_correcter_v1"
18
+ self.model_loaded = False
19
+
20
+ if models == 1:
21
+ self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False)
22
+ self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False)
23
+ self.correction_model = self.correction_model.to(device)
24
+ self.model_loaded = True
25
+ print("[Gramformer] Grammar error correct/highlight model loaded..")
26
+ elif models == 2:
27
+ # TODO
28
+ print("TO BE IMPLEMENTED!!!")
29
+
30
+ def correct(self, input_sentence, max_candidates=1):
31
+ if self.model_loaded:
32
+ correction_prefix = "gec: "
33
+ input_sentence = correction_prefix + input_sentence
34
+ input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
35
+ input_ids = input_ids.to(self.device)
36
+
37
+ preds = self.correction_model.generate(
38
+ input_ids,
39
+ do_sample=True,
40
+ max_length=128,
41
+ # top_k=50,
42
+ # top_p=0.95,
43
+ num_beams=7,
44
+ early_stopping=True,
45
+ num_return_sequences=max_candidates)
46
+
47
+ corrected = set()
48
+ for pred in preds:
49
+ corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
50
+
51
+ #corrected = list(corrected)
52
+ #scores = self.scorer.sentence_score(corrected, log=True)
53
+ #ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
54
+ #ranked_corrected.sort(key = lambda x:x[1], reverse=True)
55
+ return corrected
56
+ else:
57
+ print("Model is not loaded")
58
+ return None
59
+
60
+ def highlight(self, orig, cor):
61
+ edits = self._get_edits(orig, cor)
62
+ orig_tokens = orig.split()
63
+
64
+ ignore_indexes = []
65
+
66
+ for edit in edits:
67
+ edit_type = edit[0]
68
+ edit_str_start = edit[1]
69
+ edit_spos = edit[2]
70
+ edit_epos = edit[3]
71
+ edit_str_end = edit[4]
72
+
73
+ # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
74
+ for i in range(edit_spos+1, edit_epos):
75
+ ignore_indexes.append(i)
76
+
77
+ if edit_str_start == "":
78
+ if edit_spos - 1 >= 0:
79
+ new_edit_str = orig_tokens[edit_spos - 1]
80
+ edit_spos -= 1
81
+ else:
82
+ new_edit_str = orig_tokens[edit_spos + 1]
83
+ edit_spos += 1
84
+ if edit_type == "PUNCT":
85
+ st = "<a type='" + edit_type + "' edit='" + \
86
+ edit_str_end + "'>" + new_edit_str + "</a>"
87
+ else:
88
+ st = "<a type='" + edit_type + "' edit='" + new_edit_str + \
89
+ " " + edit_str_end + "'>" + new_edit_str + "</a>"
90
+ orig_tokens[edit_spos] = st
91
+ elif edit_str_end == "":
92
+ st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>"
93
+ orig_tokens[edit_spos] = st
94
+ else:
95
+ st = "<c type='" + edit_type + "' edit='" + \
96
+ edit_str_end + "'>" + edit_str_start + "</c>"
97
+ orig_tokens[edit_spos] = st
98
+
99
+ for i in sorted(ignore_indexes, reverse=True):
100
+ del(orig_tokens[i])
101
+
102
+ return(" ".join(orig_tokens))
103
+
104
+ def detect(self, input_sentence):
105
+ # TO BE IMPLEMENTED
106
+ pass
107
+
108
+ def _get_edits(self, orig, cor):
109
+ orig = self.annotator.parse(orig)
110
+ cor = self.annotator.parse(cor)
111
+ alignment = self.annotator.align(orig, cor)
112
+ edits = self.annotator.merge(alignment)
113
+
114
+ if len(edits) == 0:
115
+ return []
116
+
117
+ edit_annotations = []
118
+ for e in edits:
119
+ e = self.annotator.classify(e)
120
+ edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
121
+
122
+ if len(edit_annotations) > 0:
123
+ return edit_annotations
124
+ else:
125
+ return []
126
+
127
+ def get_edits(self, orig, cor):
128
+ return self._get_edits(orig, cor)
requirements.txt CHANGED
@@ -1,9 +1,4 @@
1
- gradio==3.50.2
2
- transformers==4.36.2
3
- spacy==3.5.3
4
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0.tar.gz
5
- nltk==3.8.1
6
- torch==2.1.2
7
- git+https://github.com/PrithivirajDamodaran/Gramformer.git
8
- typer==0.9.0
9
- click==8.0.4
 
1
+ transformers
2
+ torch
3
+ gradio
4
+ errant
 
 
 
 
 
setup.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+ setuptools.setup(
4
+ name="gramformer",
5
+ version="1.0",
6
+ author="prithiviraj damodaran",
7
+ author_email="",
8
+ description="Gramformer",
9
+ long_description="A framework for detecting, highlighting and correcting grammatical errors on natural language text",
10
+ url="https://github.com/PrithivirajDamodaran/Gramformer.git",
11
+ packages=setuptools.find_packages(),
12
+ #install_requires=['transformers', 'sentencepiece==0.1.95', 'python-Levenshtein==0.12.2', 'fuzzywuzzy==0.18.0', 'tokenizers==0.10.2', 'fsspec==2021.5.0', 'lm-scorer==0.4.2', 'errant'],
13
+ install_requires=['transformers', 'sentencepiece', 'python-Levenshtein', 'fuzzywuzzy', 'tokenizers', 'fsspec', 'errant'],
14
+ classifiers=[
15
+ "Programming Language :: Python :: 3.7",
16
+ "License :: Apache 2.0",
17
+ "Operating System :: OS Independent",
18
+ ],
19
+ )
20
+