HusnaManakkot commited on
Commit
7b2c62f
1 Parent(s): b649083

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -27
app.py CHANGED
@@ -1,41 +1,24 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
 
5
- # Load the WikiSQL dataset (only the table schemas are needed for validation)
6
- wikisql_dataset = load_dataset("wikisql", split='train')
7
 
8
- # Load tokenizer and model
9
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
10
- model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
11
-
12
- def validate_sql_against_schema(sql_query, schema):
13
- # This is a placeholder function. You need to implement the logic to validate
14
- # the SQL query against the table schema. The validation can be as simple or as
15
- # complex as you need, depending on the requirements.
16
- return True # Assume the query is valid for now
17
 
18
  def generate_sql_from_user_input(query):
19
- # Generate SQL for the user's query
20
- input_text = "translate English to SQL: " + query
21
- inputs = tokenizer(input_text, return_tensors="pt", padding=True)
22
- outputs = model.generate(**inputs, max_length=512)
23
- sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
-
25
- # Validate the generated SQL query against the schemas in the dataset
26
- for item in wikisql_dataset:
27
- if validate_sql_against_schema(sql_query, item['sql']):
28
- return query, sql_query
29
-
30
- return query, "Generated SQL query is not consistent with the dataset."
31
 
32
  # Create a Gradio interface
33
  interface = gr.Interface(
34
  fn=generate_sql_from_user_input,
35
  inputs=gr.Textbox(label="Enter your natural language query"),
36
- outputs=[gr.Textbox(label="Your Query"), gr.Textbox(label="Generated SQL Query")],
37
- title="NL to SQL with T5 using WikiSQL Dataset",
38
- description="This model generates an SQL query for your natural language input and validates it against the WikiSQL dataset."
39
  )
40
 
41
  # Launch the app
 
1
  import gradio as gr
 
2
  from datasets import load_dataset
3
 
4
+ # Load the WikiSQL dataset
5
+ wikisql_dataset = load_dataset("wikisql", split='train[:100]') # Load a subset of the dataset
6
 
7
+ # Create a mapping between natural language queries and SQL queries
8
+ query_sql_mapping = {item['question']: item['sql']['human_readable'] for item in wikisql_dataset}
 
 
 
 
 
 
 
9
 
10
  def generate_sql_from_user_input(query):
11
+ # Look up the SQL query corresponding to the user's input
12
+ sql_query = query_sql_mapping.get(query, "No exact match found in the dataset.")
13
+ return sql_query
 
 
 
 
 
 
 
 
 
14
 
15
  # Create a Gradio interface
16
  interface = gr.Interface(
17
  fn=generate_sql_from_user_input,
18
  inputs=gr.Textbox(label="Enter your natural language query"),
19
+ outputs=gr.Textbox(label="SQL Query from Dataset"),
20
+ title="NL to SQL using WikiSQL Dataset",
21
+ description="This interface returns the SQL query from the WikiSQL dataset that exactly matches your natural language input."
22
  )
23
 
24
  # Launch the app