From 1d656fcd31300fba8e2e3561912ef721626c4ecc Mon Sep 17 00:00:00 2001 From: Mohamed Marrouchi Date: Fri, 22 Nov 2024 11:46:18 +0100 Subject: [PATCH] fix: dynamic bert/language mapping --- docker/.env.example | 8 ++++++-- nlu/models/intent_classifier.py | 26 +++++++++++++++++++------- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index b268a760..9ca144c4 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -48,6 +48,12 @@ TFLC_REPO_ID=Hexastack/tflc INTENT_CLASSIFIER_REPO_ID=Hexastack/intent-classifier SLOT_FILLER_REPO_ID=Hexastack/slot-filler NLU_ENGINE_PORT=5000 +BERT_MODEL_BY_LANGUAGE_JSON='{ + "en": "bert-base-cased", + "fr": "dbmdz/bert-base-french-europeana-cased" +}' +# Huggingface Access token to download private models for NLU inference +HF_AUTH_TOKEN= # Frontend (Next.js) APP_FRONTEND_PORT=8080 @@ -65,5 +71,3 @@ REDIS_ENABLED=false REDIS_HOST=redis REDIS_PORT=6379 -# Huggingface Access token -HF_AUTH_TOKEN= diff --git a/nlu/models/intent_classifier.py b/nlu/models/intent_classifier.py index 0863dd3b..d96ab82b 100644 --- a/nlu/models/intent_classifier.py +++ b/nlu/models/intent_classifier.py @@ -1,3 +1,4 @@ +import os import json import math from typing import Tuple, Dict, List @@ -29,12 +30,6 @@ import boilerplate as tfbp # the paper with the original dataset. ## -BERT_MODEL_BY_LANGUAGE = { - 'en': "bert-base-cased", - 'fr': "dbmdz/bert-base-french-europeana-cased", -} - - @tfbp.default_export class IntentClassifier(tfbp.Model): default_hparams = { @@ -55,7 +50,24 @@ class IntentClassifier(tfbp.Model): # Load Tokenizer from transformers # We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier. - bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"] + + # Read the environment variable + bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON') + + # Check if the environment variable is set + if not bert_model_by_language_json: + raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.") + + # Parse the JSON string into a Python dictionary + try: + bert_models = json.loads(bert_model_by_language_json) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}") + + try: + bert_model_name = bert_models[self.hparams.language or "en"] + except json.JSONDecodeError as e: + raise ValueError(f"No Bert model is available for the provided language: {e}") self.tokenizer = AutoTokenizer.from_pretrained( bert_model_name, use_fast=False)