import os import re from tqdm import tqdm from datasets import Dataset, DatasetDict from data_manipulation.creation_gazetteers import build_reverse_dictionary, lemmatizing, load_json #################################################################################################### ### GAZETTEERS EMBEDDINGS ########################################################################## #################################################################################################### def find_multi_token_matches(tokens, looking_tokens, gazetteers, matches): i = 0 n = len(tokens) assert n == len(looking_tokens) while i < n: for length in range(min(5, n-i), 0, -1): # Assuming maximum entity length is 5 phrase = ' '.join(looking_tokens[i:i+length]) for gazetteer in gazetteers: if phrase in gazetteer: match_type = gazetteer[phrase] for index in range(i, i+length): matches.setdefault(tokens[index], []).append((phrase, match_type)) i += 1 return matches def find_single_token_matches(tokens, looking_tokens, gazetteers, matches): n = len(tokens) assert n == len(looking_tokens) for index in range(n): word = looking_tokens[index] if len(word) < 3: continue for gazetteer in gazetteers: if word in gazetteer: match_type = gazetteer[word] matches.setdefault(tokens[index], []).append((word, match_type)) return matches def gazetteer_matching(words, gazetteers_for_matching, args=None): ending_ova = True method_for_gazetteers_matching = "single" apply_lemmatizing = True if method_for_gazetteers_matching == "single": matches = find_single_token_matches(words, words, gazetteers_for_matching, {}) if apply_lemmatizing: lemmatize_tokens = [lemmatizing(t) for t in words] matches = find_single_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches) else: # multi_token_match matches = find_multi_token_matches(words, words, gazetteers_for_matching, {}) if apply_lemmatizing: lemmatize_tokens = [lemmatizing(t) for t in words] matches = find_multi_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches) result = [] for word in words: mid_res = sorted(matches.get(word, []), key=lambda x: x[0].count(" "), reverse=True) per, org, loc = 0, 0, 0 for res in mid_res: if mid_res[0][0].count(" ") == res[0].count(" "): if res[1] == "PER": per = 5 elif res[1] == "ORG": org = 5 elif res[1] == "LOC": loc = 5 if ending_ova and word.endswith("ová") and word[0].isupper(): per = 5 result.append([per, org, loc]) return result #################################################################################################### ### CNEC DATASET ################################################################################### #################################################################################################### def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args): """ label_mapper: cnec labels to int """ # Open and read the XML file as plain text id_ = 0 with open(xml_file_path, "r", encoding="utf-8") as xml_file: plain_text = xml_file.read() plain_text = plain_text[5:-5] # remove unnessery characters plain_text = re.sub(r'([a-zA-Z.])([a-zA-Z.])', r' \1', plain_text) plain_text = re.sub(r'[ ]+', ' ', plain_text) sentences = plain_text.split("\n") ne_pattern = r'([^<]+)' data = [] if args.apply_extended_embeddings: gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path) temp = [] for i in gazetteers_for_matching.keys(): temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]})) gazetteers_for_matching = temp for sentence in tqdm(sentences): entity_mapping = [] while "{entity}' index = sentence.index(pattern) temp_index = index sentence = sentence.replace(pattern, entity, 1) temp_index -= sum([len(f'') for tag in re.findall(r'', sentence[:index])]) temp_index -= sentence[:index].count("") * len("") temp_index -= (re.sub(r'', "", sentence[:index]).replace("", "")).count(" ") index = temp_index entity_mapping.append((entity, label, index, index + len(entity))) entities = [] for entity, label, start, end in entity_mapping: for tag in label_mapper.keys(): if label.lower().startswith(tag): entities.append((label_mapper[tag], entity, start, end)) break entities.sort(key=lambda x: len(x[1]), reverse=True) words = re.split(r'\s+', sentence) tags_per_word = [] sentence_counter = -1 for word in words: sentence_counter += len(word) + 1 if len(entities) == 0: tags_per_word.append(0) # tag representing no label for no word for index_entity in range(len(entities)): if not(sentence_counter - len(word) >= entities[index_entity][2] and sentence_counter <= entities[index_entity][3] and word in entities[index_entity][1]): if index_entity == len(entities) - 1: tags_per_word.append(0) # tag representing no label for word continue if args.division_to_BI_tags: if sentence_counter - len(word) == entities[index_entity][2]: tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity else: tags_per_word.append(entities[index_entity][0] * 2) # inside of entity else: tags_per_word.append(entities[index_entity][0]) break if args.contain_only_label_sentences and tags_per_word.count(0) == len(tags_per_word): continue if tags_per_word == [] or tags_per_word == [0]: continue if args.apply_extended_embeddings: matching = gazetteer_matching(words, gazetteers_for_matching, args) data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words), "gazetteers": matching}) else: data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)}) id_ += 1 return data def get_default_dataset_from_cnec(label_mapper:dict, xml_file_path): """ label_mapper: cnec labels to int """ # Open and read the XML file as plain text id_ = 0 with open(xml_file_path, "r", encoding="utf-8") as xml_file: plain_text = xml_file.read() plain_text = plain_text[5:-5] # remove unnessery characters plain_text = re.sub(r'([a-zA-Z.])([a-zA-Z.])', r' \1', plain_text) plain_text = re.sub(r'[ ]+', ' ', plain_text) sentences = plain_text.split("\n") ne_pattern = r'([^<]+)' data = [] for sentence in tqdm(sentences): entity_mapping = [] while "{entity}' index = sentence.index(pattern) temp_index = index sentence = sentence.replace(pattern, entity, 1) temp_index -= sum([len(f'') for tag in re.findall(r'', sentence[:index])]) temp_index -= sentence[:index].count("") * len("") temp_index -= (re.sub(r'', "", sentence[:index]).replace("", "")).count(" ") index = temp_index entity_mapping.append((entity, label, index, index + len(entity))) entities = [] for entity, label, start, end in entity_mapping: for tag in label_mapper.keys(): if label.lower().startswith(tag): entities.append((label_mapper[tag], entity, start, end)) break entities.sort(key=lambda x: len(x[1]), reverse=True) words = re.split(r'\s+', sentence) tags_per_word = [] sentence_counter = -1 for word in words: sentence_counter += len(word) + 1 if len(entities) == 0: tags_per_word.append(0) # tag representing no label for no word for index_entity in range(len(entities)): if not(sentence_counter - len(word) >= entities[index_entity][2] and sentence_counter <= entities[index_entity][3] and word in entities[index_entity][1]): if index_entity == len(entities) - 1: tags_per_word.append(0) # tag representing no label for word continue if sentence_counter - len(word) == entities[index_entity][2]: tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity else: tags_per_word.append(entities[index_entity][0] * 2) # inside of entity if tags_per_word == [] or tags_per_word == [0]: continue data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)}) id_ += 1 return data def create_cnec_dataset(label_mapper:dict, args): dataset = DatasetDict() for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]): file_path = os.path.join(args.cnec_dataset_dir_path, file_name) temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args) dataset[part] = Dataset.from_list(temp_dataset) return dataset #################################################################################################### ### WIKIANN DATASET ################################################################################ #################################################################################################### def load_wikiann_testing_dataset(args): if args.apply_extended_embeddings: gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path) temp = [] for i in gazetteers_for_matching.keys(): temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]})) gazetteers_for_matching = temp dataset = [] index = 0 sentences = load_tagged_sentences(args.wikiann_dataset_path) for sentence in sentences: words = [word for word, _ in sentence] tags = [tag for _, tag in sentence] if args.apply_extended_embeddings: matching = gazetteer_matching(words, gazetteers_for_matching, args) dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching}) else: dataset.append({"id": index, 'tokens': words, 'ner_tags': tags}) index += 1 test = Dataset.from_list(dataset) dataset = DatasetDict({"train": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]), "validation": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]), "test": test}) # dataset = DatasetDict({"test": test}) return dataset def load_tagged_sentences(file_path): sentences = [] # List to hold all sentences current_sentence = [] # List to hold current sentence tokens and tags with open(file_path, 'r', encoding='utf-8') as file: for line in file: line = line.strip() # Remove any extra whitespace from the line if line: # Split the line into token and tag token_tag_pair = line.split() if len(token_tag_pair) == 2: # Add the token and tag tuple to the current sentence current_sentence.append((token_tag_pair[0].split(':')[1], token_tag_pair[1])) else: # If line is empty and current sentence is not, add it to sentences if current_sentence: sentences.append(current_sentence) current_sentence = [] # Reset for the next sentence # Add the last sentence if the file doesn't end with a blank line if current_sentence: sentences.append(current_sentence) return sentences #################################################################################################### ### TOKENIZE DATASET ############################################################################### #################################################################################################### def align_labels_with_tokens(labels, word_ids): new_labels = [] current_word = None for word_id in word_ids: if word_id != current_word: # Start of a new word! current_word = word_id label = -100 if word_id is None else labels[word_id] new_labels.append(label) elif word_id is None: # Special token new_labels.append(-100) else: # Same word as previous token label = labels[word_id] # If the label is B-XXX we change it to I-XXX if label % 2 == 1: label += 1 new_labels.append(label) return new_labels def align_gazetteers_with_tokens(gazetteers, word_ids): aligned_gazetteers = [] current_word = None for word_id in word_ids: if word_id != current_word: # Start of a new word! current_word = word_id gazetteer = [0,0,0] if word_id is None else gazetteers[word_id] aligned_gazetteers.append(gazetteer) elif word_id is None: # Special token aligned_gazetteers.append([0,0,0]) else: # Same word as previous token gazetteer = gazetteers[word_id] aligned_gazetteers.append(gazetteer) return aligned_gazetteers def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True): def tokenize_and_align_labels(examples): tokenized_inputs = tokenizer( examples["tokens"], truncation=True, is_split_into_words=True ) all_labels = examples["ner_tags"] new_labels = [] for i, labels in enumerate(all_labels): word_ids = tokenized_inputs.word_ids(i) new_labels.append(align_labels_with_tokens(labels, word_ids)) tokenized_inputs["labels"] = new_labels if apply_extended_embeddings: matches = examples["gazetteers"] aligned_matches = [] for i, match in enumerate(matches): word_ids = tokenized_inputs.word_ids(i) aligned_matches.append(align_gazetteers_with_tokens(match, word_ids)) per, org, loc = [], [], [] for i in aligned_matches: per.append([x[0] for x in i]) org.append([x[1] for x in i]) loc.append([x[2] for x in i]) tokenized_inputs["per"] = per tokenized_inputs["org"] = org tokenized_inputs["loc"] = loc return tokenized_inputs dataset = raw_dataset.map( tokenize_and_align_labels, batched=True, # remove_columns=raw_dataset["train"].column_names ) return dataset