mirror of
https://github.com/hexastack/hexabot
synced 2024-12-28 06:42:23 +00:00
fix: extra refactoring
This commit is contained in:
parent
d39bd145b6
commit
896c546000
@ -93,6 +93,18 @@ class JISFDL(tfbp.DataLoader):
|
|||||||
|
|
||||||
return encoded_slots
|
return encoded_slots
|
||||||
|
|
||||||
|
def get_synonym_map(self):
|
||||||
|
helper = JsonHelper()
|
||||||
|
helper.read_dataset_json_file('train.json')
|
||||||
|
data = helper.read_dataset_json_file('train.json')
|
||||||
|
synonyms = data["entity_synonyms"]
|
||||||
|
synonym_map = {}
|
||||||
|
for entry in synonyms:
|
||||||
|
value = entry["value"]
|
||||||
|
for synonym in entry["synonyms"]:
|
||||||
|
synonym_map[synonym] = value
|
||||||
|
return synonym_map
|
||||||
|
|
||||||
def parse_dataset_intents(self, data):
|
def parse_dataset_intents(self, data):
|
||||||
|
|
||||||
intents = []
|
intents = []
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from utils.json_helper import JsonHelper
|
|
||||||
from transformers import TFBertModel, AutoTokenizer
|
from transformers import TFBertModel, AutoTokenizer
|
||||||
from keras.layers import Dropout, Dense
|
from keras.layers import Dropout, Dense
|
||||||
from sys import platform
|
from sys import platform
|
||||||
@ -125,8 +124,7 @@ class SlotFiller(tfbp.Model):
|
|||||||
|
|
||||||
# Persist the model
|
# Persist the model
|
||||||
self.extra_params["slot_names"] = slot_names
|
self.extra_params["slot_names"] = slot_names
|
||||||
|
self.extra_params["synonym_map"] = self.data_loader.get_synonym_map()
|
||||||
self.save()
|
|
||||||
|
|
||||||
@tfbp.runnable
|
@tfbp.runnable
|
||||||
def evaluate(self):
|
def evaluate(self):
|
||||||
@ -182,18 +180,6 @@ class SlotFiller(tfbp.Model):
|
|||||||
if input("Try again? (y/n): ").lower() != 'y':
|
if input("Try again? (y/n): ").lower() != 'y':
|
||||||
break
|
break
|
||||||
|
|
||||||
def get_synonym_map(self):
|
|
||||||
helper = JsonHelper()
|
|
||||||
data = helper.read_dataset_json_file('train.json')
|
|
||||||
synonyms = data["entity_synonyms"]
|
|
||||||
synonym_map = {}
|
|
||||||
for entry in synonyms:
|
|
||||||
value = entry["value"]
|
|
||||||
for synonym in entry["synonyms"]:
|
|
||||||
synonym_map[synonym] = value
|
|
||||||
return synonym_map
|
|
||||||
|
|
||||||
|
|
||||||
def get_slots_prediction(self, text: str, inputs, slot_probas):
|
def get_slots_prediction(self, text: str, inputs, slot_probas):
|
||||||
slot_probas_np = slot_probas.numpy()
|
slot_probas_np = slot_probas.numpy()
|
||||||
# Get the indices of the maximum values
|
# Get the indices of the maximum values
|
||||||
@ -264,7 +250,7 @@ class SlotFiller(tfbp.Model):
|
|||||||
continue # Skip this entity if not found in text
|
continue # Skip this entity if not found in text
|
||||||
|
|
||||||
# Post Processing
|
# Post Processing
|
||||||
synonym_map = self.get_synonym_map()
|
synonym_map = self.extra_params["synonym_map"]
|
||||||
final_slot_value = synonym_map.get(slot_value)
|
final_slot_value = synonym_map.get(slot_value)
|
||||||
if final_slot_value is None:
|
if final_slot_value is None:
|
||||||
final_slot_value = slot_value
|
final_slot_value = slot_value
|
||||||
|
Loading…
Reference in New Issue
Block a user