diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a2d114844..8184b467b 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -516,9 +516,12 @@ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_K app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS +app.state.EMBEDDING_FUNCTION = None +app.state.sentence_transformer_ef = None +app.state.sentence_transformer_rf = None app.state.YOUTUBE_LOADER_TRANSLATION = None -app.state.EMBEDDING_FUNCTION = None + ######################################## # @@ -1653,6 +1656,420 @@ async def get_base_models(user=Depends(get_admin_user)): return {"data": models} +################################## +# +# Chat Endpoints +# +################################## + + +@app.post("/api/chat/completions") +async def generate_chat_completions( + request: Request, + form_data: dict, + user=Depends(get_verified_user), + bypass_filter: bool = False, +): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + model_list = request.state.models + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = models[model_id] + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if model.get("arena"): + if not has_access( + user.id, + type="read", + access_control=model.get("info", {}) + .get("meta", {}) + .get("access_control", {}), + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + model_info = Models.get_model_by_id(model_id) + if not model_info: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + elif not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in await get_all_models() + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completions( + form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), media_type="text/event-stream" + ) + else: + return { + **( + await generate_chat_completions(form_data, user, bypass_filter=True) + ), + "selected_model_id": selected_model_id, + } + + if model.get("pipe"): + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + form_data, user=user, models=models + ) + if model["owned_by"] == "ollama": + # Using /ollama/api/chat endpoint + form_data = convert_payload_openai_to_ollama(form_data) + form_data = GenerateChatCompletionForm(**form_data) + response = await generate_ollama_chat_completion( + form_data=form_data, user=user, bypass_filter=bypass_filter + ) + if form_data.stream: + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + ) + else: + return convert_response_ollama_to_openai(response) + else: + return await generate_openai_chat_completion( + form_data, user=user, bypass_filter=bypass_filter + ) + + +@app.post("/api/chat/completed") +async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + data = form_data + model_id = data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = models[model_id] + sorted_filters = get_sorted_filters(model_id, models) + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers=headers, + json={ + "user": { + "id": user.id, + "name": user.name, + "email": user.email, + "role": user.role, + }, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except Exception: + pass + + else: + pass + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + # Sort filter_ids by priority, using the get_priority function + filter_ids.sort(key=get_priority) + + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if not filter: + continue + + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + if not hasattr(function_module, "outlet"): + continue + try: + outlet = function_module.outlet + + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + +@app.post("/api/chat/actions/{action_id}") +async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + + action = Functions.get_function_by_id(action_id) + if not action: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Action not found", + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + data = form_data + model_id = data["model"] + + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + model = models[model_id] + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + if action_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + webui_app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": sub_action_id if sub_action_id is not None else action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + ################################## # # Config Endpoints diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index bf939ecf1..9444ade95 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -11,7 +11,7 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/retrieval/vector/connector.py b/backend/open_webui/retrieval/vector/connector.py index 528835b56..bf97bc7b1 100644 --- a/backend/open_webui/retrieval/vector/connector.py +++ b/backend/open_webui/retrieval/vector/connector.py @@ -1,22 +1,22 @@ from open_webui.config import VECTOR_DB if VECTOR_DB == "milvus": - from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient + from open_webui.retrieval.vector.dbs.milvus import MilvusClient VECTOR_DB_CLIENT = MilvusClient() elif VECTOR_DB == "qdrant": - from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient + from open_webui.retrieval.vector.dbs.qdrant import QdrantClient VECTOR_DB_CLIENT = QdrantClient() elif VECTOR_DB == "opensearch": - from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient + from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient VECTOR_DB_CLIENT = OpenSearchClient() elif VECTOR_DB == "pgvector": - from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient + from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient VECTOR_DB_CLIENT = PgvectorClient() else: - from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient + from open_webui.retrieval.vector.dbs.chroma import ChromaClient VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index b2fcdd16a..00d73a889 100644 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 5351f860e..31d890664 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -4,7 +4,7 @@ import json from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( MILVUS_URI, ) diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 6234b2837..b3d8b5eb8 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -1,7 +1,7 @@ from opensearchpy import OpenSearch from typing import Optional -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( OPENSEARCH_URI, OPENSEARCH_SSL, diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index b8317957e..cb8c545e9 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -18,7 +18,7 @@ from sqlalchemy.dialects.postgresql import JSONB, array from pgvector.sqlalchemy import Vector from sqlalchemy.ext.mutable import MutableDict -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import PGVECTOR_DB_URL VECTOR_LENGTH = 1536 diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index 60c1c3d4d..f077ae45a 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -4,7 +4,7 @@ from qdrant_client import QdrantClient as Qclient from qdrant_client.http.models import PointStruct from qdrant_client.models import models -from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import QDRANT_URI, QDRANT_API_KEY NO_LIMIT = 999999999 diff --git a/backend/open_webui/retrieval/web/bing.py b/backend/open_webui/retrieval/web/bing.py index b5f889c54..09beb3460 100644 --- a/backend/open_webui/retrieval/web/bing.py +++ b/backend/open_webui/retrieval/web/bing.py @@ -3,7 +3,7 @@ import os from pprint import pprint from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS import argparse diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py index f988b3b08..3075db990 100644 --- a/backend/open_webui/retrieval/web/brave.py +++ b/backend/open_webui/retrieval/web/brave.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py index 11e512296..7c0c3f1c2 100644 --- a/backend/open_webui/retrieval/web/duckduckgo.py +++ b/backend/open_webui/retrieval/web/duckduckgo.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from duckduckgo_search import DDGS from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py index 61b919583..2c51dd3c9 100644 --- a/backend/open_webui/retrieval/web/google_pse.py +++ b/backend/open_webui/retrieval/web/google_pse.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py index f5e2febbe..3de6c1807 100644 --- a/backend/open_webui/retrieval/web/jina_search.py +++ b/backend/open_webui/retrieval/web/jina_search.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS from yarl import URL diff --git a/backend/open_webui/retrieval/web/kagi.py b/backend/open_webui/retrieval/web/kagi.py index c8c2699ed..0b69da8bc 100644 --- a/backend/open_webui/retrieval/web/kagi.py +++ b/backend/open_webui/retrieval/web/kagi.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -31,17 +31,15 @@ def search_kagi( response.raise_for_status() json_response = response.json() search_results = json_response.get("data", []) - + results = [ SearchResult( - link=result["url"], - title=result["title"], - snippet=result.get("snippet") + link=result["url"], title=result["title"], snippet=result.get("snippet") ) for result in search_results if result["t"] == 0 ] - + print(results) if filter_list: diff --git a/backend/open_webui/retrieval/web/mojeek.py b/backend/open_webui/retrieval/web/mojeek.py index f257c92aa..d298b0ee5 100644 --- a/backend/open_webui/retrieval/web/mojeek.py +++ b/backend/open_webui/retrieval/web/mojeek.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/searchapi.py b/backend/open_webui/retrieval/web/searchapi.py index 412dc6b69..38bc0b574 100644 --- a/backend/open_webui/retrieval/web/searchapi.py +++ b/backend/open_webui/retrieval/web/searchapi.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/searxng.py b/backend/open_webui/retrieval/web/searxng.py index cb1eaf91d..15e3c098a 100644 --- a/backend/open_webui/retrieval/web/searxng.py +++ b/backend/open_webui/retrieval/web/searxng.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/serper.py b/backend/open_webui/retrieval/web/serper.py index 436fa167e..685e34375 100644 --- a/backend/open_webui/retrieval/web/serper.py +++ b/backend/open_webui/retrieval/web/serper.py @@ -3,7 +3,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/serply.py b/backend/open_webui/retrieval/web/serply.py index 1c2521c47..a9b473eb0 100644 --- a/backend/open_webui/retrieval/web/serply.py +++ b/backend/open_webui/retrieval/web/serply.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/serpstack.py b/backend/open_webui/retrieval/web/serpstack.py index b655934de..d4dbda57c 100644 --- a/backend/open_webui/retrieval/web/serpstack.py +++ b/backend/open_webui/retrieval/web/serpstack.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py index 03b0be75a..cc468725d 100644 --- a/backend/open_webui/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 1617f452e..85a4d30fd 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -11,7 +11,7 @@ from open_webui.models.knowledge import ( KnowledgeUserResponse, ) from open_webui.models.files import Files, FileModel -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from backend.open_webui.routers.retrieval import process_file, ProcessFileForm diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 7973038c4..e72cf1445 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -4,7 +4,7 @@ import logging from typing import Optional from open_webui.models.memories import Memories, MemoryModel -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.auth import get_verified_user from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 517e2894d..7e0dc6018 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1,5 +1,3 @@ -# TODO: Merge this with the webui_app and make it a single app - import json import logging import mimetypes @@ -11,39 +9,55 @@ from datetime import datetime from pathlib import Path from typing import Iterator, Optional, Sequence, Union -from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + UploadFile, + Request, + status, + APIRouter, +) from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import tiktoken -from open_webui.storage.provider import Storage +from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter +from langchain_core.documents import Document + +from open_webui.models.files import Files from open_webui.models.knowledge import Knowledges -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.storage.provider import Storage + + +from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT # Document loaders -from open_webui.apps.retrieval.loaders.main import Loader -from open_webui.apps.retrieval.loaders.youtube import YoutubeLoader +from open_webui.retrieval.loaders.main import Loader +from open_webui.retrieval.loaders.youtube import YoutubeLoader # Web search engines -from open_webui.apps.retrieval.web.main import SearchResult -from open_webui.apps.retrieval.web.utils import get_web_loader -from open_webui.apps.retrieval.web.brave import search_brave -from open_webui.apps.retrieval.web.kagi import search_kagi -from open_webui.apps.retrieval.web.mojeek import search_mojeek -from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo -from open_webui.apps.retrieval.web.google_pse import search_google_pse -from open_webui.apps.retrieval.web.jina_search import search_jina -from open_webui.apps.retrieval.web.searchapi import search_searchapi -from open_webui.apps.retrieval.web.searxng import search_searxng -from open_webui.apps.retrieval.web.serper import search_serper -from open_webui.apps.retrieval.web.serply import search_serply -from open_webui.apps.retrieval.web.serpstack import search_serpstack -from open_webui.apps.retrieval.web.tavily import search_tavily -from open_webui.apps.retrieval.web.bing import search_bing +from open_webui.retrieval.web.main import SearchResult +from open_webui.retrieval.web.utils import get_web_loader +from open_webui.retrieval.web.brave import search_brave +from open_webui.retrieval.web.kagi import search_kagi +from open_webui.retrieval.web.mojeek import search_mojeek +from open_webui.retrieval.web.duckduckgo import search_duckduckgo +from open_webui.retrieval.web.google_pse import search_google_pse +from open_webui.retrieval.web.jina_search import search_jina +from open_webui.retrieval.web.searchapi import search_searchapi +from open_webui.retrieval.web.searxng import search_searxng +from open_webui.retrieval.web.serper import search_serper +from open_webui.retrieval.web.serply import search_serply +from open_webui.retrieval.web.serpstack import search_serpstack +from open_webui.retrieval.web.tavily import search_tavily +from open_webui.retrieval.web.bing import search_bing -from backend.open_webui.retrieval.utils import ( +from open_webui.retrieval.utils import ( get_embedding_function, get_model_path, query_collection, @@ -51,246 +65,132 @@ from backend.open_webui.retrieval.utils import ( query_doc, query_doc_with_hybrid_search, ) +from open_webui.utils.misc import ( + calculate_sha256_string, +) +from open_webui.utils.auth import get_admin_user, get_verified_user + -from open_webui.models.files import Files from open_webui.config import ( - BRAVE_SEARCH_API_KEY, - KAGI_SEARCH_API_KEY, - MOJEEK_SEARCH_API_KEY, - TIKTOKEN_ENCODING_NAME, - RAG_TEXT_SPLITTER, - CHUNK_OVERLAP, - CHUNK_SIZE, - CONTENT_EXTRACTION_ENGINE, - CORS_ALLOW_ORIGIN, - ENABLE_RAG_HYBRID_SEARCH, - ENABLE_RAG_LOCAL_WEB_FETCH, - ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - ENABLE_RAG_WEB_SEARCH, ENV, - GOOGLE_PSE_API_KEY, - GOOGLE_PSE_ENGINE_ID, - PDF_EXTRACT_IMAGES, - RAG_EMBEDDING_ENGINE, - RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_EMBEDDING_BATCH_SIZE, - RAG_FILE_MAX_COUNT, - RAG_FILE_MAX_SIZE, - RAG_OPENAI_API_BASE_URL, - RAG_OPENAI_API_KEY, - RAG_OLLAMA_BASE_URL, - RAG_OLLAMA_API_KEY, - RAG_RELEVANCE_THRESHOLD, - RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - DEFAULT_RAG_TEMPLATE, - RAG_TEMPLATE, - RAG_TOP_K, - RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - RAG_WEB_SEARCH_ENGINE, - RAG_WEB_SEARCH_RESULT_COUNT, - JINA_API_KEY, - SEARCHAPI_API_KEY, - SEARCHAPI_ENGINE, - SEARXNG_QUERY_URL, - SERPER_API_KEY, - SERPLY_API_KEY, - SERPSTACK_API_KEY, - SERPSTACK_HTTPS, - TAVILY_API_KEY, - BING_SEARCH_V7_ENDPOINT, - BING_SEARCH_V7_SUBSCRIPTION_KEY, - TIKA_SERVER_URL, UPLOAD_DIR, - YOUTUBE_LOADER_LANGUAGE, - YOUTUBE_LOADER_PROXY_URL, DEFAULT_LOCALE, - AppConfig, ) -from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER, ) -from open_webui.utils.misc import ( - calculate_sha256, - calculate_sha256_string, - extract_folders_after_data_docs, - sanitize_filename, -) -from open_webui.utils.auth import get_admin_user, get_verified_user - -from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter -from langchain_core.documents import Document - +from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) - -app.state.config = AppConfig() - -app.state.config.TOP_K = RAG_TOP_K -app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD -app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE -app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT - -app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH -app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION -) - -app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE -app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL - -app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER -app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME - -app.state.config.CHUNK_SIZE = CHUNK_SIZE -app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP - -app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE -app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE -app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL -app.state.config.RAG_TEMPLATE = RAG_TEMPLATE - -app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL -app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY - -app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL -app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY - -app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES - -app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE -app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL -app.state.YOUTUBE_LOADER_TRANSLATION = None - - -app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH -app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE -app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST - -app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL -app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY -app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID -app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY -app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY -app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY -app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY -app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS -app.state.config.SERPER_API_KEY = SERPER_API_KEY -app.state.config.SERPLY_API_KEY = SERPLY_API_KEY -app.state.config.TAVILY_API_KEY = TAVILY_API_KEY -app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY -app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE -app.state.config.JINA_API_KEY = JINA_API_KEY -app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT -app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY - -app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT -app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS +########################################## +# +# Utility functions +# +########################################## def update_embedding_model( + request: Request, embedding_model: str, auto_update: bool = False, ): - if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": + if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "": from sentence_transformers import SentenceTransformer try: - app.state.sentence_transformer_ef = SentenceTransformer( + request.app.state.sentence_transformer_ef = SentenceTransformer( get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) except Exception as e: log.debug(f"Error loading SentenceTransformer: {e}") - app.state.sentence_transformer_ef = None + request.app.state.sentence_transformer_ef = None else: - app.state.sentence_transformer_ef = None + request.app.state.sentence_transformer_ef = None def update_reranking_model( + request: Request, reranking_model: str, auto_update: bool = False, ): if reranking_model: if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): try: - from open_webui.apps.retrieval.models.colbert import ColBERT + from open_webui.retrieval.models.colbert import ColBERT - app.state.sentence_transformer_rf = ColBERT( + request.app.state.sentence_transformer_rf = ColBERT( get_model_path(reranking_model, auto_update), env="docker" if DOCKER else None, ) except Exception as e: log.error(f"ColBERT: {e}") - app.state.sentence_transformer_rf = None - app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + request.app.state.sentence_transformer_rf = None + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False else: import sentence_transformers try: - app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( - get_model_path(reranking_model, auto_update), - device=DEVICE_TYPE, - trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + request.app.state.sentence_transformer_rf = ( + sentence_transformers.CrossEncoder( + get_model_path(reranking_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + ) ) except: log.error("CrossEncoder error") - app.state.sentence_transformer_rf = None - app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + request.app.state.sentence_transformer_rf = None + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False else: - app.state.sentence_transformer_rf = None + request.app.state.sentence_transformer_rf = None update_embedding_model( - app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) update_reranking_model( - app.state.config.RAG_RERANKING_MODEL, + request.app.state.config.RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, ) -app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - ( - app.state.config.OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL - ), - ( - app.state.config.OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY - ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, -) +########################################## +# +# API routes +# +########################################## -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + +router = APIRouter() + + +request.app.state.EMBEDDING_FUNCTION = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.sentence_transformer_ef, + ( + request.app.state.config.OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_BASE_URL + ), + ( + request.app.state.config.OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_API_KEY + ), + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -306,43 +206,43 @@ class SearchForm(CollectionNameForm): query: str -@app.get("/") -async def get_status(): +@router.get("/") +async def get_status(request: Request): return { "status": True, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, - "template": app.state.config.RAG_TEMPLATE, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, + "template": request.app.state.config.RAG_TEMPLATE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, } -@app.get("/embedding") -async def get_embedding_config(user=Depends(get_admin_user)): +@router.get("/embedding") +async def get_embedding_config(request: Request, user=Depends(get_admin_user)): return { "status": True, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { - "url": app.state.config.OPENAI_API_BASE_URL, - "key": app.state.config.OPENAI_API_KEY, + "url": request.app.state.config.OPENAI_API_BASE_URL, + "key": request.app.state.config.OPENAI_API_KEY, }, "ollama_config": { - "url": app.state.config.OLLAMA_BASE_URL, - "key": app.state.config.OLLAMA_API_KEY, + "url": request.app.state.config.OLLAMA_BASE_URL, + "key": request.app.state.config.OLLAMA_API_KEY, }, } -@app.get("/reranking") -async def get_reraanking_config(user=Depends(get_admin_user)): +@router.get("/reranking") +async def get_reraanking_config(request: Request, user=Depends(get_admin_user)): return { "status": True, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, } @@ -364,59 +264,63 @@ class EmbeddingModelUpdateForm(BaseModel): embedding_batch_size: Optional[int] = 1 -@app.post("/embedding/update") +@router.post("/embedding/update") async def update_embedding_config( - form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" + f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: - app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model + request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine + request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config is not None: - app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url - app.state.config.OPENAI_API_KEY = form_data.openai_config.key + request.app.state.config.OPENAI_API_BASE_URL = ( + form_data.openai_config.url + ) + request.app.state.config.OPENAI_API_KEY = form_data.openai_config.key if form_data.ollama_config is not None: - app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url - app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key + request.app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url + request.app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key - app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( + form_data.embedding_batch_size + ) - update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) + update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL) - app.state.EMBEDDING_FUNCTION = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, + request.app.state.EMBEDDING_FUNCTION = get_embedding_function( + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.sentence_transformer_ef, ( - app.state.config.OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL + request.app.state.config.OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_BASE_URL ), ( - app.state.config.OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY + request.app.state.config.OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_API_KEY ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) return { "status": True, - "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, - "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, - "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { - "url": app.state.config.OPENAI_API_BASE_URL, - "key": app.state.config.OPENAI_API_KEY, + "url": request.app.state.config.OPENAI_API_BASE_URL, + "key": request.app.state.config.OPENAI_API_KEY, }, "ollama_config": { - "url": app.state.config.OLLAMA_BASE_URL, - "key": app.state.config.OLLAMA_API_KEY, + "url": request.app.state.config.OLLAMA_BASE_URL, + "key": request.app.state.config.OLLAMA_API_KEY, }, } except Exception as e: @@ -431,21 +335,21 @@ class RerankingModelUpdateForm(BaseModel): reranking_model: str -@app.post("/reranking/update") +@router.post("/reranking/update") async def update_reranking_config( - form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) + request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) ): log.info( - f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" + f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: - app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model + request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True) + update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True) return { "status": True, - "reranking_model": app.state.config.RAG_RERANKING_MODEL, + "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, } except Exception as e: log.exception(f"Problem updating reranking model: {e}") @@ -455,52 +359,52 @@ async def update_reranking_config( ) -@app.get("/config") -async def get_rag_config(user=Depends(get_admin_user)): +@router.get("/config") +async def get_rag_config(request: Request, user=Depends(get_admin_user)): return { "status": True, - "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, + "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "content_extraction": { - "engine": app.state.config.CONTENT_EXTRACTION_ENGINE, - "tika_server_url": app.state.config.TIKA_SERVER_URL, + "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "tika_server_url": request.app.state.config.TIKA_SERVER_URL, }, "chunk": { - "text_splitter": app.state.config.TEXT_SPLITTER, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, + "text_splitter": request.app.state.config.TEXT_SPLITTER, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, }, "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, + "max_size": request.app.state.config.FILE_MAX_SIZE, + "max_count": request.app.state.config.FILE_MAX_COUNT, }, "youtube": { - "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, - "translation": app.state.YOUTUBE_LOADER_TRANSLATION, - "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL, + "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, + "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, }, "web": { - "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { - "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, - "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, - "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, - "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, - "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, - "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, - "kagi_search_api_key": app.state.config.KAGI_SEARCH_API_KEY, - "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY, - "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, - "serpstack_https": app.state.config.SERPSTACK_HTTPS, - "serper_api_key": app.state.config.SERPER_API_KEY, - "serply_api_key": app.state.config.SERPLY_API_KEY, - "tavily_api_key": app.state.config.TAVILY_API_KEY, - "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY, - "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE, - "jina_api_key": app.state.config.JINA_API_KEY, - "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT, - "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, + "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, + "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, + "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, + "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, + "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, + "serper_api_key": request.app.state.config.SERPER_API_KEY, + "serply_api_key": request.app.state.config.SERPLY_API_KEY, + "tavily_api_key": request.app.state.config.TAVILY_API_KEY, + "searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, + "seaarchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "jina_api_key": request.app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } @@ -565,139 +469,159 @@ class ConfigUpdateForm(BaseModel): web: Optional[WebConfig] = None -@app.post("/config/update") -async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.config.PDF_EXTRACT_IMAGES = ( +@router.post("/config/update") +async def update_rag_config( + request: Request, form_data: ConfigUpdateForm, user=Depends(get_admin_user) +): + request.app.state.config.PDF_EXTRACT_IMAGES = ( form_data.pdf_extract_images if form_data.pdf_extract_images is not None - else app.state.config.PDF_EXTRACT_IMAGES + else request.app.state.config.PDF_EXTRACT_IMAGES ) if form_data.file is not None: - app.state.config.FILE_MAX_SIZE = form_data.file.max_size - app.state.config.FILE_MAX_COUNT = form_data.file.max_count + request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size + request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count if form_data.content_extraction is not None: log.info(f"Updating text settings: {form_data.content_extraction}") - app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine - app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url + request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( + form_data.content_extraction.engine + ) + request.app.state.config.TIKA_SERVER_URL = ( + form_data.content_extraction.tika_server_url + ) if form_data.chunk is not None: - app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter - app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size - app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap + request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter + request.app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size + request.app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap if form_data.youtube is not None: - app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language - app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url - app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation + request.app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language + request.app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url + request.app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation if form_data.web is not None: - app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False form_data.web.web_loader_ssl_verification ) - app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled - app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine - app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url - app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key - app.state.config.GOOGLE_PSE_ENGINE_ID = ( + request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled + request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine + request.app.state.config.SEARXNG_QUERY_URL = ( + form_data.web.search.searxng_query_url + ) + request.app.state.config.GOOGLE_PSE_API_KEY = ( + form_data.web.search.google_pse_api_key + ) + request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( form_data.web.search.google_pse_engine_id ) - app.state.config.BRAVE_SEARCH_API_KEY = ( + request.app.state.config.BRAVE_SEARCH_API_KEY = ( form_data.web.search.brave_search_api_key ) - app.state.config.KAGI_SEARCH_API_KEY = form_data.web.search.kagi_search_api_key - app.state.config.MOJEEK_SEARCH_API_KEY = ( + request.app.state.config.KAGI_SEARCH_API_KEY = ( + form_data.web.search.kagi_search_api_key + ) + request.app.state.config.MOJEEK_SEARCH_API_KEY = ( form_data.web.search.mojeek_search_api_key ) - app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key - app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https - app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key - app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key - app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key - app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key - app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine + request.app.state.config.SERPSTACK_API_KEY = ( + form_data.web.search.serpstack_api_key + ) + request.app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https + request.app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key + request.app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key + request.app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key + request.app.state.config.SEARCHAPI_API_KEY = ( + form_data.web.search.searchapi_api_key + ) + request.app.state.config.SEARCHAPI_ENGINE = ( + form_data.web.search.searchapi_engine + ) - app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key - app.state.config.BING_SEARCH_V7_ENDPOINT = ( + request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key + request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( form_data.web.search.bing_search_v7_endpoint ) - app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( form_data.web.search.bing_search_v7_subscription_key ) - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count - app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = ( + form_data.web.search.result_count + ) + request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests ) return { "status": True, - "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, + "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "file": { - "max_size": app.state.config.FILE_MAX_SIZE, - "max_count": app.state.config.FILE_MAX_COUNT, + "max_size": request.app.state.config.FILE_MAX_SIZE, + "max_count": request.app.state.config.FILE_MAX_COUNT, }, "content_extraction": { - "engine": app.state.config.CONTENT_EXTRACTION_ENGINE, - "tika_server_url": app.state.config.TIKA_SERVER_URL, + "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, + "tika_server_url": request.app.state.config.TIKA_SERVER_URL, }, "chunk": { - "text_splitter": app.state.config.TEXT_SPLITTER, - "chunk_size": app.state.config.CHUNK_SIZE, - "chunk_overlap": app.state.config.CHUNK_OVERLAP, + "text_splitter": request.app.state.config.TEXT_SPLITTER, + "chunk_size": request.app.state.config.CHUNK_SIZE, + "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, }, "youtube": { - "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, - "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL, - "translation": app.state.YOUTUBE_LOADER_TRANSLATION, + "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { - "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { - "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH, - "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, - "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, - "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, - "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, - "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, - "kagi_search_api_key": app.state.config.KAGI_SEARCH_API_KEY, - "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY, - "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, - "serpstack_https": app.state.config.SERPSTACK_HTTPS, - "serper_api_key": app.state.config.SERPER_API_KEY, - "serply_api_key": app.state.config.SERPLY_API_KEY, - "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY, - "searchapi_engine": app.state.config.SEARCHAPI_ENGINE, - "tavily_api_key": app.state.config.TAVILY_API_KEY, - "jina_api_key": app.state.config.JINA_API_KEY, - "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT, - "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, + "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, + "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, + "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, + "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, + "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, + "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, + "serper_api_key": request.app.state.config.SERPER_API_KEY, + "serply_api_key": request.app.state.config.SERPLY_API_KEY, + "serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, + "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, + "tavily_api_key": request.app.state.config.TAVILY_API_KEY, + "jina_api_key": request.app.state.config.JINA_API_KEY, + "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, + "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } -@app.get("/template") -async def get_rag_template(user=Depends(get_verified_user)): +@router.get("/template") +async def get_rag_template(request: Request, user=Depends(get_verified_user)): return { "status": True, - "template": app.state.config.RAG_TEMPLATE, + "template": request.app.state.config.RAG_TEMPLATE, } -@app.get("/query/settings") -async def get_query_settings(user=Depends(get_admin_user)): +@router.get("/query/settings") +async def get_query_settings(request: Request, user=Depends(get_admin_user)): return { "status": True, - "template": app.state.config.RAG_TEMPLATE, - "k": app.state.config.TOP_K, - "r": app.state.config.RELEVANCE_THRESHOLD, - "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, + "template": request.app.state.config.RAG_TEMPLATE, + "k": request.app.state.config.TOP_K, + "r": request.app.state.config.RELEVANCE_THRESHOLD, + "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -708,24 +632,24 @@ class QuerySettingsForm(BaseModel): hybrid: Optional[bool] = None -@app.post("/query/settings/update") +@router.post("/query/settings/update") async def update_query_settings( - form_data: QuerySettingsForm, user=Depends(get_admin_user) + request: Request, form_data: QuerySettingsForm, user=Depends(get_admin_user) ): - app.state.config.RAG_TEMPLATE = form_data.template - app.state.config.TOP_K = form_data.k if form_data.k else 4 - app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 + request.app.state.config.RAG_TEMPLATE = form_data.template + request.app.state.config.TOP_K = form_data.k if form_data.k else 4 + request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( form_data.hybrid if form_data.hybrid else False ) return { "status": True, - "template": app.state.config.RAG_TEMPLATE, - "k": app.state.config.TOP_K, - "r": app.state.config.RELEVANCE_THRESHOLD, - "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, + "template": request.app.state.config.RAG_TEMPLATE, + "k": request.app.state.config.TOP_K, + "r": request.app.state.config.RELEVANCE_THRESHOLD, + "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, } @@ -736,24 +660,8 @@ async def update_query_settings( #################################### -def _get_docs_info(docs: list[Document]) -> str: - docs_info = set() - - # Trying to select relevant metadata identifying the document. - for doc in docs: - metadata = getattr(doc, "metadata", {}) - doc_name = metadata.get("name", "") - if not doc_name: - doc_name = metadata.get("title", "") - if not doc_name: - doc_name = metadata.get("source", "") - if doc_name: - docs_info.add(doc_name) - - return ", ".join(docs_info) - - def save_docs_to_vector_db( + request: Request, docs, collection_name, metadata: Optional[dict] = None, @@ -761,6 +669,22 @@ def save_docs_to_vector_db( split: bool = True, add: bool = False, ) -> bool: + def _get_docs_info(docs: list[Document]) -> str: + docs_info = set() + + # Trying to select relevant metadata identifying the document. + for doc in docs: + metadata = getattr(doc, "metadata", {}) + doc_name = metadata.get("name", "") + if not doc_name: + doc_name = metadata.get("title", "") + if not doc_name: + doc_name = metadata.get("source", "") + if doc_name: + docs_info.add(doc_name) + + return ", ".join(docs_info) + log.info( f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" ) @@ -779,22 +703,22 @@ def save_docs_to_vector_db( raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: - if app.state.config.TEXT_SPLITTER in ["", "character"]: + if request.app.state.config.TEXT_SPLITTER in ["", "character"]: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, + chunk_size=request.app.state.config.CHUNK_SIZE, + chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) - elif app.state.config.TEXT_SPLITTER == "token": + elif request.app.state.config.TEXT_SPLITTER == "token": log.info( - f"Using token text splitter: {app.state.config.TIKTOKEN_ENCODING_NAME}" + f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" ) - tiktoken.get_encoding(str(app.state.config.TIKTOKEN_ENCODING_NAME)) + tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME)) text_splitter = TokenTextSplitter( - encoding_name=str(app.state.config.TIKTOKEN_ENCODING_NAME), - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, + encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME), + chunk_size=request.app.state.config.CHUNK_SIZE, + chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) else: @@ -812,8 +736,8 @@ def save_docs_to_vector_db( **(metadata if metadata else {}), "embedding_config": json.dumps( { - "engine": app.state.config.RAG_EMBEDDING_ENGINE, - "model": app.state.config.RAG_EMBEDDING_MODEL, + "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "model": request.app.state.config.RAG_EMBEDDING_MODEL, } ), } @@ -842,20 +766,20 @@ def save_docs_to_vector_db( log.info(f"adding to collection {collection_name}") embedding_function = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, + request.app.state.config.RAG_EMBEDDING_ENGINE, + request.app.state.config.RAG_EMBEDDING_MODEL, + request.app.state.sentence_transformer_ef, ( - app.state.config.OPENAI_API_BASE_URL - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_BASE_URL + request.app.state.config.OPENAI_API_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_BASE_URL ), ( - app.state.config.OPENAI_API_KEY - if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.OLLAMA_API_KEY + request.app.state.config.OPENAI_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else request.app.state.config.OLLAMA_API_KEY ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) embeddings = embedding_function( @@ -889,8 +813,9 @@ class ProcessFileForm(BaseModel): collection_name: Optional[str] = None -@app.post("/process/file") +@router.post("/process/file") def process_file( + request: Request, form_data: ProcessFileForm, user=Depends(get_verified_user), ): @@ -960,9 +885,9 @@ def process_file( if file_path: file_path = Storage.get_file(file_path) loader = Loader( - engine=app.state.config.CONTENT_EXTRACTION_ENGINE, - TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, - PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, ) docs = loader.load( file.filename, file.meta.get("content_type"), file_path @@ -1007,6 +932,7 @@ def process_file( try: result = save_docs_to_vector_db( + request, docs=docs, collection_name=collection_name, metadata={ @@ -1053,8 +979,9 @@ class ProcessTextForm(BaseModel): collection_name: Optional[str] = None -@app.post("/process/text") +@router.post("/process/text") def process_text( + request: Request, form_data: ProcessTextForm, user=Depends(get_verified_user), ): @@ -1071,8 +998,7 @@ def process_text( text_content = form_data.content log.debug(f"text_content: {text_content}") - result = save_docs_to_vector_db(docs, collection_name) - + result = save_docs_to_vector_db(request, docs, collection_name) if result: return { "status": True, @@ -1086,8 +1012,10 @@ def process_text( ) -@app.post("/process/youtube") -def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)): +@router.post("/process/youtube") +def process_youtube_video( + request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) +): try: collection_name = form_data.collection_name if not collection_name: @@ -1095,14 +1023,15 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u loader = YoutubeLoader( form_data.url, - language=app.state.config.YOUTUBE_LOADER_LANGUAGE, - proxy_url=app.state.config.YOUTUBE_LOADER_PROXY_URL, + language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, ) docs = loader.load() content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {content}") - save_docs_to_vector_db(docs, collection_name, overwrite=True) + + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, @@ -1125,8 +1054,10 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u ) -@app.post("/process/web") -def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): +@router.post("/process/web") +def process_web( + request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) +): try: collection_name = form_data.collection_name if not collection_name: @@ -1134,13 +1065,14 @@ def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): loader = get_web_loader( form_data.url, - verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) docs = loader.load() content = " ".join([doc.page_content for doc in docs]) + log.debug(f"text_content: {content}") - save_docs_to_vector_db(docs, collection_name, overwrite=True) + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, @@ -1163,7 +1095,7 @@ def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)): ) -def search_web(engine: str, query: str) -> list[SearchResult]: +def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: - SEARXNG_QUERY_URL @@ -1182,150 +1114,151 @@ def search_web(engine: str, query: str) -> list[SearchResult]: # TODO: add playwright to search the web if engine == "searxng": - if app.state.config.SEARXNG_QUERY_URL: + if request.app.state.config.SEARXNG_QUERY_URL: return search_searxng( - app.state.config.SEARXNG_QUERY_URL, + request.app.state.config.SEARXNG_QUERY_URL, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SEARXNG_QUERY_URL found in environment variables") elif engine == "google_pse": if ( - app.state.config.GOOGLE_PSE_API_KEY - and app.state.config.GOOGLE_PSE_ENGINE_ID + request.app.state.config.GOOGLE_PSE_API_KEY + and request.app.state.config.GOOGLE_PSE_ENGINE_ID ): return search_google_pse( - app.state.config.GOOGLE_PSE_API_KEY, - app.state.config.GOOGLE_PSE_ENGINE_ID, + request.app.state.config.GOOGLE_PSE_API_KEY, + request.app.state.config.GOOGLE_PSE_ENGINE_ID, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception( "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables" ) elif engine == "brave": - if app.state.config.BRAVE_SEARCH_API_KEY: + if request.app.state.config.BRAVE_SEARCH_API_KEY: return search_brave( - app.state.config.BRAVE_SEARCH_API_KEY, + request.app.state.config.BRAVE_SEARCH_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") elif engine == "kagi": - if app.state.config.KAGI_SEARCH_API_KEY: + if request.app.state.config.KAGI_SEARCH_API_KEY: return search_kagi( - app.state.config.KAGI_SEARCH_API_KEY, + request.app.state.config.KAGI_SEARCH_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No KAGI_SEARCH_API_KEY found in environment variables") elif engine == "mojeek": - if app.state.config.MOJEEK_SEARCH_API_KEY: + if request.app.state.config.MOJEEK_SEARCH_API_KEY: return search_mojeek( - app.state.config.MOJEEK_SEARCH_API_KEY, + request.app.state.config.MOJEEK_SEARCH_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables") elif engine == "serpstack": - if app.state.config.SERPSTACK_API_KEY: + if request.app.state.config.SERPSTACK_API_KEY: return search_serpstack( - app.state.config.SERPSTACK_API_KEY, + request.app.state.config.SERPSTACK_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - https_enabled=app.state.config.SERPSTACK_HTTPS, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + https_enabled=request.app.state.config.SERPSTACK_HTTPS, ) else: raise Exception("No SERPSTACK_API_KEY found in environment variables") elif engine == "serper": - if app.state.config.SERPER_API_KEY: + if request.app.state.config.SERPER_API_KEY: return search_serper( - app.state.config.SERPER_API_KEY, + request.app.state.config.SERPER_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SERPER_API_KEY found in environment variables") elif engine == "serply": - if app.state.config.SERPLY_API_KEY: + if request.app.state.config.SERPLY_API_KEY: return search_serply( - app.state.config.SERPLY_API_KEY, + request.app.state.config.SERPLY_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SERPLY_API_KEY found in environment variables") elif engine == "duckduckgo": return search_duckduckgo( query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) elif engine == "tavily": - if app.state.config.TAVILY_API_KEY: + if request.app.state.config.TAVILY_API_KEY: return search_tavily( - app.state.config.TAVILY_API_KEY, + request.app.state.config.TAVILY_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) else: raise Exception("No TAVILY_API_KEY found in environment variables") elif engine == "searchapi": - if app.state.config.SEARCHAPI_API_KEY: + if request.app.state.config.SEARCHAPI_API_KEY: return search_searchapi( - app.state.config.SEARCHAPI_API_KEY, - app.state.config.SEARCHAPI_ENGINE, + request.app.state.config.SEARCHAPI_API_KEY, + request.app.state.config.SEARCHAPI_ENGINE, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SEARCHAPI_API_KEY found in environment variables") elif engine == "jina": return search_jina( - app.state.config.JINA_API_KEY, + request.app.state.config.JINA_API_KEY, query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) elif engine == "bing": return search_bing( - app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, - app.state.config.BING_SEARCH_V7_ENDPOINT, + request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, + request.app.state.config.BING_SEARCH_V7_ENDPOINT, str(DEFAULT_LOCALE), query, - app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, - app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No search engine API key found in environment variables") -@app.post("/process/web/search") -def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): +@router.post("/process/web/search") +def process_web_search( + request: Request, form_data: SearchForm, user=Depends(get_verified_user) +): try: logging.info( - f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" + f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" ) web_results = search_web( - app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query + request, request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query ) except Exception as e: log.exception(e) - print(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), @@ -1334,18 +1267,19 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): try: collection_name = form_data.collection_name if collection_name == "": - collection_name = calculate_sha256_string(form_data.query)[:63] + collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[ + :63 + ] urls = [result.link for result in web_results] - loader = get_web_loader( - urls, - verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + urls=urls, + verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) docs = loader.aload() - save_docs_to_vector_db(docs, collection_name, overwrite=True) + save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, @@ -1368,28 +1302,31 @@ class QueryDocForm(BaseModel): hybrid: Optional[bool] = None -@app.post("/query/doc") +@router.post("/query/doc") def query_doc_handler( + request: Request, form_data: QueryDocForm, user=Depends(get_verified_user), ): try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + reranking_function=request.app.state.sentence_transformer_rf, r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + form_data.r + if form_data.r + else request.app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_doc( collection_name=form_data.collection_name, - query_embedding=app.state.EMBEDDING_FUNCTION(form_data.query), - k=form_data.k if form_data.k else app.state.config.TOP_K, + query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query), + k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: log.exception(e) @@ -1407,29 +1344,32 @@ class QueryCollectionsForm(BaseModel): hybrid: Optional[bool] = None -@app.post("/query/collection") +@router.post("/query/collection") def query_collection_handler( + request: Request, form_data: QueryCollectionsForm, user=Depends(get_verified_user), ): try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, + reranking_function=request.app.state.sentence_transformer_rf, r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + form_data.r + if form_data.r + else request.app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_collection( collection_names=form_data.collection_names, queries=[form_data.query], - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, + embedding_function=request.app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: @@ -1452,7 +1392,7 @@ class DeleteForm(BaseModel): file_id: str -@app.post("/delete") +@router.post("/delete") def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): try: if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): @@ -1471,13 +1411,13 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin return {"status": False} -@app.post("/reset/db") +@router.post("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): VECTOR_DB_CLIENT.reset() Knowledges.delete_all_knowledge() -@app.post("/reset/uploads") +@router.post("/reset/uploads") def reset_upload_dir(user=Depends(get_admin_user)) -> bool: folder = f"{UPLOAD_DIR}" try: @@ -1502,10 +1442,6 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool: if ENV == "dev": - @app.get("/ef") - async def get_embeddings(): - return {"result": app.state.EMBEDDING_FUNCTION("hello world")} - - @app.get("/ef/{text}") - async def get_embeddings_text(text: str): - return {"result": app.state.EMBEDDING_FUNCTION(text)} + @router.get("/ef/{text}") + async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): + return {"result": request.app.state.EMBEDDING_FUNCTION(text)}