diff --git a/nlu/models/slot_filler.py b/nlu/models/slot_filler.py index ff3436d7..b1929a3d 100644 --- a/nlu/models/slot_filler.py +++ b/nlu/models/slot_filler.py @@ -10,7 +10,7 @@ if platform == "darwin": else: from keras.optimizers import Adam -from keras.losses import SparseCategoricalCrossentropy +from focal_loss import SparseCategoricalFocalLoss from keras.metrics import SparseCategoricalAccuracy import numpy as np @@ -36,7 +36,8 @@ class SlotFiller(tfbp.Model): "language": "", "num_epochs": 2, "dropout_prob": 0.1, - "slot_num_labels": 40 + "slot_num_labels": 40, + "gamma": 2.0 } data_loader: JISFDL @@ -106,11 +107,9 @@ class SlotFiller(tfbp.Model): # Hyperparams, Optimizer and Loss function opt = Adam(learning_rate=3e-5, epsilon=1e-08) - - # two outputs, one for slots, another for intents - # we have to fine tune for both - losses = SparseCategoricalCrossentropy() - + + losses = SparseCategoricalFocalLoss(gamma=self.hparams.gamma) + metrics = [SparseCategoricalAccuracy("accuracy")] # Compile model @@ -187,71 +186,78 @@ class SlotFiller(tfbp.Model): # Get the indices of the maximum values slot_ids = slot_probas_np.argmax(axis=-1)[0, :] - # get all slot names and add to out_dict as keys + # Initialize the output dictionary out_dict = {} - predicted_slots = set([self.extra_params["slot_names"][s] - for s in slot_ids if s != 0]) + predicted_slots = set([self.extra_params["slot_names"][s] for s in slot_ids if s != 0]) for ps in predicted_slots: out_dict[ps] = [] - - # retrieving the tokenization that was used in the predictions + tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) - # We'd like to eliminate all special tokens from our output - special_tokens = self.tokenizer.special_tokens_map.values() + # Special tokens to exclude + special_tokens = set(self.tokenizer.special_tokens_map.values()) - for token, slot_id in zip(tokens, slot_ids): - if token in special_tokens: - continue - # add all to out_dict + idx = 0 # Initialize index explicitly for token tracking + while idx < len(tokens): + token = tokens[idx] + slot_id = slot_ids[idx] + + + # Get slot name slot_name = self.extra_params["slot_names"][slot_id] - if slot_name == "": + idx += 1 continue - # collect tokens - collected_tokens = [token] - idx = tokens.index(token) + # Collect tokens for the current slot + collected_tokens = [] - # see if it starts with ## - # then it belongs to the previous token - if token.startswith("##"): - # check if the token already exists or not - if tokens[idx - 1] not in out_dict[slot_name]: - collected_tokens.insert(0, tokens[idx - 1]) + # Handle regular tokens and sub-tokens + if not token.startswith("##"): + collected_tokens = [token] + else: + # Collect sub-tokens + while idx > 0 and tokens[idx - 1].startswith("##"): + idx -= 1 + collected_tokens.insert(0, tokens[idx]) + collected_tokens.append(token) - # add collected tokens to slots - out_dict[slot_name].extend(collected_tokens) + # Handle subsequent sub-tokens + while idx + 1 < len(tokens) and tokens[idx + 1].startswith("##"): + idx += 1 + collected_tokens.append(tokens[idx]) - slot_names_to_ids = {value: key for key, value in enumerate( - self.extra_params["slot_names"])} + # Add collected tokens to the appropriate slot + if slot_name in out_dict: + out_dict[slot_name].extend(collected_tokens) + idx += 1 # Move to the next token + + # Map slot names to IDs + slot_names_to_ids = {value: key for key, value in enumerate(self.extra_params["slot_names"])} + + # Create entities from the out_dict entities = [] - # process out_dict - for slot_name in out_dict: + for slot_name, slot_tokens in out_dict.items(): slot_id = slot_names_to_ids[slot_name] - slot_tokens = out_dict[slot_name] - slot_value = self.tokenizer.convert_tokens_to_string( - slot_tokens).strip() + # Convert tokens to string + slot_value = self.tokenizer.convert_tokens_to_string(slot_tokens).strip() + # Calculate entity start and end indices entity = { "entity": slot_name, "value": slot_value, "start": text.find(slot_value), - "end": text.find(slot_value) + len(slot_value), + "end": text.find(slot_value) + len(slot_value), "confidence": 0, } - # The confidence of a slot is the average confidence of tokens in that slot. + # Calculate confidence as the average of token probabilities indices = [tokens.index(token) for token in slot_tokens] - if len(slot_tokens) > 0: - total = functools.reduce( - lambda proba1, proba2: proba1+proba2, slot_probas_np[0, indices, slot_id], 0) - entity["confidence"] = total / len(slot_tokens) - else: - entity["confidence"] = 0 - + if slot_tokens: + total_confidence = sum(slot_probas_np[0, idx, slot_id] for idx in indices) + entity["confidence"] = total_confidence / len(slot_tokens) entities.append(entity) return entities