File size: 2,267 Bytes
6c6d2f7
f38ba4d
6c6d2f7
 
887c95b
 
f1efe67
769f777
 
 
 
2922759
5f6cbaa
 
 
 
 
769f777
f38ba4d
514fc02
 
4b8f9d6
769f777
0f2dfa7
 
769f777
0f2dfa7
 
 
769f777
0f2dfa7
769f777
0f2dfa7
 
769f777
f38ba4d
 
 
 
 
0f2dfa7
 
769f777
f38ba4d
7cbc7f5
6c6d2f7
 
f38ba4d
769f777
f38ba4d
887c95b
 
6c6d2f7
 
 
db1852e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset

# Load the Spider dataset
spider_dataset = load_dataset("spider", split='train')  # Load a subset of the dataset

# Extract schema information from the dataset
db_table_names = set()
column_names = set()
for item in spider_dataset:
    db_id = item['db_id']
    for table in item['sql']['from']['table_units']:
        if isinstance(table, list):
            db_table_names.add((db_id, table[1]))
    for column in item['sql']['select'][1]:
        column_names.add(column[1][1])

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")

def post_process_sql_query(sql_query):
    # Modify the SQL query to match the dataset's schema
    # This is just an example and might need to be adapted based on the dataset and model output
    for db_id, table_name in db_table_names:
        if "TABLE" in sql_query:
            sql_query = sql_query.replace("TABLE", table_name)
            break  # Assuming only one table is referenced in the query
    for column_name in column_names:
        if "COLUMN" in sql_query:
            sql_query = sql_query.replace("COLUMN", column_name, 1)
    return sql_query

def generate_sql_from_user_input(query):
    # Generate SQL for the user's query
    input_text = "translate English to SQL: " + query
    inputs = tokenizer(input_text, return_tensors="pt", padding=True)
    outputs = model.generate(**inputs, max_length=512)
    sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Post-process the SQL query to match the dataset's schema
    sql_query = post_process_sql_query(sql_query)
    return sql_query

# Create a Gradio interface
interface = gr.Interface(
    fn=generate_sql_from_user_input,
    inputs=gr.Textbox(label="Enter your natural language query"),
    outputs=gr.Textbox(label="Generated SQL Query"),
    title="NL to SQL with T5 using Spider Dataset",
    description="This model generates an SQL query for your natural language input based on the Spider dataset."
)

# Launch the app
if __name__ == "__main__":
    interface.launch()