From 626eaa513d9765bc0f4bcf32a3b5688257d84ab6 Mon Sep 17 00:00:00 2001 From: hexastack Date: Tue, 22 Oct 2024 11:57:30 +0100 Subject: [PATCH] feat: refactor inference function --- .idea/.gitignore | 3 +++ .idea/Hexabot.iml | 14 +++++++++++++ .../inspectionProfiles/profiles_settings.xml | 6 ++++++ .idea/misc.xml | 7 +++++++ .idea/modules.xml | 8 +++++++ .idea/vcs.xml | 6 ++++++ nlu/data_loaders/tflcdl.py | 8 ------- nlu/models/intent_classifier.py | 21 ++----------------- nlu/models/slot_filler.py | 17 ++++++++------- nlu/models/tflc.py | 16 ++++++++++---- 10 files changed, 68 insertions(+), 38 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/Hexabot.iml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/Hexabot.iml b/.idea/Hexabot.iml new file mode 100644 index 0000000..e3ccb12 --- /dev/null +++ b/.idea/Hexabot.iml @@ -0,0 +1,14 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..cd633c6 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..41382fd --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/nlu/data_loaders/tflcdl.py b/nlu/data_loaders/tflcdl.py index b765f78..bca3e2d 100644 --- a/nlu/data_loaders/tflcdl.py +++ b/nlu/data_loaders/tflcdl.py @@ -125,14 +125,6 @@ class TFLCDL(tfbp.DataLoader): else: raise ValueError("Unknown method!") - def get_prediction_data(self): - # The predict file contains a single JSON object whose only key is text. - data = self.json_helper.read_dataset_json_file("predict.json") - text = self.strip_numbers(data["text"]) - encoded_texts = np.array(self.tfidf.transform( - [text]).toarray()) # type: ignore - return np.array([text]), encoded_texts - def encode_text(self, text: str): sanitized_text = self.strip_numbers(text) return self.tfidf.transform([sanitized_text]).toarray() # type: ignore diff --git a/nlu/models/intent_classifier.py b/nlu/models/intent_classifier.py index 76314c6..c7a5536 100644 --- a/nlu/models/intent_classifier.py +++ b/nlu/models/intent_classifier.py @@ -211,27 +211,10 @@ class IntentClassifier(tfbp.Model): @tfbp.runnable def predict(self): while True: + text = input("Provide text: ") - inputs = self.data_loader.encode_text(text, self.tokenizer) - intent_probas = self(inputs) # type: ignore - - intent_probas_np = intent_probas.numpy() - - # Get the indices of the maximum values - intent_id = intent_probas_np.argmax(axis=-1)[0] - - # get the confidences for each intent - intent_confidences = intent_probas_np[0] - - weighted_margin = self.compute_normalized_confidence_margin(intent_probas_np) - output = { - "text": text, - "intent": {"name": self.extra_params["intent_names"][intent_id], - "confidence": float(intent_confidences[intent_id])}, - "margin": weighted_margin, - } + output = self.get_prediction(text) print(output) - # Optionally, provide a way to exit the loop if input("Try again? (y/n): ").lower() != 'y': break diff --git a/nlu/models/slot_filler.py b/nlu/models/slot_filler.py index 0393fb3..d4d3182 100644 --- a/nlu/models/slot_filler.py +++ b/nlu/models/slot_filler.py @@ -151,16 +151,19 @@ class SlotFiller(tfbp.Model): @tfbp.runnable def predict(self): - text = self.data_loader.get_prediction_data() + while True: + text = input("Provide text: ") + info = self.get_prediction(text) - info = self.get_prediction(text) + print(self.summary()) + print("Text : " + text) + print(info) + + # Optionally, provide a way to exit the loop + if input("Try again? (y/n): ").lower() != 'y': + break - print(self.summary()) - print("Text : " + text) - print(json.dumps(info, indent=2)) - return json.dumps(info, indent=2) - def get_slots_prediction(self, text: str, inputs, slot_probas): slot_probas_np = slot_probas.numpy() # Get the indices of the maximum values diff --git a/nlu/models/tflc.py b/nlu/models/tflc.py index 23ccc20..c4d1046 100644 --- a/nlu/models/tflc.py +++ b/nlu/models/tflc.py @@ -95,19 +95,27 @@ class TFLC(tfbp.Model): self.calculate_metrics(y_test, y_pred, languages) + def preprocess_text(self, text): + # The predict file contains a single JSON object whose only key is text. + stripped_text = self.strip_numbers(text) + encoded_text = np.array(self.tfidf.transform( + [stripped_text]).toarray()) # type: ignore + return np.array([stripped_text]), encoded_text + @tfbp.runnable def predict(self): languages = list(self.extra_params['languages']) - texts, encoded_texts = self.data_loader.get_prediction_data() + input_provided = input("Provide text: ") + text, encoded_text = self.preprocess_text(input_provided) # converting a one hot output to language index - probas = super().predict(encoded_texts) + probas = super().predict(encoded_text) predictions = np.argmax(probas, axis=1) results = [] for idx, prediction in enumerate(predictions): print('The sentence "{}" is in {}.'.format( - texts[idx], languages[prediction].upper())) - results.append({'text': texts[idx], 'language': prediction}) + text[idx], languages[prediction].upper())) + results.append({'text': text[idx], 'language': prediction}) return results def get_prediction(self, text: str):