Mini-GPT / app.py
TharunSivamani's picture
updated
f3f33cf verified
raw
history blame
No virus
654 Bytes
import gradio as gr
from model import *
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = GPTLanguageModel().to(DEVICE)
model.load_state_dict(torch.load("mini-gpt.pth",map_location=DEVICE), strict=False)
model.eval()
def display(text,number):
answer = decode(model.generate(context, max_new_tokens=number)[0].tolist())
return answer
input_box = gr.Textbox(label="Story Lines",value="Once Upon a Time")
input_slider = gr.Slider(minimum=200, maximum=500, label="Select the maxium number of tokens/words:",step=100)
output_text = gr.Textbox()
gr.Interface(fn=display, inputs=[input_box,input_slider], outputs=output_text).launch()