mirror of
https://github.com/hexastack/hexabot
synced 2025-02-17 01:57:13 +00:00
feat: refactor inference function
This commit is contained in:
parent
15a3787fee
commit
626eaa513d
3
.idea/.gitignore
vendored
Normal file
3
.idea/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
14
.idea/Hexabot.iml
Normal file
14
.idea/Hexabot.iml
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.10 (Hexabot)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
7
.idea/misc.xml
Normal file
7
.idea/misc.xml
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.10" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (Hexabot)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/Hexabot.iml" filepath="$PROJECT_DIR$/.idea/Hexabot.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
@ -125,14 +125,6 @@ class TFLCDL(tfbp.DataLoader):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unknown method!")
|
raise ValueError("Unknown method!")
|
||||||
|
|
||||||
def get_prediction_data(self):
|
|
||||||
# The predict file contains a single JSON object whose only key is text.
|
|
||||||
data = self.json_helper.read_dataset_json_file("predict.json")
|
|
||||||
text = self.strip_numbers(data["text"])
|
|
||||||
encoded_texts = np.array(self.tfidf.transform(
|
|
||||||
[text]).toarray()) # type: ignore
|
|
||||||
return np.array([text]), encoded_texts
|
|
||||||
|
|
||||||
def encode_text(self, text: str):
|
def encode_text(self, text: str):
|
||||||
sanitized_text = self.strip_numbers(text)
|
sanitized_text = self.strip_numbers(text)
|
||||||
return self.tfidf.transform([sanitized_text]).toarray() # type: ignore
|
return self.tfidf.transform([sanitized_text]).toarray() # type: ignore
|
||||||
|
@ -211,27 +211,10 @@ class IntentClassifier(tfbp.Model):
|
|||||||
@tfbp.runnable
|
@tfbp.runnable
|
||||||
def predict(self):
|
def predict(self):
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
text = input("Provide text: ")
|
text = input("Provide text: ")
|
||||||
inputs = self.data_loader.encode_text(text, self.tokenizer)
|
output = self.get_prediction(text)
|
||||||
intent_probas = self(inputs) # type: ignore
|
|
||||||
|
|
||||||
intent_probas_np = intent_probas.numpy()
|
|
||||||
|
|
||||||
# Get the indices of the maximum values
|
|
||||||
intent_id = intent_probas_np.argmax(axis=-1)[0]
|
|
||||||
|
|
||||||
# get the confidences for each intent
|
|
||||||
intent_confidences = intent_probas_np[0]
|
|
||||||
|
|
||||||
weighted_margin = self.compute_normalized_confidence_margin(intent_probas_np)
|
|
||||||
output = {
|
|
||||||
"text": text,
|
|
||||||
"intent": {"name": self.extra_params["intent_names"][intent_id],
|
|
||||||
"confidence": float(intent_confidences[intent_id])},
|
|
||||||
"margin": weighted_margin,
|
|
||||||
}
|
|
||||||
print(output)
|
print(output)
|
||||||
|
|
||||||
# Optionally, provide a way to exit the loop
|
# Optionally, provide a way to exit the loop
|
||||||
if input("Try again? (y/n): ").lower() != 'y':
|
if input("Try again? (y/n): ").lower() != 'y':
|
||||||
break
|
break
|
||||||
|
@ -151,16 +151,19 @@ class SlotFiller(tfbp.Model):
|
|||||||
|
|
||||||
@tfbp.runnable
|
@tfbp.runnable
|
||||||
def predict(self):
|
def predict(self):
|
||||||
text = self.data_loader.get_prediction_data()
|
while True:
|
||||||
|
text = input("Provide text: ")
|
||||||
|
info = self.get_prediction(text)
|
||||||
|
|
||||||
info = self.get_prediction(text)
|
print(self.summary())
|
||||||
|
print("Text : " + text)
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
# Optionally, provide a way to exit the loop
|
||||||
|
if input("Try again? (y/n): ").lower() != 'y':
|
||||||
|
break
|
||||||
|
|
||||||
print(self.summary())
|
|
||||||
print("Text : " + text)
|
|
||||||
print(json.dumps(info, indent=2))
|
|
||||||
|
|
||||||
return json.dumps(info, indent=2)
|
|
||||||
|
|
||||||
def get_slots_prediction(self, text: str, inputs, slot_probas):
|
def get_slots_prediction(self, text: str, inputs, slot_probas):
|
||||||
slot_probas_np = slot_probas.numpy()
|
slot_probas_np = slot_probas.numpy()
|
||||||
# Get the indices of the maximum values
|
# Get the indices of the maximum values
|
||||||
|
@ -95,19 +95,27 @@ class TFLC(tfbp.Model):
|
|||||||
|
|
||||||
self.calculate_metrics(y_test, y_pred, languages)
|
self.calculate_metrics(y_test, y_pred, languages)
|
||||||
|
|
||||||
|
def preprocess_text(self, text):
|
||||||
|
# The predict file contains a single JSON object whose only key is text.
|
||||||
|
stripped_text = self.strip_numbers(text)
|
||||||
|
encoded_text = np.array(self.tfidf.transform(
|
||||||
|
[stripped_text]).toarray()) # type: ignore
|
||||||
|
return np.array([stripped_text]), encoded_text
|
||||||
|
|
||||||
@tfbp.runnable
|
@tfbp.runnable
|
||||||
def predict(self):
|
def predict(self):
|
||||||
languages = list(self.extra_params['languages'])
|
languages = list(self.extra_params['languages'])
|
||||||
texts, encoded_texts = self.data_loader.get_prediction_data()
|
input_provided = input("Provide text: ")
|
||||||
|
text, encoded_text = self.preprocess_text(input_provided)
|
||||||
# converting a one hot output to language index
|
# converting a one hot output to language index
|
||||||
probas = super().predict(encoded_texts)
|
probas = super().predict(encoded_text)
|
||||||
predictions = np.argmax(probas, axis=1)
|
predictions = np.argmax(probas, axis=1)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for idx, prediction in enumerate(predictions):
|
for idx, prediction in enumerate(predictions):
|
||||||
print('The sentence "{}" is in {}.'.format(
|
print('The sentence "{}" is in {}.'.format(
|
||||||
texts[idx], languages[prediction].upper()))
|
text[idx], languages[prediction].upper()))
|
||||||
results.append({'text': texts[idx], 'language': prediction})
|
results.append({'text': text[idx], 'language': prediction})
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_prediction(self, text: str):
|
def get_prediction(self, text: str):
|
||||||
|
Loading…
Reference in New Issue
Block a user