mirror of
https://github.com/hexastack/hexabot
synced 2024-12-27 14:22:41 +00:00
Merge pull request #380 from Hexastack/fix/slots-prediction-multitoken-account
fix: fix slots predictions to account for multiple tokens and handle …
This commit is contained in:
commit
b9ea504ed1
@ -10,7 +10,7 @@ if platform == "darwin":
|
||||
else:
|
||||
from keras.optimizers import Adam
|
||||
|
||||
from keras.losses import SparseCategoricalCrossentropy
|
||||
from focal_loss import SparseCategoricalFocalLoss
|
||||
from keras.metrics import SparseCategoricalAccuracy
|
||||
import numpy as np
|
||||
|
||||
@ -36,7 +36,8 @@ class SlotFiller(tfbp.Model):
|
||||
"language": "",
|
||||
"num_epochs": 2,
|
||||
"dropout_prob": 0.1,
|
||||
"slot_num_labels": 40
|
||||
"slot_num_labels": 40,
|
||||
"gamma": 2.0
|
||||
}
|
||||
data_loader: JISFDL
|
||||
|
||||
@ -106,11 +107,9 @@ class SlotFiller(tfbp.Model):
|
||||
|
||||
# Hyperparams, Optimizer and Loss function
|
||||
opt = Adam(learning_rate=3e-5, epsilon=1e-08)
|
||||
|
||||
# two outputs, one for slots, another for intents
|
||||
# we have to fine tune for both
|
||||
losses = SparseCategoricalCrossentropy()
|
||||
|
||||
|
||||
losses = SparseCategoricalFocalLoss(gamma=self.hparams.gamma)
|
||||
|
||||
metrics = [SparseCategoricalAccuracy("accuracy")]
|
||||
|
||||
# Compile model
|
||||
@ -187,71 +186,78 @@ class SlotFiller(tfbp.Model):
|
||||
# Get the indices of the maximum values
|
||||
slot_ids = slot_probas_np.argmax(axis=-1)[0, :]
|
||||
|
||||
# get all slot names and add to out_dict as keys
|
||||
# Initialize the output dictionary
|
||||
out_dict = {}
|
||||
predicted_slots = set([self.extra_params["slot_names"][s]
|
||||
for s in slot_ids if s != 0])
|
||||
predicted_slots = set([self.extra_params["slot_names"][s] for s in slot_ids if s != 0])
|
||||
for ps in predicted_slots:
|
||||
out_dict[ps] = []
|
||||
|
||||
# retrieving the tokenization that was used in the predictions
|
||||
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
||||
|
||||
# We'd like to eliminate all special tokens from our output
|
||||
special_tokens = self.tokenizer.special_tokens_map.values()
|
||||
# Special tokens to exclude
|
||||
special_tokens = set(self.tokenizer.special_tokens_map.values())
|
||||
|
||||
for token, slot_id in zip(tokens, slot_ids):
|
||||
if token in special_tokens:
|
||||
continue
|
||||
# add all to out_dict
|
||||
idx = 0 # Initialize index explicitly for token tracking
|
||||
while idx < len(tokens):
|
||||
token = tokens[idx]
|
||||
slot_id = slot_ids[idx]
|
||||
|
||||
|
||||
# Get slot name
|
||||
slot_name = self.extra_params["slot_names"][slot_id]
|
||||
|
||||
if slot_name == "<PAD>":
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
# collect tokens
|
||||
collected_tokens = [token]
|
||||
idx = tokens.index(token)
|
||||
# Collect tokens for the current slot
|
||||
collected_tokens = []
|
||||
|
||||
# see if it starts with ##
|
||||
# then it belongs to the previous token
|
||||
if token.startswith("##"):
|
||||
# check if the token already exists or not
|
||||
if tokens[idx - 1] not in out_dict[slot_name]:
|
||||
collected_tokens.insert(0, tokens[idx - 1])
|
||||
# Handle regular tokens and sub-tokens
|
||||
if not token.startswith("##"):
|
||||
collected_tokens = [token]
|
||||
else:
|
||||
# Collect sub-tokens
|
||||
while idx > 0 and tokens[idx - 1].startswith("##"):
|
||||
idx -= 1
|
||||
collected_tokens.insert(0, tokens[idx])
|
||||
collected_tokens.append(token)
|
||||
|
||||
# add collected tokens to slots
|
||||
out_dict[slot_name].extend(collected_tokens)
|
||||
# Handle subsequent sub-tokens
|
||||
while idx + 1 < len(tokens) and tokens[idx + 1].startswith("##"):
|
||||
idx += 1
|
||||
collected_tokens.append(tokens[idx])
|
||||
|
||||
slot_names_to_ids = {value: key for key, value in enumerate(
|
||||
self.extra_params["slot_names"])}
|
||||
# Add collected tokens to the appropriate slot
|
||||
if slot_name in out_dict:
|
||||
out_dict[slot_name].extend(collected_tokens)
|
||||
|
||||
idx += 1 # Move to the next token
|
||||
|
||||
# Map slot names to IDs
|
||||
slot_names_to_ids = {value: key for key, value in enumerate(self.extra_params["slot_names"])}
|
||||
|
||||
# Create entities from the out_dict
|
||||
entities = []
|
||||
# process out_dict
|
||||
for slot_name in out_dict:
|
||||
for slot_name, slot_tokens in out_dict.items():
|
||||
slot_id = slot_names_to_ids[slot_name]
|
||||
slot_tokens = out_dict[slot_name]
|
||||
|
||||
slot_value = self.tokenizer.convert_tokens_to_string(
|
||||
slot_tokens).strip()
|
||||
# Convert tokens to string
|
||||
slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip()
|
||||
|
||||
# Calculate entity start and end indices
|
||||
entity = {
|
||||
"entity": slot_name,
|
||||
"value": slot_value,
|
||||
"start": text.find(slot_value),
|
||||
"end": text.find(slot_value) + len(slot_value),
|
||||
"end": text.find(slot_value) + len(slot_value),
|
||||
"confidence": 0,
|
||||
}
|
||||
|
||||
# The confidence of a slot is the average confidence of tokens in that slot.
|
||||
# Calculate confidence as the average of token probabilities
|
||||
indices = [tokens.index(token) for token in slot_tokens]
|
||||
if len(slot_tokens) > 0:
|
||||
total = functools.reduce(
|
||||
lambda proba1, proba2: proba1+proba2, slot_probas_np[0, indices, slot_id], 0)
|
||||
entity["confidence"] = total / len(slot_tokens)
|
||||
else:
|
||||
entity["confidence"] = 0
|
||||
|
||||
if slot_tokens:
|
||||
total_confidence = sum(slot_probas_np[0, idx, slot_id] for idx in indices)
|
||||
entity["confidence"] = total_confidence / len(slot_tokens)
|
||||
entities.append(entity)
|
||||
|
||||
return entities
|
||||
|
Loading…
Reference in New Issue
Block a user