HusnaManakkot commited on
Commit
0e64ed5
1 Parent(s): 27a9983

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -30
app.py CHANGED
@@ -5,46 +5,16 @@ from datasets import load_dataset
5
  # Load the Spider dataset
6
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
8
- # Extract schema information from the dataset
9
- db_ids = set()
10
- table_names = set()
11
- column_names = set()
12
- for item in spider_dataset:
13
- db_ids.add(item['db_id'])
14
- for table in item['table_names']:
15
- table_names.add(table)
16
- for column in item['column_names']:
17
- column_names.add(column[1])
18
-
19
  # Load tokenizer and model
20
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
21
  model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
22
 
23
- def post_process_sql_query(sql_query):
24
- # Modify the SQL query to match the dataset's schema
25
- # This is just an example and might need to be adapted based on the dataset and model output
26
- for db_id in db_ids:
27
- if "DB_ID" in sql_query:
28
- sql_query = sql_query.replace("DB_ID", db_id)
29
- break # Assuming only one database is referenced in the query
30
- for table_name in table_names:
31
- if "TABLE" in sql_query:
32
- sql_query = sql_query.replace("TABLE", table_name)
33
- break # Assuming only one table is referenced in the query
34
- for column_name in column_names:
35
- if "COLUMN" in sql_query:
36
- sql_query = sql_query.replace("COLUMN", column_name, 1)
37
- return sql_query
38
-
39
  def generate_sql_from_user_input(query):
40
  # Generate SQL for the user's query
41
  input_text = "translate English to SQL: " + query
42
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
43
  outputs = model.generate(**inputs, max_length=512)
44
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
-
46
- # Post-process the SQL query to match the dataset's schema
47
- sql_query = post_process_sql_query(sql_query)
48
  return sql_query
49
 
50
  # Create a Gradio interface
 
5
  # Load the Spider dataset
6
  spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
 
 
 
 
 
 
 
 
 
 
 
8
  # Load tokenizer and model
9
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
10
  model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def generate_sql_from_user_input(query):
13
  # Generate SQL for the user's query
14
  input_text = "translate English to SQL: " + query
15
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
16
  outputs = model.generate(**inputs, max_length=512)
17
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
18
  return sql_query
19
 
20
  # Create a Gradio interface