diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 44ce0db86..790f8b9ec 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -15,8 +15,9 @@ from langchain_core.documents import Document from open_webui.config import VECTOR_DB from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message +from open_webui.models.users import UserModel -from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE +from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE, ENABLE_FORWARD_USER_INFO_HEADERS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -64,6 +65,7 @@ def query_doc( collection_name: str, query_embedding: list[float], k: int, + user: UserModel=None ): try: result = VECTOR_DB_CLIENT.search( @@ -256,29 +258,32 @@ def get_embedding_function( embedding_function, url, key, - embedding_batch_size, + embedding_batch_size ): if embedding_engine == "": - return lambda query: embedding_function.encode(query).tolist() + return lambda query, user=None: embedding_function.encode(query).tolist() elif embedding_engine in ["ollama", "openai"]: - func = lambda query: generate_embeddings( + func = lambda query, user=None: generate_embeddings( engine=embedding_engine, model=embedding_model, text=query, url=url, key=key, + user=user ) - def generate_multiple(query, func): + def generate_multiple(query, user, func): if isinstance(query, list): embeddings = [] for i in range(0, len(query), embedding_batch_size): - embeddings.extend(func(query[i : i + embedding_batch_size])) + embeddings.extend(func(query[i : i + embedding_batch_size], user=user)) return embeddings else: - return func(query) + return func(query, user) - return lambda query: generate_multiple(query, func) + return lambda query, user=None: generate_multiple(query, user, func) + else: + raise ValueError(f"Unknown embedding engine: {embedding_engine}") def get_sources_from_files( @@ -423,7 +428,7 @@ def get_model_path(model: str, update_model: bool = False): def generate_openai_batch_embeddings( - model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "" + model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "", user: UserModel = None ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -431,6 +436,16 @@ def generate_openai_batch_embeddings( headers={ "Content-Type": "application/json", "Authorization": f"Bearer {key}", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, json={"input": texts, "model": model}, ) @@ -446,7 +461,7 @@ def generate_openai_batch_embeddings( def generate_ollama_batch_embeddings( - model: str, texts: list[str], url: str, key: str = "" + model: str, texts: list[str], url: str, key: str = "", user: UserModel = None ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -454,6 +469,16 @@ def generate_ollama_batch_embeddings( headers={ "Content-Type": "application/json", "Authorization": f"Bearer {key}", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), }, json={"input": texts, "model": model}, ) @@ -472,22 +497,23 @@ def generate_ollama_batch_embeddings( def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): url = kwargs.get("url", "") key = kwargs.get("key", "") + user = kwargs.get("user") if engine == "ollama": if isinstance(text, list): embeddings = generate_ollama_batch_embeddings( - **{"model": model, "texts": text, "url": url, "key": key} + **{"model": model, "texts": text, "url": url, "key": key, "user": user} ) else: embeddings = generate_ollama_batch_embeddings( - **{"model": model, "texts": [text], "url": url, "key": key} + **{"model": model, "texts": [text], "url": url, "key": key, "user": user} ) return embeddings[0] if isinstance(text, str) else embeddings elif engine == "openai": if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, url, key) + embeddings = generate_openai_batch_embeddings(model, text, url, key, user) else: - embeddings = generate_openai_batch_embeddings(model, [text], url, key) + embeddings = generate_openai_batch_embeddings(model, [text], url, key, user) return embeddings[0] if isinstance(text, str) else embeddings diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index b648fccc2..1ab76fac1 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -71,7 +71,7 @@ def upload_file( ) try: - process_file(request, ProcessFileForm(file_id=id)) + process_file(request, ProcessFileForm(file_id=id), user=user) file_item = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) @@ -193,7 +193,9 @@ async def update_file_data_content_by_id( if file and (file.user_id == user.id or user.role == "admin"): try: process_file( - request, ProcessFileForm(file_id=id, content=form_data.content) + request, + ProcessFileForm(file_id=id, content=form_data.content), + user=user ) file = Files.get_file_by_id(id=id) except Exception as e: diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index aac16e851..41061f6e3 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -289,7 +289,9 @@ def add_file_to_knowledge_by_id( # Add content to the vector database try: process_file( - request, ProcessFileForm(file_id=form_data.file_id, collection_name=id) + request, + ProcessFileForm(file_id=form_data.file_id, collection_name=id), + user=user ) except Exception as e: log.debug(e) @@ -372,7 +374,9 @@ def update_file_from_knowledge_by_id( # Add content to the vector database try: process_file( - request, ProcessFileForm(file_id=form_data.file_id, collection_name=id) + request, + ProcessFileForm(file_id=form_data.file_id, collection_name=id), + user=user ) except Exception as e: raise HTTPException( diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index e72cf1445..8ffc67cc0 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -57,7 +57,7 @@ async def add_memory( { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user), "metadata": {"created_at": memory.created_at}, } ], @@ -82,7 +82,7 @@ async def query_memory( ): results = VECTOR_DB_CLIENT.search( collection_name=f"user-memory-{user.id}", - vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)], + vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)], limit=form_data.k, ) @@ -105,7 +105,7 @@ async def reset_memory_from_vector_db( { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user), "metadata": { "created_at": memory.created_at, "updated_at": memory.updated_at, @@ -160,7 +160,7 @@ async def update_memory_by_id( { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION(memory.content), + "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user), "metadata": { "created_at": memory.created_at, "updated_at": memory.updated_at, diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 35cea6237..0b87a4adb 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -666,6 +666,7 @@ def save_docs_to_vector_db( overwrite: bool = False, split: bool = True, add: bool = False, + user = None, ) -> bool: def _get_docs_info(docs: list[Document]) -> str: docs_info = set() @@ -781,7 +782,8 @@ def save_docs_to_vector_db( ) embeddings = embedding_function( - list(map(lambda x: x.replace("\n", " "), texts)) + list(map(lambda x: x.replace("\n", " "), texts)), + user = user ) items = [ @@ -939,6 +941,7 @@ def process_file( "hash": hash, }, add=(True if form_data.collection_name else False), + user=user ) if result: @@ -996,7 +999,7 @@ def process_text( text_content = form_data.content log.debug(f"text_content: {text_content}") - result = save_docs_to_vector_db(request, docs, collection_name) + result = save_docs_to_vector_db(request, docs, collection_name, user=user) if result: return { "status": True, @@ -1029,7 +1032,7 @@ def process_youtube_video( content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {content}") - save_docs_to_vector_db(request, docs, collection_name, overwrite=True) + save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user) return { "status": True, @@ -1070,7 +1073,7 @@ def process_web( content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {content}") - save_docs_to_vector_db(request, docs, collection_name, overwrite=True) + save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user) return { "status": True, @@ -1286,7 +1289,7 @@ def process_web_search( requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) docs = loader.load() - save_docs_to_vector_db(request, docs, collection_name, overwrite=True) + save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user) return { "status": True, @@ -1320,7 +1323,7 @@ def query_doc_handler( return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, - embedding_function=request.app.state.EMBEDDING_FUNCTION, + embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query, user=user), k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=request.app.state.rf, r=( @@ -1328,12 +1331,14 @@ def query_doc_handler( if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD ), + user=user ) else: return query_doc( collection_name=form_data.collection_name, - query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query), + query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query, user=user), k=form_data.k if form_data.k else request.app.state.config.TOP_K, + user=user ) except Exception as e: log.exception(e) @@ -1362,7 +1367,7 @@ def query_collection_handler( return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=request.app.state.EMBEDDING_FUNCTION, + embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query, user=user), k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=request.app.state.rf, r=( @@ -1375,7 +1380,7 @@ def query_collection_handler( return query_collection( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=request.app.state.EMBEDDING_FUNCTION, + embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query,user=user), k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) @@ -1523,6 +1528,7 @@ def process_files_batch( docs=all_docs, collection_name=collection_name, add=True, + user=user, ) # Update all files with collection name diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 0dedbfa4b..e356fd118 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -634,7 +634,7 @@ async def chat_completion_files_handler( lambda: get_sources_from_files( files=files, queries=queries, - embedding_function=request.app.state.EMBEDDING_FUNCTION, + embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query,user=user), k=request.app.state.config.TOP_K, reranking_function=request.app.state.rf, r=request.app.state.config.RELEVANCE_THRESHOLD,