Ashhar commited on
Commit
a94a1dc
1 Parent(s): f97efd9

changes to context window + image prompt

Browse files
Files changed (1) hide show
  1. app.py +144 -39
app.py CHANGED
@@ -12,43 +12,51 @@ from gradio_client import Client
12
  from dotenv import load_dotenv
13
  load_dotenv()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- from groq import Groq
17
- client = Groq(
18
- api_key=os.environ.get("GROQ_API_KEY"),
19
- )
20
 
21
- MODEL = "llama-3.1-70b-versatile"
22
  JSON_SEPARATOR = ">>>>"
23
 
24
 
25
- tokenizer = AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
26
-
27
-
28
  def countTokens(text):
29
- # Tokenize the input text
30
  tokens = tokenizer.encode(text, add_special_tokens=False)
31
- # Return the number of tokens
32
  return len(tokens)
33
 
34
 
35
  SYSTEM_MSG = f"""
36
  You're an storytelling assistant who guides users through four phases of narrative development, helping them craft compelling personal or professional stories. The story created should be in simple language, yet evoke great emotions.
37
- Ask one question at a time, give the options in a well formatted manner in different lines
38
  If your response has number of options to choose from, only then append your final response with this exact keyword "{JSON_SEPARATOR}", and only after this, append with the JSON of options to choose from. The JSON should be of the format:
39
  {{
40
  "options": [
41
  {{ "id": "1", "label": "Option 1"}},
42
- {{ "id": "2", "label": "Option 2"}},
43
  ]
44
  }}
45
  Do not write "Choose one of the options below:"
46
- Keep options to less than 9
 
47
 
48
  # Tier 1: Story Creation
49
  You initiate the storytelling process through a series of engaging prompts:
50
  Story Origin:
51
- Asks users to choose between personal anecdotes or adapting a well-known story (creating a story database here of well-known finctional stories to choose from).
52
 
53
  Story Use Case:
54
  Asks users to define the purpose of building a story (e.g., profile story, for social media content).
@@ -146,6 +154,37 @@ def pprint(log: str):
146
 
147
  pprint("\n")
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  def __isInvalidResponse(response: str):
151
  # new line followed by small case char
@@ -161,7 +200,7 @@ def __isInvalidResponse(response: str):
161
  return True
162
 
163
  # json response without json separator
164
- if ('\n{\n "options"' in response) and (JSON_SEPARATOR not in response):
165
  return True
166
 
167
 
@@ -180,23 +219,60 @@ def __isStringNumber(s: str) -> bool:
180
  return False
181
 
182
 
183
- def __getImageGenerationPrompt(prompt: str, response: str):
184
- responseLower = response.lower()
 
 
 
 
 
 
 
185
  if (
186
  __matchingKeywordsCount(
187
  ["adapt", "profile", "social media", "purpose", "use case"],
188
- responseLower
189
  ) > 2
190
  and not __isStringNumber(prompt)
191
- and prompt.lower() in responseLower
 
192
  ):
193
- return f'a scene from (({prompt})). Include main character'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  if __matchingKeywordsCount(
196
- ["Tier 2", "Tier-2"],
197
- response
198
  ) > 0:
199
- return f"photo of a scene from this text: {response}"
 
 
 
 
 
 
 
 
 
200
 
201
 
202
  def __resetButtonState():
@@ -210,6 +286,9 @@ def __setStartMsg(msg):
210
  if "messages" not in st.session_state:
211
  st.session_state.messages = []
212
 
 
 
 
213
  if "buttonValue" not in st.session_state:
214
  __resetButtonState()
215
 
@@ -217,19 +296,33 @@ if "startMsg" not in st.session_state:
217
  st.session_state.startMsg = ""
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def predict(prompt):
221
- historyFormatted = [{"role": "system", "content": SYSTEM_MSG}]
222
- historyFormatted.extend([
223
- {"role": message["role"], "content": message["content"]}
224
- for message in st.session_state.messages
225
- ])
226
- historyFormatted.append({"role": "user", "content": prompt })
227
- contextSize = countTokens(str(historyFormatted))
228
- pprint(f"{contextSize=}")
229
 
230
  response = client.chat.completions.create(
231
- model="llama-3.1-70b-versatile",
232
- messages=historyFormatted,
233
  temperature=0.8,
234
  max_tokens=4000,
235
  stream=True
@@ -245,13 +338,13 @@ def predict(prompt):
245
 
246
  def generateImage(prompt: str):
247
  pprint(f"imagePrompt={prompt}")
248
- client = Client("black-forest-labs/FLUX.1-schnell")
249
- result = client.predict(
250
  prompt=prompt,
251
  seed=0,
252
  randomize_seed=True,
253
- width=1152,
254
- height=896,
255
  num_inference_steps=4,
256
  api_name="/infer"
257
  )
@@ -321,14 +414,26 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
321
  [response, jsonStr] = responseParts
322
 
323
  imagePath = None
 
324
  try:
325
- imagePrompt = __getImageGenerationPrompt(prompt, response)
326
  if imagePrompt:
327
- imageContainer = st.empty().image(IMAGE_LOADER)
 
 
 
 
 
 
 
 
 
 
328
  (imagePath, seed) = generateImage(imagePrompt)
329
  imageContainer.image(imagePath)
330
  except Exception as e:
331
  pprint(e)
 
332
 
333
  if jsonStr:
334
  try:
 
12
  from dotenv import load_dotenv
13
  load_dotenv()
14
 
15
+ useGpt4 = os.environ.get("USE_GPT_4") == "1"
16
+
17
+ if useGpt4:
18
+ from openai import OpenAI
19
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
20
+ MODEL = "gpt-4o-mini"
21
+ MAX_CONTEXT = 128000
22
+ tokenizer = AutoTokenizer.from_pretrained("Xenova/gpt-4o")
23
+ else:
24
+ from groq import Groq
25
+ client = Groq(
26
+ api_key=os.environ.get("GROQ_API_KEY"),
27
+ )
28
+ MODEL = "llama-3.1-70b-versatile"
29
+ MAX_CONTEXT = 8000
30
+ tokenizer = AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
31
 
 
 
 
 
32
 
 
33
  JSON_SEPARATOR = ">>>>"
34
 
35
 
 
 
 
36
  def countTokens(text):
37
+ text = str(text)
38
  tokens = tokenizer.encode(text, add_special_tokens=False)
 
39
  return len(tokens)
40
 
41
 
42
  SYSTEM_MSG = f"""
