BaseChat / app_single.py
yuchenlin's picture
side by side
d8f6559
raw
history blame contribute delete
No virus
4.63 kB
import gradio as gr
import os
from typing import List
import logging
import urllib.request
from utils import model_name_mapping, urial_template, openai_base_request
from constant import js_code_label, HEADER_MD
from openai import OpenAI
import datetime
# add logging info to console
logging.basicConfig(level=logging.INFO)
URIAL_VERSION = "inst_1k_v4.help"
URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
STOP_STRS = ['"""', '# Query:', '# Answer:']
addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now()
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
rp,
model_name,
api_key,
request:gr.Request
):
global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
rp = 1.0
prompt = urial_template(urial_prompt, history, message)
# _model_name = "meta-llama/Llama-3-8b-hf"
_model_name = model_name_mapping(model_name)
if api_key and len(api_key) == 64:
api_key = api_key
else:
api_key = None
# headers = request.headers
# if already 24 hours passed, reset the counter
if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now()
host_addr = request.client.host
if host_addr not in addr_limit_counter:
addr_limit_counter[host_addr] = 0
if addr_limit_counter[host_addr] > 100:
return "You have reached the limit of 100 requests for today. Please use your own API key."
infer_request = openai_base_request(prompt=prompt, model=_model_name,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
repetition_penalty=rp,
stop=STOP_STRS, api_key=api_key)
addr_limit_counter[host_addr] += 1
logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
response = ""
for msg in infer_request:
# print(msg.choices[0].delta.keys())
if hasattr(msg.choices[0], "delta"):
token = msg.choices[0].delta["content"]
else:
token = msg.choices[0].text
should_stop = False
for _stop in STOP_STRS:
if _stop in response + token:
should_stop = True
break
if should_stop:
break
response += token
if response.endswith('\n"'):
response = response[:-1]
elif response.endswith('\n""'):
response = response[:-2]
yield response
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(HEADER_MD)
model_name = gr.Radio(["Llama-3.1-405B-FP8", "Llama-3-70B", "Llama-3-8B",
"Mistral-7B-v0.1",
"Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
, value="Llama-3.1-405B-FP8", label="Base LLM name")
with gr.Column():
api_key = gr.Textbox(label="πŸ”‘ APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
# with gr.Column():
with gr.Accordion("βš™οΈ Parameters for Base LLM", open=True):
with gr.Row():
max_tokens = gr.Textbox(value=256, label="Max tokens")
temperature = gr.Textbox(value=0.5, label="Temperature")
top_p = gr.Textbox(value=0.9, label="Top-p")
rp = gr.Textbox(value=1.1, label="Repetition penalty")
# with gr.Row():
chat = gr.ChatInterface(
respond,
additional_inputs=[max_tokens, temperature, top_p, rp, model_name, api_key],
# additional_inputs_accordion="βš™οΈ Parameters",
# fill_height=True,
)
chat.chatbot.label="Chat with Base LLMs via URIAL"
chat.chatbot.height = 550
chat.chatbot.show_copy_button = True
if __name__ == "__main__":
demo.launch(show_api=False)