mistral-7b-inf2 / app.py
cszhzleo's picture
Update app.py
2243cd0 verified
raw
history blame contribute delete
No virus
5 kB
import gradio as gr
import boto3
import sagemaker
import json
import io
import os
from transformers import AutoTokenizer
from huggingface_hub import login
region = os.getenv("region")
sm_endpoint_name = os.getenv("sm_endpoint_name")
access_key = os.getenv("access_key")
secret_key = os.getenv("secret_key")
hf_token = os.getenv("hf_read_access")
HF_TOKEN = os.getenv('HF_TOKEN')
print("hf_token",hf_token)
print("HF_TOKEN",HF_TOKEN)
session = boto3.Session(
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=region
)
sess = sagemaker.Session(boto_session=session)
smr = session.client("sagemaker-runtime")
DEFAULT_SYSTEM_PROMPT = (
"You are an helpful, concise and direct Assistant."
)
# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",token=hf_token)
MAX_INPUT_TOKEN_LENGTH = 256
# hyperparameters for llm
parameters = {
"do_sample": True,
"top_p": 0.6,
"temperature": 0.9,
"max_new_tokens": 768,
"repetition_penalty": 1.2,
"return_full_text": False,
"stop": ["</s>"],
}
# Helper for reading lines from a stream
class LineIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
print("Unknown event type:" + chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
def format_prompt(message, history):
'''
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
for interaction in history:
messages.append({"role": "user", "content": interaction[0]})
messages.append({"role": "assistant", "content": interaction[1]})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
'''
messages = [
{"role": "user", "content": message},]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt
def generate(
prompt,
history,
):
formatted_prompt = format_prompt(prompt, history)
check_input_token_length(formatted_prompt)
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
resp = smr.invoke_endpoint_with_response_stream(
EndpointName=sm_endpoint_name,
Body=json.dumps(request),
ContentType="application/json",
)
output = ""
for c in LineIterator(resp["Body"]):
c = c.decode("utf-8")
if c.startswith("data:"):
chunk = json.loads(c.lstrip("data:").rstrip("/n"))
if chunk["token"]["special"]:
continue
if chunk["token"]["text"] in request["parameters"]["stop"]:
break
output += chunk["token"]["text"]
for stop_str in request["parameters"]["stop"]:
if output.endswith(stop_str):
output = output[: -len(stop_str)]
output = output.rstrip()
yield output
yield output
return output
def check_input_token_length(prompt: str) -> None:
input_token_length = len(tokenizer(prompt)["input_ids"])
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
raise gr.Error(
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
)
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
DESCRIPTION = """
<div style="text-align: center; max-width: 650px; margin: 0 auto; display:grid; gap:25px;">
<img class="logo" src="https://huggingface.co/datasets/philschmid/assets/resolve/main/aws-neuron_hf.png" alt="Hugging Face Neuron Logo"
style="margin: auto; max-width: 14rem;">
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
Mistral-7B-Instruct-v0.2 Chat on AWS INF2 ⚡
</h1>
</div>
"""
demo = gr.ChatInterface(
generate,
description=DESCRIPTION,
chatbot=gr.Chatbot(layout="panel"),
theme=theme,
)
demo.queue().launch(share=False)