From 34628a256ca66466c19e5e3bc134b4d33b32dbe3 Mon Sep 17 00:00:00 2001 From: hexastack Date: Thu, 28 Nov 2024 19:56:28 +0100 Subject: [PATCH] fix: fix inference, retrieve synonym map & fix slot names --- nlu/data_loaders/jisfdl.py | 12 +++++++++++- nlu/models/slot_filler.py | 36 +++++++++++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/nlu/data_loaders/jisfdl.py b/nlu/data_loaders/jisfdl.py index ce497918..b9ae7397 100644 --- a/nlu/data_loaders/jisfdl.py +++ b/nlu/data_loaders/jisfdl.py @@ -116,7 +116,17 @@ class JISFDL(tfbp.DataLoader): # Filter out language entities slot_entities = filter( lambda e: e["entity"] != "language", entities) - slots = {e["entity"]: e["value"] for e in slot_entities} + slots = {} + for e in slot_entities: + # Create slots with entity values and resolve synonyms + if "start" in e and "end" in e and isinstance(e["start"], int) and isinstance(e["end"], int): + original_value = text[e["start"]:e["end"]] + entity_value = e["value"] + if entity_value != original_value: + entity_value = original_value.lower() + slots[e["entity"]] = entity_value + else: + continue positions = [[e.get("start", -1), e.get("end", -1)] for e in slot_entities] diff --git a/nlu/models/slot_filler.py b/nlu/models/slot_filler.py index b1929a3d..e6445c78 100644 --- a/nlu/models/slot_filler.py +++ b/nlu/models/slot_filler.py @@ -1,6 +1,8 @@ import os import functools import json +import re +from utils.json_helper import JsonHelper from transformers import TFBertModel, AutoTokenizer from keras.layers import Dropout, Dense from sys import platform @@ -179,6 +181,19 @@ class SlotFiller(tfbp.Model): # Optionally, provide a way to exit the loop if input("Try again? (y/n): ").lower() != 'y': break + + def get_synonym_map(self): + helper = JsonHelper() + + helper.read_dataset_json_file('train.json') + data = helper.read_dataset_json_file('train.json') + synonyms = data["entity_synonyms"] + synonym_map = {} + for entry in data["entity_synonyms"]: + value = entry["value"] + for synonym in entry["synonyms"]: + synonym_map[synonym] = value + return synonym_map def get_slots_prediction(self, text: str, inputs, slot_probas): @@ -202,6 +217,10 @@ class SlotFiller(tfbp.Model): token = tokens[idx] slot_id = slot_ids[idx] + # Skip special tokens + # if token in special_tokens: + # idx += 1 + # continue # Get slot name slot_name = self.extra_params["slot_names"][slot_id] @@ -243,13 +262,24 @@ class SlotFiller(tfbp.Model): # Convert tokens to string slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip() + # slot_value = re.sub(r'\s+', '', slot_value) + + # Ensure the slot value exists in the text (avoid -1 for start index) + start_idx = text.find(slot_value) + if start_idx == -1: + print(f"Skipping entity for '{slot_name}' because '{slot_value}' was not found in text.") + continue # Skip this entity if not found in text + + #Post Processing + synonym_map = self.get_synonym_map() + final_slot_value = synonym_map.get(slot_value) # Calculate entity start and end indices entity = { "entity": slot_name, - "value": slot_value, - "start": text.find(slot_value), - "end": text.find(slot_value) + len(slot_value), + "value": final_slot_value, + "start": start_idx, + "end": start_idx + len(slot_value), "confidence": 0, }