mirror of
https://github.com/hexastack/hexabot
synced 2025-01-22 10:35:37 +00:00
fix: dynamic bert/language mapping
This commit is contained in:
parent
76b6cb1d8d
commit
1d656fcd31
@ -48,6 +48,12 @@ TFLC_REPO_ID=Hexastack/tflc
|
|||||||
INTENT_CLASSIFIER_REPO_ID=Hexastack/intent-classifier
|
INTENT_CLASSIFIER_REPO_ID=Hexastack/intent-classifier
|
||||||
SLOT_FILLER_REPO_ID=Hexastack/slot-filler
|
SLOT_FILLER_REPO_ID=Hexastack/slot-filler
|
||||||
NLU_ENGINE_PORT=5000
|
NLU_ENGINE_PORT=5000
|
||||||
|
BERT_MODEL_BY_LANGUAGE_JSON='{
|
||||||
|
"en": "bert-base-cased",
|
||||||
|
"fr": "dbmdz/bert-base-french-europeana-cased"
|
||||||
|
}'
|
||||||
|
# Huggingface Access token to download private models for NLU inference
|
||||||
|
HF_AUTH_TOKEN=
|
||||||
|
|
||||||
# Frontend (Next.js)
|
# Frontend (Next.js)
|
||||||
APP_FRONTEND_PORT=8080
|
APP_FRONTEND_PORT=8080
|
||||||
@ -65,5 +71,3 @@ REDIS_ENABLED=false
|
|||||||
REDIS_HOST=redis
|
REDIS_HOST=redis
|
||||||
REDIS_PORT=6379
|
REDIS_PORT=6379
|
||||||
|
|
||||||
# Huggingface Access token
|
|
||||||
HF_AUTH_TOKEN=
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
from typing import Tuple, Dict, List
|
from typing import Tuple, Dict, List
|
||||||
@ -29,12 +30,6 @@ import boilerplate as tfbp
|
|||||||
# the paper with the original dataset.
|
# the paper with the original dataset.
|
||||||
##
|
##
|
||||||
|
|
||||||
BERT_MODEL_BY_LANGUAGE = {
|
|
||||||
'en': "bert-base-cased",
|
|
||||||
'fr': "dbmdz/bert-base-french-europeana-cased",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@tfbp.default_export
|
@tfbp.default_export
|
||||||
class IntentClassifier(tfbp.Model):
|
class IntentClassifier(tfbp.Model):
|
||||||
default_hparams = {
|
default_hparams = {
|
||||||
@ -55,7 +50,24 @@ class IntentClassifier(tfbp.Model):
|
|||||||
|
|
||||||
# Load Tokenizer from transformers
|
# Load Tokenizer from transformers
|
||||||
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
|
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
|
||||||
bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"]
|
|
||||||
|
# Read the environment variable
|
||||||
|
bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON')
|
||||||
|
|
||||||
|
# Check if the environment variable is set
|
||||||
|
if not bert_model_by_language_json:
|
||||||
|
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.")
|
||||||
|
|
||||||
|
# Parse the JSON string into a Python dictionary
|
||||||
|
try:
|
||||||
|
bert_models = json.loads(bert_model_by_language_json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
bert_model_name = bert_models[self.hparams.language or "en"]
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"No Bert model is available for the provided language: {e}")
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
bert_model_name, use_fast=False)
|
bert_model_name, use_fast=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user