fix: add missing bert load from env var

This commit is contained in:
Mohamed Marrouchi 2024-11-25 12:00:05 +01:00
parent 3dbb500a5f
commit 16b11cb93d

View File

@ -1,3 +1,4 @@
import os
import functools
import json
from transformers import TFBertModel, AutoTokenizer
@ -29,12 +30,6 @@ import boilerplate as tfbp
# the paper with the original dataset.
##
BERT_MODEL_BY_LANGUAGE = {
'en': "bert-base-cased",
'fr': "dbmdz/bert-base-french-europeana-cased",
}
@tfbp.default_export
class SlotFiller(tfbp.Model):
default_hparams = {
@ -53,7 +48,24 @@ class SlotFiller(tfbp.Model):
# Load Tokenizer from transformers
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"]
# Read the environment variable
bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON')
# Check if the environment variable is set
if not bert_model_by_language_json:
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.")
# Parse the JSON string into a Python dictionary
try:
bert_models = json.loads(bert_model_by_language_json)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
try:
bert_model_name = bert_models[self.hparams.language or "en"]
except json.JSONDecodeError as e:
raise ValueError(f"No Bert model is available for the provided language: {e}")
self.tokenizer = AutoTokenizer.from_pretrained(
bert_model_name, use_fast=False)