AutoGen / app.py
poonampal's picture
Add application file
1ad9199
raw
history blame contribute delete
No virus
10.4 kB
import gradio as gr
import os
from pathlib import Path
import autogen
import chromadb
import multiprocessing as mp
from autogen.retrieve_utils import TEXT_FORMATS, get_file_from_url, is_url
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
from autogen.agentchat.contrib.retrieve_user_proxy_agent import (
RetrieveUserProxyAgent,
PROMPT_CODE,
)
TIMEOUT = 60
def initialize_agents(config_list, docs_path=None):
if isinstance(config_list, gr.State):
_config_list = config_list.value
else:
_config_list = config_list
if docs_path is None:
docs_path = "https://raw.githubusercontent.com/microsoft/autogen/main/README.md"
assistant = RetrieveAssistantAgent(
name="assistant",
system_message="You are a helpful assistant.",
)
ragproxyagent = RetrieveUserProxyAgent(
name="ragproxyagent",
human_input_mode="NEVER",
max_consecutive_auto_reply=5,
retrieve_config={
"task": "code",
"docs_path": docs_path,
"chunk_token_size": 2000,
"model": _config_list[0]["model"],
"client": chromadb.PersistentClient(path="/tmp/chromadb"),
"embedding_model": "all-mpnet-base-v2",
"customized_prompt": PROMPT_CODE,
"get_or_create": True,
"collection_name": "autogen_rag",
},
)
return assistant, ragproxyagent
def initiate_chat(config_list, problem, queue, n_results=3):
global assistant, ragproxyagent
if isinstance(config_list, gr.State):
_config_list = config_list.value
else:
_config_list = config_list
if len(_config_list[0].get("api_key", "")) < 2:
queue.put(
["Hi, nice to meet you! Please enter your API keys in below text boxs."]
)
return
else:
llm_config = (
{
"request_timeout": TIMEOUT,
# "seed": 42,
"config_list": _config_list,
"use_cache": False,
},
)
assistant.llm_config.update(llm_config[0])
assistant.reset()
try:
ragproxyagent.initiate_chat(
assistant, problem=problem, silent=False, n_results=n_results
)
messages = ragproxyagent.chat_messages
messages = [messages[k] for k in messages.keys()][0]
messages = [m["content"] for m in messages if m["role"] == "user"]
print("messages: ", messages)
except Exception as e:
messages = [str(e)]
queue.put(messages)
def chatbot_reply(input_text):
"""Chat with the agent through terminal."""
queue = mp.Queue()
process = mp.Process(
target=initiate_chat,
args=(config_list, input_text, queue),
)
process.start()
try:
# process.join(TIMEOUT+2)
messages = queue.get(timeout=TIMEOUT)
except Exception as e:
messages = [
str(e)
if len(str(e)) > 0
else "Invalid Request to OpenAI, please check your API keys."
]
finally:
try:
process.terminate()
except:
pass
return messages
def get_description_text():
return """
# Microsoft AutoGen: Retrieve Chat Demo
This demo shows how to use the RetrieveUserProxyAgent and RetrieveAssistantAgent to build a chatbot.
#### [AutoGen](https://github.com/microsoft/autogen) [Discord](https://discord.gg/pAbnFJrkgZ) [Blog](https://microsoft.github.io/autogen/blog/2023/10/18/RetrieveChat) [Paper](https://arxiv.org/abs/2308.08155) [SourceCode](https://github.com/thinkall/autogen-demos)
"""
global assistant, ragproxyagent
with gr.Blocks() as demo:
config_list, assistant, ragproxyagent = (
gr.State(
[
{
"api_key": "",
"api_base": "",
"api_type": "azure",
"api_version": "2023-07-01-preview",
"model": "gpt-35-turbo",
}
]
),
None,
None,
)
assistant, ragproxyagent = initialize_agents(config_list)
gr.Markdown(get_description_text())
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
avatar_images=(None, (os.path.join(os.path.dirname(__file__), "autogen.png"))),
# height=600,
)
txt_input = gr.Textbox(
scale=4,
show_label=False,
placeholder="Enter text and press enter",
container=False,
)
with gr.Row():
def update_config(config_list):
global assistant, ragproxyagent
config_list = autogen.config_list_from_models(
model_list=[os.environ.get("MODEL", "gpt-35-turbo")],
)
if not config_list:
config_list = [
{
"api_key": "",
"api_base": "",
"api_type": "azure",
"api_version": "2023-07-01-preview",
"model": "gpt-35-turbo",
}
]
llm_config = (
{
"request_timeout": TIMEOUT,
# "seed": 42,
"config_list": config_list,
},
)
assistant.llm_config.update(llm_config[0])
ragproxyagent._model = config_list[0]["model"]
return config_list
def set_params(model, oai_key, aoai_key, aoai_base):
os.environ["MODEL"] = model
os.environ["OPENAI_API_KEY"] = oai_key
os.environ["AZURE_OPENAI_API_KEY"] = aoai_key
os.environ["AZURE_OPENAI_API_BASE"] = aoai_base
return model, oai_key, aoai_key, aoai_base
txt_model = gr.Dropdown(
label="Model",
choices=[
"gpt-4",
"gpt-35-turbo",
"gpt-3.5-turbo",
],
allow_custom_value=True,
value="gpt-35-turbo",
container=True,
)
txt_oai_key = gr.Textbox(
label="OpenAI API Key",
placeholder="Enter key and press enter",
max_lines=1,
show_label=True,
value=os.environ.get("OPENAI_API_KEY", ""),
container=True,
type="password",
)
txt_aoai_key = gr.Textbox(
label="Azure OpenAI API Key",
placeholder="Enter key and press enter",
max_lines=1,
show_label=True,
value=os.environ.get("AZURE_OPENAI_API_KEY", ""),
container=True,
type="password",
)
txt_aoai_base_url = gr.Textbox(
label="Azure OpenAI API Base",
placeholder="Enter base url and press enter",
max_lines=1,
show_label=True,
value=os.environ.get("AZURE_OPENAI_API_BASE", ""),
container=True,
type="password",
)
clear = gr.ClearButton([txt_input, chatbot])
with gr.Row():
def upload_file(file):
return update_context_url(file.name)
upload_button = gr.UploadButton(
"Click to upload a context file or enter a url in the right textbox",
file_types=[f".{i}" for i in TEXT_FORMATS],
file_count="single",
)
txt_context_url = gr.Textbox(
label="Enter the url to your context file and chat on the context",
info=f"File must be in the format of [{', '.join(TEXT_FORMATS)}]",
max_lines=1,
show_label=True,
value="https://raw.githubusercontent.com/microsoft/autogen/main/README.md",
container=True,
)
txt_prompt = gr.Textbox(
label="Enter your prompt for Retrieve Agent and press enter to replace the default prompt",
max_lines=40,
show_label=True,
value=PROMPT_CODE,
container=True,
show_copy_button=True,
)
def respond(message, chat_history, model, oai_key, aoai_key, aoai_base):
global config_list
set_params(model, oai_key, aoai_key, aoai_base)
config_list = update_config(config_list)
messages = chatbot_reply(message)
_msg = (
messages[-1]
if len(messages) > 0 and messages[-1] != "TERMINATE"
else messages[-2]
if len(messages) > 1
else "Context is not enough for answering the question. Please press `enter` in the context url textbox to make sure the context is activated for the chat."
)
chat_history.append((message, _msg))
return "", chat_history
def update_prompt(prompt):
ragproxyagent.customized_prompt = prompt
return prompt
def update_context_url(context_url):
global assistant, ragproxyagent
file_extension = Path(context_url).suffix
print("file_extension: ", file_extension)
if file_extension.lower() not in [f".{i}" for i in TEXT_FORMATS]:
return f"File must be in the format of {TEXT_FORMATS}"
if is_url(context_url):
try:
file_path = get_file_from_url(
context_url,
save_path=os.path.join("/tmp", os.path.basename(context_url)),
)
except Exception as e:
return str(e)
else:
file_path = context_url
context_url = os.path.basename(context_url)
try:
chromadb.PersistentClient(path="/tmp/chromadb").delete_collection(
name="autogen_rag"
)
except:
pass
assistant, ragproxyagent = initialize_agents(config_list, docs_path=file_path)
return context_url
txt_input.submit(
respond,
[txt_input, chatbot, txt_model, txt_oai_key, txt_aoai_key, txt_aoai_base_url],
[txt_input, chatbot],
)
txt_prompt.submit(update_prompt, [txt_prompt], [txt_prompt])
txt_context_url.submit(update_context_url, [txt_context_url], [txt_context_url])
upload_button.upload(upload_file, upload_button, [txt_context_url])
if __name__ == "__main__":
demo.launch(share=True, server_name="0.0.0.0")