mirror of
https://github.com/hexastack/hexabot
synced 2024-12-28 23:02:03 +00:00
Merge pull request #380 from Hexastack/fix/slots-prediction-multitoken-account
fix: fix slots predictions to account for multiple tokens and handle …
This commit is contained in:
commit
b9ea504ed1
@ -10,7 +10,7 @@ if platform == "darwin":
|
|||||||
else:
|
else:
|
||||||
from keras.optimizers import Adam
|
from keras.optimizers import Adam
|
||||||
|
|
||||||
from keras.losses import SparseCategoricalCrossentropy
|
from focal_loss import SparseCategoricalFocalLoss
|
||||||
from keras.metrics import SparseCategoricalAccuracy
|
from keras.metrics import SparseCategoricalAccuracy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -36,7 +36,8 @@ class SlotFiller(tfbp.Model):
|
|||||||
"language": "",
|
"language": "",
|
||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"dropout_prob": 0.1,
|
"dropout_prob": 0.1,
|
||||||
"slot_num_labels": 40
|
"slot_num_labels": 40,
|
||||||
|
"gamma": 2.0
|
||||||
}
|
}
|
||||||
data_loader: JISFDL
|
data_loader: JISFDL
|
||||||
|
|
||||||
@ -107,9 +108,7 @@ class SlotFiller(tfbp.Model):
|
|||||||
# Hyperparams, Optimizer and Loss function
|
# Hyperparams, Optimizer and Loss function
|
||||||
opt = Adam(learning_rate=3e-5, epsilon=1e-08)
|
opt = Adam(learning_rate=3e-5, epsilon=1e-08)
|
||||||
|
|
||||||
# two outputs, one for slots, another for intents
|
losses = SparseCategoricalFocalLoss(gamma=self.hparams.gamma)
|
||||||
# we have to fine tune for both
|
|
||||||
losses = SparseCategoricalCrossentropy()
|
|
||||||
|
|
||||||
metrics = [SparseCategoricalAccuracy("accuracy")]
|
metrics = [SparseCategoricalAccuracy("accuracy")]
|
||||||
|
|
||||||
@ -187,54 +186,65 @@ class SlotFiller(tfbp.Model):
|
|||||||
# Get the indices of the maximum values
|
# Get the indices of the maximum values
|
||||||
slot_ids = slot_probas_np.argmax(axis=-1)[0, :]
|
slot_ids = slot_probas_np.argmax(axis=-1)[0, :]
|
||||||
|
|
||||||
# get all slot names and add to out_dict as keys
|
# Initialize the output dictionary
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
predicted_slots = set([self.extra_params["slot_names"][s]
|
predicted_slots = set([self.extra_params["slot_names"][s] for s in slot_ids if s != 0])
|
||||||
for s in slot_ids if s != 0])
|
|
||||||
for ps in predicted_slots:
|
for ps in predicted_slots:
|
||||||
out_dict[ps] = []
|
out_dict[ps] = []
|
||||||
|
|
||||||
# retrieving the tokenization that was used in the predictions
|
|
||||||
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
||||||
|
|
||||||
# We'd like to eliminate all special tokens from our output
|
# Special tokens to exclude
|
||||||
special_tokens = self.tokenizer.special_tokens_map.values()
|
special_tokens = set(self.tokenizer.special_tokens_map.values())
|
||||||
|
|
||||||
for token, slot_id in zip(tokens, slot_ids):
|
idx = 0 # Initialize index explicitly for token tracking
|
||||||
if token in special_tokens:
|
while idx < len(tokens):
|
||||||
continue
|
token = tokens[idx]
|
||||||
# add all to out_dict
|
slot_id = slot_ids[idx]
|
||||||
|
|
||||||
|
|
||||||
|
# Get slot name
|
||||||
slot_name = self.extra_params["slot_names"][slot_id]
|
slot_name = self.extra_params["slot_names"][slot_id]
|
||||||
|
|
||||||
if slot_name == "<PAD>":
|
if slot_name == "<PAD>":
|
||||||
|
idx += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# collect tokens
|
# Collect tokens for the current slot
|
||||||
|
collected_tokens = []
|
||||||
|
|
||||||
|
# Handle regular tokens and sub-tokens
|
||||||
|
if not token.startswith("##"):
|
||||||
collected_tokens = [token]
|
collected_tokens = [token]
|
||||||
idx = tokens.index(token)
|
else:
|
||||||
|
# Collect sub-tokens
|
||||||
|
while idx > 0 and tokens[idx - 1].startswith("##"):
|
||||||
|
idx -= 1
|
||||||
|
collected_tokens.insert(0, tokens[idx])
|
||||||
|
collected_tokens.append(token)
|
||||||
|
|
||||||
# see if it starts with ##
|
# Handle subsequent sub-tokens
|
||||||
# then it belongs to the previous token
|
while idx + 1 < len(tokens) and tokens[idx + 1].startswith("##"):
|
||||||
if token.startswith("##"):
|
idx += 1
|
||||||
# check if the token already exists or not
|
collected_tokens.append(tokens[idx])
|
||||||
if tokens[idx - 1] not in out_dict[slot_name]:
|
|
||||||
collected_tokens.insert(0, tokens[idx - 1])
|
|
||||||
|
|
||||||
# add collected tokens to slots
|
# Add collected tokens to the appropriate slot
|
||||||
|
if slot_name in out_dict:
|
||||||
out_dict[slot_name].extend(collected_tokens)
|
out_dict[slot_name].extend(collected_tokens)
|
||||||
|
|
||||||
slot_names_to_ids = {value: key for key, value in enumerate(
|
idx += 1 # Move to the next token
|
||||||
self.extra_params["slot_names"])}
|
|
||||||
|
|
||||||
|
# Map slot names to IDs
|
||||||
|
slot_names_to_ids = {value: key for key, value in enumerate(self.extra_params["slot_names"])}
|
||||||
|
|
||||||
|
# Create entities from the out_dict
|
||||||
entities = []
|
entities = []
|
||||||
# process out_dict
|
for slot_name, slot_tokens in out_dict.items():
|
||||||
for slot_name in out_dict:
|
|
||||||
slot_id = slot_names_to_ids[slot_name]
|
slot_id = slot_names_to_ids[slot_name]
|
||||||
slot_tokens = out_dict[slot_name]
|
|
||||||
|
|
||||||
slot_value = self.tokenizer.convert_tokens_to_string(
|
# Convert tokens to string
|
||||||
slot_tokens).strip()
|
slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip()
|
||||||
|
|
||||||
|
# Calculate entity start and end indices
|
||||||
entity = {
|
entity = {
|
||||||
"entity": slot_name,
|
"entity": slot_name,
|
||||||
"value": slot_value,
|
"value": slot_value,
|
||||||
@ -243,15 +253,11 @@ class SlotFiller(tfbp.Model):
|
|||||||
"confidence": 0,
|
"confidence": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# The confidence of a slot is the average confidence of tokens in that slot.
|
# Calculate confidence as the average of token probabilities
|
||||||
indices = [tokens.index(token) for token in slot_tokens]
|
indices = [tokens.index(token) for token in slot_tokens]
|
||||||
if len(slot_tokens) > 0:
|
if slot_tokens:
|
||||||
total = functools.reduce(
|
total_confidence = sum(slot_probas_np[0, idx, slot_id] for idx in indices)
|
||||||
lambda proba1, proba2: proba1+proba2, slot_probas_np[0, indices, slot_id], 0)
|
entity["confidence"] = total_confidence / len(slot_tokens)
|
||||||
entity["confidence"] = total / len(slot_tokens)
|
|
||||||
else:
|
|
||||||
entity["confidence"] = 0
|
|
||||||
|
|
||||||
entities.append(entity)
|
entities.append(entity)
|
||||||
|
|
||||||
return entities
|
return entities
|
||||||
|
Loading…
Reference in New Issue
Block a user