fix: extra refactoring

This commit is contained in:
hexastack 2024-11-28 20:48:01 +01:00
parent d39bd145b6
commit 896c546000
2 changed files with 14 additions and 16 deletions

View File

@ -93,6 +93,18 @@ class JISFDL(tfbp.DataLoader):
return encoded_slots return encoded_slots
def get_synonym_map(self):
helper = JsonHelper()
helper.read_dataset_json_file('train.json')
data = helper.read_dataset_json_file('train.json')
synonyms = data["entity_synonyms"]
synonym_map = {}
for entry in synonyms:
value = entry["value"]
for synonym in entry["synonyms"]:
synonym_map[synonym] = value
return synonym_map
def parse_dataset_intents(self, data): def parse_dataset_intents(self, data):
intents = [] intents = []

View File

@ -2,7 +2,6 @@ import os
import functools import functools
import json import json
import re import re
from utils.json_helper import JsonHelper
from transformers import TFBertModel, AutoTokenizer from transformers import TFBertModel, AutoTokenizer
from keras.layers import Dropout, Dense from keras.layers import Dropout, Dense
from sys import platform from sys import platform
@ -125,8 +124,7 @@ class SlotFiller(tfbp.Model):
# Persist the model # Persist the model
self.extra_params["slot_names"] = slot_names self.extra_params["slot_names"] = slot_names
self.extra_params["synonym_map"] = self.data_loader.get_synonym_map()
self.save()
@tfbp.runnable @tfbp.runnable
def evaluate(self): def evaluate(self):
@ -182,18 +180,6 @@ class SlotFiller(tfbp.Model):
if input("Try again? (y/n): ").lower() != 'y': if input("Try again? (y/n): ").lower() != 'y':
break break
def get_synonym_map(self):
helper = JsonHelper()
data = helper.read_dataset_json_file('train.json')
synonyms = data["entity_synonyms"]
synonym_map = {}
for entry in synonyms:
value = entry["value"]
for synonym in entry["synonyms"]:
synonym_map[synonym] = value
return synonym_map
def get_slots_prediction(self, text: str, inputs, slot_probas): def get_slots_prediction(self, text: str, inputs, slot_probas):
slot_probas_np = slot_probas.numpy() slot_probas_np = slot_probas.numpy()
# Get the indices of the maximum values # Get the indices of the maximum values
@ -264,7 +250,7 @@ class SlotFiller(tfbp.Model):
continue # Skip this entity if not found in text continue # Skip this entity if not found in text
# Post Processing # Post Processing
synonym_map = self.get_synonym_map() synonym_map = self.extra_params["synonym_map"]
final_slot_value = synonym_map.get(slot_value) final_slot_value = synonym_map.get(slot_value)
if final_slot_value is None: if final_slot_value is None:
final_slot_value = slot_value final_slot_value = slot_value