diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index a48b2db05..ac121672e 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1330,6 +1330,18 @@ RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( ), ) +RAG_EMBEDDING_PASSAGE_PREFIX = PersistentConfig( + "RAG_EMBEDDING_PASSAGE_PREFIX", + "rag.embedding_passage_prefix", + os.environ.get("RAG_EMBEDDING_PASSAGE_PREFIX", False), +) + +RAG_EMBEDDING_QUERY_PREFIX = PersistentConfig( + "RAG_EMBEDDING_QUERY_PREFIX", + "rag.embedding_query_prefix", + os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", False), +) + RAG_RERANKING_MODEL = PersistentConfig( "RAG_RERANKING_MODEL", "rag.reranking_model", diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index c95367e6c..e420814d8 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -15,7 +15,7 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE - +from open_webui.config import RAG_EMBEDDING_QUERY_PREFIX, RAG_EMBEDDING_PASSAGE_PREFIX log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -39,7 +39,7 @@ class VectorSearchRetriever(BaseRetriever): ) -> list[Document]: result = VECTOR_DB_CLIENT.search( collection_name=self.collection_name, - vectors=[self.embedding_function(query)], + vectors=[self.embedding_function(query,RAG_EMBEDDING_QUERY_PREFIX)], limit=self.top_k, ) @@ -183,7 +183,7 @@ def query_collection( ) -> dict: results = [] for query in queries: - query_embedding = embedding_function(query) + query_embedding = embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) for collection_name in collection_names: if collection_name: try: @@ -247,26 +247,27 @@ def get_embedding_function( embedding_batch_size, ): if embedding_engine == "": - return lambda query: embedding_function.encode(query).tolist() + return lambda query, prefix: embedding_function.encode(query, prompt = prefix if prefix else None).tolist() elif embedding_engine in ["ollama", "openai"]: - func = lambda query: generate_embeddings( + func = lambda query, prefix: generate_embeddings( engine=embedding_engine, model=embedding_model, text=query, + prefix=prefix, url=url, key=key, ) - def generate_multiple(query, func): + def generate_multiple(query, prefix, func): if isinstance(query, list): embeddings = [] for i in range(0, len(query), embedding_batch_size): - embeddings.extend(func(query[i : i + embedding_batch_size])) + embeddings.extend(func(query[i : i + embedding_batch_size], prefix)) return embeddings else: return func(query) - return lambda query: generate_multiple(query, func) + return lambda query, prefix: generate_multiple(query, prefix, func) def get_sources_from_files( @@ -411,7 +412,7 @@ def get_model_path(model: str, update_model: bool = False): def generate_openai_batch_embeddings( - model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "" + model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "", prefix: str = None ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -420,7 +421,7 @@ def generate_openai_batch_embeddings( "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model}, + json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix}, ) r.raise_for_status() data = r.json() @@ -434,7 +435,7 @@ def generate_openai_batch_embeddings( def generate_ollama_batch_embeddings( - model: str, texts: list[str], url: str, key: str = "" + model: str, texts: list[str], url: str, key: str = "", prefix: str = None ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -443,7 +444,7 @@ def generate_ollama_batch_embeddings( "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model}, + json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix}, ) r.raise_for_status() data = r.json() @@ -457,25 +458,25 @@ def generate_ollama_batch_embeddings( return None -def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): +def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], prefix: Union[str , None] = None, **kwargs): url = kwargs.get("url", "") key = kwargs.get("key", "") if engine == "ollama": if isinstance(text, list): embeddings = generate_ollama_batch_embeddings( - **{"model": model, "texts": text, "url": url, "key": key} + **{"model": model, "texts": text, "url": url, "key": key, "prefix": prefix} ) else: embeddings = generate_ollama_batch_embeddings( - **{"model": model, "texts": [text], "url": url, "key": key} + **{"model": model, "texts": [text], "url": url, "key": key, "prefix": prefix} ) return embeddings[0] if isinstance(text, str) else embeddings elif engine == "openai": if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, url, key) + embeddings = generate_openai_batch_embeddings(model, text, url, key, prefix) else: - embeddings = generate_openai_batch_embeddings(model, [text], url, key) + embeddings = generate_openai_batch_embeddings(model, [text], url, key, prefix) return embeddings[0] if isinstance(text, str) else embeddings @@ -512,9 +513,10 @@ class RerankCompressor(BaseDocumentCompressor): else: from sentence_transformers import util - query_embedding = self.embedding_function(query) + query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) document_embedding = self.embedding_function( - [doc.page_content for doc in documents] + [doc.page_content for doc in documents], + RAG_EMBEDDING_PASSAGE_PREFIX ) scores = util.cos_sim(query_embedding, document_embedding)[0] diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index c791bde84..b0c3f8e04 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -79,6 +79,7 @@ from open_webui.config import ( RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, UPLOAD_DIR, DEFAULT_LOCALE, + RAG_EMBEDDING_PASSAGE_PREFIX ) from open_webui.env import ( SRC_LOG_LEVELS, @@ -775,7 +776,7 @@ def save_docs_to_vector_db( ) embeddings = embedding_function( - list(map(lambda x: x.replace("\n", " "), texts)) + list(map(lambda x: x.replace("\n", " "), texts)), RAG_EMBEDDING_PASSAGE_PREFIX ) items = [