artificialguybr commited on
Commit
1ce6f28
1 Parent(s): 42883b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -40
app.py CHANGED
@@ -1,60 +1,74 @@
1
- import gradio as gr
2
  import re
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
4
 
5
- model_name_or_path = "teknium/OpenHermes-2-Mistral-7B"
6
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
7
- device_map="auto",
8
- trust_remote_code=False,
9
- load_in_8bit=True,
10
- revision="main")
11
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
12
 
13
- BASE_SYSTEM_MESSAGE = "I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning."
 
 
14
 
15
  def make_prediction(prompt, max_tokens=None, temperature=None, top_p=None, top_k=None, repetition_penalty=None):
16
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
17
- out = model.generate(input_ids, max_length=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
18
- text = tokenizer.decode(out[0], skip_special_tokens=True)
19
- yield text
20
 
21
  def clear_chat(chat_history_state, chat_message):
22
  chat_history_state = []
23
  chat_message = ''
24
  return chat_history_state, chat_message
25
 
 
26
  def user(message, history):
27
  history = history or []
 
28
  history.append([message, ""])
29
  return "", history
30
 
 
31
  def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty):
32
  history = history or []
33
-
34
- # A última mensagem do usuário
35
- user_prompt = history[-1][0]
36
-
37
- # Definindo o template e o prompt
38
- prompt_template = f'''system
39
- {system_message.strip()}
40
- user
41
- {user_prompt}
42
- assistant
43
- '''
44
-
45
- # Preparando o input
46
- input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids # .cuda() se você estiver usando GPU
47
-
48
- # Gerar a saída
49
- output = model.generate(input_ids=input_ids, temperature=temperature, do_sample=True, top_p=top_p, top_k=top_k, max_length=max_tokens)
50
-
51
- # Decodificar a saída
52
- decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
53
-
54
- # Atualizar o histórico
55
- history[-1][1] += decoded_output
56
-
57
- yield history, history, ""
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  start_message = ""
60
 
@@ -64,6 +78,8 @@ CSS ="""
64
  #component-0 { height: 100%; }
65
  #chatbot { flex-grow: 1; overflow: auto; resize: vertical; }
66
  """
 
 
67
  with gr.Blocks(css=CSS) as demo:
68
  with gr.Row():
69
  with gr.Column():
@@ -109,4 +125,4 @@ with gr.Blocks(css=CSS) as demo:
109
  )
110
  stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event], queue=False)
111
 
112
- demo.queue(max_size=128, concurrency_count=48).launch(debug=True, server_name="0.0.0.0", server_port=7860)
 
1
+ import os
2
  import re
3
+ import logging
4
+ import gradio as gr
5
+ import openai
6
 
7
+ print(os.environ)
8
+ openai.api_base = os.environ.get("OPENAI_API_BASE")
9
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
 
 
 
 
10
 
11
+ BASE_SYSTEM_MESSAGE = """I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning.
12
+ I am an assistant who thinks through their answers step-by-step to be sure I always get the right answer.
13
+ I think more clearly if I write out my thought process in a scratchpad manner first; therefore, I always explain background context, assumptions, and step-by-step thinking BEFORE trying to answer or solve anything."""
14
 
15
  def make_prediction(prompt, max_tokens=None, temperature=None, top_p=None, top_k=None, repetition_penalty=None):
16
+ completion = openai.Completion.create(model="teknium/OpenHermes-2-Mistral-7B", prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, stream=True, stop=["</s>", "<|im_end|>"])
17
+ for chunk in completion:
18
+ yield chunk["choices"][0]["text"]
19
+
20
 
21
  def clear_chat(chat_history_state, chat_message):
22
  chat_history_state = []
23
  chat_message = ''
24
  return chat_history_state, chat_message
25
 
26
+
27
  def user(message, history):
28
  history = history or []
29
+ # Append the user's message to the conversation history
30
  history.append([message, ""])
31
  return "", history
32
 
33
+
34
  def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty):
35
  history = history or []
36
+
37
+ if system_message.strip():
38
+ messages = "<|im_start|> "+"system\n" + system_message.strip() + "<|im_end|>\n" + \
39
+ "\n".join(["\n".join(["<|im_start|> "+"user\n"+item[0]+"<|im_end|>", "<|im_start|> assistant\n"+item[1]+"<|im_end|>"])
40
+ for item in history])
41
+ else:
42
+ messages = "<|im_start|> "+"system\n" + BASE_SYSTEM_MESSAGE + "<|im_end|>\n" + \
43
+ "\n".join(["\n".join(["<|im_start|> "+"user\n"+item[0]+"<|im_end|>", "<|im_start|> assistant\n"+item[1]+"<|im_end|>"])
44
+ for item in history])
45
+ # strip the last `<|end_of_turn|>` from the messages
46
+ messages = messages.rstrip("<|im_end|>")
47
+ # remove last space from assistant, some models output a ZWSP if you leave a space
48
+ messages = messages.rstrip()
49
+
50
+ # If temperature is set to 0, force Top P to 1 and Top K to -1
51
+ if temperature == 0:
52
+ top_p = 1
53
+ top_k = -1
54
+
55
+ prediction = make_prediction(
56
+ messages,
57
+ max_tokens=max_tokens,
58
+ temperature=temperature,
59
+ top_p=top_p,
60
+ top_k=top_k,
61
+ repetition_penalty=repetition_penalty,
62
+ )
63
+ for tokens in prediction:
64
+ tokens = re.findall(r'(.*?)(\s|$)', tokens)
65
+ for subtoken in tokens:
66
+ subtoken = "".join(subtoken)
67
+ answer = subtoken
68
+ history[-1][1] += answer
69
+ # stream the response
70
+ yield history, history, ""
71
+
72
 
73
  start_message = ""
74
 
 
78
  #component-0 { height: 100%; }
79
  #chatbot { flex-grow: 1; overflow: auto; resize: vertical; }
80
  """
81
+
82
+ #with gr.Blocks() as demo:
83
  with gr.Blocks(css=CSS) as demo:
84
  with gr.Row():
85
  with gr.Column():
 
125
  )
126
  stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event], queue=False)
127
 
128
+ demo.queue(max_size=128, concurrency_count=48).launch(debug=True, server_name="0.0.0.0", server_port=7860)