Spaces:
Sleeping
Sleeping
File size: 3,468 Bytes
7ea4cc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import boto3
import json
import io
# hyperparameters for llm
parameters = {
"do_sample": True,
"top_p": 0.6,
"temperature": 0.9,
"max_new_tokens": 1024,
"return_full_text": False,
"stop": ["</s>"],
}
system_prompt = (
"You are an helpful Assistant, called Llama 2. Knowing everyting about AWS."
)
# Helper for reading lines from a stream
class LineIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
print("Unknown event type:" + chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
# helper method to format prompt
def create_messages_dict(message, history, system_prompt):
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for user_prompt, bot_response in history:
messages.append({"role": "user", "content": user_prompt})
messages.append({"role": "assistant", "content": bot_response})
messages.append({"role": "user", "content": message})
return messages
def create_gradio_app(
endpoint_name,
session=boto3,
parameters=parameters,
system_prompt=system_prompt,
tokenizer=None,
concurrency_count=4,
share=True,
):
smr = session.client("sagemaker-runtime")
def generate(
prompt,
history,
):
messages = create_messages_dict(prompt, history, system_prompt)
formatted_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
resp = smr.invoke_endpoint_with_response_stream(
EndpointName=endpoint_name,
Body=json.dumps(request),
ContentType="application/json",
)
output = ""
for c in LineIterator(resp["Body"]):
c = c.decode("utf-8")
if c.startswith("data:"):
chunk = json.loads(c.lstrip("data:").rstrip("/n"))
if chunk["token"]["special"]:
continue
if chunk["token"]["text"] in request["parameters"]["stop"]:
break
output += chunk["token"]["text"]
for stop_str in request["parameters"]["stop"]:
if output.endswith(stop_str):
output = output[: -len(stop_str)]
output = output.rstrip()
yield output
yield output
return output
demo = gr.ChatInterface(
generate, title="Chat with Amazon SageMaker", chatbot=gr.Chatbot(layout="panel")
)
demo.queue(concurrency_count=concurrency_count).launch(share=share) |