Merge pull request #370 from Hexastack/fix/nlu-slot-filler-bert
Some checks are pending
Build and Push Docker Images / paths-filter (push) Waiting to run
Build and Push Docker Images / build-and-push (push) Blocked by required conditions

fix: add missing bert load from env var
This commit is contained in:
Med Marrouchi 2024-11-25 12:47:04 +01:00 committed by GitHub
commit 23b05d7976
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 11 deletions

View File

@ -63,11 +63,17 @@ class IntentClassifier(tfbp.Model):
bert_models = json.loads(bert_model_by_language_json)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
# Ensure the parsed JSON is a dictionary
if not isinstance(bert_models, dict):
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).")
# Retrieve the BERT model name for the specified language
language = getattr(self.hparams, 'language', "en")
try:
bert_model_name = bert_models[self.hparams.language or "en"]
except json.JSONDecodeError as e:
raise ValueError(f"No Bert model is available for the provided language: {e}")
bert_model_name = bert_models[language]
except KeyError as e:
raise ValueError(f"No BERT model is available for the provided language '{language}': {e}")
self.tokenizer = AutoTokenizer.from_pretrained(
bert_model_name, use_fast=False)

View File

@ -1,3 +1,4 @@
import os
import functools
import json
from transformers import TFBertModel, AutoTokenizer
@ -29,12 +30,6 @@ import boilerplate as tfbp
# the paper with the original dataset.
##
BERT_MODEL_BY_LANGUAGE = {
'en': "bert-base-cased",
'fr': "dbmdz/bert-base-french-europeana-cased",
}
@tfbp.default_export
class SlotFiller(tfbp.Model):
default_hparams = {
@ -53,7 +48,30 @@ class SlotFiller(tfbp.Model):
# Load Tokenizer from transformers
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"]
# Read the environment variable
bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON')
# Check if the environment variable is set
if not bert_model_by_language_json:
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.")
# Parse the JSON string into a Python dictionary
try:
bert_models = json.loads(bert_model_by_language_json)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
# Ensure the parsed JSON is a dictionary
if not isinstance(bert_models, dict):
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).")
# Retrieve the BERT model name for the specified language
language = getattr(self.hparams, 'language', "en")
try:
bert_model_name = bert_models[language]
except KeyError as e:
raise ValueError(f"No BERT model is available for the provided language '{language}': {e}")
self.tokenizer = AutoTokenizer.from_pretrained(
bert_model_name, use_fast=False)