This commit is contained in:
Timothy J. Baek 2024-09-17 22:58:06 +02:00
parent d1dbb9a3be
commit 67f95ddfdc

View File

@ -180,13 +180,13 @@ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_
def update_embedding_model(
embedding_model: str,
update_model: bool = False,
auto_update: bool = False,
):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
@ -196,17 +196,18 @@ def update_embedding_model(
def update_reranking_model(
reranking_model: str,
update_model: bool = False,
auto_update: bool = False,
):
if reranking_model:
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
class Colbert:
class ColBERT:
def __init__(self, name) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
self.device
)
self.ckpt = Checkpoint(
get_model_path(name, auto_update),
colbert_config=ColBERTConfig(),
).to(self.device)
pass
def calculate_similarity_scores(
@ -264,13 +265,13 @@ def update_reranking_model(
return scores
app.state.sentence_transformer_rf = Colbert(reranking_model)
app.state.sentence_transformer_rf = ColBERT(reranking_model)
else:
import sentence_transformers
try:
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
get_model_path(reranking_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)