43
  You're an storytelling assistant who guides users through four phases of narrative development, helping them craft compelling personal or professional stories. The story created should be in simple language, yet evoke great emotions.
44
+ Ask one question at a time, give the options in a numbered and well formatted manner in different lines
45
  If your response has number of options to choose from, only then append your final response with this exact keyword "{JSON_SEPARATOR}", and only after this, append with the JSON of options to choose from. The JSON should be of the format:
46
  {{
47
  "options": [
48
  {{ "id": "1", "label": "Option 1"}},
49
+ {{ "id": "2", "label": "Option 2"}}
50
  ]
51
  }}
52
  Do not write "Choose one of the options below:"
53
+ Keep options to less than 9.
54
+ Summarise options chosen so far in each step.
55
 
56
  # Tier 1: Story Creation
57
  You initiate the storytelling process through a series of engaging prompts:
58
  Story Origin:
59
+ Asks users to choose between personal anecdotes or adapting a well-known story (creating a story database here of well-known stories to choose from).
60
 
61
  Story Use Case:
62
  Asks users to define the purpose of building a story (e.g., profile story, for social media content).
 
154
 
155
  pprint("\n")
156
 
157
+ st.markdown(
158
+ """
159
+ <style>
160
+ @keyframes blinker {
161
+ 0% {
162
+ opacity: 1;
163
+ }
164
+ 50% {
165
+ opacity: 0.2;
166
+ }
167
+ 100% {
168
+ opacity: 1;
169
+ }
170
+ }
171
+
172
+ .blinking {
173
+ animation: blinker 3s ease-out infinite;
174
+ }
175
+
176
+ .code {
177
+ color: green;
178
+ border-radius: 3px;
179
+ padding: 2px 4px; /* Padding around the text */
180
+ font-family: 'Courier New', Courier, monospace; /* Monospace font */
181
+ }
182
+
183
+ </style>
184
+ """,
185
+ unsafe_allow_html=True
186
+ )
187
+
188
 
189
  def __isInvalidResponse(response: str):
190
  # new line followed by small case char
 
200
  return True
201
 
202
  # json response without json separator
203
+ if ('{\n "options"' in response) and (JSON_SEPARATOR not in response):
204
  return True
205
 
206
 
 
219
  return False
220
 
221
 
