HusnaManakkot commited on
Commit
22376c3
1 Parent(s): 1f0417a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -21
app.py CHANGED
@@ -5,30 +5,18 @@ from datasets import load_dataset
5
  # Load the Spider dataset
6
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
8
- # Extract schema information from the dataset
9
- db_table_names = set()
10
  column_names = set()
11
  for item in spider_dataset:
12
- db_id = item['db_id']
13
- for table in item['db']['table_names_original']:
14
- db_table_names.add((db_id, table))
15
- for column in item['db']['column_names_original']:
16
- column_names.add(column[1])
17
 
18
  # Load tokenizer and model
19
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
20
- model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
21
-
22
- def post_process_sql_query(sql_query):
23
- # Modify the SQL query to match the dataset's schema
24
- for db_id, table_name in db_table_names:
25
- if "TABLE" in sql_query:
26
- sql_query = sql_query.replace("TABLE", table_name)
27
- break # Assuming only one table is referenced in the query
28
- for column_name in column_names:
29
- if "COLUMN" in sql_query:
30
- sql_query = sql_query.replace("COLUMN", column_name, 1)
31
- return sql_query
32
 
33
  def generate_sql_from_user_input(query):
34
  # Generate SQL for the user's query
@@ -38,7 +26,13 @@ def generate_sql_from_user_input(query):
38
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
 
40
  # Post-process the SQL query to match the dataset's schema
41
- sql_query = post_process_sql_query(sql_query)
 
 
 
 
 
 
42
  return sql_query
43
 
44
  # Create a Gradio interface
 
5
  # Load the Spider dataset
6
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
8
+ # Extract schema information from the Spider dataset
9
+ table_names = set()
10
  column_names = set()
11
  for item in spider_dataset:
12
+ for table in item['db_id']:
13
+ table_names.add(table)
14
+ for column in item['question']:
15
+ column_names.add(column)
 
16
 
17
  # Load tokenizer and model
18
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
19
+ model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def generate_sql_from_user_input(query):
22
  # Generate SQL for the user's query
 
26
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
 
28
  # Post-process the SQL query to match the dataset's schema
29
+ for table_name in table_names:
30
+ if "TABLE" in sql_query:
31
+ sql_query = sql_query.replace("TABLE", table_name)
32
+ break # Assuming only one table is referenced in the query
33
+ for column_name in column_names:
34
+ if "COLUMN" in sql_query:
35
+ sql_query = sql_query.replace("COLUMN", column_name, 1)
36
  return sql_query
37
 
38
  # Create a Gradio interface