BerserkerMother's picture
Lints the entire project
5aaa03c
raw
history blame
No virus
2.73 kB
"""
Parses user raw user prompts to jobs to run against DB
Classes:
Semantic Parser: object to parse the user prompts
"""
from typing import List, Union
import transformers
class SentenceParser:
"""
Parses user prompts.
Attributes
----------
pipe: pipeline to convert prompts to jobs for matcher
Methods
-------
from_huggingface(pipe_path)
get_ners(sentence)
get_jobs(ner)
"""
mappings = {
"Amenity": "Semantic",
"Price": "Price",
"Hours": "Hours",
"Dish": "Menu",
"Restaurant_Name": "Name",
"Location": "Location",
"Cuisine": "Semantic",
}
def __init__(self, pipe):
"""
Constructs the a nlu
Parameters
----------
pipe: pipeline to converts prompts to jobs
"""
self.pipe = pipe
@classmethod
def from_huggingface(cls, pipe_path: str):
"""
builds a nlu from hf pipeline
Parameters:
pipe_path: hf pipe name
Returns:
nlu object
"""
token_classifier = transformers.pipeline(
"token-classification",
model=pipe_path,
aggregation_strategy="simple",
)
return cls(token_classifier)
def get_ner(self, sentences: Union[str, List[str]]):
"""
Identifies the user intents from prompts
Parameters:
sentences: user prompt(s)
Returns:
list_out: list of dictionaries{category: words}
"""
if isinstance(sentences, str):
sentences = [sentences]
# format output
list_out = []
for items in self.pipe(sentences):
sentence_ner = {}
for recognized_token in items:
if recognized_token["entity_group"] in list(sentence_ner):
sentence_ner[recognized_token["entity_group"]].append(
recognized_token["word"]
)
else:
sentence_ner[recognized_token["entity_group"]] = [
recognized_token["word"]
]
list_out.append(sentence_ner)
return list_out
def get_jobs(self, ners):
"""
Maps user's intents to Matcher jobs
Parameters:
ners: list of user intents for each prompts
"""
list_out = []
for item in ners:
jobs = {}
for ner, value in item.items():
if ner in self.mappings:
jobs[self.mappings[ner]] = value
list_out.append(jobs)
return list_out