HusnaManakkot commited on
Commit
f1efe67
1 Parent(s): db1852e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -17
app.py CHANGED
@@ -1,35 +1,52 @@
1
  import gradio as gr
 
2
  from datasets import load_dataset
3
- from difflib import get_close_matches
4
 
5
  # Load the WikiSQL dataset
6
- wikisql_dataset = load_dataset("wikisql", split='train[:100]')
7
 
8
- # Create a mapping between natural language queries and SQL queries
9
- query_sql_mapping = {item['question']: item['sql']['human_readable'] for item in wikisql_dataset}
 
 
 
 
 
10
 
11
- def find_closest_match(query, dataset):
12
- questions = [item['question'] for item in dataset]
13
- matches = get_close_matches(query, questions, n=1)
14
- return matches[0] if matches else None
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def generate_sql_from_user_input(query):
17
- # Find the closest match in the dataset
18
- matched_query = find_closest_match(query, wikisql_dataset)
19
- if not matched_query:
20
- return "No close match found in the dataset."
 
21
 
22
- # Retrieve the corresponding SQL query from the dataset
23
- sql_query = query_sql_mapping.get(matched_query, "SQL query not found.")
24
  return sql_query
25
 
26
  # Create a Gradio interface
27
  interface = gr.Interface(
28
  fn=generate_sql_from_user_input,
29
  inputs=gr.Textbox(label="Enter your natural language query"),
30
- outputs=gr.Textbox(label="SQL Query from Dataset"),
31
- title="NL to SQL using WikiSQL Dataset",
32
- description="Enter a natural language query and get the corresponding SQL query from the WikiSQL dataset."
33
  )
34
 
35
  # Launch the app
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
 
4
 
5
  # Load the WikiSQL dataset
6
+ wikisql_dataset = load_dataset("wikisql", split='train[:100]') # Load a subset of the dataset
7
 
8
+ # Extract schema information from the dataset
9
+ table_names = set()
10
+ column_names = set()
11
+ for item in wikisql_dataset:
12
+ table_names.add(item['table']['name'])
13
+ for column in item['table']['header']:
14
+ column_names.add(column)
15
 
16
+ # Load tokenizer and model
17
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
18
+ model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
19
+
20
+ def post_process_sql_query(sql_query):
21
+ # Modify the SQL query to match the dataset's schema
22
+ # This is just an example and might need to be adapted based on the dataset and model output
23
+ for table_name in table_names:
24
+ if "TABLE" in sql_query:
25
+ sql_query = sql_query.replace("TABLE", table_name)
26
+ break # Assuming only one table is referenced in the query
27
+ for column_name in column_names:
28
+ if "COLUMN" in sql_query:
29
+ sql_query = sql_query.replace("COLUMN", column_name, 1)
30
+ return sql_query
31
 
32
  def generate_sql_from_user_input(query):
33
+ # Generate SQL for the user's query
34
+ input_text = "translate English to SQL: " + query
35
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True)
36
+ outputs = model.generate(**inputs, max_length=512)
37
+ sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
 
39
+ # Post-process the SQL query to match the dataset's schema
40
+ sql_query = post_process_sql_query(sql_query)
41
  return sql_query
42
 
43
  # Create a Gradio interface
44
  interface = gr.Interface(
45
  fn=generate_sql_from_user_input,
46
  inputs=gr.Textbox(label="Enter your natural language query"),
47
+ outputs=gr.Textbox(label="Generated SQL Query"),
48
+ title="NL to SQL with T5 using WikiSQL Dataset",
49
+ description="This model generates an SQL query for your natural language input based on the WikiSQL dataset."
50
  )
51
 
52
  # Launch the app