File size: 6,138 Bytes
1ce6f28
408fb2e
 
880e945
408fb2e
 
 
 
 
 
 
880e945
408fb2e
880e945
dcf6e59
408fb2e
 
 
 
1e1d43c
dcf6e59
 
 
 
880e945
dcf6e59
 
 
 
880e945
71a6f99
 
 
 
 
 
 
 
 
dcf6e59
 
408fb2e
71a6f99
 
 
408fb2e
70a5709
408fb2e
70a5709
 
71a6f99
70a5709
 
 
 
8244168
408fb2e
70a5709
 
 
 
 
 
 
 
 
408fb2e
70a5709
 
 
acc1fb0
70a5709
 
 
 
 
8312b78
 
4c1f576
ac31486
25819b2
 
ae4438b
dcf6e59
 
463c3f1
 
 
 
 
 
f7c578d
 
dcf6e59
 
f7c578d
 
 
 
 
dcf6e59
a4cd409
 
 
71a6f99
dcf6e59
880e945
dcf6e59
f7c578d
 
 
 
 
 
 
 
880e945
dcf6e59
 
 
f7c578d
 
71a6f99
f7c578d
 
 
71a6f99
 
 
 
 
 
 
f7c578d
 
71a6f99
 
 
 
90981c9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import re
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name_or_path = "teknium/OpenHermes-2-Mistral-7B"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             load_in_8bit=True,
                                             revision="main")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

BASE_SYSTEM_MESSAGE = "I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning."

def make_prediction(prompt, max_tokens=None, temperature=None, top_p=None, top_k=None, repetition_penalty=None):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    out = model.generate(input_ids, max_length=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    yield text

def clear_chat(chat_history_state, chat_message):
    chat_history_state = []
    chat_message = ''
    return chat_history_state, chat_message

def user(message, history):
    history = history or []
    history.append([message, ""])
    return "", history

def regenerate(_chatbot, _task_history):
    if not _task_history:
        yield _chatbot
        return
    item = _task_history.pop(-1)
    _chatbot.pop(-1)
    yield from make_prediction(item[0])  # Assuming make_prediction is the function you want to use for regeneration


def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty):
    history = history or []
    
    # Use BASE_SYSTEM_MESSAGE if system_message is empty
    system_message_to_use = system_message if system_message.strip() else BASE_SYSTEM_MESSAGE
    
    # A última mensagem do usuário
    user_prompt = history[-1][0] if history else ""
    
    # Preparar a entrada para o modelo
    prompt_template = f'''system
{system_message_to_use.strip()}
user
{user_prompt}
assistant
'''
    input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
    
    # Gerar a saída
    output = model.generate(
        input_ids=input_ids,
        max_length=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty
    )
    
    # Decodificar a saída
    decoded_output = tokenizer.decode(output[0])
    assistant_response = decoded_output.split('assistant')[-1].strip()  # Pegar apenas a última resposta do assistente
    
    # Atualizar o histórico
    if history:
        history[-1][1] += assistant_response
    else:
        history.append(["", assistant_response])
    
    return history, history, ""


start_message = ""

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown("""
    ## OpenHermes-V2 Finetuned on Mistral 7B
    **Space created by [@artificialguybr](https://twitter.com/artificialguybr). Model by [@Teknium1](https://twitter.com/Teknium1). Thanks HF for GPU!**
    **OpenHermes-V2 is currently SOTA in some benchmarks for 7B models.**
    **Hermes 2 model was trained on 900,000 instructions, and surpasses all previous versions of Hermes 13B and below, and matches 70B on some benchmarks! Hermes 2 changes the game with strong multiturn chat skills, system prompt capabilities, and uses ChatML format. It's quality, diversity and scale is unmatched in the current OS LM landscape. Not only does it do well in benchmarks, but also in unmeasured capabilities, like Roleplaying, Tasks, and more.**
    """)
    with gr.Row():
        #chatbot = gr.Chatbot().style(height=500)
        chatbot = gr.Chatbot(elem_id="chatbot")
    with gr.Row():
        message = gr.Textbox(
            label="What do you want to chat about?",
            placeholder="Ask me anything.",
            lines=3,
        )
    with gr.Row():
        submit = gr.Button(value="Send message", variant="secondary", scale=1)
        clear = gr.Button(value="New topic", variant="secondary", scale=0)
        stop = gr.Button(value="Stop", variant="secondary", scale=0)
        regen_btn = gr.Button(value="Regenerate", variant="secondary", scale=0)
    with gr.Accordion("Show Model Parameters", open=False):
        with gr.Row():
            with gr.Column():
                max_tokens = gr.Slider(20, 2500, label="Max Tokens", step=20, value=500)
                temperature = gr.Slider(0.0, 2.0, label="Temperature", step=0.1, value=0.4)
                top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95)
                top_k = gr.Slider(1, 100, label="Top K", step=1, value=40)
                repetition_penalty = gr.Slider(1.0, 2.0, label="Repetition Penalty", step=0.1, value=1.1)

        system_msg = gr.Textbox(
            start_message, label="System Message", interactive=True, visible=True, placeholder="System prompt. Provide instructions which you want the model to remember.", lines=5)

    chat_history_state = gr.State()
    clear.click(clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False)
    clear.click(lambda: None, None, chatbot, queue=False)

    submit_click_event = submit.click(
    fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True
    ).then(
        fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, chat_history_state, message], queue=True
    )
    
    # Corrected the clear button click event
    clear.click(
        fn=clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False
    )
    
    # Stop button remains the same
    stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event], queue=False)

    regen_btn.click(
        fn=regenerate, inputs=[chatbot, chat_history_state], outputs=[chatbot], queue=True
    )

demo.queue(max_size=128, concurrency_count=2)
demo.launch()