diff --git a/nlu/models/intent_classifier.py b/nlu/models/intent_classifier.py index d96ab82b..61276fdd 100644 --- a/nlu/models/intent_classifier.py +++ b/nlu/models/intent_classifier.py @@ -63,11 +63,17 @@ class IntentClassifier(tfbp.Model): bert_models = json.loads(bert_model_by_language_json) except json.JSONDecodeError as e: raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}") - + + # Ensure the parsed JSON is a dictionary + if not isinstance(bert_models, dict): + raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).") + + # Retrieve the BERT model name for the specified language + language = getattr(self.hparams, 'language', "en") try: - bert_model_name = bert_models[self.hparams.language or "en"] - except json.JSONDecodeError as e: - raise ValueError(f"No Bert model is available for the provided language: {e}") + bert_model_name = bert_models[language] + except KeyError as e: + raise ValueError(f"No BERT model is available for the provided language '{language}': {e}") self.tokenizer = AutoTokenizer.from_pretrained( bert_model_name, use_fast=False) diff --git a/nlu/models/slot_filler.py b/nlu/models/slot_filler.py index d4d31820..ff3436d7 100644 --- a/nlu/models/slot_filler.py +++ b/nlu/models/slot_filler.py @@ -1,3 +1,4 @@ +import os import functools import json from transformers import TFBertModel, AutoTokenizer @@ -29,12 +30,6 @@ import boilerplate as tfbp # the paper with the original dataset. ## -BERT_MODEL_BY_LANGUAGE = { - 'en': "bert-base-cased", - 'fr': "dbmdz/bert-base-french-europeana-cased", -} - - @tfbp.default_export class SlotFiller(tfbp.Model): default_hparams = { @@ -53,7 +48,30 @@ class SlotFiller(tfbp.Model): # Load Tokenizer from transformers # We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier. - bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"] + + # Read the environment variable + bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON') + + # Check if the environment variable is set + if not bert_model_by_language_json: + raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.") + + # Parse the JSON string into a Python dictionary + try: + bert_models = json.loads(bert_model_by_language_json) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}") + + # Ensure the parsed JSON is a dictionary + if not isinstance(bert_models, dict): + raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).") + + # Retrieve the BERT model name for the specified language + language = getattr(self.hparams, 'language', "en") + try: + bert_model_name = bert_models[language] + except KeyError as e: + raise ValueError(f"No BERT model is available for the provided language '{language}': {e}") self.tokenizer = AutoTokenizer.from_pretrained( bert_model_name, use_fast=False)