File size: 5,939 Bytes
1ce6f28
408fb2e
 
880e945
408fb2e
 
 
 
 
 
 
880e945
408fb2e
880e945
dcf6e59
 
 
 
880e945
dcf6e59
 
 
 
880e945
6ac567d
 
 
 
 
 
 
c314120
dcf6e59
 
408fb2e
71a6f99
 
 
408fb2e
70a5709
408fb2e
70a5709
 
71a6f99
70a5709
 
 
 
8244168
408fb2e
70a5709
 
 
 
 
 
 
 
 
408fb2e
70a5709
9ab3033
70a5709
acc1fb0
70a5709
 
 
 
 
8312b78
 
4c1f576
ac31486
25819b2
 
ae4438b
dcf6e59
 
463c3f1
 
 
 
 
 
f7c578d
 
dcf6e59
 
f7c578d
 
 
 
 
dcf6e59
a4cd409
 
 
71a6f99
dcf6e59
880e945
dcf6e59
f7c578d
 
 
 
 
 
 
 
880e945
dcf6e59
 
 
f7c578d
 
71a6f99
f7c578d
 
 
71a6f99
 
 
 
 
 
 
f7c578d
6ac567d
 
 
 
 
71a6f99
 
bce3dcd
90981c9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import re
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name_or_path = "teknium/OpenHermes-2-Mistral-7B"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             load_in_8bit=True,
                                             revision="main")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

BASE_SYSTEM_MESSAGE = "I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning."

def clear_chat(chat_history_state, chat_message):
    chat_history_state = []
    chat_message = ''
    return chat_history_state, chat_message

def user(message, history):
    history = history or []
    history.append([message, ""])
    return "", history

def regenerate(history, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty):
    # Remove the last item from the history
    if history:
        history.pop(-1)
    
    # Re-execute the chat function
    return chat(history, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty)

def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty):
    history = history or []
    
    # Use BASE_SYSTEM_MESSAGE if system_message is empty
    system_message_to_use = system_message if system_message.strip() else BASE_SYSTEM_MESSAGE
    
    # A última mensagem do usuário
    user_prompt = history[-1][0] if history else ""
    
    # Preparar a entrada para o modelo
    prompt_template = f'''system
{system_message_to_use.strip()}
user
{user_prompt}
assistant
'''
    input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
    
    # Gerar a saída
    output = model.generate(
        input_ids=input_ids,
        max_length=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty
    )
    
    # Decodificar a saída
    decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
    assistant_response = decoded_output.split('assistant')[-1].strip()  # Pegar apenas a última resposta do assistente
    
    # Atualizar o histórico
    if history:
        history[-1][1] += assistant_response
    else:
        history.append(["", assistant_response])
    
    return history, history, ""


start_message = ""

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown("""
    ## OpenHermes-V2 Finetuned on Mistral 7B
    **Space created by [@artificialguybr](https://twitter.com/artificialguybr). Model by [@Teknium1](https://twitter.com/Teknium1). Thanks HF for GPU!**
    **OpenHermes-V2 is currently SOTA in some benchmarks for 7B models.**
    **Hermes 2 model was trained on 900,000 instructions, and surpasses all previous versions of Hermes 13B and below, and matches 70B on some benchmarks! Hermes 2 changes the game with strong multiturn chat skills, system prompt capabilities, and uses ChatML format. It's quality, diversity and scale is unmatched in the current OS LM landscape. Not only does it do well in benchmarks, but also in unmeasured capabilities, like Roleplaying, Tasks, and more.**
    """)
    with gr.Row():
        #chatbot = gr.Chatbot().style(height=500)
        chatbot = gr.Chatbot(elem_id="chatbot")
    with gr.Row():
        message = gr.Textbox(
            label="What do you want to chat about?",
            placeholder="Ask me anything.",
            lines=3,
        )
    with gr.Row():
        submit = gr.Button(value="Send message", variant="secondary", scale=1)
        clear = gr.Button(value="New topic", variant="secondary", scale=0)
        stop = gr.Button(value="Stop", variant="secondary", scale=0)
        regen_btn = gr.Button(value="Regenerate", variant="secondary", scale=0)
    with gr.Accordion("Show Model Parameters", open=False):
        with gr.Row():
            with gr.Column():
                max_tokens = gr.Slider(20, 2500, label="Max Tokens", step=20, value=500)
                temperature = gr.Slider(0.0, 2.0, label="Temperature", step=0.1, value=0.4)
                top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95)
                top_k = gr.Slider(1, 100, label="Top K", step=1, value=40)
                repetition_penalty = gr.Slider(1.0, 2.0, label="Repetition Penalty", step=0.1, value=1.1)

        system_msg = gr.Textbox(
            start_message, label="System Message", interactive=True, visible=True, placeholder="System prompt. Provide instructions which you want the model to remember.", lines=5)

    chat_history_state = gr.State()
    clear.click(clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False)
    clear.click(lambda: None, None, chatbot, queue=False)

    submit_click_event = submit.click(
    fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True
    ).then(
        fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, chat_history_state, message], queue=True
    )
    
    # Corrected the clear button click event
    clear.click(
        fn=clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False
    )
    
    # Stop button remains the same
    stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event], queue=False)
    regen_click_event = regen_btn.click(
        fn=regenerate, 
        inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], 
        outputs=[chatbot, chat_history_state, message], 
        queue=True
    )


demo.queue(max_size=128, concurrency_count=2)
demo.launch()