from typing import Generator import numpy as np import torch from modules import config, models from modules.utils.SeedContext import SeedContext @torch.inference_mode() def refine_text( text: str, prompt="[oral_2][laugh_0][break_6]", seed=-1, top_P=0.7, top_K=20, temperature=0.7, repetition_penalty=1.0, max_new_token=384, ) -> str: chat_tts = models.load_chat_tts() with SeedContext(seed): refined_text = chat_tts.refiner_prompt( text, { "prompt": prompt, "top_K": top_K, "top_P": top_P, "temperature": temperature, "repetition_penalty": repetition_penalty, "max_new_token": max_new_token, "disable_tqdm": config.runtime_env_vars.off_tqdm, }, ) if isinstance(refined_text, Generator): raise NotImplementedError( "Refiner is not yet implemented for generator output" ) if isinstance(refined_text, list): refined_text = "\n".join(refined_text) return refined_text