fix: extra refactoring

This commit is contained in:
hexastack 2024-11-28 20:48:01 +01:00
parent d39bd145b6
commit 896c546000
2 changed files with 14 additions and 16 deletions

View File

@ -93,6 +93,18 @@ class JISFDL(tfbp.DataLoader):
return encoded_slots
def get_synonym_map(self):
helper = JsonHelper()
helper.read_dataset_json_file('train.json')
data = helper.read_dataset_json_file('train.json')
synonyms = data["entity_synonyms"]
synonym_map = {}
for entry in synonyms:
value = entry["value"]
for synonym in entry["synonyms"]:
synonym_map[synonym] = value
return synonym_map
def parse_dataset_intents(self, data):
intents = []

View File

@ -2,7 +2,6 @@ import os
import functools
import json
import re
from utils.json_helper import JsonHelper
from transformers import TFBertModel, AutoTokenizer
from keras.layers import Dropout, Dense
from sys import platform
@ -125,8 +124,7 @@ class SlotFiller(tfbp.Model):
# Persist the model
self.extra_params["slot_names"] = slot_names
self.save()
self.extra_params["synonym_map"] = self.data_loader.get_synonym_map()
@tfbp.runnable
def evaluate(self):
@ -182,18 +180,6 @@ class SlotFiller(tfbp.Model):
if input("Try again? (y/n): ").lower() != 'y':
break
def get_synonym_map(self):
helper = JsonHelper()
data = helper.read_dataset_json_file('train.json')
synonyms = data["entity_synonyms"]
synonym_map = {}
for entry in synonyms:
value = entry["value"]
for synonym in entry["synonyms"]:
synonym_map[synonym] = value
return synonym_map
def get_slots_prediction(self, text: str, inputs, slot_probas):
slot_probas_np = slot_probas.numpy()
# Get the indices of the maximum values
@ -264,7 +250,7 @@ class SlotFiller(tfbp.Model):
continue # Skip this entity if not found in text
# Post Processing
synonym_map = self.get_synonym_map()
synonym_map = self.extra_params["synonym_map"]
final_slot_value = synonym_map.get(slot_value)
if final_slot_value is None:
final_slot_value = slot_value