mirror of
https://github.com/hexastack/hexabot
synced 2025-06-03 11:06:34 +00:00
feat: enhance intent-classifier
This commit is contained in:
parent
4fb1971fdc
commit
15a3787fee
@ -99,28 +99,28 @@ class JISFDL(tfbp.DataLoader):
|
||||
k = 0
|
||||
|
||||
# Filter examples by language
|
||||
lang = self.hparams.language
|
||||
all_examples = data["common_examples"]
|
||||
|
||||
if not bool(lang):
|
||||
examples = all_examples
|
||||
else:
|
||||
examples = filter(lambda exp: any(e['entity'] == 'language' and e['value'] == lang for e in exp['entities']), all_examples)
|
||||
# lang = self.hparams.language
|
||||
# all_examples = data["common_examples"]
|
||||
#
|
||||
# if not bool(lang):
|
||||
# examples = all_examples
|
||||
# else:
|
||||
# examples = filter(lambda exp: any(e['entity'] == 'language' and e['value'] == lang for e in exp['entities']), all_examples)
|
||||
|
||||
# Parse raw data
|
||||
for exp in examples:
|
||||
for exp in data:
|
||||
text = exp["text"]
|
||||
intent = exp["intent"]
|
||||
entities = exp["entities"]
|
||||
# entities = exp["entities"]
|
||||
|
||||
# Filter out language entities
|
||||
slot_entities = filter(
|
||||
lambda e: e["entity"] != "language", entities)
|
||||
slots = {e["entity"]: e["value"] for e in slot_entities}
|
||||
positions = [[e.get("start", -1), e.get("end", -1)]
|
||||
for e in slot_entities]
|
||||
# slot_entities = filter(
|
||||
# lambda e: e["entity"] != "language", entities)
|
||||
# slots = {e["entity"]: e["value"] for e in slot_entities}
|
||||
# positions = [[e.get("start", -1), e.get("end", -1)]
|
||||
# for e in slot_entities]
|
||||
|
||||
temp = JointRawData(k, intent, positions, slots, text)
|
||||
temp = JointRawData(k, intent, None, None, text)
|
||||
k += 1
|
||||
intents.append(temp)
|
||||
|
||||
@ -133,7 +133,7 @@ class JISFDL(tfbp.DataLoader):
|
||||
helper = JsonHelper()
|
||||
|
||||
if self.method in ["fit", "train"]:
|
||||
dataset = helper.read_dataset_json_file('train.json')
|
||||
dataset = helper.read_dataset_json_file('english.json')
|
||||
train_data = self.parse_dataset_intents(dataset)
|
||||
return self._transform_dataset(train_data, tokenizer)
|
||||
elif self.method in ["evaluate"]:
|
||||
@ -154,14 +154,14 @@ class JISFDL(tfbp.DataLoader):
|
||||
intent_names = list(set(intents))
|
||||
# Map slots, load from the model (evaluate), recompute from dataset otherwise (train)
|
||||
slot_names = set()
|
||||
for td in dataset:
|
||||
slots = td.slots
|
||||
for slot in slots:
|
||||
slot_names.add(slot)
|
||||
slot_names = list(slot_names)
|
||||
# To pad all the texts to the same length, the tokenizer will use special characters.
|
||||
# To handle those we need to add <PAD> to slots_names. It can be some other symbol as well.
|
||||
slot_names.insert(0, "<PAD>")
|
||||
# for td in dataset:
|
||||
# slots = td.slots
|
||||
# for slot in slots:
|
||||
# slot_names.add(slot)
|
||||
# slot_names = list(slot_names)
|
||||
# # To pad all the texts to the same length, the tokenizer will use special characters.
|
||||
# # To handle those we need to add <PAD> to slots_names. It can be some other symbol as well.
|
||||
# slot_names.insert(0, "<PAD>")
|
||||
else:
|
||||
if "intent_names" in model_params:
|
||||
intent_names = model_params["intent_names"]
|
||||
@ -210,10 +210,6 @@ class JISFDL(tfbp.DataLoader):
|
||||
|
||||
return encoded_texts, encoded_intents, encoded_slots, intent_names, slot_names
|
||||
|
||||
def get_prediction_data(self) -> str:
|
||||
helper = JsonHelper()
|
||||
dataset = helper.read_dataset_json_file('predict.json')
|
||||
return dataset["text"]
|
||||
|
||||
def encode_text(self, text: str, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
|
||||
return self.encode_texts([text], tokenizer)
|
||||
|
@ -14,6 +14,7 @@ else:
|
||||
|
||||
from keras.losses import SparseCategoricalCrossentropy
|
||||
from keras.metrics import SparseCategoricalAccuracy
|
||||
from focal_loss import SparseCategoricalFocalLoss
|
||||
import numpy as np
|
||||
|
||||
from data_loaders.jisfdl import JISFDL
|
||||
@ -128,7 +129,7 @@ class IntentClassifier(tfbp.Model):
|
||||
# Hyperparams, Optimizer and Loss function
|
||||
opt = Adam(learning_rate=3e-5, epsilon=1e-08)
|
||||
|
||||
losses = SparseCategoricalCrossentropy()
|
||||
losses = SparseCategoricalFocalLoss(gamma=2.5)
|
||||
|
||||
metrics = [SparseCategoricalAccuracy("accuracy")]
|
||||
|
||||
@ -172,32 +173,65 @@ class IntentClassifier(tfbp.Model):
|
||||
|
||||
return scores
|
||||
|
||||
@tfbp.runnable
|
||||
def predict(self):
|
||||
text = self.data_loader.get_prediction_data()
|
||||
|
||||
info = self.get_prediction(text)
|
||||
|
||||
print(self.summary())
|
||||
print("Text : " + text)
|
||||
print(json.dumps(info, indent=2))
|
||||
|
||||
return json.dumps(info, indent=2)
|
||||
|
||||
def get_prediction(self, text: str):
|
||||
inputs = self.data_loader.encode_text(text, self.tokenizer)
|
||||
intent_probas = self(inputs) # type: ignore
|
||||
|
||||
intent_probas_np = intent_probas.numpy()
|
||||
|
||||
|
||||
# Get the indices of the maximum values
|
||||
intent_id = intent_probas_np.argmax(axis=-1)[0]
|
||||
|
||||
|
||||
# get the confidences for each intent
|
||||
intent_confidences = intent_probas_np[0]
|
||||
|
||||
return {
|
||||
margin = self.compute_normalized_confidence_margin(intent_probas_np)
|
||||
output = {
|
||||
"text": text,
|
||||
"intent": {"name": self.extra_params["intent_names"][intent_id],
|
||||
"confidence": float(intent_confidences[intent_id])},
|
||||
"margin": margin,
|
||||
}
|
||||
|
||||
return output
|
||||
|
||||
def compute_top_k_confidence(self, probs, k=3):
|
||||
sorted_probas = np.sort(probs[0])[::-1] # Sort in descending order
|
||||
top_k_sum = np.sum(sorted_probas[:k])
|
||||
return top_k_sum
|
||||
|
||||
def compute_normalized_confidence_margin(self, probs):
|
||||
highest_proba = np.max(probs[0])
|
||||
sum_of_probas = self.compute_top_k_confidence(probs)
|
||||
# Normalized margin
|
||||
normalized_margin = highest_proba / sum_of_probas
|
||||
return normalized_margin
|
||||
|
||||
@tfbp.runnable
|
||||
def predict(self):
|
||||
while True:
|
||||
text = input("Provide text: ")
|
||||
inputs = self.data_loader.encode_text(text, self.tokenizer)
|
||||
intent_probas = self(inputs) # type: ignore
|
||||
|
||||
intent_probas_np = intent_probas.numpy()
|
||||
|
||||
# Get the indices of the maximum values
|
||||
intent_id = intent_probas_np.argmax(axis=-1)[0]
|
||||
|
||||
# get the confidences for each intent
|
||||
intent_confidences = intent_probas_np[0]
|
||||
|
||||
weighted_margin = self.compute_normalized_confidence_margin(intent_probas_np)
|
||||
output = {
|
||||
"text": text,
|
||||
"intent": {"name": self.extra_params["intent_names"][intent_id],
|
||||
"confidence": float(intent_confidences[intent_id])},
|
||||
"margin": weighted_margin,
|
||||
}
|
||||
print(output)
|
||||
|
||||
# Optionally, provide a way to exit the loop
|
||||
if input("Try again? (y/n): ").lower() != 'y':
|
||||
break
|
||||
|
@ -1,9 +1,37 @@
|
||||
tensorflow==2.13.*
|
||||
transformers==4.30.2
|
||||
keras==2.13.*
|
||||
numpy==1.24.*
|
||||
scikit_learn==1.2.2
|
||||
fastapi==0.100.0
|
||||
uvicorn[standard]==0.23.1
|
||||
autopep8==2.0.2
|
||||
h5py --only-binary=h5py
|
||||
absl-py==2.1.0
|
||||
astunparse==1.6.3
|
||||
certifi==2024.8.30
|
||||
charset-normalizer==3.4.0
|
||||
flatbuffers==24.3.25
|
||||
focal-loss==0.0.7
|
||||
gast==0.6.0
|
||||
google-pasta==0.2.0
|
||||
grpcio==1.67.0
|
||||
h5py==3.12.1
|
||||
idna==3.10
|
||||
keras==3.6.0
|
||||
libclang==18.1.1
|
||||
Markdown==3.7
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==3.0.1
|
||||
mdurl==0.1.2
|
||||
ml-dtypes==0.4.1
|
||||
namex==0.0.8
|
||||
numpy==1.26.4
|
||||
opt_einsum==3.4.0
|
||||
optree==0.13.0
|
||||
packaging==24.1
|
||||
protobuf==4.25.5
|
||||
Pygments==2.18.0
|
||||
requests==2.32.3
|
||||
rich==13.9.2
|
||||
six==1.16.0
|
||||
tensorboard==2.17.1
|
||||
tensorboard-data-server==0.7.2
|
||||
tensorflow==2.17.0
|
||||
tensorflow-io-gcs-filesystem==0.37.1
|
||||
termcolor==2.5.0
|
||||
typing_extensions==4.12.2
|
||||
urllib3==2.2.3
|
||||
Werkzeug==3.0.4
|
||||
wrapt==1.16.0
|
||||
|
@ -6,7 +6,8 @@ class JsonHelper:
|
||||
|
||||
def __init__(self, model:str = "intent_classifier"):
|
||||
self.data_folder=os.path.join("data",model)
|
||||
|
||||
# self.data_folder = os.path.join(os.path.dirname(__file__), '..', 'data', model)
|
||||
|
||||
def read_dataset_json_file(self, filename):
|
||||
file_path = os.path.join(self.data_folder, filename)
|
||||
if os.path.exists(file_path):
|
||||
|
Loading…
Reference in New Issue
Block a user