diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index bcffbc139..114bca437 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -256,7 +256,7 @@ def query_collection( ) -> dict: results = [] for query in queries: - query_embedding = embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) + query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX) for collection_name in collection_names: if collection_name: try: @@ -334,11 +334,11 @@ def get_embedding_function( embedding_batch_size, ): if embedding_engine == "": - return lambda query, prefix, user=None: embedding_function.encode( + return lambda query, prefix=None, user=None: embedding_function.encode( query, prompt=prefix if prefix else None ).tolist() elif embedding_engine in ["ollama", "openai"]: - func = lambda query, prefix, user=None: generate_embeddings( + func = lambda query, prefix=None, user=None: generate_embeddings( engine=embedding_engine, model=embedding_model, text=query, @@ -363,7 +363,7 @@ def get_embedding_function( else: return func(query, prefix, user) - return lambda query, prefix, user=None: generate_multiple( + return lambda query, prefix=None, user=None: generate_multiple( query, prefix, user, func ) else: diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index c55a6a9cc..e660ef852 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -57,7 +57,9 @@ async def add_memory( { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user), + "vector": request.app.state.EMBEDDING_FUNCTION( + memory.content, user=user + ), "metadata": {"created_at": memory.created_at}, } ], @@ -82,7 +84,7 @@ async def query_memory( ): results = VECTOR_DB_CLIENT.search( collection_name=f"user-memory-{user.id}", - vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)], + vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)], limit=form_data.k, ) @@ -105,7 +107,9 @@ async def reset_memory_from_vector_db( { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user), + "vector": request.app.state.EMBEDDING_FUNCTION( + memory.content, user=user + ), "metadata": { "created_at": memory.created_at, "updated_at": memory.updated_at, @@ -161,7 +165,7 @@ async def update_memory_by_id( "id": memory.id, "text": memory.content, "vector": request.app.state.EMBEDDING_FUNCTION( - memory.content, user + memory.content, user=user ), "metadata": { "created_at": memory.created_at, diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index abca72f11..2bd908606 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1518,8 +1518,8 @@ def query_doc_handler( return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, - embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION( - query, user=user + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( + query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=request.app.state.rf, @@ -1569,8 +1569,8 @@ def query_collection_handler( return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION( - query, user=user + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( + query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=request.app.state.rf, @@ -1586,8 +1586,8 @@ def query_collection_handler( return query_collection( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION( - query, user=user + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( + query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) @@ -1666,7 +1666,7 @@ if ENV == "dev": async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): return { "result": request.app.state.EMBEDDING_FUNCTION( - text, RAG_EMBEDDING_QUERY_PREFIX + text, prefix=RAG_EMBEDDING_QUERY_PREFIX ) }