From 1981f2245ddd11c2dd1247fcf0e90a572c0d3c6d Mon Sep 17 00:00:00 2001 From: Mohamed Marrouchi Date: Mon, 25 Nov 2024 12:37:35 +0100 Subject: [PATCH] fix: enhance --- nlu/models/intent_classifier.py | 14 ++++++++++---- nlu/models/slot_filler.py | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/nlu/models/intent_classifier.py b/nlu/models/intent_classifier.py index d96ab82b..61276fdd 100644 --- a/nlu/models/intent_classifier.py +++ b/nlu/models/intent_classifier.py @@ -63,11 +63,17 @@ class IntentClassifier(tfbp.Model): bert_models = json.loads(bert_model_by_language_json) except json.JSONDecodeError as e: raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}") - + + # Ensure the parsed JSON is a dictionary + if not isinstance(bert_models, dict): + raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).") + + # Retrieve the BERT model name for the specified language + language = getattr(self.hparams, 'language', "en") try: - bert_model_name = bert_models[self.hparams.language or "en"] - except json.JSONDecodeError as e: - raise ValueError(f"No Bert model is available for the provided language: {e}") + bert_model_name = bert_models[language] + except KeyError as e: + raise ValueError(f"No BERT model is available for the provided language '{language}': {e}") self.tokenizer = AutoTokenizer.from_pretrained( bert_model_name, use_fast=False) diff --git a/nlu/models/slot_filler.py b/nlu/models/slot_filler.py index 11c509e0..ff3436d7 100644 --- a/nlu/models/slot_filler.py +++ b/nlu/models/slot_filler.py @@ -61,11 +61,17 @@ class SlotFiller(tfbp.Model): bert_models = json.loads(bert_model_by_language_json) except json.JSONDecodeError as e: raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}") - + + # Ensure the parsed JSON is a dictionary + if not isinstance(bert_models, dict): + raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).") + + # Retrieve the BERT model name for the specified language + language = getattr(self.hparams, 'language', "en") try: - bert_model_name = bert_models[self.hparams.language or "en"] - except json.JSONDecodeError as e: - raise ValueError(f"No Bert model is available for the provided language: {e}") + bert_model_name = bert_models[language] + except KeyError as e: + raise ValueError(f"No BERT model is available for the provided language '{language}': {e}") self.tokenizer = AutoTokenizer.from_pretrained( bert_model_name, use_fast=False)