File size: 6,162 Bytes
604b23b
 
 
 
 
 
 
 
 
 
 
a25255e
c140265
604b23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5c5e60
604b23b
 
 
 
 
 
 
 
 
f90d0a9
29d509b
604b23b
a25255e
 
604b23b
 
 
 
c140265
604b23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34de65c
604b23b
 
 
 
 
 
 
 
 
c140265
 
604b23b
 
 
 
 
 
 
327a801
604b23b
dc11589
604b23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c140265
604b23b
 
 
 
c140265
604b23b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import spaces
import os
import random
import argparse

import torch
import gradio as gr
import numpy as np

import ChatTTS

from OpenVoice import se_extractor
from OpenVoice.api import ToneColorConverter
import soundfile

print("loading ChatTTS model...")
chat = ChatTTS.Chat()
chat.load_models()


def generate_seed():
    new_seed = random.randint(1, 100000000)
    return {
        "__type__": "update",
        "value": new_seed
        }

@spaces.GPU
def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None):

    torch.manual_seed(audio_seed_input)
    rand_spk = torch.randn(768)
    params_infer_code = {
        'spk_emb': rand_spk, 
        'temperature': temperature,
        'top_P': top_P,
        'top_K': top_K,
        }
    params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
    
    torch.manual_seed(text_seed_input)

    if refine_text_flag:
        if refine_text_input:
           params_refine_text['prompt'] = refine_text_input
        text = chat.infer(text, 
                          skip_refine_text=False,
                          refine_text_only=True,
                          params_refine_text=params_refine_text,
                          params_infer_code=params_infer_code
                          )
        print("Text has been refined!")
    
    wav = chat.infer(text, 
                     skip_refine_text=True, 
                     params_refine_text=params_refine_text, 
                     params_infer_code=params_infer_code
                     )
    
    audio_data = np.array(wav[0]).flatten()
    sample_rate = 24000
    text_data = text[0] if isinstance(text, list) else text

    if output_path is None:
        return [(sample_rate, audio_data), text_data]
    else:
        soundfile.write(output_path, audio_data, sample_rate)
        return text_data

# OpenVoice Clone
ckpt_converter = 'OpenVoice/checkpoints/converter'
device = "cuda:0" if torch.cuda.is_available() else "cpu"

tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')

def generate_audio(text, audio_ref, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
    save_path = "output.wav"
    
    if audio_ref != "" :
      # Run the base speaker tts
      src_path = "tmp.wav"
      text_data = chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path)
      print("Ready for voice cloning!")
    
      source_se, audio_name = se_extractor.get_se(src_path, tone_color_converter, target_dir='processed', vad=True)
      reference_speaker = audio_ref
      target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)

      print("Get voices segment!")
    
      # Run the tone color converter
      # convert from file
      tone_color_converter.convert(
        audio_src_path=src_path,
        src_se=source_se,
        tgt_se=target_se,
        output_path=save_path)
    else:
      chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, save_path)

    print("Finished!")

    return [save_path, text_data]


with gr.Blocks() as demo:
    gr.Markdown("# <center>🥳 ChatTTS x OpenVoice 🥳</center>")
    gr.Markdown("## <center>🌟 Make it sound super natural and switch it up to any voice you want, nailing the mood and tone also!🌟 </center>")

    default_text = "Today a man knocked on my door and asked for a small donation toward the local swimming pool. I gave him a glass of water."        
    text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text)


    default_refine_text = "[oral_2][laugh_0][break_6]"    
    refine_text_input = gr.Textbox(label="Refine Prompt", lines=1, placeholder="Please Refine Prompt...", value=default_refine_text)
    refine_text_checkbox = gr.Checkbox(label="Refine text", info="use oral_(0-9), laugh_(0-2), break_(0-7).'oral' means add filler words, 'laugh' means add laughter, and 'break' means add a pause.", value=True)
    with gr.Column():    
        voice_ref = gr.Audio(label="Reference Audio", type="filepath", value="Examples/speaker.mp3")

    with gr.Row():
        temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature")
        top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P")
        top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K")

    with gr.Row():
        audio_seed_input = gr.Number(value=42, label="Speaker Seed")
        generate_audio_seed = gr.Button("\U0001F3B2")
        text_seed_input = gr.Number(value=42, label="Text Seed")
        generate_text_seed = gr.Button("\U0001F3B2")

    generate_button = gr.Button("Generate")
        
    text_output = gr.Textbox(label="Refined Text", interactive=False)
    audio_output = gr.Audio(label="Output Audio")

    generate_audio_seed.click(generate_seed, 
                              inputs=[], 
                              outputs=audio_seed_input)
        
    generate_text_seed.click(generate_seed, 
                             inputs=[], 
                             outputs=text_seed_input)
        
    generate_button.click(generate_audio, 
                          inputs=[text_input, voice_ref, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input], 
                          outputs=[audio_output,text_output])

parser = argparse.ArgumentParser(description='ChatTTS-OpenVoice Launch')
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
parser.add_argument('--server_port', type=int, default=8080, help='Server port')
args = parser.parse_args()

# demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)

if __name__ == '__main__':
    demo.launch()