jwang2373's picture
Update app.py
0b2655c verified
raw
history blame
3.99 kB
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
MODEL = "jwang2373/UW-SBEL-ChronoPhi-4b-it"
TITLE = "<h1><center>UW-SBEL-ChronoPhi-4b</center></h1>"
PLACEHOLDER = """
<center>
<p>Hi! I'm a PyChrono Digital Twin expert. How can I assist you today?</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
model = model.eval()
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.1,
max_new_tokens: int = 32768,
top_p: float = 1.0,
top_k: int = 50,
):
print(f'message: {message}')
print(f'history: {history}')
full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
for prompt, answer in history:
full_prompt += f"[INST]{prompt}[/INST]{answer}"
full_prompt += f"[INST]{message}[/INST]"
inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
context_length = inputs.input_ids.shape[-1]
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs=inputs.input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Textbox(
value="You are a PyChrono expert.",
label="System Prompt",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.5,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=4096,
step=1024,
value=4096,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="Top p",
render=False,
),
gr.Slider(
minimum=1,
maximum=100,
step=1,
value=100,
label="Top k",
render=False,
),
],
examples=[
["Run a PyChrono simulation of a sedan driving on a flat surface with a detailed vehicle dynamics model."],
["Run a real-time simulation of an HMMWV vehicle on a bumpy and textured road."],
["Set up a Curiosity rover driving simulation on flat, rigid ground in PyChrono."],
["Simulate a FEDA vehicle driving on rigid terrain in PyChrono."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()