martintomov commited on
Commit
9552c68
1 Parent(s): c382f0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -61
app.py CHANGED
@@ -9,18 +9,11 @@ import time
9
  import base64
10
  import json
11
 
12
- # Local Dev
13
- import os
14
- from dotenv import load_dotenv
15
-
16
- load_dotenv()
17
- FAL_KEY = os.getenv("FAL_KEY")
18
-
19
  with open("examples/examples.json") as f:
20
  examples = json.load(f)
21
 
22
  # IC Light, Replace Background
23
- async def submit_ic_light_bria(image_data, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color):
24
  if not lightsource_start_color.startswith("#"):
25
  lightsource_start_color = f"#{lightsource_start_color}"
26
  if not lightsource_end_color.startswith("#"):
@@ -29,7 +22,6 @@ async def submit_ic_light_bria(image_data, positive_prompt, negative_prompt, lig
29
  retries = 3
30
  for attempt in range(retries):
31
  try:
32
- fal_client.set_api_key(FAL_KEY) # Set the API key before making the request
33
  handler = await fal_client.submit_async(
34
  "comfy/martintmv-git/ic-light-bria",
35
  arguments={
@@ -38,7 +30,8 @@ async def submit_ic_light_bria(image_data, positive_prompt, negative_prompt, lig
38
  "Negative Prompt": negative_prompt,
39
  "lightsource_start_color": lightsource_start_color,
40
  "lightsource_end_color": lightsource_end_color
41
- }
 
42
  )
43
 
44
  log_index = 0
@@ -70,18 +63,18 @@ async def submit_ic_light_bria(image_data, positive_prompt, negative_prompt, lig
70
  return [f"Error: {str(e)}"], None
71
 
72
  # SDXL, Depth Anything, Replace Background
73
- async def submit_sdxl_rembg(image_data, positive_prompt, negative_prompt):
74
  retries = 3
75
  for attempt in range(retries):
76
  try:
77
- fal_client.set_api_key(FAL_KEY) # Set the API key before making the request
78
  handler = await fal_client.submit_async(
79
  "comfy/martintmv-git/sdxl-depthanything-rembg",
80
  arguments={
81
  "loadimage_1": image_data,
82
  "Positive prompt": positive_prompt,
83
  "Negative prompt": negative_prompt
84
- }
 
85
  )
86
 
87
  log_index = 0
@@ -108,16 +101,15 @@ async def submit_sdxl_rembg(image_data, positive_prompt, negative_prompt):
108
  except Exception as e:
109
  print(f"Attempt {attempt + 1} failed: {e}")
110
  if attempt < retries - 1:
111
- time.sleep(2) # HTTP req retry mechanism
112
  else:
113
  return [f"Error: {str(e)}"], None
114
 
115
  # SV3D, AnimateDiff
116
- async def submit_sv3d(image_data, fps, loop_frames_count, gif_loop):
117
  retries = 3
118
  for attempt in range(retries):
119
  try:
120
- fal_client.set_api_key(FAL_KEY) # Set the API key before making the request
121
  handler = await fal_client.submit_async(
122
  "comfy/martintmv-git/sv3d",
123
  arguments={
@@ -126,7 +118,7 @@ async def submit_sv3d(image_data, fps, loop_frames_count, gif_loop):
126
  "Loop Frames Count": loop_frames_count,
127
  "GIF Loop": gif_loop
128
  },
129
- credentials={"fal_key": FAL_KEY}
130
  )
131
 
132
  log_index = 0
@@ -158,17 +150,17 @@ def convert_image_to_base64(image):
158
  image.save(buffered, format="PNG")
159
  return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode()
160
 
161
- def submit_sync_ic_light_bria(image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color):
162
  image_data = convert_image_to_base64(Image.open(image_upload))
163
- return asyncio.run(submit_ic_light_bria(image_data, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color))
164
 
165
- def submit_sync_sdxl_rembg(image_upload, positive_prompt, negative_prompt):
166
  image_data = convert_image_to_base64(Image.open(image_upload))
167
- return asyncio.run(submit_sdxl_rembg(image_data, positive_prompt, negative_prompt))
168
 
169
- def submit_sync_sv3d(image_upload, fps, loop_frames_count, gif_loop):
170
  image_data = convert_image_to_base64(Image.open(image_upload))
171
- return asyncio.run(submit_sv3d(image_data, fps, loop_frames_count, gif_loop))
172
 
173
  def run_gradio_app():
174
  with gr.Blocks() as demo:
@@ -203,55 +195,36 @@ def run_gradio_app():
203
  output_result = gr.Image(label="Result")
204
 
205
  def validate_api_key(api_key):
206
- global FAL_KEY
207
- FAL_KEY = api_key
208
- return gr.Row(visible=True)
209
 
210
- api_key_submit.click(
211
- fn=validate_api_key,
212
- inputs=api_key_input,
213
- outputs=main_content
214
- )
215
 
216
- def update_ui(workflow):
217
  if workflow == "IC Light, Replace Background":
218
- return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
 
219
  elif workflow == "SDXL, Depth Anything, Replace Background":
220
- return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
 
221
  elif workflow == "SV3D":
222
- return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
 
 
 
223
 
224
- workflow.change(fn=update_ui, inputs=workflow, outputs=[positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop])
225
 
226
- def on_submit(image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop, workflow):
227
  if workflow == "IC Light, Replace Background":
228
- logs, image = submit_sync_ic_light_bria(image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color)
229
- return logs, image
230
  elif workflow == "SDXL, Depth Anything, Replace Background":
231
- logs, image = submit_sync_sdxl_rembg(image_upload, positive_prompt, negative_prompt)
232
- return logs, image
233
  elif workflow == "SV3D":
234
- logs, gif_url = submit_sync_sv3d(image_upload, fps, loop_frames_count, gif_loop)
235
- return logs, gif_url
236
-
237
- submit_btn.click(
238
- fn=on_submit,
239
- inputs=[image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop, workflow],
240
- outputs=[output_logs, output_result]
241
- )
242
-
243
- gr.Examples(
244
- examples=[
245
- [example['input_image'], example['positive_prompt'], example['negative_prompt'], example.get('lightsource_start_color', "#FFFFFF"), example.get('lightsource_end_color', "#000000"), example.get('fps', 8), example.get('loop_frames_count', 30), example.get('gif_loop', True), example['workflow']]
246
- for example in examples
247
- ],
248
- inputs=[image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop, workflow],
249
- outputs=[output_logs, output_result],
250
- fn=on_submit,
251
- cache_examples=True
252
- )
253
 
254
  demo.launch()
255
 
256
  if __name__ == "__main__":
257
- run_gradio_app()
 
9
  import base64
10
  import json
11
 
 
 
 
 
 
 
 
12
  with open("examples/examples.json") as f:
13
  examples = json.load(f)
14
 
15
  # IC Light, Replace Background
16
+ async def submit_ic_light_bria(api_key, image_data, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color):
17
  if not lightsource_start_color.startswith("#"):
18
  lightsource_start_color = f"#{lightsource_start_color}"
19
  if not lightsource_end_color.startswith("#"):
 
22
  retries = 3
23
  for attempt in range(retries):
24
  try:
 
25
  handler = await fal_client.submit_async(
26
  "comfy/martintmv-git/ic-light-bria",
27
  arguments={
 
30
  "Negative Prompt": negative_prompt,
31
  "lightsource_start_color": lightsource_start_color,
32
  "lightsource_end_color": lightsource_end_color
33
+ },
34
+ credentials={"api_key": api_key} # Pass the user's API key dynamically
35
  )
36
 
37
  log_index = 0
 
63
  return [f"Error: {str(e)}"], None
64
 
65
  # SDXL, Depth Anything, Replace Background
66
+ async def submit_sdxl_rembg(api_key, image_data, positive_prompt, negative_prompt):
67
  retries = 3
68
  for attempt in range(retries):
69
  try:
 
70
  handler = await fal_client.submit_async(
71
  "comfy/martintmv-git/sdxl-depthanything-rembg",
72
  arguments={
73
  "loadimage_1": image_data,
74
  "Positive prompt": positive_prompt,
75
  "Negative prompt": negative_prompt
76
+ },
77
+ credentials={"api_key": api_key} # Pass the user's API key dynamically
78
  )
79
 
80
  log_index = 0
 
101
  except Exception as e:
102
  print(f"Attempt {attempt + 1} failed: {e}")
103
  if attempt < retries - 1:
104
+ time.sleep(2) # HTTP req retry mechanism
105
  else:
106
  return [f"Error: {str(e)}"], None
107
 
108
  # SV3D, AnimateDiff
109
+ async def submit_sv3d(api_key, image_data, fps, loop_frames_count, gif_loop):
110
  retries = 3
111
  for attempt in range(retries):
112
  try:
 
113
  handler = await fal_client.submit_async(
114
  "comfy/martintmv-git/sv3d",
115
  arguments={
 
118
  "Loop Frames Count": loop_frames_count,
119
  "GIF Loop": gif_loop
120
  },
121
+ credentials={"api_key": api_key} # Pass the user's API key dynamically
122
  )
123
 
124
  log_index = 0
 
150
  image.save(buffered, format="PNG")
151
  return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode()
152
 
153
+ def submit_sync_ic_light_bria(api_key, image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color):
154
  image_data = convert_image_to_base64(Image.open(image_upload))
155
+ return asyncio.run(submit_ic_light_bria(api_key, image_data, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color))
156
 
157
+ def submit_sync_sdxl_rembg(api_key, image_upload, positive_prompt, negative_prompt):
158
  image_data = convert_image_to_base64(Image.open(image_upload))
159
+ return asyncio.run(submit_sdxl_rembg(api_key, image_data, positive_prompt, negative_prompt))
160
 
161
+ def submit_sync_sv3d(api_key, image_upload, fps, loop_frames_count, gif_loop):
162
  image_data = convert_image_to_base64(Image.open(image_upload))
163
+ return asyncio.run(submit_sv3d(api_key, image_data, fps, loop_frames_count, gif_loop))
164
 
165
  def run_gradio_app():
166
  with gr.Blocks() as demo:
 
195
  output_result = gr.Image(label="Result")
196
 
197
  def validate_api_key(api_key):
198
+ return gr.Row(visible=True), api_key
 
 
199
 
200
+ api_key_submit.click(validate_api_key, [api_key_input], [main_content, api_key_input])
 
 
 
 
201
 
202
+ def submit_handler(api_key, workflow, image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop):
203
  if workflow == "IC Light, Replace Background":
204
+ logs, result_image = submit_sync_ic_light_bria(api_key, image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color)
205
+ return logs, result_image
206
  elif workflow == "SDXL, Depth Anything, Replace Background":
207
+ logs, result_image = submit_sync_sdxl_rembg(api_key, image_upload, positive_prompt, negative_prompt)
208
+ return logs, result_image
209
  elif workflow == "SV3D":
210
+ logs, gif_url = submit_sync_sv3d(api_key, image_upload, fps, loop_frames_count, gif_loop)
211
+ response = requests.get(gif_url)
212
+ gif_bytes = BytesIO(response.content)
213
+ return logs, Image.open(gif_bytes)
214
 
215
+ submit_btn.click(submit_handler, [api_key_input, workflow, image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop], [output_logs, output_result])
216
 
217
+ def update_fields(workflow):
218
  if workflow == "IC Light, Replace Background":
219
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
 
220
  elif workflow == "SDXL, Depth Anything, Replace Background":
221
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
 
222
  elif workflow == "SV3D":
223
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
224
+
225
+ workflow.change(update_fields, workflow, [positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  demo.launch()
228
 
229
  if __name__ == "__main__":
230
+ run_gradio_app()