import transformers import re from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch import gradio as gr import difflib from concurrent.futures import ThreadPoolExecutor import os # OCR Correction Model model_name = "PleIAs/OCRonos-Vintage" device = "cuda" if torch.cuda.is_available() else "cpu" # Load pre-trained model and tokenizer model = GPT2LMHeadModel.from_pretrained(model_name).to(device) tokenizer = GPT2Tokenizer.from_pretrained(model_name) # CSS for formatting css = """ """ def generate_html_diff(old_text, new_text): d = difflib.Differ() diff = list(d.compare(old_text.split(), new_text.split())) html_diff = [] for word in diff: if word.startswith(' '): html_diff.append(word[2:]) elif word.startswith('+ '): html_diff.append(f'{word[2:]}') return ' '.join(html_diff) def split_text(text, max_tokens=400): tokens = tokenizer.tokenize(text) chunks = [] current_chunk = [] for token in tokens: current_chunk.append(token) if len(current_chunk) >= max_tokens: chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) current_chunk = [] if current_chunk: chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) return chunks def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()): prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n""" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) torch.set_num_threads(num_threads) with ThreadPoolExecutor(max_workers=num_threads) as executor: future = executor.submit( model.generate, input_ids, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, top_k=50, num_return_sequences=1, do_sample=False ) output = future.result() result = tokenizer.decode(output[0], skip_special_tokens=True) return result.split("### Correction ###")[1].strip() def process_text(user_message): chunks = split_text(user_message) corrected_chunks = [] for chunk in chunks: corrected_chunk = ocr_correction(chunk) corrected_chunks.append(corrected_chunk) corrected_text = ' '.join(corrected_chunks) html_diff = generate_html_diff(user_message, corrected_text) ocr_result = f'

OCR Correction

\n
{html_diff}
' final_output = f"{css}{ocr_result}" return final_output # Define the Gradio interface with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: gr.HTML("""

Vintage OCR corrector (CPU)

""") text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5) process_button = gr.Button("Process Text") text_output = gr.HTML(label="Processed text") process_button.click(process_text, inputs=text_input, outputs=[text_output]) if __name__ == "__main__": demo.queue().launch()