HusnaManakkot commited on
Commit
4f13759
1 Parent(s): 1d4b29f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -28
app.py CHANGED
@@ -5,50 +5,35 @@ from datasets import load_dataset
5
  # Load the Spider dataset
6
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
8
- # Extract schema information from the dataset
9
- db_table_names = set()
10
- column_names = set()
11
- for item in spider_dataset:
12
- db_id = item['db_id']
13
- for table in item['table_names']:
14
- db_table_names.add((db_id, table))
15
- for column in item['column_names']:
16
- column_names.add(column[1])
17
-
18
  # Load tokenizer and model
19
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
20
  model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
21
 
22
- def post_process_sql_query(sql_query):
23
- # Modify the SQL query to match the dataset's schema
24
- # This is just an example and might need to be adapted based on the dataset and model output
25
- for db_id, table_name in db_table_names:
26
- if "TABLE" in sql_query:
27
- sql_query = sql_query.replace("TABLE", table_name)
28
- break # Assuming only one table is referenced in the query
29
- for column_name in column_names:
30
- if "COLUMN" in sql_query:
31
- sql_query = sql_query.replace("COLUMN", column_name, 1)
32
- return sql_query
33
-
34
  def generate_sql_from_user_input(query):
35
  # Generate SQL for the user's query
36
  input_text = "translate English to SQL: " + query
37
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
38
  outputs = model.generate(**inputs, max_length=512)
39
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
-
41
- # Post-process the SQL query to match the dataset's schema
42
- sql_query = post_process_sql_query(sql_query)
43
  return sql_query
44
 
 
 
 
 
 
 
 
45
  # Create a Gradio interface
46
  interface = gr.Interface(
47
- fn=generate_sql_from_user_input,
 
 
 
48
  inputs=gr.Textbox(label="Enter your natural language query"),
49
- outputs=gr.Textbox(label="Generated SQL Query"),
50
  title="NL to SQL with T5 using Spider Dataset",
51
- description="This model generates an SQL query for your natural language input based on the Spider dataset."
52
  )
53
 
54
  # Launch the app
 
5
  # Load the Spider dataset
6
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
 
 
 
 
 
 
 
 
 
 
8
  # Load tokenizer and model
9
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
10
  model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def generate_sql_from_user_input(query):
13
  # Generate SQL for the user's query
14
  input_text = "translate English to SQL: " + query
15
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
16
  outputs = model.generate(**inputs, max_length=512)
17
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
18
  return sql_query
19
 
20
+ def find_matching_sql(nl_query):
21
+ # Find the matching SQL query from the Spider dataset
22
+ for item in spider_dataset:
23
+ if item['question'].lower() == nl_query.lower():
24
+ return item['query']
25
+ return "No matching SQL query found in the Spider dataset."
26
+
27
  # Create a Gradio interface
28
  interface = gr.Interface(
29
+ fn=lambda query: {
30
+ "Generated SQL Query": generate_sql_from_user_input(query),
31
+ "Matching SQL Query from Spider Dataset": find_matching_sql(query)
32
+ },
33
  inputs=gr.Textbox(label="Enter your natural language query"),
34
+ outputs=[gr.Textbox(label="Generated SQL Query"), gr.Textbox(label="Matching SQL Query from Spider Dataset")],
35
  title="NL to SQL with T5 using Spider Dataset",
36
+ description="This model generates an SQL query for your natural language input and finds a matching SQL query from the Spider dataset."
37
  )
38
 
39
  # Launch the app