fix: fix inference, retrieve synonym map & fix slot names

This commit is contained in:
hexastack 2024-11-28 19:56:28 +01:00
parent b9ea504ed1
commit 34628a256c
2 changed files with 44 additions and 4 deletions

View File

@ -116,7 +116,17 @@ class JISFDL(tfbp.DataLoader):
# Filter out language entities # Filter out language entities
slot_entities = filter( slot_entities = filter(
lambda e: e["entity"] != "language", entities) lambda e: e["entity"] != "language", entities)
slots = {e["entity"]: e["value"] for e in slot_entities} slots = {}
for e in slot_entities:
# Create slots with entity values and resolve synonyms
if "start" in e and "end" in e and isinstance(e["start"], int) and isinstance(e["end"], int):
original_value = text[e["start"]:e["end"]]
entity_value = e["value"]
if entity_value != original_value:
entity_value = original_value.lower()
slots[e["entity"]] = entity_value
else:
continue
positions = [[e.get("start", -1), e.get("end", -1)] positions = [[e.get("start", -1), e.get("end", -1)]
for e in slot_entities] for e in slot_entities]

View File

@ -1,6 +1,8 @@
import os import os
import functools import functools
import json import json
import re
from utils.json_helper import JsonHelper
from transformers import TFBertModel, AutoTokenizer from transformers import TFBertModel, AutoTokenizer
from keras.layers import Dropout, Dense from keras.layers import Dropout, Dense
from sys import platform from sys import platform
@ -180,6 +182,19 @@ class SlotFiller(tfbp.Model):
if input("Try again? (y/n): ").lower() != 'y': if input("Try again? (y/n): ").lower() != 'y':
break break
def get_synonym_map(self):
helper = JsonHelper()
helper.read_dataset_json_file('train.json')
data = helper.read_dataset_json_file('train.json')
synonyms = data["entity_synonyms"]
synonym_map = {}
for entry in data["entity_synonyms"]:
value = entry["value"]
for synonym in entry["synonyms"]:
synonym_map[synonym] = value
return synonym_map
def get_slots_prediction(self, text: str, inputs, slot_probas): def get_slots_prediction(self, text: str, inputs, slot_probas):
slot_probas_np = slot_probas.numpy() slot_probas_np = slot_probas.numpy()
@ -202,6 +217,10 @@ class SlotFiller(tfbp.Model):
token = tokens[idx] token = tokens[idx]
slot_id = slot_ids[idx] slot_id = slot_ids[idx]
# Skip special tokens
# if token in special_tokens:
# idx += 1
# continue
# Get slot name # Get slot name
slot_name = self.extra_params["slot_names"][slot_id] slot_name = self.extra_params["slot_names"][slot_id]
@ -243,13 +262,24 @@ class SlotFiller(tfbp.Model):
# Convert tokens to string # Convert tokens to string
slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip() slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip()
# slot_value = re.sub(r'\s+', '', slot_value)
# Ensure the slot value exists in the text (avoid -1 for start index)
start_idx = text.find(slot_value)
if start_idx == -1:
print(f"Skipping entity for '{slot_name}' because '{slot_value}' was not found in text.")
continue # Skip this entity if not found in text
#Post Processing
synonym_map = self.get_synonym_map()
final_slot_value = synonym_map.get(slot_value)
# Calculate entity start and end indices # Calculate entity start and end indices
entity = { entity = {
"entity": slot_name, "entity": slot_name,
"value": slot_value, "value": final_slot_value,
"start": text.find(slot_value), "start": start_idx,
"end": text.find(slot_value) + len(slot_value), "end": start_idx + len(slot_value),
"confidence": 0, "confidence": 0,
} }