Merge pull request #389 from Hexastack/fix/new-nlu-fixes
Some checks failed
Build and Push Docker Images / paths-filter (push) Has been cancelled
Build and Push Docker Images / build-and-push (push) Has been cancelled

fix: fix inference, retrieve synonym map & fix slot names
This commit is contained in:
Med Marrouchi 2024-11-29 07:44:52 +01:00 committed by GitHub
commit dbc651a314
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 9 deletions

View File

@ -93,6 +93,18 @@ class JISFDL(tfbp.DataLoader):
return encoded_slots return encoded_slots
def get_synonym_map(self):
helper = JsonHelper()
helper.read_dataset_json_file('train.json')
data = helper.read_dataset_json_file('train.json')
synonyms = data["entity_synonyms"]
synonym_map = {}
for entry in synonyms:
value = entry["value"]
for synonym in entry["synonyms"]:
synonym_map[synonym] = value
return synonym_map
def parse_dataset_intents(self, data): def parse_dataset_intents(self, data):
intents = [] intents = []
@ -109,14 +121,24 @@ class JISFDL(tfbp.DataLoader):
# Parse raw data # Parse raw data
for exp in examples: for exp in examples:
text = exp["text"] text = exp["text"].lower()
intent = exp["intent"] intent = exp["intent"]
entities = exp["entities"] entities = exp["entities"]
# Filter out language entities # Filter out language entities
slot_entities = filter( slot_entities = filter(
lambda e: e["entity"] != "language", entities) lambda e: e["entity"] != "language", entities)
slots = {e["entity"]: e["value"] for e in slot_entities} slots = {}
for e in slot_entities:
# Create slots with entity values and resolve synonyms
if "start" in e and "end" in e and isinstance(e["start"], int) and isinstance(e["end"], int):
original_value = text[e["start"]:e["end"]]
entity_value = e["value"]
if entity_value != original_value:
entity_value = original_value.lower()
slots[e["entity"]] = entity_value
else:
continue
positions = [[e.get("start", -1), e.get("end", -1)] positions = [[e.get("start", -1), e.get("end", -1)]
for e in slot_entities] for e in slot_entities]

View File

@ -1,6 +1,7 @@
import os import os
import functools import functools
import json import json
import re
from transformers import TFBertModel, AutoTokenizer from transformers import TFBertModel, AutoTokenizer
from keras.layers import Dropout, Dense from keras.layers import Dropout, Dense
from sys import platform from sys import platform
@ -123,7 +124,7 @@ class SlotFiller(tfbp.Model):
# Persist the model # Persist the model
self.extra_params["slot_names"] = slot_names self.extra_params["slot_names"] = slot_names
self.extra_params["synonym_map"] = self.data_loader.get_synonym_map()
self.save() self.save()
@tfbp.runnable @tfbp.runnable
@ -170,7 +171,7 @@ class SlotFiller(tfbp.Model):
def predict(self): def predict(self):
while True: while True:
text = input("Provide text: ") text = input("Provide text: ")
info = self.get_prediction(text) info = self.get_prediction(text.lower())
print(self.summary()) print(self.summary())
print("Text : " + text) print("Text : " + text)
@ -180,7 +181,6 @@ class SlotFiller(tfbp.Model):
if input("Try again? (y/n): ").lower() != 'y': if input("Try again? (y/n): ").lower() != 'y':
break break
def get_slots_prediction(self, text: str, inputs, slot_probas): def get_slots_prediction(self, text: str, inputs, slot_probas):
slot_probas_np = slot_probas.numpy() slot_probas_np = slot_probas.numpy()
# Get the indices of the maximum values # Get the indices of the maximum values
@ -202,7 +202,6 @@ class SlotFiller(tfbp.Model):
token = tokens[idx] token = tokens[idx]
slot_id = slot_ids[idx] slot_id = slot_ids[idx]
# Get slot name # Get slot name
slot_name = self.extra_params["slot_names"][slot_id] slot_name = self.extra_params["slot_names"][slot_id]
if slot_name == "<PAD>": if slot_name == "<PAD>":
@ -243,13 +242,26 @@ class SlotFiller(tfbp.Model):
# Convert tokens to string # Convert tokens to string
slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip() slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip()
slot_value = re.sub(r'\s+', '', slot_value)
# Ensure the slot value exists in the text (avoid -1 for start index)
start_idx = text.find(slot_value)
if start_idx == -1:
print(f"Skipping entity for '{slot_name}' because '{slot_value}' was not found in text.")
continue # Skip this entity if not found in text
# Post Processing
synonym_map = self.extra_params["synonym_map"]
final_slot_value = synonym_map.get(slot_value)
if final_slot_value is None:
final_slot_value = slot_value
# Calculate entity start and end indices # Calculate entity start and end indices
entity = { entity = {
"entity": slot_name, "entity": slot_name,
"value": slot_value, "value": final_slot_value,
"start": text.find(slot_value), "start": start_idx,
"end": text.find(slot_value) + len(slot_value), "end": start_idx + len(slot_value),
"confidence": 0, "confidence": 0,
} }