mirror of
https://github.com/hexastack/hexabot
synced 2024-11-25 21:37:59 +00:00
fix: add missing bert load from env var
This commit is contained in:
parent
3dbb500a5f
commit
16b11cb93d
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
from transformers import TFBertModel, AutoTokenizer
|
from transformers import TFBertModel, AutoTokenizer
|
||||||
@ -29,12 +30,6 @@ import boilerplate as tfbp
|
|||||||
# the paper with the original dataset.
|
# the paper with the original dataset.
|
||||||
##
|
##
|
||||||
|
|
||||||
BERT_MODEL_BY_LANGUAGE = {
|
|
||||||
'en': "bert-base-cased",
|
|
||||||
'fr': "dbmdz/bert-base-french-europeana-cased",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@tfbp.default_export
|
@tfbp.default_export
|
||||||
class SlotFiller(tfbp.Model):
|
class SlotFiller(tfbp.Model):
|
||||||
default_hparams = {
|
default_hparams = {
|
||||||
@ -53,7 +48,24 @@ class SlotFiller(tfbp.Model):
|
|||||||
|
|
||||||
# Load Tokenizer from transformers
|
# Load Tokenizer from transformers
|
||||||
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
|
# We will use a pretrained bert model bert-base-cased for both Tokenizer and our classifier.
|
||||||
bert_model_name = BERT_MODEL_BY_LANGUAGE[self.hparams.language or "en"]
|
|
||||||
|
# Read the environment variable
|
||||||
|
bert_model_by_language_json = os.getenv('BERT_MODEL_BY_LANGUAGE_JSON')
|
||||||
|
|
||||||
|
# Check if the environment variable is set
|
||||||
|
if not bert_model_by_language_json:
|
||||||
|
raise ValueError("The BERT_MODEL_BY_LANGUAGE_JSON environment variable is not set.")
|
||||||
|
|
||||||
|
# Parse the JSON string into a Python dictionary
|
||||||
|
try:
|
||||||
|
bert_models = json.loads(bert_model_by_language_json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Failed to parse BERT_MODEL_BY_LANGUAGE_JSON: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
bert_model_name = bert_models[self.hparams.language or "en"]
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"No Bert model is available for the provided language: {e}")
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
bert_model_name, use_fast=False)
|
bert_model_name, use_fast=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user