mirror of
https://github.com/hexastack/hexabot
synced 2025-04-24 00:04:17 +00:00
fix: restore data loader
This commit is contained in:
parent
9b61e36c88
commit
13df530881
@ -99,28 +99,28 @@ class JISFDL(tfbp.DataLoader):
|
|||||||
k = 0
|
k = 0
|
||||||
|
|
||||||
# Filter examples by language
|
# Filter examples by language
|
||||||
# lang = self.hparams.language
|
lang = self.hparams.language
|
||||||
# all_examples = data["common_examples"]
|
all_examples = data["common_examples"]
|
||||||
#
|
|
||||||
# if not bool(lang):
|
if not bool(lang):
|
||||||
# examples = all_examples
|
examples = all_examples
|
||||||
# else:
|
else:
|
||||||
# examples = filter(lambda exp: any(e['entity'] == 'language' and e['value'] == lang for e in exp['entities']), all_examples)
|
examples = filter(lambda exp: any(e['entity'] == 'language' and e['value'] == lang for e in exp['entities']), all_examples)
|
||||||
|
|
||||||
# Parse raw data
|
# Parse raw data
|
||||||
for exp in data:
|
for exp in examples:
|
||||||
text = exp["text"]
|
text = exp["text"]
|
||||||
intent = exp["intent"]
|
intent = exp["intent"]
|
||||||
# entities = exp["entities"]
|
entities = exp["entities"]
|
||||||
|
|
||||||
# Filter out language entities
|
# Filter out language entities
|
||||||
# slot_entities = filter(
|
slot_entities = filter(
|
||||||
# lambda e: e["entity"] != "language", entities)
|
lambda e: e["entity"] != "language", entities)
|
||||||
# slots = {e["entity"]: e["value"] for e in slot_entities}
|
slots = {e["entity"]: e["value"] for e in slot_entities}
|
||||||
# positions = [[e.get("start", -1), e.get("end", -1)]
|
positions = [[e.get("start", -1), e.get("end", -1)]
|
||||||
# for e in slot_entities]
|
for e in slot_entities]
|
||||||
|
|
||||||
temp = JointRawData(k, intent, None, None, text)
|
temp = JointRawData(k, intent, positions, slots, text)
|
||||||
k += 1
|
k += 1
|
||||||
intents.append(temp)
|
intents.append(temp)
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ class JISFDL(tfbp.DataLoader):
|
|||||||
helper = JsonHelper()
|
helper = JsonHelper()
|
||||||
|
|
||||||
if self.method in ["fit", "train"]:
|
if self.method in ["fit", "train"]:
|
||||||
dataset = helper.read_dataset_json_file('english.json')
|
dataset = helper.read_dataset_json_file('train.json')
|
||||||
train_data = self.parse_dataset_intents(dataset)
|
train_data = self.parse_dataset_intents(dataset)
|
||||||
return self._transform_dataset(train_data, tokenizer)
|
return self._transform_dataset(train_data, tokenizer)
|
||||||
elif self.method in ["evaluate"]:
|
elif self.method in ["evaluate"]:
|
||||||
@ -154,14 +154,14 @@ class JISFDL(tfbp.DataLoader):
|
|||||||
intent_names = list(set(intents))
|
intent_names = list(set(intents))
|
||||||
# Map slots, load from the model (evaluate), recompute from dataset otherwise (train)
|
# Map slots, load from the model (evaluate), recompute from dataset otherwise (train)
|
||||||
slot_names = set()
|
slot_names = set()
|
||||||
# for td in dataset:
|
for td in dataset:
|
||||||
# slots = td.slots
|
slots = td.slots
|
||||||
# for slot in slots:
|
for slot in slots:
|
||||||
# slot_names.add(slot)
|
slot_names.add(slot)
|
||||||
# slot_names = list(slot_names)
|
slot_names = list(slot_names)
|
||||||
# # To pad all the texts to the same length, the tokenizer will use special characters.
|
# To pad all the texts to the same length, the tokenizer will use special characters.
|
||||||
# # To handle those we need to add <PAD> to slots_names. It can be some other symbol as well.
|
# To handle those we need to add <PAD> to slots_names. It can be some other symbol as well.
|
||||||
# slot_names.insert(0, "<PAD>")
|
slot_names.insert(0, "<PAD>")
|
||||||
else:
|
else:
|
||||||
if "intent_names" in model_params:
|
if "intent_names" in model_params:
|
||||||
intent_names = model_params["intent_names"]
|
intent_names = model_params["intent_names"]
|
||||||
|
@ -43,6 +43,8 @@ class IntentClassifier(tfbp.Model):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"dropout_prob": 0.1,
|
"dropout_prob": 0.1,
|
||||||
"intent_num_labels": 7,
|
"intent_num_labels": 7,
|
||||||
|
"gamma": 2,
|
||||||
|
"k": 3
|
||||||
}
|
}
|
||||||
data_loader: JISFDL
|
data_loader: JISFDL
|
||||||
|
|
||||||
@ -129,7 +131,7 @@ class IntentClassifier(tfbp.Model):
|
|||||||
# Hyperparams, Optimizer and Loss function
|
# Hyperparams, Optimizer and Loss function
|
||||||
opt = Adam(learning_rate=3e-5, epsilon=1e-08)
|
opt = Adam(learning_rate=3e-5, epsilon=1e-08)
|
||||||
|
|
||||||
losses = SparseCategoricalFocalLoss(gamma=2.5)
|
losses = SparseCategoricalFocalLoss(gamma=self.hparams.gamma)
|
||||||
|
|
||||||
metrics = [SparseCategoricalAccuracy("accuracy")]
|
metrics = [SparseCategoricalAccuracy("accuracy")]
|
||||||
|
|
||||||
@ -203,7 +205,7 @@ class IntentClassifier(tfbp.Model):
|
|||||||
|
|
||||||
def compute_normalized_confidence_margin(self, probs):
|
def compute_normalized_confidence_margin(self, probs):
|
||||||
highest_proba = np.max(probs[0])
|
highest_proba = np.max(probs[0])
|
||||||
sum_of_probas = self.compute_top_k_confidence(probs)
|
sum_of_probas = self.compute_top_k_confidence(probs, self.hparams.k)
|
||||||
# Normalized margin
|
# Normalized margin
|
||||||
normalized_margin = highest_proba / sum_of_probas
|
normalized_margin = highest_proba / sum_of_probas
|
||||||
return normalized_margin
|
return normalized_margin
|
||||||
|
@ -6,7 +6,6 @@ class JsonHelper:
|
|||||||
|
|
||||||
def __init__(self, model:str = "intent_classifier"):
|
def __init__(self, model:str = "intent_classifier"):
|
||||||
self.data_folder=os.path.join("data",model)
|
self.data_folder=os.path.join("data",model)
|
||||||
# self.data_folder = os.path.join(os.path.dirname(__file__), '..', 'data', model)
|
|
||||||
|
|
||||||
def read_dataset_json_file(self, filename):
|
def read_dataset_json_file(self, filename):
|
||||||
file_path = os.path.join(self.data_folder, filename)
|
file_path = os.path.join(self.data_folder, filename)
|
||||||
|
Loading…
Reference in New Issue
Block a user