mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	fix: address comment in pr #1687
This commit is contained in:
		
							parent
							
								
									d5f60b119c
								
							
						
					
					
						commit
						c9c9660459
					
				@ -92,10 +92,6 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)):
 | 
			
		||||
    return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_ollama_endpoint(url_idx: int = 0):
 | 
			
		||||
    return app.state.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UrlUpdateForm(BaseModel):
 | 
			
		||||
    urls: List[str]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -39,8 +39,6 @@ import json
 | 
			
		||||
 | 
			
		||||
import sentence_transformers
 | 
			
		||||
 | 
			
		||||
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
 | 
			
		||||
 | 
			
		||||
from apps.web.models.documents import (
 | 
			
		||||
    Documents,
 | 
			
		||||
    DocumentForm,
 | 
			
		||||
@ -48,6 +46,7 @@ from apps.web.models.documents import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from apps.rag.utils import (
 | 
			
		||||
    get_model_path,
 | 
			
		||||
    query_embeddings_doc,
 | 
			
		||||
    query_embeddings_function,
 | 
			
		||||
    query_embeddings_collection,
 | 
			
		||||
@ -60,6 +59,7 @@ from utils.misc import (
 | 
			
		||||
    extract_folders_after_data_docs,
 | 
			
		||||
)
 | 
			
		||||
from utils.utils import get_current_user, get_admin_user
 | 
			
		||||
 | 
			
		||||
from config import (
 | 
			
		||||
    SRC_LOG_LEVELS,
 | 
			
		||||
    UPLOAD_DIR,
 | 
			
		||||
@ -68,8 +68,10 @@ from config import (
 | 
			
		||||
    RAG_RELEVANCE_THRESHOLD,
 | 
			
		||||
    RAG_EMBEDDING_ENGINE,
 | 
			
		||||
    RAG_EMBEDDING_MODEL,
 | 
			
		||||
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
 | 
			
		||||
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
    RAG_RERANKING_MODEL,
 | 
			
		||||
    RAG_RERANKING_MODEL_AUTO_UPDATE,
 | 
			
		||||
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
    RAG_OPENAI_API_BASE_URL,
 | 
			
		||||
    RAG_OPENAI_API_KEY,
 | 
			
		||||
@ -87,13 +89,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.TOP_K = RAG_TOP_K
 | 
			
		||||
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
 | 
			
		||||
app.state.CHUNK_SIZE = CHUNK_SIZE
 | 
			
		||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 | 
			
		||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 | 
			
		||||
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 | 
			
		||||
@ -104,27 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
 | 
			
		||||
 | 
			
		||||
app.state.PDF_EXTRACT_IMAGES = False
 | 
			
		||||
 | 
			
		||||
if app.state.RAG_EMBEDDING_ENGINE == "":
 | 
			
		||||
    app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
 | 
			
		||||
        app.state.RAG_EMBEDDING_MODEL,
 | 
			
		||||
        device=DEVICE_TYPE,
 | 
			
		||||
        trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
    )
 | 
			
		||||
else:
 | 
			
		||||
    app.state.sentence_transformer_ef = None
 | 
			
		||||
 | 
			
		||||
if not app.state.RAG_RERANKING_MODEL == "":
 | 
			
		||||
    app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
 | 
			
		||||
        app.state.RAG_RERANKING_MODEL,
 | 
			
		||||
        device=DEVICE_TYPE,
 | 
			
		||||
        trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
    )
 | 
			
		||||
else:
 | 
			
		||||
    app.state.sentence_transformer_rf = None
 | 
			
		||||
def update_embedding_model(
 | 
			
		||||
    embedding_model: str,
 | 
			
		||||
    update_model: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
 | 
			
		||||
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
 | 
			
		||||
            get_model_path(embedding_model, update_model),
 | 
			
		||||
            device=DEVICE_TYPE,
 | 
			
		||||
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        app.state.sentence_transformer_ef = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_reranking_model(
 | 
			
		||||
    reranking_model: str,
 | 
			
		||||
    update_model: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    if reranking_model:
 | 
			
		||||
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
 | 
			
		||||
            get_model_path(reranking_model, update_model),
 | 
			
		||||
            device=DEVICE_TYPE,
 | 
			
		||||
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        app.state.sentence_transformer_rf = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
update_embedding_model(
 | 
			
		||||
    app.state.RAG_EMBEDDING_MODEL,
 | 
			
		||||
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
update_reranking_model(
 | 
			
		||||
    app.state.RAG_RERANKING_MODEL,
 | 
			
		||||
    RAG_RERANKING_MODEL_AUTO_UPDATE,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
origins = ["*"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.add_middleware(
 | 
			
		||||
    CORSMiddleware,
 | 
			
		||||
    allow_origins=origins,
 | 
			
		||||
@ -200,15 +221,7 @@ async def update_embedding_config(
 | 
			
		||||
                app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
 | 
			
		||||
                app.state.OPENAI_API_KEY = form_data.openai_config.key
 | 
			
		||||
 | 
			
		||||
            app.state.sentence_transformer_ef = None
 | 
			
		||||
        else:
 | 
			
		||||
            app.state.sentence_transformer_ef = (
 | 
			
		||||
                sentence_transformers.SentenceTransformer(
 | 
			
		||||
                    app.state.RAG_EMBEDDING_MODEL,
 | 
			
		||||
                    device=DEVICE_TYPE,
 | 
			
		||||
                    trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "status": True,
 | 
			
		||||
@ -219,7 +232,6 @@ async def update_embedding_config(
 | 
			
		||||
                "key": app.state.OPENAI_API_KEY,
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(f"Problem updating embedding model: {e}")
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -242,13 +254,7 @@ async def update_reranking_config(
 | 
			
		||||
    try:
 | 
			
		||||
        app.state.RAG_RERANKING_MODEL = form_data.reranking_model
 | 
			
		||||
 | 
			
		||||
        if app.state.RAG_RERANKING_MODEL == "":
 | 
			
		||||
            app.state.sentence_transformer_rf = None
 | 
			
		||||
        else:
 | 
			
		||||
            app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
 | 
			
		||||
                app.state.RAG_RERANKING_MODEL,
 | 
			
		||||
                device=DEVICE_TYPE,
 | 
			
		||||
            )
 | 
			
		||||
        update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "status": True,
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
import os
 | 
			
		||||
import logging
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
@ -8,6 +9,8 @@ from apps.ollama.main import (
 | 
			
		||||
    GenerateEmbeddingsForm,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
 | 
			
		||||
from langchain_core.documents import Document
 | 
			
		||||
from langchain_community.retrievers import BM25Retriever
 | 
			
		||||
from langchain.retrievers import (
 | 
			
		||||
@ -282,8 +285,6 @@ def rag_messages(
 | 
			
		||||
 | 
			
		||||
        extracted_collections.extend(collection)
 | 
			
		||||
 | 
			
		||||
    log.debug(f"relevant_contexts: {relevant_contexts}")
 | 
			
		||||
 | 
			
		||||
    context_string = ""
 | 
			
		||||
    for context in relevant_contexts:
 | 
			
		||||
        items = context["documents"][0]
 | 
			
		||||
@ -319,6 +320,44 @@ def rag_messages(
 | 
			
		||||
    return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_model_path(model: str, update_model: bool = False):
 | 
			
		||||
    # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
 | 
			
		||||
    cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
 | 
			
		||||
 | 
			
		||||
    local_files_only = not update_model
 | 
			
		||||
 | 
			
		||||
    snapshot_kwargs = {
 | 
			
		||||
        "cache_dir": cache_dir,
 | 
			
		||||
        "local_files_only": local_files_only,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    log.debug(f"embedding_model: {model}")
 | 
			
		||||
    log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
 | 
			
		||||
 | 
			
		||||
    # Inspiration from upstream sentence_transformers
 | 
			
		||||
    if (
 | 
			
		||||
        os.path.exists(model)
 | 
			
		||||
        or ("\\" in model or model.count("/") > 1)
 | 
			
		||||
        and local_files_only
 | 
			
		||||
    ):
 | 
			
		||||
        # If fully qualified path exists, return input, else set repo_id
 | 
			
		||||
        return model
 | 
			
		||||
    elif "/" not in model:
 | 
			
		||||
        # Set valid repo_id for model short-name
 | 
			
		||||
        model = "sentence-transformers" + "/" + model
 | 
			
		||||
 | 
			
		||||
    snapshot_kwargs["repo_id"] = model
 | 
			
		||||
 | 
			
		||||
    # Attempt to query the huggingface_hub library to determine the local path and/or to update
 | 
			
		||||
    try:
 | 
			
		||||
        model_repo_path = snapshot_download(**snapshot_kwargs)
 | 
			
		||||
        log.debug(f"model_repo_path: {model_repo_path}")
 | 
			
		||||
        return model_repo_path
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(f"Cannot determine model snapshot path: {e}")
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_openai_embeddings(
 | 
			
		||||
    model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
@ -430,6 +430,10 @@ RAG_EMBEDDING_MODEL = os.environ.get(
 | 
			
		||||
)
 | 
			
		||||
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
 | 
			
		||||
 | 
			
		||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
 | 
			
		||||
    os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
 | 
			
		||||
    os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 | 
			
		||||
)
 | 
			
		||||
@ -438,6 +442,10 @@ RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
 | 
			
		||||
if not RAG_RERANKING_MODEL == "":
 | 
			
		||||
    log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
 | 
			
		||||
 | 
			
		||||
RAG_RERANKING_MODEL_AUTO_UPDATE = (
 | 
			
		||||
    os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
 | 
			
		||||
    os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user