""" from https://github.com/keithito/tacotron """ # ***************************************************************************** # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of the NVIDIA CORPORATION nor the # names of its contributors may be used to endorse or promote products # derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # ***************************************************************************** import re valid_symbols = [ "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2", "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2", "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY", "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1", "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0", "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW", "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH", ] """ Defines the set of symbols used in text input to the model. The default is a set of ASCII characters that works well for English. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ _pad = "_" _punctuation = "!'(),.:;? " _special = "-" _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz^*" # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same # as uppercase letters): _arpabet = ["@" + s for s in valid_symbols] # Export all symbols: symbols = ( [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet ) # Mappings from symbol to numeric ID and vice versa: _symbol_to_id = {s: i for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)} # Regular expression matching text enclosed in curly braces: _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") # Regular expression matching whitespace: _whitespace_re = re.compile(r"\s+") # List of (regular expression, replacement) pairs for abbreviations: _abbreviations = [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ ("mrs", "misess"), ("mr", "mister"), ("dr", "doctor"), ("st", "saint"), ("co", "company"), ("jr", "junior"), ("maj", "major"), ("gen", "general"), ("drs", "doctors"), ("rev", "reverend"), ("lt", "lieutenant"), ("hon", "honorable"), ("sgt", "sergeant"), ("capt", "captain"), ("esq", "esquire"), ("ltd", "limited"), ("col", "colonel"), ("ft", "fort"), ] ] def expand_abbreviations(text): """expand abbreviations pre-defined """ for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) return text # def expand_numbers(text): # return normalize_numbers(text) def lowercase(text): """lowercase the text """ return text.lower() def collapse_whitespace(text): """Replaces whitespace by " " in the text """ return re.sub(_whitespace_re, " ", text) def convert_to_ascii(text): """Converts text to ascii """ text_encoded = text.encode("ascii", "ignore") return text_encoded.decode() def basic_cleaners(text): """Basic pipeline that lowercases and collapses whitespace without transliteration. """ text = lowercase(text) text = collapse_whitespace(text) return text def transliteration_cleaners(text): """Pipeline for non-English text that transliterates to ASCII. """ text = convert_to_ascii(text) text = lowercase(text) text = collapse_whitespace(text) return text def english_cleaners(text): """Pipeline for English text, including number and abbreviation expansion. """ text = convert_to_ascii(text) text = lowercase(text) text = expand_abbreviations(text) text = collapse_whitespace(text) return text def text_to_sequence(text, cleaner_names): """Returns a list of integers corresponding to the symbols in the text. Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." Arguments --------- text : str string to convert to a sequence cleaner_names : list names of the cleaner functions to run the text through """ sequence = [] # Check for curly braces and treat their contents as ARPAbet: while len(text): m = _curly_re.match(text) if not m: sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) break sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) sequence += _arpabet_to_sequence(m.group(2)) text = m.group(3) return sequence def sequence_to_text(sequence): """Converts a sequence of IDs back to a string """ result = "" for symbol_id in sequence: if symbol_id in _id_to_symbol: s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: if len(s) > 1 and s[0] == "@": s = "{%s}" % s[1:] result += s return result.replace("}{", " ") def _clean_text(text, cleaner_names): """apply different cleaning pipeline according to cleaner_names """ for name in cleaner_names: if name == "english_cleaners": cleaner = english_cleaners if name == "transliteration_cleaners": cleaner = transliteration_cleaners if name == "basic_cleaners": cleaner = basic_cleaners if not cleaner: raise Exception("Unknown cleaner: %s" % name) text = cleaner(text) return text def _symbols_to_sequence(symbols): """convert symbols to sequence """ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] def _arpabet_to_sequence(text): """Prepend "@" to ensure uniqueness """ return _symbols_to_sequence(["@" + s for s in text.split()]) def _should_keep_symbol(s): """whether to keep a certain symbol """ return s in _symbol_to_id and s != "_" and s != "~"