fix: dynamic bert/language mapping

This commit is contained in:
Mohamed Marrouchi 2024-11-22 11:46:18 +01:00
parent 76b6cb1d8d
commit 1d656fcd31
2 changed files with 25 additions and 9 deletions

View File

@ -48,6 +48,12 @@ TFLC_REPO_ID=Hexastack/tflc
INTENT_CLASSIFIER_REPO_ID=Hexastack/intent-classifier
SLOT_FILLER_REPO_ID=Hexastack/slot-filler
NLU_ENGINE_PORT=5000
BERT_MODEL_BY_LANGUAGE_JSON='{
"en": "bert-base-cased",
"fr": "dbmdz/bert-base-french-europeana-cased"
}'
# Huggingface Access token to download private models for NLU inference
HF_AUTH_TOKEN=
# Frontend (Next.js)
APP_FRONTEND_PORT=8080
@ -65,5 +71,3 @@ REDIS_ENABLED=false
REDIS_HOST=redis
REDIS_PORT=6379
# Huggingface Access token
HF_AUTH_TOKEN=

View File

@ -1,3 +1,4 @@
import os
import json
import math
from typing import Tuple, Dict, List
@ -29,12 +30,6 @@ import boilerplate as tfbp
# the paper with the original dataset.
##
BERT_MODEL_BY_LANGUAGE = {
'en': "bert-base-cased",
'fr': "dbmdz/bert-base-french-europeana-cased",
}
@tfbp.default_export
class IntentClassifier(tfbp.Model):
default_hparams = {
@ -55,7 +50,24 @@ class IntentClassifier(tfbp.Model):
# Load Tokenizer from transformers
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"]
# Read the environment variable
bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON')
# Check if the environment variable is set
if not bert_model_by_language_json:
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.")
# Parse the JSON string into a Python dictionary
try:
bert_models = json.loads(bert_model_by_language_json)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
try:
bert_model_name = bert_models[self.hparams.language or "en"]
except json.JSONDecodeError as e:
raise ValueError(f"No Bert model is available for the provided language: {e}")
self.tokenizer = AutoTokenizer.from_pretrained(
bert_model_name, use_fast=False)