Spaces:
Sleeping
Sleeping
File size: 2,267 Bytes
6c6d2f7 f38ba4d 6c6d2f7 887c95b f1efe67 769f777 2922759 5f6cbaa 769f777 f38ba4d 514fc02 4b8f9d6 769f777 0f2dfa7 769f777 0f2dfa7 769f777 0f2dfa7 769f777 0f2dfa7 769f777 f38ba4d 0f2dfa7 769f777 f38ba4d 7cbc7f5 6c6d2f7 f38ba4d 769f777 f38ba4d 887c95b 6c6d2f7 db1852e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
# Load the Spider dataset
spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
# Extract schema information from the dataset
db_table_names = set()
column_names = set()
for item in spider_dataset:
db_id = item['db_id']
for table in item['sql']['from']['table_units']:
if isinstance(table, list):
db_table_names.add((db_id, table[1]))
for column in item['sql']['select'][1]:
column_names.add(column[1][1])
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
def post_process_sql_query(sql_query):
# Modify the SQL query to match the dataset's schema
# This is just an example and might need to be adapted based on the dataset and model output
for db_id, table_name in db_table_names:
if "TABLE" in sql_query:
sql_query = sql_query.replace("TABLE", table_name)
break # Assuming only one table is referenced in the query
for column_name in column_names:
if "COLUMN" in sql_query:
sql_query = sql_query.replace("COLUMN", column_name, 1)
return sql_query
def generate_sql_from_user_input(query):
# Generate SQL for the user's query
input_text = "translate English to SQL: " + query
inputs = tokenizer(input_text, return_tensors="pt", padding=True)
outputs = model.generate(**inputs, max_length=512)
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Post-process the SQL query to match the dataset's schema
sql_query = post_process_sql_query(sql_query)
return sql_query
# Create a Gradio interface
interface = gr.Interface(
fn=generate_sql_from_user_input,
inputs=gr.Textbox(label="Enter your natural language query"),
outputs=gr.Textbox(label="Generated SQL Query"),
title="NL to SQL with T5 using Spider Dataset",
description="This model generates an SQL query for your natural language input based on the Spider dataset."
)
# Launch the app
if __name__ == "__main__":
interface.launch()
|