mirror of
https://github.com/hexastack/hexabot
synced 2025-04-02 12:21:19 +00:00
Merge pull request #370 from Hexastack/fix/nlu-slot-filler-bert
fix: add missing bert load from env var
This commit is contained in:
commit
23b05d7976
@ -63,11 +63,17 @@ class IntentClassifier(tfbp.Model):
|
|||||||
bert_models = json.loads(bert_model_by_language_json)
|
bert_models = json.loads(bert_model_by_language_json)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
|
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
|
||||||
|
|
||||||
|
# Ensure the parsed JSON is a dictionary
|
||||||
|
if not isinstance(bert_models, dict):
|
||||||
|
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).")
|
||||||
|
|
||||||
|
# Retrieve the BERT model name for the specified language
|
||||||
|
language = getattr(self.hparams, 'language', "en")
|
||||||
try:
|
try:
|
||||||
bert_model_name = bert_models[self.hparams.language or "en"]
|
bert_model_name = bert_models[language]
|
||||||
except json.JSONDecodeError as e:
|
except KeyError as e:
|
||||||
raise ValueError(f"No Bert model is available for the provided language: {e}")
|
raise ValueError(f"No BERT model is available for the provided language '{language}': {e}")
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
bert_model_name, use_fast=False)
|
bert_model_name, use_fast=False)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
from transformers import TFBertModel, AutoTokenizer
|
from transformers import TFBertModel, AutoTokenizer
|
||||||
@ -29,12 +30,6 @@ import boilerplate as tfbp
|
|||||||
# the paper with the original dataset.
|
# the paper with the original dataset.
|
||||||
##
|
##
|
||||||
|
|
||||||
BERT_MODEL_BY_LANGUAGE = {
|
|
||||||
'en': "bert-base-cased",
|
|
||||||
'fr': "dbmdz/bert-base-french-europeana-cased",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@tfbp.default_export
|
@tfbp.default_export
|
||||||
class SlotFiller(tfbp.Model):
|
class SlotFiller(tfbp.Model):
|
||||||
default_hparams = {
|
default_hparams = {
|
||||||
@ -53,7 +48,30 @@ class SlotFiller(tfbp.Model):
|
|||||||
|
|
||||||
# Load Tokenizer from transformers
|
# Load Tokenizer from transformers
|
||||||
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
|
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
|
||||||
bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"]
|
|
||||||
|
# Read the environment variable
|
||||||
|
bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON')
|
||||||
|
|
||||||
|
# Check if the environment variable is set
|
||||||
|
if not bert_model_by_language_json:
|
||||||
|
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.")
|
||||||
|
|
||||||
|
# Parse the JSON string into a Python dictionary
|
||||||
|
try:
|
||||||
|
bert_models = json.loads(bert_model_by_language_json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
|
||||||
|
|
||||||
|
# Ensure the parsed JSON is a dictionary
|
||||||
|
if not isinstance(bert_models, dict):
|
||||||
|
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON must be a valid JSON object (dictionary).")
|
||||||
|
|
||||||
|
# Retrieve the BERT model name for the specified language
|
||||||
|
language = getattr(self.hparams, 'language', "en")
|
||||||
|
try:
|
||||||
|
bert_model_name = bert_models[language]
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"No BERT model is available for the provided language '{language}': {e}")
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
bert_model_name, use_fast=False)
|
bert_model_name, use_fast=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user