mirror of
https://github.com/hexastack/hexabot
synced 2024-12-28 23:02:03 +00:00
fix: fix inference, retrieve synonym map & fix slot names
This commit is contained in:
parent
b9ea504ed1
commit
34628a256c
@ -116,7 +116,17 @@ class JISFDL(tfbp.DataLoader):
|
||||
# Filter out language entities
|
||||
slot_entities = filter(
|
||||
lambda e: e["entity"] != "language", entities)
|
||||
slots = {e["entity"]: e["value"] for e in slot_entities}
|
||||
slots = {}
|
||||
for e in slot_entities:
|
||||
# Create slots with entity values and resolve synonyms
|
||||
if "start" in e and "end" in e and isinstance(e["start"], int) and isinstance(e["end"], int):
|
||||
original_value = text[e["start"]:e["end"]]
|
||||
entity_value = e["value"]
|
||||
if entity_value != original_value:
|
||||
entity_value = original_value.lower()
|
||||
slots[e["entity"]] = entity_value
|
||||
else:
|
||||
continue
|
||||
positions = [[e.get("start", -1), e.get("end", -1)]
|
||||
for e in slot_entities]
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
import os
|
||||
import functools
|
||||
import json
|
||||
import re
|
||||
from utils.json_helper import JsonHelper
|
||||
from transformers import TFBertModel, AutoTokenizer
|
||||
from keras.layers import Dropout, Dense
|
||||
from sys import platform
|
||||
@ -180,6 +182,19 @@ class SlotFiller(tfbp.Model):
|
||||
if input("Try again? (y/n): ").lower() != 'y':
|
||||
break
|
||||
|
||||
def get_synonym_map(self):
|
||||
helper = JsonHelper()
|
||||
|
||||
helper.read_dataset_json_file('train.json')
|
||||
data = helper.read_dataset_json_file('train.json')
|
||||
synonyms = data["entity_synonyms"]
|
||||
synonym_map = {}
|
||||
for entry in data["entity_synonyms"]:
|
||||
value = entry["value"]
|
||||
for synonym in entry["synonyms"]:
|
||||
synonym_map[synonym] = value
|
||||
return synonym_map
|
||||
|
||||
|
||||
def get_slots_prediction(self, text: str, inputs, slot_probas):
|
||||
slot_probas_np = slot_probas.numpy()
|
||||
@ -202,6 +217,10 @@ class SlotFiller(tfbp.Model):
|
||||
token = tokens[idx]
|
||||
slot_id = slot_ids[idx]
|
||||
|
||||
# Skip special tokens
|
||||
# if token in special_tokens:
|
||||
# idx += 1
|
||||
# continue
|
||||
|
||||
# Get slot name
|
||||
slot_name = self.extra_params["slot_names"][slot_id]
|
||||
@ -243,13 +262,24 @@ class SlotFiller(tfbp.Model):
|
||||
|
||||
# Convert tokens to string
|
||||
slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip()
|
||||
# slot_value = re.sub(r'\s+', '', slot_value)
|
||||
|
||||
# Ensure the slot value exists in the text (avoid -1 for start index)
|
||||
start_idx = text.find(slot_value)
|
||||
if start_idx == -1:
|
||||
print(f"Skipping entity for '{slot_name}' because '{slot_value}' was not found in text.")
|
||||
continue # Skip this entity if not found in text
|
||||
|
||||
#Post Processing
|
||||
synonym_map = self.get_synonym_map()
|
||||
final_slot_value = synonym_map.get(slot_value)
|
||||
|
||||
# Calculate entity start and end indices
|
||||
entity = {
|
||||
"entity": slot_name,
|
||||
"value": slot_value,
|
||||
"start": text.find(slot_value),
|
||||
"end": text.find(slot_value) + len(slot_value),
|
||||
"value": final_slot_value,
|
||||
"start": start_idx,
|
||||
"end": start_idx + len(slot_value),
|
||||
"confidence": 0,
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user