From a0aee4ff28b964ff839d884695c03608f0d6a881 Mon Sep 17 00:00:00 2001 From: shamil Date: Mon, 30 Dec 2024 13:45:20 +0300 Subject: [PATCH 1/7] feat: Small optimization --- backend/open_webui/retrieval/utils.py | 172 ++++++++++++-------------- 1 file changed, 79 insertions(+), 93 deletions(-) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 17f1438da..1d94e58fe 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -1,9 +1,8 @@ import logging import os -import uuid +import heapq from typing import Optional, Union -import asyncio import requests from huggingface_hub import snapshot_download @@ -34,8 +33,6 @@ class VectorSearchRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, - *, - run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: result = VECTOR_DB_CLIENT.search( collection_name=self.collection_name, @@ -47,15 +44,12 @@ class VectorSearchRetriever(BaseRetriever): metadatas = result.metadatas[0] documents = result.documents[0] - results = [] - for idx in range(len(ids)): - results.append( - Document( - metadata=metadatas[idx], - page_content=documents[idx], - ) - ) - return results + return [ + Document( + metadata=metadatas[idx], + page_content=documents[idx], + ) for idx in range(len(ids)) + ] def query_doc( @@ -64,16 +58,14 @@ def query_doc( k: int, ): try: - result = VECTOR_DB_CLIENT.search( + if result := VECTOR_DB_CLIENT.search( collection_name=collection_name, vectors=[query_embedding], limit=k, - ) - - if result: + ): log.info(f"query_doc:result {result.ids} {result.metadatas}") - return result + return result except Exception as e: print(e) raise e @@ -135,44 +127,38 @@ def query_doc_with_hybrid_search( def merge_and_sort_query_results( query_results: list[dict], k: int, reverse: bool = False ) -> list[dict]: - # Initialize lists to store combined data - combined_distances = [] - combined_documents = [] - combined_metadatas = [] - - for data in query_results: - combined_distances.extend(data["distances"][0]) - combined_documents.extend(data["documents"][0]) - combined_metadatas.extend(data["metadatas"][0]) - - # Create a list of tuples (distance, document, metadata) - combined = list(zip(combined_distances, combined_documents, combined_metadatas)) - - # Sort the list based on distances - combined.sort(key=lambda x: x[0], reverse=reverse) - - # We don't have anything :-( - if not combined: - sorted_distances = [] - sorted_documents = [] - sorted_metadatas = [] + if not query_results: + return { + "distances": [[]], + "documents": [[]], + "metadatas": [[]], + } + + combined = ( + (data.get("distances", [float('inf')])[0], + data.get("documents", [None])[0], + data.get("metadatas", [{}])[0]) + for data in query_results + ) + + if reverse: + top_k = heapq.nlargest(k, combined, key=lambda x: x[0]) else: - # Unzip the sorted list - sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) - - # Slicing the lists to include only k elements - sorted_distances = list(sorted_distances)[:k] - sorted_documents = list(sorted_documents)[:k] - sorted_metadatas = list(sorted_metadatas)[:k] - - # Create the output dictionary - result = { - "distances": [sorted_distances], - "documents": [sorted_documents], - "metadatas": [sorted_metadatas], - } - - return result + top_k = heapq.nsmallest(k, combined, key=lambda x: x[0]) + + if not top_k: + return { + "distances": [[]], + "documents": [[]], + "metadatas": [[]], + } + else: + sorted_distances, sorted_documents, sorted_metadatas = zip(*top_k) + return { + "distances": [sorted_distances], + "documents": [sorted_documents], + "metadatas": [sorted_metadatas], + } def query_collection( @@ -185,19 +171,18 @@ def query_collection( for query in queries: query_embedding = embedding_function(query) for collection_name in collection_names: - if collection_name: - try: - result = query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ) - if result is not None: - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") - else: - pass + if not collection_name: + continue + + try: + if result := query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ): + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") return merge_and_sort_query_results(results, k=k) @@ -213,8 +198,8 @@ def query_collection_with_hybrid_search( results = [] error = False for collection_name in collection_names: - try: - for query in queries: + for query in queries: + try: result = query_doc_with_hybrid_search( collection_name=collection_name, query=query, @@ -224,11 +209,11 @@ def query_collection_with_hybrid_search( r=r, ) results.append(result) - except Exception as e: - log.exception( - "Error when querying the collection with " f"hybrid_search: {e}" - ) - error = True + except Exception as e: + log.exception( + "Error when querying the collection with " f"hybrid_search: {e}" + ) + error = True if error: raise Exception( @@ -259,10 +244,10 @@ def get_embedding_function( def generate_multiple(query, func): if isinstance(query, list): - embeddings = [] - for i in range(0, len(query), embedding_batch_size): - embeddings.extend(func(query[i : i + embedding_batch_size])) - return embeddings + return [ + func(query[i : i + embedding_batch_size]) + for i in range(0, len(query), embedding_batch_size) + ] else: return func(query) @@ -433,25 +418,26 @@ def generate_openai_batch_embeddings( def generate_ollama_batch_embeddings( model: str, texts: list[str], url: str, key: str = "" ) -> Optional[list[list[float]]]: + r = requests.post( + f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + }, + json={"input": texts, "model": model}, + ) try: - r = requests.post( - f"{url}/api/embed", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", - }, - json={"input": texts, "model": model}, - ) r.raise_for_status() - data = r.json() - - if "embeddings" in data: - return data["embeddings"] - else: - raise "Something went wrong :/" except Exception as e: print(e) return None + + data = r.json() + + if 'embeddings' not in data: + raise "Something went wrong :/" + + return data['embeddings'] def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): From ba2964cb015fcf023674816e1faa0ae6d1f77955 Mon Sep 17 00:00:00 2001 From: Gabriel Ecegi Date: Mon, 30 Dec 2024 17:36:34 +0100 Subject: [PATCH 2/7] fix: missing parameter --- backend/open_webui/routers/knowledge.py | 5 ++++- backend/open_webui/routers/retrieval.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 87e8599ed..12cb2acf3 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -520,6 +520,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse]) def add_files_to_knowledge_batch( + request: Request, id: str, form_data: list[KnowledgeFileIdForm], user=Depends(get_verified_user), @@ -555,7 +556,9 @@ def add_files_to_knowledge_batch( # Process files try: result = process_files_batch( - BatchProcessFilesForm(files=files, collection_name=id) + request=request, + form_data=BatchProcessFilesForm(files=files, collection_name=id), + user=user ) except Exception as e: log.error( diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index d6ff463a9..c791bde84 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1458,6 +1458,7 @@ class BatchProcessFilesResponse(BaseModel): @router.post("/process/files/batch") def process_files_batch( + request: Request, form_data: BatchProcessFilesForm, user=Depends(get_verified_user), ) -> BatchProcessFilesResponse: @@ -1504,7 +1505,10 @@ def process_files_batch( if all_docs: try: save_docs_to_vector_db( - docs=all_docs, collection_name=collection_name, add=True + request=request, + docs=all_docs, + collection_name=collection_name, + add=True, ) # Update all files with collection name From 46e57706c18ad5d69a5ae695db5f1a72d03a9a45 Mon Sep 17 00:00:00 2001 From: Gabriel Ecegi Date: Mon, 30 Dec 2024 17:45:43 +0100 Subject: [PATCH 3/7] refac: formatting --- backend/open_webui/routers/knowledge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 12cb2acf3..ad67cc31f 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -558,7 +558,7 @@ def add_files_to_knowledge_batch( result = process_files_batch( request=request, form_data=BatchProcessFilesForm(files=files, collection_name=id), - user=user + user=user, ) except Exception as e: log.error( From 79ce6e0a3f17dae644b591416ba617f531c892be Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 30 Dec 2024 11:29:18 -0800 Subject: [PATCH 4/7] refac --- backend/open_webui/utils/middleware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 4741623c4..baaacb2d4 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -750,7 +750,7 @@ async def process_chat_response( ): async def background_tasks_handler(): message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) - message = message_map.get(metadata["message_id"]) + message = message_map.get(metadata["message_id"]) if message_map else None if message: messages = get_message_list(message_map, message.get("id")) From 947f5600d656735f338dcc61e1f2a91565b8fc34 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 30 Dec 2024 15:39:35 -0800 Subject: [PATCH 5/7] refac --- backend/open_webui/utils/middleware.py | 79 ++++++++++++++------------ 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index baaacb2d4..0e32bf626 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -23,7 +23,7 @@ from open_webui.models.users import Users from open_webui.socket.main import ( get_event_call, get_event_emitter, - get_user_id_from_session_pool, + get_active_status_by_user_id, ) from open_webui.routers.tasks import ( generate_queries, @@ -896,7 +896,7 @@ async def process_chat_response( ) # Send a webhook notification if the user is not active - if get_user_id_from_session_pool(metadata["session_id"]) is None: + if get_active_status_by_user_id(user.id) is None: webhook_url = Users.get_user_webhook_url_by_id(user.id) if webhook_url: post_webhook( @@ -1002,51 +1002,56 @@ async def process_chat_response( "content": content, } + await event_emitter( + { + "type": "chat:completion", + "data": data, + } + ) + except Exception as e: done = "data: [DONE]" in line - title = Chats.get_chat_title_by_id(metadata["chat_id"]) if done: - data = {"done": True, "content": content, "title": title} - - if not ENABLE_REALTIME_CHAT_SAVE: - # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "content": content, - }, - ) - - # Send a webhook notification if the user is not active - if ( - get_user_id_from_session_pool(metadata["session_id"]) - is None - ): - webhook_url = Users.get_user_webhook_url_by_id(user.id) - if webhook_url: - post_webhook( - webhook_url, - f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", - { - "action": "chat", - "message": content, - "title": title, - "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", - }, - ) - + pass else: continue - await event_emitter( + title = Chats.get_chat_title_by_id(metadata["chat_id"]) + data = {"done": True, "content": content, "title": title} + + if not ENABLE_REALTIME_CHAT_SAVE: + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], { - "type": "chat:completion", - "data": data, - } + "content": content, + }, ) + # Send a webhook notification if the user is not active + if get_active_status_by_user_id(user.id) is None: + webhook_url = Users.get_user_webhook_url_by_id(user.id) + if webhook_url: + post_webhook( + webhook_url, + f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", + { + "action": "chat", + "message": content, + "title": title, + "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", + }, + ) + + await event_emitter( + { + "type": "chat:completion", + "data": data, + } + ) + await background_tasks_handler() except asyncio.CancelledError: print("Task was cancelled!") From 46bcf98ef27f5d532c5c39fdb9d4f2dbb35fceb9 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 30 Dec 2024 15:52:07 -0800 Subject: [PATCH 6/7] fix: usage stats --- backend/open_webui/utils/response.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index d429db8aa..d6f7b0ac6 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -29,7 +29,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) ( ( data.get("eval_count", 0) - / ((data.get("eval_duration", 0) / 1_000_000)) + / ((data.get("eval_duration", 0) / 10_000_000)) ) * 100 ), @@ -43,7 +43,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) ( ( data.get("prompt_eval_count", 0) - / ((data.get("prompt_eval_duration", 0) / 1_000_000)) + / ((data.get("prompt_eval_duration", 0) / 10_000_000)) ) * 100 ), From fd0170c179ae01dc36056efdca1f46e885d286a4 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 30 Dec 2024 16:55:29 -0800 Subject: [PATCH 7/7] revert --- backend/open_webui/retrieval/utils.py | 172 ++++++++++++++------------ 1 file changed, 93 insertions(+), 79 deletions(-) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 7e8771bd6..c95367e6c 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -1,8 +1,9 @@ import logging import os -import heapq +import uuid from typing import Optional, Union +import asyncio import requests from huggingface_hub import snapshot_download @@ -33,6 +34,8 @@ class VectorSearchRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, + *, + run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: result = VECTOR_DB_CLIENT.search( collection_name=self.collection_name, @@ -44,12 +47,15 @@ class VectorSearchRetriever(BaseRetriever): metadatas = result.metadatas[0] documents = result.documents[0] - return [ - Document( - metadata=metadatas[idx], - page_content=documents[idx], - ) for idx in range(len(ids)) - ] + results = [] + for idx in range(len(ids)): + results.append( + Document( + metadata=metadatas[idx], + page_content=documents[idx], + ) + ) + return results def query_doc( @@ -58,14 +64,16 @@ def query_doc( k: int, ): try: - if result := VECTOR_DB_CLIENT.search( + result = VECTOR_DB_CLIENT.search( collection_name=collection_name, vectors=[query_embedding], limit=k, - ): + ) + + if result: log.info(f"query_doc:result {result.ids} {result.metadatas}") - return result + return result except Exception as e: print(e) raise e @@ -127,38 +135,44 @@ def query_doc_with_hybrid_search( def merge_and_sort_query_results( query_results: list[dict], k: int, reverse: bool = False ) -> list[dict]: - if not query_results: - return { - "distances": [[]], - "documents": [[]], - "metadatas": [[]], - } - - combined = ( - (data.get("distances", [float('inf')])[0], - data.get("documents", [None])[0], - data.get("metadatas", [{}])[0]) - for data in query_results - ) - - if reverse: - top_k = heapq.nlargest(k, combined, key=lambda x: x[0]) + # Initialize lists to store combined data + combined_distances = [] + combined_documents = [] + combined_metadatas = [] + + for data in query_results: + combined_distances.extend(data["distances"][0]) + combined_documents.extend(data["documents"][0]) + combined_metadatas.extend(data["metadatas"][0]) + + # Create a list of tuples (distance, document, metadata) + combined = list(zip(combined_distances, combined_documents, combined_metadatas)) + + # Sort the list based on distances + combined.sort(key=lambda x: x[0], reverse=reverse) + + # We don't have anything :-( + if not combined: + sorted_distances = [] + sorted_documents = [] + sorted_metadatas = [] else: - top_k = heapq.nsmallest(k, combined, key=lambda x: x[0]) - - if not top_k: - return { - "distances": [[]], - "documents": [[]], - "metadatas": [[]], - } - else: - sorted_distances, sorted_documents, sorted_metadatas = zip(*top_k) - return { - "distances": [sorted_distances], - "documents": [sorted_documents], - "metadatas": [sorted_metadatas], - } + # Unzip the sorted list + sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) + + # Slicing the lists to include only k elements + sorted_distances = list(sorted_distances)[:k] + sorted_documents = list(sorted_documents)[:k] + sorted_metadatas = list(sorted_metadatas)[:k] + + # Create the output dictionary + result = { + "distances": [sorted_distances], + "documents": [sorted_documents], + "metadatas": [sorted_metadatas], + } + + return result def query_collection( @@ -171,18 +185,19 @@ def query_collection( for query in queries: query_embedding = embedding_function(query) for collection_name in collection_names: - if not collection_name: - continue - - try: - if result := query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ): - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") + if collection_name: + try: + result = query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ) + if result is not None: + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + else: + pass return merge_and_sort_query_results(results, k=k) @@ -198,8 +213,8 @@ def query_collection_with_hybrid_search( results = [] error = False for collection_name in collection_names: - for query in queries: - try: + try: + for query in queries: result = query_doc_with_hybrid_search( collection_name=collection_name, query=query, @@ -209,11 +224,11 @@ def query_collection_with_hybrid_search( r=r, ) results.append(result) - except Exception as e: - log.exception( - "Error when querying the collection with " f"hybrid_search: {e}" - ) - error = True + except Exception as e: + log.exception( + "Error when querying the collection with " f"hybrid_search: {e}" + ) + error = True if error: raise Exception( @@ -244,10 +259,10 @@ def get_embedding_function( def generate_multiple(query, func): if isinstance(query, list): - return [ - func(query[i : i + embedding_batch_size]) - for i in range(0, len(query), embedding_batch_size) - ] + embeddings = [] + for i in range(0, len(query), embedding_batch_size): + embeddings.extend(func(query[i : i + embedding_batch_size])) + return embeddings else: return func(query) @@ -421,26 +436,25 @@ def generate_openai_batch_embeddings( def generate_ollama_batch_embeddings( model: str, texts: list[str], url: str, key: str = "" ) -> Optional[list[list[float]]]: - r = requests.post( - f"{url}/api/embed", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", - }, - json={"input": texts, "model": model}, - ) try: + r = requests.post( + f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + }, + json={"input": texts, "model": model}, + ) r.raise_for_status() + data = r.json() + + if "embeddings" in data: + return data["embeddings"] + else: + raise "Something went wrong :/" except Exception as e: print(e) return None - - data = r.json() - - if 'embeddings' not in data: - raise "Something went wrong :/" - - return data['embeddings'] def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):