Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import boto3
|
3 |
+
import json
|
4 |
+
import io
|
5 |
+
|
6 |
+
# hyperparameters for llm
|
7 |
+
parameters = {
|
8 |
+
"do_sample": True,
|
9 |
+
"top_p": 0.6,
|
10 |
+
"temperature": 0.9,
|
11 |
+
"max_new_tokens": 1024,
|
12 |
+
"return_full_text": False,
|
13 |
+
"stop": ["</s>"],
|
14 |
+
}
|
15 |
+
|
16 |
+
system_prompt = (
|
17 |
+
"You are an helpful Assistant, called Llama 2. Knowing everyting about AWS."
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
# Helper for reading lines from a stream
|
22 |
+
class LineIterator:
|
23 |
+
def __init__(self, stream):
|
24 |
+
self.byte_iterator = iter(stream)
|
25 |
+
self.buffer = io.BytesIO()
|
26 |
+
self.read_pos = 0
|
27 |
+
|
28 |
+
def __iter__(self):
|
29 |
+
return self
|
30 |
+
|
31 |
+
def __next__(self):
|
32 |
+
while True:
|
33 |
+
self.buffer.seek(self.read_pos)
|
34 |
+
line = self.buffer.readline()
|
35 |
+
if line and line[-1] == ord("\n"):
|
36 |
+
self.read_pos += len(line)
|
37 |
+
return line[:-1]
|
38 |
+
try:
|
39 |
+
chunk = next(self.byte_iterator)
|
40 |
+
except StopIteration:
|
41 |
+
if self.read_pos < self.buffer.getbuffer().nbytes:
|
42 |
+
continue
|
43 |
+
raise
|
44 |
+
if "PayloadPart" not in chunk:
|
45 |
+
print("Unknown event type:" + chunk)
|
46 |
+
continue
|
47 |
+
self.buffer.seek(0, io.SEEK_END)
|
48 |
+
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
49 |
+
|
50 |
+
|
51 |
+
# helper method to format prompt
|
52 |
+
def create_messages_dict(message, history, system_prompt):
|
53 |
+
messages = []
|
54 |
+
if system_prompt:
|
55 |
+
messages.append({"role": "system", "content": system_prompt})
|
56 |
+
for user_prompt, bot_response in history:
|
57 |
+
messages.append({"role": "user", "content": user_prompt})
|
58 |
+
messages.append({"role": "assistant", "content": bot_response})
|
59 |
+
messages.append({"role": "user", "content": message})
|
60 |
+
return messages
|
61 |
+
|
62 |
+
|
63 |
+
def create_gradio_app(
|
64 |
+
endpoint_name,
|
65 |
+
session=boto3,
|
66 |
+
parameters=parameters,
|
67 |
+
system_prompt=system_prompt,
|
68 |
+
tokenizer=None,
|
69 |
+
concurrency_count=4,
|
70 |
+
share=True,
|
71 |
+
):
|
72 |
+
smr = session.client("sagemaker-runtime")
|
73 |
+
|
74 |
+
def generate(
|
75 |
+
prompt,
|
76 |
+
history,
|
77 |
+
):
|
78 |
+
messages = create_messages_dict(prompt, history, system_prompt)
|
79 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
80 |
+
messages, tokenize=False, add_generation_prompt=True
|
81 |
+
)
|
82 |
+
|
83 |
+
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
|
84 |
+
resp = smr.invoke_endpoint_with_response_stream(
|
85 |
+
EndpointName=endpoint_name,
|
86 |
+
Body=json.dumps(request),
|
87 |
+
ContentType="application/json",
|
88 |
+
)
|
89 |
+
|
90 |
+
output = ""
|
91 |
+
for c in LineIterator(resp["Body"]):
|
92 |
+
c = c.decode("utf-8")
|
93 |
+
if c.startswith("data:"):
|
94 |
+
chunk = json.loads(c.lstrip("data:").rstrip("/n"))
|
95 |
+
if chunk["token"]["special"]:
|
96 |
+
continue
|
97 |
+
if chunk["token"]["text"] in request["parameters"]["stop"]:
|
98 |
+
break
|
99 |
+
output += chunk["token"]["text"]
|
100 |
+
for stop_str in request["parameters"]["stop"]:
|
101 |
+
if output.endswith(stop_str):
|
102 |
+
output = output[: -len(stop_str)]
|
103 |
+
output = output.rstrip()
|
104 |
+
yield output
|
105 |
+
|
106 |
+
yield output
|
107 |
+
return output
|
108 |
+
|
109 |
+
demo = gr.ChatInterface(
|
110 |
+
generate, title="Chat with Amazon SageMaker", chatbot=gr.Chatbot(layout="panel")
|
111 |
+
)
|
112 |
+
|
113 |
+
demo.queue(concurrency_count=concurrency_count).launch(share=share)
|