From 16b11cb93db07faec3f039c4d1fd0750b29580bd Mon Sep 17 00:00:00 2001 From: Mohamed Marrouchi Date: Mon, 25 Nov 2024 12:00:05 +0100 Subject: [PATCH] fix: add missing bert load from env var --- nlu/models/slot_filler.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/nlu/models/slot_filler.py b/nlu/models/slot_filler.py index d4d3182..11c509e 100644 --- a/nlu/models/slot_filler.py +++ b/nlu/models/slot_filler.py @@ -1,3 +1,4 @@ +import os import functools import json from transformers import TFBertModel, AutoTokenizer @@ -29,12 +30,6 @@ import boilerplate as tfbp # the paper with the original dataset. ## -BERT_MODEL_BY_LANGUAGE = { - 'en': "bert-base-cased", - 'fr': "dbmdz/bert-base-french-europeana-cased", -} - - @tfbp.default_export class SlotFiller(tfbp.Model): default_hparams = { @@ -53,7 +48,24 @@ class SlotFiller(tfbp.Model): # Load Tokenizer from transformers # We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier. - bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"] + + # Read the environment variable + bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON') + + # Check if the environment variable is set + if not bert_model_by_language_json: + raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.") + + # Parse the JSON string into a Python dictionary + try: + bert_models = json.loads(bert_model_by_language_json) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}") + + try: + bert_model_name = bert_models[self.hparams.language or "en"] + except json.JSONDecodeError as e: + raise ValueError(f"No Bert model is available for the provided language: {e}") self.tokenizer = AutoTokenizer.from_pretrained( bert_model_name, use_fast=False)