UniPortrait / src /generation.py
Junjie96's picture
add nsfw safety checker
6475329 verified
raw
history blame contribute delete
No virus
3.05 kB
import json
import os
import time
import gradio as gr
import requests
from src.log import logger
from src.util import download_images
def call_generation(data):
url_task = os.getenv("URL_TASK")
api_key = os.getenv("API_KEY_GENERATION")
model_id = os.getenv("MODEL_ID")
url_query = os.getenv("URL_QUERY")
batch_size = 4
repeat_times = 1
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key}",
"X-DashScope-Async": "enable",
}
data["model"] = model_id
data["parameters"]["n"] = batch_size
all_res_ = []
for i in range(repeat_times):
if data["parameters"]["seed"] != -1:
data["parameters"]["seed"] = data["parameters"]["seed"] * (i+1)
res_ = requests.post(url_task, data=json.dumps(data), headers=headers)
all_res_.append(res_)
all_image_data = []
for res_ in all_res_:
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
task_id = res['output']['task_id']
logger.info(f"task_id: {task_id}: Create request success. Params: {data}")
# Async query
is_running = True
while is_running:
res_ = requests.post(f'{url_query}/{task_id}', headers=headers)
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
if "SUCCEEDED" == res['output']['task_status']:
logger.info(f"task_id: {task_id}: Generation task query success.")
results = res['output']['results']
img_urls = [x['url'] for x in results]
logger.info(f"task_id: {task_id}: {res}")
break
elif "FAILED" != res['output']['task_status']:
logger.debug(f"task_id: {task_id}: query result...")
time.sleep(1)
else:
raise gr.Error(
"Fail to get results from Generation task. Make sure all the ID images have a clear face. If it still doesn't work, you can contact us or open an issue.")
else:
logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
raise gr.Error("Fail to query task result.")
logger.info(f"task_id: {task_id}: download generated images.")
img_data = download_images(img_urls, batch_size)
logger.info(f"task_id: {task_id}: Generate done.")
all_image_data += img_data
else:
logger.error(f'Fail to create Generation task: {res_.content}')
raise gr.Error("Fail to create Generation task.")
if len(all_image_data) != repeat_times * batch_size:
raise gr.Error("Fail to Generation.")
return all_image_data