222
+ def __getImagePromptDetails(prompt: str, response: str):
223
+ regex = r'[^a-z0-9 \n\.\-]|((the) +)'
224
+
225
+ cleanedResponse = re.sub(regex, '', response.lower())
226
+ pprint(f"{cleanedResponse=}")
227
+
228
+ cleanedPrompt = re.sub(regex, '', prompt.lower())
229
+ pprint(f"{cleanedPrompt=}")
230
+
231
  if (
232
  __matchingKeywordsCount(
233
  ["adapt", "profile", "social media", "purpose", "use case"],
234
+ cleanedResponse
235
  ) > 2
236
  and not __isStringNumber(prompt)
237
+ and cleanedPrompt in cleanedResponse
238
+ and "story so far" not in cleanedResponse
239
  ):
240
+ return (
241
+ f'''
242
+ Subject: {prompt}.
243
+ Style: Fantastical, in a storybook, surreal, bokeh
244
+ ''',
245
+ "Painting your character ..."
246
+ )
247
+
248
+ '''
249
+ Mood: ethereal lighting that emphasizes the fantastical nature of the scene.
250
+
251
+ storybook style
252
+
253
+ 4d model, unreal engine
254
+
255
+ Alejandro Bursido
256
+
257
+ vintage, nostalgic
258
+
259
+ Dreamlike, Mystical, Fantastical, Charming
260
+ '''
261
 
262
  if __matchingKeywordsCount(
263
+ ["tier 2", "tier-2"],
264
+ cleanedResponse
265
  ) > 0:
266
+ possibleStoryEndIdx = [response.find("tier 2"), response.find("tier-2")]
267
+ storyEndIdx = max(possibleStoryEndIdx)
268
+ relevantResponse = response[:storyEndIdx]
269
+ pprint(f"{relevantResponse=}")
270
+ return (
271
+ f"photo of a scene from this text: {relevantResponse}",
272
+ "Imagining your scene (beta) ..."
273
+ )
274
+
275
+ return (None, None)
276
 
277
 
278
  def __resetButtonState():
 
286
  if "messages" not in st.session_state:
287
  st.session_state.messages = []
288
 
289
+ if "history" not in st.session_state:
290
+ st.session_state.history = []
291
+
292
  if "buttonValue" not in st.session_state:
293
  __resetButtonState()
294
 
 
296
  st.session_state.startMsg = ""
297
 
298
 
299
+ def __getChatMessages(prompt: str):
300
+ st.session_state.history.append({
301
+ "role": "user",
302
+ "content": prompt
303
+ })
304
+
305
+ def getContextSize():
306
+ currContextSize = countTokens(SYSTEM_MSG) + countTokens(st.session_state.history) + 100
307
+ pprint(f"{currContextSize=}")
308
+ return currContextSize
309
+
310
+ while getContextSize() > MAX_CONTEXT:
311
+ pprint("Context size exceeded, removing first message")
312
+ st.session_state.history.pop(0)
313
+
314
+ return st.session_state.history
315
+
316
+
317
  def predict(prompt):
318
+ messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}]
319
+ messagesFormatted.extend(__getChatMessages(prompt))
320
+ contextSize = countTokens(messagesFormatted)
321
+ pprint(f"{contextSize=} | {MODEL}")
 
 
 
 
322
 
323
  response = client.chat.completions.create(
324
+ model=MODEL,
325
+ messages=messagesFormatted,
326
  temperature=0.8,
327
  max_tokens=4000,
328
  stream=True
 
338
 
339
  def generateImage(prompt: str):
340
  pprint(f"imagePrompt={prompt}")
341
+ fluxClient = Client("black-forest-labs/FLUX.1-schnell")
342
+ result = fluxClient.predict(
343
  prompt=prompt,
344
  seed=0,
345
  randomize_seed=True,
346
+ width=1024,
347
+ height=768,
348
  num_inference_steps=4,
349
  api_name="/infer"
350
  )
 
414
  [response, jsonStr] = responseParts
415
 
416
  imagePath = None
417
+ imageContainer = st.empty()
418
  try:
419
+ (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response)
420
  if imagePrompt:
421
+ imgContainer = imageContainer.container()
422
+ imgContainer.write(
423
+ f"""
424
+ <div class='blinking code'>
425
+ {loaderText}
426
+ </div>
427
+ """,
428
+ unsafe_allow_html=True
429
+ )
430
+ # imgContainer.markdown(f"`{loaderText}`")
431
+ imgContainer.image(IMAGE_LOADER)
432
  (imagePath, seed) = generateImage(imagePrompt)
433
  imageContainer.image(imagePath)
434
  except Exception as e:
435
  pprint(e)
436
+ imageContainer.empty()
437
 
438
  if jsonStr:
439
  try: