mirror of
https://github.com/hexastack/hexabot
synced 2024-12-27 14:22:41 +00:00
Merge pull request #389 from Hexastack/fix/new-nlu-fixes
fix: fix inference, retrieve synonym map & fix slot names
This commit is contained in:
commit
dbc651a314
@ -93,6 +93,18 @@ class JISFDL(tfbp.DataLoader):
|
|||||||
|
|
||||||
return encoded_slots
|
return encoded_slots
|
||||||
|
|
||||||
|
def get_synonym_map(self):
|
||||||
|
helper = JsonHelper()
|
||||||
|
helper.read_dataset_json_file('train.json')
|
||||||
|
data = helper.read_dataset_json_file('train.json')
|
||||||
|
synonyms = data["entity_synonyms"]
|
||||||
|
synonym_map = {}
|
||||||
|
for entry in synonyms:
|
||||||
|
value = entry["value"]
|
||||||
|
for synonym in entry["synonyms"]:
|
||||||
|
synonym_map[synonym] = value
|
||||||
|
return synonym_map
|
||||||
|
|
||||||
def parse_dataset_intents(self, data):
|
def parse_dataset_intents(self, data):
|
||||||
|
|
||||||
intents = []
|
intents = []
|
||||||
@ -109,14 +121,24 @@ class JISFDL(tfbp.DataLoader):
|
|||||||
|
|
||||||
# Parse raw data
|
# Parse raw data
|
||||||
for exp in examples:
|
for exp in examples:
|
||||||
text = exp["text"]
|
text = exp["text"].lower()
|
||||||
intent = exp["intent"]
|
intent = exp["intent"]
|
||||||
entities = exp["entities"]
|
entities = exp["entities"]
|
||||||
|
|
||||||
# Filter out language entities
|
# Filter out language entities
|
||||||
slot_entities = filter(
|
slot_entities = filter(
|
||||||
lambda e: e["entity"] != "language", entities)
|
lambda e: e["entity"] != "language", entities)
|
||||||
slots = {e["entity"]: e["value"] for e in slot_entities}
|
slots = {}
|
||||||
|
for e in slot_entities:
|
||||||
|
# Create slots with entity values and resolve synonyms
|
||||||
|
if "start" in e and "end" in e and isinstance(e["start"], int) and isinstance(e["end"], int):
|
||||||
|
original_value = text[e["start"]:e["end"]]
|
||||||
|
entity_value = e["value"]
|
||||||
|
if entity_value != original_value:
|
||||||
|
entity_value = original_value.lower()
|
||||||
|
slots[e["entity"]] = entity_value
|
||||||
|
else:
|
||||||
|
continue
|
||||||
positions = [[e.get("start", -1), e.get("end", -1)]
|
positions = [[e.get("start", -1), e.get("end", -1)]
|
||||||
for e in slot_entities]
|
for e in slot_entities]
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from transformers import TFBertModel, AutoTokenizer
|
from transformers import TFBertModel, AutoTokenizer
|
||||||
from keras.layers import Dropout, Dense
|
from keras.layers import Dropout, Dense
|
||||||
from sys import platform
|
from sys import platform
|
||||||
@ -123,7 +124,7 @@ class SlotFiller(tfbp.Model):
|
|||||||
|
|
||||||
# Persist the model
|
# Persist the model
|
||||||
self.extra_params["slot_names"] = slot_names
|
self.extra_params["slot_names"] = slot_names
|
||||||
|
self.extra_params["synonym_map"] = self.data_loader.get_synonym_map()
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
@tfbp.runnable
|
@tfbp.runnable
|
||||||
@ -170,7 +171,7 @@ class SlotFiller(tfbp.Model):
|
|||||||
def predict(self):
|
def predict(self):
|
||||||
while True:
|
while True:
|
||||||
text = input("Provide text: ")
|
text = input("Provide text: ")
|
||||||
info = self.get_prediction(text)
|
info = self.get_prediction(text.lower())
|
||||||
|
|
||||||
print(self.summary())
|
print(self.summary())
|
||||||
print("Text : " + text)
|
print("Text : " + text)
|
||||||
@ -180,7 +181,6 @@ class SlotFiller(tfbp.Model):
|
|||||||
if input("Try again? (y/n): ").lower() != 'y':
|
if input("Try again? (y/n): ").lower() != 'y':
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def get_slots_prediction(self, text: str, inputs, slot_probas):
|
def get_slots_prediction(self, text: str, inputs, slot_probas):
|
||||||
slot_probas_np = slot_probas.numpy()
|
slot_probas_np = slot_probas.numpy()
|
||||||
# Get the indices of the maximum values
|
# Get the indices of the maximum values
|
||||||
@ -202,7 +202,6 @@ class SlotFiller(tfbp.Model):
|
|||||||
token = tokens[idx]
|
token = tokens[idx]
|
||||||
slot_id = slot_ids[idx]
|
slot_id = slot_ids[idx]
|
||||||
|
|
||||||
|
|
||||||
# Get slot name
|
# Get slot name
|
||||||
slot_name = self.extra_params["slot_names"][slot_id]
|
slot_name = self.extra_params["slot_names"][slot_id]
|
||||||
if slot_name == "<PAD>":
|
if slot_name == "<PAD>":
|
||||||
@ -243,13 +242,26 @@ class SlotFiller(tfbp.Model):
|
|||||||
|
|
||||||
# Convert tokens to string
|
# Convert tokens to string
|
||||||
slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip()
|
slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip()
|
||||||
|
slot_value = re.sub(r'\s+', '', slot_value)
|
||||||
|
|
||||||
|
# Ensure the slot value exists in the text (avoid -1 for start index)
|
||||||
|
start_idx = text.find(slot_value)
|
||||||
|
if start_idx == -1:
|
||||||
|
print(f"Skipping entity for '{slot_name}' because '{slot_value}' was not found in text.")
|
||||||
|
continue # Skip this entity if not found in text
|
||||||
|
|
||||||
|
# Post Processing
|
||||||
|
synonym_map = self.extra_params["synonym_map"]
|
||||||
|
final_slot_value = synonym_map.get(slot_value)
|
||||||
|
if final_slot_value is None:
|
||||||
|
final_slot_value = slot_value
|
||||||
|
|
||||||
# Calculate entity start and end indices
|
# Calculate entity start and end indices
|
||||||
entity = {
|
entity = {
|
||||||
"entity": slot_name,
|
"entity": slot_name,
|
||||||
"value": slot_value,
|
"value": final_slot_value,
|
||||||
"start": text.find(slot_value),
|
"start": start_idx,
|
||||||
"end": text.find(slot_value) + len(slot_value),
|
"end": start_idx + len(slot_value),
|
||||||
"confidence": 0,
|
"confidence": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user