mirror of
https://github.com/hexastack/hexabot
synced 2024-11-29 07:21:29 +00:00
fix: add missing bert load from env var
This commit is contained in:
parent
3dbb500a5f
commit
16b11cb93d
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import functools
|
||||
import json
|
||||
from transformers import TFBertModel, AutoTokenizer
|
||||
@ -29,12 +30,6 @@ import boilerplate as tfbp
|
||||
# the paper with the original dataset.
|
||||
##
|
||||
|
||||
BERT_MODEL_BY_LANGUAGE = {
|
||||
'en': "bert-base-cased",
|
||||
'fr': "dbmdz/bert-base-french-europeana-cased",
|
||||
}
|
||||
|
||||
|
||||
@tfbp.default_export
|
||||
class SlotFiller(tfbp.Model):
|
||||
default_hparams = {
|
||||
@ -53,7 +48,24 @@ class SlotFiller(tfbp.Model):
|
||||
|
||||
# Load Tokenizer from transformers
|
||||
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
|
||||
bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"]
|
||||
|
||||
# Read the environment variable
|
||||
bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON')
|
||||
|
||||
# Check if the environment variable is set
|
||||
if not bert_model_by_language_json:
|
||||
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.")
|
||||
|
||||
# Parse the JSON string into a Python dictionary
|
||||
try:
|
||||
bert_models = json.loads(bert_model_by_language_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
|
||||
|
||||
try:
|
||||
bert_model_name = bert_models[self.hparams.language or "en"]
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"No Bert model is available for the provided language: {e}")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
bert_model_name, use_fast=False)
|
||||
|
Loading…
Reference in New Issue
Block a user