diff --git a/nlu/data_loaders/jisfdl.py b/nlu/data_loaders/jisfdl.py index ce497918..babec3b1 100644 --- a/nlu/data_loaders/jisfdl.py +++ b/nlu/data_loaders/jisfdl.py @@ -93,6 +93,18 @@ class JISFDL(tfbp.DataLoader): return encoded_slots + def get_synonym_map(self): + helper = JsonHelper() + helper.read_dataset_json_file('train.json') + data = helper.read_dataset_json_file('train.json') + synonyms = data["entity_synonyms"] + synonym_map = {} + for entry in synonyms: + value = entry["value"] + for synonym in entry["synonyms"]: + synonym_map[synonym] = value + return synonym_map + def parse_dataset_intents(self, data): intents = [] @@ -109,14 +121,24 @@ class JISFDL(tfbp.DataLoader): # Parse raw data for exp in examples: - text = exp["text"] + text = exp["text"].lower() intent = exp["intent"] entities = exp["entities"] # Filter out language entities slot_entities = filter( lambda e: e["entity"] != "language", entities) - slots = {e["entity"]: e["value"] for e in slot_entities} + slots = {} + for e in slot_entities: + # Create slots with entity values and resolve synonyms + if "start" in e and "end" in e and isinstance(e["start"], int) and isinstance(e["end"], int): + original_value = text[e["start"]:e["end"]] + entity_value = e["value"] + if entity_value != original_value: + entity_value = original_value.lower() + slots[e["entity"]] = entity_value + else: + continue positions = [[e.get("start", -1), e.get("end", -1)] for e in slot_entities] diff --git a/nlu/models/slot_filler.py b/nlu/models/slot_filler.py index b1929a3d..60fbfee6 100644 --- a/nlu/models/slot_filler.py +++ b/nlu/models/slot_filler.py @@ -1,6 +1,7 @@ import os import functools import json +import re from transformers import TFBertModel, AutoTokenizer from keras.layers import Dropout, Dense from sys import platform @@ -123,7 +124,7 @@ class SlotFiller(tfbp.Model): # Persist the model self.extra_params["slot_names"] = slot_names - + self.extra_params["synonym_map"] = self.data_loader.get_synonym_map() self.save() @tfbp.runnable @@ -170,7 +171,7 @@ class SlotFiller(tfbp.Model): def predict(self): while True: text = input("Provide text: ") - info = self.get_prediction(text) + info = self.get_prediction(text.lower()) print(self.summary()) print("Text : " + text) @@ -180,7 +181,6 @@ class SlotFiller(tfbp.Model): if input("Try again? (y/n): ").lower() != 'y': break - def get_slots_prediction(self, text: str, inputs, slot_probas): slot_probas_np = slot_probas.numpy() # Get the indices of the maximum values @@ -202,7 +202,6 @@ class SlotFiller(tfbp.Model): token = tokens[idx] slot_id = slot_ids[idx] - # Get slot name slot_name = self.extra_params["slot_names"][slot_id] if slot_name == "": @@ -243,13 +242,26 @@ class SlotFiller(tfbp.Model): # Convert tokens to string slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip() + slot_value = re.sub(r'\s+', '', slot_value) + + # Ensure the slot value exists in the text (avoid -1 for start index) + start_idx = text.find(slot_value) + if start_idx == -1: + print(f"Skipping entity for '{slot_name}' because '{slot_value}' was not found in text.") + continue # Skip this entity if not found in text + + # Post Processing + synonym_map = self.extra_params["synonym_map"] + final_slot_value = synonym_map.get(slot_value) + if final_slot_value is None: + final_slot_value = slot_value # Calculate entity start and end indices entity = { "entity": slot_name, - "value": slot_value, - "start": text.find(slot_value), - "end": text.find(slot_value) + len(slot_value), + "value": final_slot_value, + "start": start_idx, + "end": start_idx + len(slot_value), "confidence": 0, }