mistral-7b-inf2 / app.py
cszhzleo's picture
Update app.py
509bdb6 verified
raw
history blame
No virus
4.56 kB
import gradio as gr
import boto3
import sagemaker
import json
import io
import os
from transformers import AutoTokenizer
from huggingface_hub import login
region = os.getenv("region")
sm_endpoint_name = os.getenv("sm_endpoint_name")
access_key = os.getenv("access_key")
secret_key = os.getenv("secret_key")
hf_token = os.getenv("hf_read_access")
HF_TOKEN = os.getenv('HF_TOKEN')
print("hf_token",hf_token)
print("HF_TOKEN",HF_TOKEN)
session = boto3.Session(
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=region
)
sess = sagemaker.Session(boto_session=session)
smr = session.client("sagemaker-runtime")
DEFAULT_SYSTEM_PROMPT = (
"You are an helpful, concise and direct Assistant."
)
# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",token=hf_token)
MAX_INPUT_TOKEN_LENGTH = 256
# hyperparameters for llm
parameters = {
"do_sample": True,
"top_p": 0.6,
"temperature": 0.9,
"max_new_tokens": 768,
"repetition_penalty": 1.2,
"return_full_text": False,
}
# Helper for reading lines from a stream
class LineIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
print("Unknown event type:" + chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
def format_prompt(message, history):
'''
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
for interaction in history:
messages.append({"role": "user", "content": interaction[0]})
messages.append({"role": "assistant", "content": interaction[1]})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
'''
messages = [
{"role": "user", "content": "Can you tell me an interesting fact about AWS?"},]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt
def generate(
prompt,
history,
):
formatted_prompt = format_prompt(prompt, history)
check_input_token_length(formatted_prompt)
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
resp = smr.invoke_endpoint_with_response_stream(
EndpointName=sm_endpoint_name,
Body=json.dumps(request),
ContentType="application/json",
)
output = ""
for c in LineIterator(resp["Body"]):
c = c.decode("utf-8")
if c.startswith("data:"):
chunk = json.loads(c.lstrip("data:").rstrip("/n"))
if chunk["token"]["special"]:
continue
if chunk["token"]["text"] in request["parameters"]["stop"]:
break
output += chunk["token"]["text"]
for stop_str in request["parameters"]["stop"]:
if output.endswith(stop_str):
output = output[: -len(stop_str)]
output = output.rstrip()
yield output
yield output
return output
def check_input_token_length(prompt: str) -> None:
input_token_length = len(tokenizer(prompt)["input_ids"])
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
raise gr.Error(
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
)
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
demo = gr.ChatInterface(
generate,
chatbot=gr.Chatbot(layout="panel"),
theme=theme,
)
demo.queue(concurrency_count=5).launch(share=False)