mirror of
https://github.com/open-webui/open-webui
synced 2025-06-09 07:56:42 +00:00
refac
This commit is contained in:
parent
d1dbb9a3be
commit
67f95ddfdc
@ -180,13 +180,13 @@ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_
|
|||||||
|
|
||||||
def update_embedding_model(
|
def update_embedding_model(
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
update_model: bool = False,
|
auto_update: bool = False,
|
||||||
):
|
):
|
||||||
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
|
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
||||||
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||||
get_model_path(embedding_model, update_model),
|
get_model_path(embedding_model, auto_update),
|
||||||
device=DEVICE_TYPE,
|
device=DEVICE_TYPE,
|
||||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||||
)
|
)
|
||||||
@ -196,17 +196,18 @@ def update_embedding_model(
|
|||||||
|
|
||||||
def update_reranking_model(
|
def update_reranking_model(
|
||||||
reranking_model: str,
|
reranking_model: str,
|
||||||
update_model: bool = False,
|
auto_update: bool = False,
|
||||||
):
|
):
|
||||||
if reranking_model:
|
if reranking_model:
|
||||||
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
|
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
|
||||||
|
|
||||||
class Colbert:
|
class ColBERT:
|
||||||
def __init__(self, name) -> None:
|
def __init__(self, name) -> None:
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
|
self.ckpt = Checkpoint(
|
||||||
self.device
|
get_model_path(name, auto_update),
|
||||||
)
|
colbert_config=ColBERTConfig(),
|
||||||
|
).to(self.device)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def calculate_similarity_scores(
|
def calculate_similarity_scores(
|
||||||
@ -264,13 +265,13 @@ def update_reranking_model(
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
app.state.sentence_transformer_rf = Colbert(reranking_model)
|
app.state.sentence_transformer_rf = ColBERT(reranking_model)
|
||||||
else:
|
else:
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
||||||
get_model_path(reranking_model, update_model),
|
get_model_path(reranking_model, auto_update),
|
||||||
device=DEVICE_TYPE,
|
device=DEVICE_TYPE,
|
||||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user