Merge pull request #380 from Hexastack/fix/slots-prediction-multitoken-account
Some checks are pending
Build and Push Docker Images / paths-filter (push) Waiting to run
Build and Push Docker Images / build-and-push (push) Blocked by required conditions

fix: fix slots predictions to account for multiple tokens and handle …
This commit is contained in:
Med Marrouchi 2024-11-28 19:47:24 +01:00 committed by GitHub
commit b9ea504ed1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,7 +10,7 @@ if platform == "darwin":
else:
from keras.optimizers import Adam
from keras.losses import SparseCategoricalCrossentropy
from focal_loss import SparseCategoricalFocalLoss
from keras.metrics import SparseCategoricalAccuracy
import numpy as np
@ -36,7 +36,8 @@ class SlotFiller(tfbp.Model):
"language": "",
"num_epochs": 2,
"dropout_prob": 0.1,
"slot_num_labels": 40
"slot_num_labels": 40,
"gamma": 2.0
}
data_loader: JISFDL
@ -106,11 +107,9 @@ class SlotFiller(tfbp.Model):
# Hyperparams, Optimizer and Loss function
opt = Adam(learning_rate=3e-5, epsilon=1e-08)
# two outputs, one for slots, another for intents
# we have to fine tune for both
losses = SparseCategoricalCrossentropy()
losses = SparseCategoricalFocalLoss(gamma=self.hparams.gamma)
metrics = [SparseCategoricalAccuracy("accuracy")]
# Compile model
@ -187,71 +186,78 @@ class SlotFiller(tfbp.Model):
# Get the indices of the maximum values
slot_ids = slot_probas_np.argmax(axis=-1)[0, :]
# get all slot names and add to out_dict as keys
# Initialize the output dictionary
out_dict = {}
predicted_slots = set([self.extra_params["slot_names"][s]
for s in slot_ids if s != 0])
predicted_slots = set([self.extra_params["slot_names"][s] for s in slot_ids if s != 0])
for ps in predicted_slots:
out_dict[ps] = []
# retrieving the tokenization that was used in the predictions
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# We'd like to eliminate all special tokens from our output
special_tokens = self.tokenizer.special_tokens_map.values()
# Special tokens to exclude
special_tokens = set(self.tokenizer.special_tokens_map.values())
for token, slot_id in zip(tokens, slot_ids):
if token in special_tokens:
continue
# add all to out_dict
idx = 0 # Initialize index explicitly for token tracking
while idx < len(tokens):
token = tokens[idx]
slot_id = slot_ids[idx]
# Get slot name
slot_name = self.extra_params["slot_names"][slot_id]
if slot_name == "<PAD>":
idx += 1
continue
# collect tokens
collected_tokens = [token]
idx = tokens.index(token)
# Collect tokens for the current slot
collected_tokens = []
# see if it starts with ##
# then it belongs to the previous token
if token.startswith("##"):
# check if the token already exists or not
if tokens[idx - 1] not in out_dict[slot_name]:
collected_tokens.insert(0, tokens[idx - 1])
# Handle regular tokens and sub-tokens
if not token.startswith("##"):
collected_tokens = [token]
else:
# Collect sub-tokens
while idx > 0 and tokens[idx - 1].startswith("##"):
idx -= 1
collected_tokens.insert(0, tokens[idx])
collected_tokens.append(token)
# add collected tokens to slots
out_dict[slot_name].extend(collected_tokens)
# Handle subsequent sub-tokens
while idx + 1 < len(tokens) and tokens[idx + 1].startswith("##"):
idx += 1
collected_tokens.append(tokens[idx])
slot_names_to_ids = {value: key for key, value in enumerate(
self.extra_params["slot_names"])}
# Add collected tokens to the appropriate slot
if slot_name in out_dict:
out_dict[slot_name].extend(collected_tokens)
idx += 1 # Move to the next token
# Map slot names to IDs
slot_names_to_ids = {value: key for key, value in enumerate(self.extra_params["slot_names"])}
# Create entities from the out_dict
entities = []
# process out_dict
for slot_name in out_dict:
for slot_name, slot_tokens in out_dict.items():
slot_id = slot_names_to_ids[slot_name]
slot_tokens = out_dict[slot_name]
slot_value = self.tokenizer.convert_tokens_to_string(
slot_tokens).strip()
# Convert tokens to string
slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip()
# Calculate entity start and end indices
entity = {
"entity": slot_name,
"value": slot_value,
"start": text.find(slot_value),
"end": text.find(slot_value) + len(slot_value),
"end": text.find(slot_value) + len(slot_value),
"confidence": 0,
}
# The confidence of a slot is the average confidence of tokens in that slot.
# Calculate confidence as the average of token probabilities
indices = [tokens.index(token) for token in slot_tokens]
if len(slot_tokens) > 0:
total = functools.reduce(
lambda proba1, proba2: proba1+proba2, slot_probas_np[0, indices, slot_id], 0)
entity["confidence"] = total / len(slot_tokens)
else:
entity["confidence"] = 0
if slot_tokens:
total_confidence = sum(slot_probas_np[0, idx, slot_id] for idx in indices)
entity["confidence"] = total_confidence / len(slot_tokens)
entities.append(entity)
return entities