Merge pull request #380 from Hexastack/fix/slots-prediction-multitoken-account
Some checks are pending
Build and Push Docker Images / paths-filter (push) Waiting to run
Build and Push Docker Images / build-and-push (push) Blocked by required conditions

fix: fix slots predictions to account for multiple tokens and handle …
This commit is contained in:
Med Marrouchi 2024-11-28 19:47:24 +01:00 committed by GitHub
commit b9ea504ed1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,7 +10,7 @@ if platform == "darwin":
else: else:
from keras.optimizers import Adam from keras.optimizers import Adam
from keras.losses import SparseCategoricalCrossentropy from focal_loss import SparseCategoricalFocalLoss
from keras.metrics import SparseCategoricalAccuracy from keras.metrics import SparseCategoricalAccuracy
import numpy as np import numpy as np
@ -36,7 +36,8 @@ class SlotFiller(tfbp.Model):
"language": "", "language": "",
"num_epochs": 2, "num_epochs": 2,
"dropout_prob": 0.1, "dropout_prob": 0.1,
"slot_num_labels": 40 "slot_num_labels": 40,
"gamma": 2.0
} }
data_loader: JISFDL data_loader: JISFDL
@ -107,9 +108,7 @@ class SlotFiller(tfbp.Model):
# Hyperparams, Optimizer and Loss function # Hyperparams, Optimizer and Loss function
opt = Adam(learning_rate=3e-5, epsilon=1e-08) opt = Adam(learning_rate=3e-5, epsilon=1e-08)
# two outputs, one for slots, another for intents losses = SparseCategoricalFocalLoss(gamma=self.hparams.gamma)
# we have to fine tune for both
losses = SparseCategoricalCrossentropy()
metrics = [SparseCategoricalAccuracy("accuracy")] metrics = [SparseCategoricalAccuracy("accuracy")]
@ -187,71 +186,78 @@ class SlotFiller(tfbp.Model):
# Get the indices of the maximum values # Get the indices of the maximum values
slot_ids = slot_probas_np.argmax(axis=-1)[0, :] slot_ids = slot_probas_np.argmax(axis=-1)[0, :]
# get all slot names and add to out_dict as keys # Initialize the output dictionary
out_dict = {} out_dict = {}
predicted_slots = set([self.extra_params["slot_names"][s] predicted_slots = set([self.extra_params["slot_names"][s] for s in slot_ids if s != 0])
for s in slot_ids if s != 0])
for ps in predicted_slots: for ps in predicted_slots:
out_dict[ps] = [] out_dict[ps] = []
# retrieving the tokenization that was used in the predictions
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# We'd like to eliminate all special tokens from our output # Special tokens to exclude
special_tokens = self.tokenizer.special_tokens_map.values() special_tokens = set(self.tokenizer.special_tokens_map.values())
for token, slot_id in zip(tokens, slot_ids): idx = 0 # Initialize index explicitly for token tracking
if token in special_tokens: while idx < len(tokens):
continue token = tokens[idx]
# add all to out_dict slot_id = slot_ids[idx]
# Get slot name
slot_name = self.extra_params["slot_names"][slot_id] slot_name = self.extra_params["slot_names"][slot_id]
if slot_name == "<PAD>": if slot_name == "<PAD>":
idx += 1
continue continue
# collect tokens # Collect tokens for the current slot
collected_tokens = [token] collected_tokens = []
idx = tokens.index(token)
# see if it starts with ## # Handle regular tokens and sub-tokens
# then it belongs to the previous token if not token.startswith("##"):
if token.startswith("##"): collected_tokens = [token]
# check if the token already exists or not else:
if tokens[idx - 1] not in out_dict[slot_name]: # Collect sub-tokens
collected_tokens.insert(0, tokens[idx - 1]) while idx > 0 and tokens[idx - 1].startswith("##"):
idx -= 1
collected_tokens.insert(0, tokens[idx])
collected_tokens.append(token)
# add collected tokens to slots # Handle subsequent sub-tokens
out_dict[slot_name].extend(collected_tokens) while idx + 1 < len(tokens) and tokens[idx + 1].startswith("##"):
idx += 1
collected_tokens.append(tokens[idx])
slot_names_to_ids = {value: key for key, value in enumerate( # Add collected tokens to the appropriate slot
self.extra_params["slot_names"])} if slot_name in out_dict:
out_dict[slot_name].extend(collected_tokens)
idx += 1 # Move to the next token
# Map slot names to IDs
slot_names_to_ids = {value: key for key, value in enumerate(self.extra_params["slot_names"])}
# Create entities from the out_dict
entities = [] entities = []
# process out_dict for slot_name, slot_tokens in out_dict.items():
for slot_name in out_dict:
slot_id = slot_names_to_ids[slot_name] slot_id = slot_names_to_ids[slot_name]
slot_tokens = out_dict[slot_name]
slot_value = self.tokenizer.convert_tokens_to_string( # Convert tokens to string
slot_tokens).strip() slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip()
# Calculate entity start and end indices
entity = { entity = {
"entity": slot_name, "entity": slot_name,
"value": slot_value, "value": slot_value,
"start": text.find(slot_value), "start": text.find(slot_value),
"end": text.find(slot_value) + len(slot_value), "end": text.find(slot_value) + len(slot_value),
"confidence": 0, "confidence": 0,
} }
# The confidence of a slot is the average confidence of tokens in that slot. # Calculate confidence as the average of token probabilities
indices = [tokens.index(token) for token in slot_tokens] indices = [tokens.index(token) for token in slot_tokens]
if len(slot_tokens) > 0: if slot_tokens:
total = functools.reduce( total_confidence = sum(slot_probas_np[0, idx, slot_id] for idx in indices)
lambda proba1, proba2: proba1+proba2, slot_probas_np[0, indices, slot_id], 0) entity["confidence"] = total_confidence / len(slot_tokens)
entity["confidence"] = total / len(slot_tokens)
else:
entity["confidence"] = 0
entities.append(entity) entities.append(entity)
return entities return entities