ohtaman's picture
Create app.py
419b7c9
raw
history blame
No virus
2.97 kB
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
import torch
import gradio as gr
BASE_MODEL_NAME = "tiiuae/falcon-7b"
MODEL_NAME = "ohtaman/falcon-7b-kokkai2022-lora"
tokenizer = transformers.AutoTokenizer.from_pretrained(BASE_MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model = PeftModel.from_pretrained(base_model, MODEL_NAME)
def generate_prompt(question: str, questioner: str="", answerer: str=""):
return f"""# question
{questioner}
{question}
# answer
{answerer}
"""
def evaluate(
quetion: str,
questioner: str="",
answerer: str="",
temperature: float=0.1,
top_p: float=0.75,
top_k: int=40,
num_beams: int=4,
repetition_penalty: float=1.05,
outputs.sequences[0, input_length:-1]_tokens: int=256,
**kwargs
):
prompt = generate_prompt(question, questioner, answerer)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
n_input_tokens = input_ids.shape[1]
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0, n_input_tokens:-1]
return tokenizer.decode(s)
g = gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(lines=5, label="Question", placeholder="Question"),
gr.components.Textbox(lines=1, label="Questioner", placeholder="Questioner"),
gr.components.Textbox(lines=1, label="Answerer", placeholder="Answerer"),
gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
gr.components.Slider(minimum=0, maximum=2, step=0.01, value=1.05, label="Repetition Penalty"),
gr.components.Slider(minimum=1, maximum=512, step=1, value=128, label="Max tokens"),
],
outputs=[
gr.inputs.Textbox(
lines=5,
label="Output",
)
],
title="🏛️: Kokkai 2022",
description="falcon-7b-kokkai2022 is a 7B-parameter model trained on Japan's 2022 Diet proceedings using LoRA based on [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b).",
)
g.queue(concurrency_count=1)
g.launch()