import gradio as gr import spaces import torch from transformers import AutoTokenizer, AutoModelForCausalLM # Load the model and tokenizer model_name = "akjindal53244/Llama-3.1-Storm-8B" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto" ) @spaces.GPU(duration=120) def generate_text(prompt, max_length, temperature): messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ] formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=max_length, do_sample=True, temperature=temperature, top_k=100, top_p=0.95, ) return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) # Custom CSS css = """ body { background-color: #1a1a2e; color: #e0e0e0; font-family: 'Arial', sans-serif; } .container { max-width: 900px; margin: auto; padding: 20px; } .gradio-container { background-color: #16213e; border-radius: 15px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .header { background-color: #0f3460; padding: 20px; border-radius: 15px 15px 0 0; text-align: center; margin-bottom: 20px; } .header h1 { color: #e94560; font-size: 2.5em; margin-bottom: 10px; } .header p { color: #a0a0a0; } .header img { max-width: 300px; border-radius: 10px; margin: 15px auto; display: block; } .input-group, .output-group { background-color: #1a1a2e; padding: 20px; border-radius: 10px; margin-bottom: 20px; } .input-group label, .output-group label { color: #e94560; font-weight: bold; } .generate-btn { background-color: #e94560 !important; color: white !important; border: none !important; border-radius: 5px !important; padding: 10px 20px !important; font-size: 16px !important; cursor: pointer !important; transition: background-color 0.3s ease !important; } .generate-btn:hover { background-color: #c81e45 !important; } .example-prompts { background-color: #1f2b47; padding: 15px; border-radius: 10px; margin-bottom: 20px; } .example-prompts h3 { color: #e94560; margin-bottom: 10px; } .example-prompts ul { list-style-type: none; padding-left: 0; } .example-prompts li { margin-bottom: 5px; cursor: pointer; transition: color 0.3s ease; } .example-prompts li:hover { color: #e94560; } """ # Example prompts example_prompts = [ "Write a Python function to find the n-th Fibonacci number.", "Explain the concept of recursion in programming.", "What are the key differences between Python and JavaScript?", "Tell me a short story about a time-traveling robot.", "Describe the process of photosynthesis in simple terms." ] # Gradio interface # Gradio interface with gr.Blocks(css=css) as iface: gr.HTML( """

Llama-3.1-Storm-8B Text Generation

Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!

Llama
""" ) with gr.Group(): with gr.Group(elem_classes="example-prompts"): gr.HTML("

Example Prompts:

") example_buttons = [gr.Button(prompt) for prompt in example_prompts] with gr.Group(elem_classes="input-group"): prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=5) max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length") temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") generate_btn = gr.Button("Generate", elem_classes="generate-btn") with gr.Group(elem_classes="output-group"): output = gr.Textbox(label="Generated Text", lines=10) generate_btn.click(generate_text, inputs=[prompt, max_length, temperature], outputs=output) # Set up example prompt buttons for button in example_buttons: button.click(lambda x: x, inputs=[button], outputs=[prompt]) # Launch the app iface.launch()