mirror of
https://github.com/hexastack/hexabot
synced 2024-11-24 04:53:41 +00:00
Merge pull request #361 from Hexastack/fix/bert-language-env-var
fix: dynamic bert/language mapping
This commit is contained in:
commit
2d4b00b9c0
@ -48,6 +48,12 @@ TFLC_REPO_ID=Hexastack/tflc
|
||||
INTENT_CLASSIFIER_REPO_ID=Hexastack/intent-classifier
|
||||
SLOT_FILLER_REPO_ID=Hexastack/slot-filler
|
||||
NLU_ENGINE_PORT=5000
|
||||
BERT_MODEL_BY_LANGUAGE_JSON='{
|
||||
"en": "bert-base-cased",
|
||||
"fr": "dbmdz/bert-base-french-europeana-cased"
|
||||
}'
|
||||
# Huggingface Access token to download private models for NLU inference
|
||||
HF_AUTH_TOKEN=
|
||||
|
||||
# Frontend (Next.js)
|
||||
APP_FRONTEND_PORT=8080
|
||||
@ -65,5 +71,3 @@ REDIS_ENABLED=false
|
||||
REDIS_HOST=redis
|
||||
REDIS_PORT=6379
|
||||
|
||||
# Huggingface Access token
|
||||
HF_AUTH_TOKEN=
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
from typing import Tuple, Dict, List
|
||||
@ -29,12 +30,6 @@ import boilerplate as tfbp
|
||||
# the paper with the original dataset.
|
||||
##
|
||||
|
||||
BERT_MODEL_BY_LANGUAGE = {
|
||||
'en': "bert-base-cased",
|
||||
'fr': "dbmdz/bert-base-french-europeana-cased",
|
||||
}
|
||||
|
||||
|
||||
@tfbp.default_export
|
||||
class IntentClassifier(tfbp.Model):
|
||||
default_hparams = {
|
||||
@ -55,7 +50,24 @@ class IntentClassifier(tfbp.Model):
|
||||
|
||||
# Load Tokenizer from transformers
|
||||
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
|
||||
bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"]
|
||||
|
||||
# Read the environment variable
|
||||
bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON')
|
||||
|
||||
# Check if the environment variable is set
|
||||
if not bert_model_by_language_json:
|
||||
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.")
|
||||
|
||||
# Parse the JSON string into a Python dictionary
|
||||
try:
|
||||
bert_models = json.loads(bert_model_by_language_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
|
||||
|
||||
try:
|
||||
bert_model_name = bert_models[self.hparams.language or "en"]
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"No Bert model is available for the provided language: {e}")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
bert_model_name, use_fast=False)
|
||||
|
Loading…
Reference in New Issue
Block a user