From 896c54600056743bb02d5071bfb1036302e6ddff Mon Sep 17 00:00:00 2001 From: hexastack Date: Thu, 28 Nov 2024 20:48:01 +0100 Subject: [PATCH] fix: extra refactoring --- nlu/data_loaders/jisfdl.py | 12 ++++++++++++ nlu/models/slot_filler.py | 18 ++---------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/nlu/data_loaders/jisfdl.py b/nlu/data_loaders/jisfdl.py index 2025c961..babec3b1 100644 --- a/nlu/data_loaders/jisfdl.py +++ b/nlu/data_loaders/jisfdl.py @@ -93,6 +93,18 @@ class JISFDL(tfbp.DataLoader): return encoded_slots + def get_synonym_map(self): + helper = JsonHelper() + helper.read_dataset_json_file('train.json') + data = helper.read_dataset_json_file('train.json') + synonyms = data["entity_synonyms"] + synonym_map = {} + for entry in synonyms: + value = entry["value"] + for synonym in entry["synonyms"]: + synonym_map[synonym] = value + return synonym_map + def parse_dataset_intents(self, data): intents = [] diff --git a/nlu/models/slot_filler.py b/nlu/models/slot_filler.py index c18a4840..d7f4dfb5 100644 --- a/nlu/models/slot_filler.py +++ b/nlu/models/slot_filler.py @@ -2,7 +2,6 @@ import os import functools import json import re -from utils.json_helper import JsonHelper from transformers import TFBertModel, AutoTokenizer from keras.layers import Dropout, Dense from sys import platform @@ -125,8 +124,7 @@ class SlotFiller(tfbp.Model): # Persist the model self.extra_params["slot_names"] = slot_names - - self.save() + self.extra_params["synonym_map"] = self.data_loader.get_synonym_map() @tfbp.runnable def evaluate(self): @@ -181,18 +179,6 @@ class SlotFiller(tfbp.Model): # Optionally, provide a way to exit the loop if input("Try again? (y/n): ").lower() != 'y': break - - def get_synonym_map(self): - helper = JsonHelper() - data = helper.read_dataset_json_file('train.json') - synonyms = data["entity_synonyms"] - synonym_map = {} - for entry in synonyms: - value = entry["value"] - for synonym in entry["synonyms"]: - synonym_map[synonym] = value - return synonym_map - def get_slots_prediction(self, text: str, inputs, slot_probas): slot_probas_np = slot_probas.numpy() @@ -264,7 +250,7 @@ class SlotFiller(tfbp.Model): continue # Skip this entity if not found in text # Post Processing - synonym_map = self.get_synonym_map() + synonym_map = self.extra_params["synonym_map"] final_slot_value = synonym_map.get(slot_value) if final_slot_value is None: final_slot_value = slot_value