fix: enhance

This commit is contained in:
Mohamed Marrouchi 2024-11-25 12:37:35 +01:00
parent 16b11cb93d
commit 1981f2245d
2 changed files with 20 additions and 8 deletions

View File

@ -63,11 +63,17 @@ class IntentClassifier(tfbp.Model):
bert_models = json.loads(bert_model_by_language_json)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
# Ensure the parsed JSON is a dictionary
if not isinstance(bert_models, dict):
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).")
# Retrieve the BERT model name for the specified language
language = getattr(self.hparams, 'language', "en")
try:
bert_model_name = bert_models[self.hparams.language or "en"]
except json.JSONDecodeError as e:
raise ValueError(f"No Bert model is available for the provided language: {e}")
bert_model_name = bert_models[language]
except KeyError as e:
raise ValueError(f"No BERT model is available for the provided language '{language}': {e}")
self.tokenizer = AutoTokenizer.from_pretrained(
bert_model_name, use_fast=False)

View File

@ -61,11 +61,17 @@ class SlotFiller(tfbp.Model):
bert_models = json.loads(bert_model_by_language_json)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
# Ensure the parsed JSON is a dictionary
if not isinstance(bert_models, dict):
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).")
# Retrieve the BERT model name for the specified language
language = getattr(self.hparams, 'language', "en")
try:
bert_model_name = bert_models[self.hparams.language or "en"]
except json.JSONDecodeError as e:
raise ValueError(f"No Bert model is available for the provided language: {e}")
bert_model_name = bert_models[language]
except KeyError as e:
raise ValueError(f"No BERT model is available for the provided language '{language}': {e}")
self.tokenizer = AutoTokenizer.from_pretrained(
bert_model_name, use_fast=False)