diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index ac121672e..f1b1c14a5 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1330,16 +1330,16 @@ RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( ), ) -RAG_EMBEDDING_PASSAGE_PREFIX = PersistentConfig( - "RAG_EMBEDDING_PASSAGE_PREFIX", - "rag.embedding_passage_prefix", - os.environ.get("RAG_EMBEDDING_PASSAGE_PREFIX", False), +RAG_EMBEDDING_QUERY_PREFIX = ( + os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", None) ) -RAG_EMBEDDING_QUERY_PREFIX = PersistentConfig( - "RAG_EMBEDDING_QUERY_PREFIX", - "rag.embedding_query_prefix", - os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", False), +RAG_EMBEDDING_PASSAGE_PREFIX = ( + os.environ.get("RAG_EMBEDDING_PASSAGE_PREFIX", None) +) + +RAG_EMBEDDING_PREFIX_FIELD_NAME = ( + os.environ.get("RAG_EMBEDDING_PREFIX_FIELD_NAME", "input_type") ) RAG_RERANKING_MODEL = PersistentConfig( diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index e420814d8..544a65a89 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -15,7 +15,11 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE -from open_webui.config import RAG_EMBEDDING_QUERY_PREFIX, RAG_EMBEDDING_PASSAGE_PREFIX +from open_webui.config import ( + RAG_EMBEDDING_QUERY_PREFIX, + RAG_EMBEDDING_PASSAGE_PREFIX, + RAG_EMBEDDING_PREFIX_FIELD_NAME +) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -265,7 +269,7 @@ def get_embedding_function( embeddings.extend(func(query[i : i + embedding_batch_size], prefix)) return embeddings else: - return func(query) + return func(query, prefix) return lambda query, prefix: generate_multiple(query, prefix, func) @@ -421,7 +425,7 @@ def generate_openai_batch_embeddings( "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix}, + json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, RAG_EMBEDDING_PREFIX_FIELD_NAME: prefix}, ) r.raise_for_status() data = r.json() @@ -444,7 +448,7 @@ def generate_ollama_batch_embeddings( "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix}, + json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, RAG_EMBEDDING_PREFIX_FIELD_NAME: prefix}, ) r.raise_for_status() data = r.json() diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index b0c3f8e04..255cff112 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -70,7 +70,6 @@ from open_webui.utils.misc import ( ) from open_webui.utils.auth import get_admin_user, get_verified_user - from open_webui.config import ( ENV, RAG_EMBEDDING_MODEL_AUTO_UPDATE, @@ -79,7 +78,8 @@ from open_webui.config import ( RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, UPLOAD_DIR, DEFAULT_LOCALE, - RAG_EMBEDDING_PASSAGE_PREFIX + RAG_EMBEDDING_PASSAGE_PREFIX, + RAG_EMBEDDING_QUERY_PREFIX ) from open_webui.env import ( SRC_LOG_LEVELS, @@ -1319,7 +1319,7 @@ def query_doc_handler( else: return query_doc( collection_name=form_data.collection_name, - query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query), + query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query, RAG_EMBEDDING_QUERY_PREFIX), k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: @@ -1438,7 +1438,7 @@ if ENV == "dev": @router.get("/ef/{text}") async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): - return {"result": request.app.state.EMBEDDING_FUNCTION(text)} + return {"result": request.app.state.EMBEDDING_FUNCTION(text, RAG_EMBEDDING_QUERY_PREFIX)} class BatchProcessFilesForm(BaseModel):