cszhzleo commited on
Commit
7ea4cc0
1 Parent(s): 84df004

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import boto3
3
+ import json
4
+ import io
5
+
6
+ # hyperparameters for llm
7
+ parameters = {
8
+ "do_sample": True,
9
+ "top_p": 0.6,
10
+ "temperature": 0.9,
11
+ "max_new_tokens": 1024,
12
+ "return_full_text": False,
13
+ "stop": ["</s>"],
14
+ }
15
+
16
+ system_prompt = (
17
+ "You are an helpful Assistant, called Llama 2. Knowing everyting about AWS."
18
+ )
19
+
20
+
21
+ # Helper for reading lines from a stream
22
+ class LineIterator:
23
+ def __init__(self, stream):
24
+ self.byte_iterator = iter(stream)
25
+ self.buffer = io.BytesIO()
26
+ self.read_pos = 0
27
+
28
+ def __iter__(self):
29
+ return self
30
+
31
+ def __next__(self):
32
+ while True:
33
+ self.buffer.seek(self.read_pos)
34
+ line = self.buffer.readline()
35
+ if line and line[-1] == ord("\n"):
36
+ self.read_pos += len(line)
37
+ return line[:-1]
38
+ try:
39
+ chunk = next(self.byte_iterator)
40
+ except StopIteration:
41
+ if self.read_pos < self.buffer.getbuffer().nbytes:
42
+ continue
43
+ raise
44
+ if "PayloadPart" not in chunk:
45
+ print("Unknown event type:" + chunk)
46
+ continue
47
+ self.buffer.seek(0, io.SEEK_END)
48
+ self.buffer.write(chunk["PayloadPart"]["Bytes"])
49
+
50
+
51
+ # helper method to format prompt
52
+ def create_messages_dict(message, history, system_prompt):
53
+ messages = []
54
+ if system_prompt:
55
+ messages.append({"role": "system", "content": system_prompt})
56
+ for user_prompt, bot_response in history:
57
+ messages.append({"role": "user", "content": user_prompt})
58
+ messages.append({"role": "assistant", "content": bot_response})
59
+ messages.append({"role": "user", "content": message})
60
+ return messages
61
+
62
+
63
+ def create_gradio_app(
64
+ endpoint_name,
65
+ session=boto3,
66
+ parameters=parameters,
67
+ system_prompt=system_prompt,
68
+ tokenizer=None,
69
+ concurrency_count=4,
70
+ share=True,
71
+ ):
72
+ smr = session.client("sagemaker-runtime")
73
+
74
+ def generate(
75
+ prompt,
76
+ history,
77
+ ):
78
+ messages = create_messages_dict(prompt, history, system_prompt)
79
+ formatted_prompt = tokenizer.apply_chat_template(
80
+ messages, tokenize=False, add_generation_prompt=True
81
+ )
82
+
83
+ request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
84
+ resp = smr.invoke_endpoint_with_response_stream(
85
+ EndpointName=endpoint_name,
86
+ Body=json.dumps(request),
87
+ ContentType="application/json",
88
+ )
89
+
90
+ output = ""
91
+ for c in LineIterator(resp["Body"]):
92
+ c = c.decode("utf-8")
93
+ if c.startswith("data:"):
94
+ chunk = json.loads(c.lstrip("data:").rstrip("/n"))
95
+ if chunk["token"]["special"]:
96
+ continue
97
+ if chunk["token"]["text"] in request["parameters"]["stop"]:
98
+ break
99
+ output += chunk["token"]["text"]
100
+ for stop_str in request["parameters"]["stop"]:
101
+ if output.endswith(stop_str):
102
+ output = output[: -len(stop_str)]
103
+ output = output.rstrip()
104
+ yield output
105
+
106
+ yield output
107
+ return output
108
+
109
+ demo = gr.ChatInterface(
110
+ generate, title="Chat with Amazon SageMaker", chatbot=gr.Chatbot(layout="panel")
111
+ )
112
+
113
+ demo.queue(concurrency_count=concurrency_count).launch(share=share)