File size: 5,022 Bytes
880e945
 
 
 
1e1d43c
880e945
 
40ce4ad
880e945
 
 
 
2ef541e
880e945
 
 
1e1d43c
 
880e945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3cb2df
 
4c1f576
d3cb2df
 
873976f
4c1f576
 
 
 
 
 
 
d3cb2df
5e45ef8
873976f
4c1f576
d3cb2df
 
 
 
 
 
 
 
 
 
880e945
4c1f576
880e945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947e089
880e945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e1d43c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import gradio as gr
import mdtex2html
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, MistralConfig

# Initialize model and tokenizer
model_name_or_path = "teknium/OpenHermes-2-Mistral-7B"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             load_in_8bit=True,
                                             revision="main")

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
config = MistralConfig()

# Text parsing function
def _parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split("`")
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f"<br></code></pre>"
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", r"\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text

# Demo launching function
def _launch_demo(args, model, tokenizer, config):
    def predict(_query, _chatbot, _task_history):
        print(f"User: {_parse_text(_query)}")
        _chatbot.append((_parse_text(_query), ""))
        
        # Tokenize the input
        input_ids = tokenizer.encode(_query, return_tensors='pt')
        print("Input IDs:", input_ids)
        
        # Move input_ids to CUDA if available
        input_ids = input_ids.to('cuda')
        
        # Generate attention_mask
        attention_mask = torch.ones(input_ids.shape).to('cuda')
        
        # Generate a response using the model
        generated_ids = model.generate(input_ids, max_length=300)
        print("Generated IDs:", generated_ids)
        
        # Decode the generated IDs to text
        full_response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        
        # Update the chatbot state
        _chatbot[-1] = (_parse_text(_query), _parse_text(full_response))
        yield _chatbot
        
        print(f"History: {_task_history}")
        _task_history.append((_query, full_response))
        print(f"OpenHermes: {_parse_text(full_response)}")


    def regenerate(_chatbot, _task_history):
        if not _task_history:
            yield _chatbot
            return
        item = _task_history.pop(-1)
        _chatbot.pop(-1)
        yield from predict(item[0], _chatbot, _task_history)

    def reset_user_input():
        return gr.update(value="")

    def reset_state(_chatbot, _task_history):
        _task_history.clear()
        _chatbot.clear()
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        return _chatbot

    with gr.Blocks() as demo:
        gr.Markdown("""
    ## OpenHermes V2 - Mistral 7B: Mistral 7B Based by Teknium!
    **Space created by [@artificialguybr](https://twitter.com/artificialguybr). Model by [@Teknium1](https://twitter.com/Teknium1).Thanks HF for GPU!**
    **OpenHermes V2 Mistral 7B was trained on 900,000 instructions, and surpasses all previous versions of Hermes 13B and below, and matches 70B on some benchmarks!**
    """)
        chatbot = gr.Chatbot(label='OpenHermes-V2', elem_classes="control-height", queue=True)
        query = gr.Textbox(lines=2, label='Input')
        task_history = gr.State([])

        with gr.Row():
            submit_btn = gr.Button("🚀 Submit")
            empty_btn = gr.Button("🧹 Clear History")
            regen_btn = gr.Button("🤔️ Regenerate")
    
        submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True, queue=True)  # Enable queue
        submit_btn.click(reset_user_input, [], [query], queue=False) #No queue for resetting
        empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True, queue=False) #No queue for clearing
        regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True, queue=True)  # Enable queue
    demo.queue(max_size=20)
    demo.launch()


# Main execution
if __name__ == "__main__":
    _launch_demo(None, model, tokenizer, config)