From fe5519e0a2ab884a685e9d02d9f2695bae8c8a9e Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 19:52:46 -0800 Subject: [PATCH] wip --- backend/open_webui/main.py | 242 ++++++------------------ backend/open_webui/routers/ollama.py | 11 +- backend/open_webui/routers/pipelines.py | 140 +++++++++++++- backend/open_webui/routers/tasks.py | 39 ++-- backend/open_webui/utils/task.py | 16 ++ 5 files changed, 236 insertions(+), 212 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 84af1685a..dbb9518af 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -75,6 +75,11 @@ from open_webui.routers.retrieval import ( get_ef, get_rf, ) +from open_webui.routers.pipelines import ( + process_pipeline_inlet_filter, + process_pipeline_outlet_filter, +) + from open_webui.retrieval.utils import get_sources_from_files @@ -290,6 +295,7 @@ from open_webui.utils.response import ( ) from open_webui.utils.task import ( + get_task_model_id, rag_template, tools_function_calling_generation_template, ) @@ -662,35 +668,36 @@ app.state.MODELS = {} ################################## -def get_filter_function_ids(model): - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - filter_ids.sort(key=get_priority) - return filter_ids - - async def chat_completion_filter_functions_handler(body, model, extra_params): skip_files = None + def get_filter_function_ids(model): + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [ + function.id for function in Functions.get_global_filter_functions() + ] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + filter_ids.sort(key=get_priority) + return filter_ids + filter_ids = get_filter_function_ids(model) for filter_id in filter_ids: filter = Functions.get_function_by_id(filter_id) @@ -791,22 +798,6 @@ async def get_content_from_response(response) -> Optional[str]: return content -def get_task_model_id( - default_model_id: str, task_model: str, task_model_external: str, models -) -> str: - # Set the task model - task_model_id = default_model_id - # Check if the user has a custom task model and use that model - if models[task_model_id]["owned_by"] == "ollama": - if task_model and task_model in models: - task_model_id = task_model - else: - if task_model_external and task_model_external in models: - task_model_id = task_model_external - - return task_model_id - - async def chat_completion_tools_handler( body: dict, user: UserModel, models, extra_params: dict ) -> tuple[dict, dict]: @@ -857,7 +848,7 @@ async def chat_completion_tools_handler( ) try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(request, payload, user, models) except Exception as e: raise e @@ -1153,7 +1144,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if prompt is None: raise Exception("No user message found") if ( - retrieval_app.state.config.RELEVANCE_THRESHOLD == 0 + app.state.config.RELEVANCE_THRESHOLD == 0 and context_string.strip() == "" ): log.debug( @@ -1164,16 +1155,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): # TODO: replace with add_or_update_system_message if model["owned_by"] == "ollama": body["messages"] = prepend_to_first_user_message_content( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), + rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt), body["messages"], ) else: body["messages"] = add_or_update_system_message( - rag_template( - retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt - ), + rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt), body["messages"], ) @@ -1225,77 +1212,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): app.add_middleware(ChatCompletionMiddleware) -################################## -# -# Pipeline Middleware -# -################################## - - -def get_sorted_filters(model_id, models): - filters = [ - model - for model in models.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - return sorted_filters - - -def filter_pipeline(payload, user, models): - user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} - model_id = payload["model"] - - sorted_filters = get_sorted_filters(model_id, models) - model = models[model_id] - - if "pipeline" in model: - sorted_filters.append(model) - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = app.state.config.OPENAI_API_KEYS[urlIdx] - - if key == "": - continue - - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) - - r.raise_for_status() - payload = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - res = r.json() - if "detail" in res: - raise Exception(r.status_code, res["detail"]) - - return payload - - class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if not request.method == "POST" and any( @@ -1335,11 +1251,11 @@ class PipelineMiddleware(BaseHTTPMiddleware): content={"detail": e.detail}, ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + await get_all_models() + models = app.state.MODELS try: - data = filter_pipeline(data, user, models) + data = process_pipeline_inlet_filter(request, data, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1447,8 +1363,8 @@ app.include_router(ollama.router, prefix="/ollama", tags=["ollama"]) app.include_router(openai.router, prefix="/openai", tags=["openai"]) -app.include_router(pipelines.router, prefix="/pipelines", tags=["pipelines"]) -app.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) +app.include_router(pipelines.router, prefix="/api/pipelines", tags=["pipelines"]) +app.include_router(tasks.router, prefix="/api/tasks", tags=["tasks"]) app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) @@ -2105,7 +2021,6 @@ async def generate_chat_completions( if model["owned_by"] == "ollama": # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) - form_data = GenerateChatCompletionForm(**form_data) response = await generate_ollama_chat_completion( form_data=form_data, user=user, bypass_filter=bypass_filter ) @@ -2124,7 +2039,9 @@ async def generate_chat_completions( @app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): +async def chat_completed( + request: Request, form_data: dict, user=Depends(get_verified_user) +): model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -2137,53 +2054,14 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ) model = models[model_id] - sorted_filters = get_sorted_filters(model_id, models) - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": { - "id": user.id, - "name": user.name, - "email": user.email, - "role": user.role, - }, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except Exception: - pass - - else: - pass + try: + data = process_pipeline_outlet_filter(request, data, user, models) + except Exception as e: + return HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) __event_emitter__ = get_event_emitter( { @@ -2455,8 +2333,8 @@ async def get_app_config(request: Request): "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, **( { - "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, - "enable_image_generation": images_app.state.config.ENABLED, + "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH, + "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION, "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, "enable_admin_export": ENABLE_ADMIN_EXPORT, @@ -2472,17 +2350,17 @@ async def get_app_config(request: Request): "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "audio": { "tts": { - "engine": audio_app.state.config.TTS_ENGINE, - "voice": audio_app.state.config.TTS_VOICE, - "split_on": audio_app.state.config.TTS_SPLIT_ON, + "engine": app.state.config.TTS_ENGINE, + "voice": app.state.config.TTS_VOICE, + "split_on": app.state.config.TTS_SPLIT_ON, }, "stt": { - "engine": audio_app.state.config.STT_ENGINE, + "engine": app.state.config.STT_ENGINE, }, }, "file": { - "max_size": retrieval_app.state.config.FILE_MAX_SIZE, - "max_count": retrieval_app.state.config.FILE_MAX_COUNT, + "max_size": app.state.config.FILE_MAX_SIZE, + "max_count": app.state.config.FILE_MAX_COUNT, }, "permissions": {**app.state.config.USER_PERMISSIONS}, } diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index b217b8f45..c36c2d730 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -941,7 +941,7 @@ async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = @router.post("/api/chat/{url_idx}") async def generate_chat_completion( request: Request, - form_data: GenerateChatCompletionForm, + form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, @@ -949,6 +949,15 @@ async def generate_chat_completion( if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True + try: + form_data = GenerateChatCompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + payload = {**form_data.model_dump(exclude_none=True)} if "metadata" in payload: del payload["metadata"] diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index f1cdae140..258c10ee6 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -30,6 +30,130 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) +################################## +# +# Pipeline Middleware +# +################################## + + +def get_sorted_filters(model_id, models): + filters = [ + model + for model in models.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + return sorted_filters + + +def process_pipeline_inlet_filter(request, payload, user, models): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] + + if "pipeline" in model: + sorted_filters.append(model) + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + if key == "": + continue + + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + res = r.json() + if "detail" in res: + raise Exception(r.status_code, res["detail"]) + + return payload + + +def process_pipeline_outlet_filter(request, payload, user, models): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] + + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = request.app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers={"Authorization": f"Bearer {key}"}, + json={ + "user": { + "id": user.id, + "name": user.name, + "email": user.email, + "role": user.role, + }, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return Exception(r.status_code, res) + except Exception: + pass + + else: + pass + + return payload + + ################################## # # Pipelines Endpoints @@ -39,7 +163,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) router = APIRouter() -@router.get("/api/pipelines/list") +@router.get("/list") async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): responses = await get_all_models_responses(request) log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") @@ -61,7 +185,7 @@ async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): } -@router.post("/api/pipelines/upload") +@router.post("/upload") async def upload_pipeline( request: Request, urlIdx: int = Form(...), @@ -131,7 +255,7 @@ class AddPipelineForm(BaseModel): urlIdx: int -@router.post("/api/pipelines/add") +@router.post("/add") async def add_pipeline( request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user) ): @@ -176,7 +300,7 @@ class DeletePipelineForm(BaseModel): urlIdx: int -@router.delete("/api/pipelines/delete") +@router.delete("/delete") async def delete_pipeline( request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user) ): @@ -216,7 +340,7 @@ async def delete_pipeline( ) -@router.get("/api/pipelines") +@router.get("/") async def get_pipelines( request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user) ): @@ -250,7 +374,7 @@ async def get_pipelines( ) -@router.get("/api/pipelines/{pipeline_id}/valves") +@router.get("/{pipeline_id}/valves") async def get_pipeline_valves( request: Request, urlIdx: Optional[int], @@ -289,7 +413,7 @@ async def get_pipeline_valves( ) -@router.get("/api/pipelines/{pipeline_id}/valves/spec") +@router.get("/{pipeline_id}/valves/spec") async def get_pipeline_valves_spec( request: Request, urlIdx: Optional[int], @@ -329,7 +453,7 @@ async def get_pipeline_valves_spec( ) -@router.post("/api/pipelines/{pipeline_id}/valves/update") +@router.post("/{pipeline_id}/valves/update") async def update_pipeline_valves( request: Request, urlIdx: Optional[int], diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 4af25d4d3..13e5a95a3 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, Response, status, Request +from fastapi.responses import JSONResponse, RedirectResponse + from pydantic import BaseModel -from starlette.responses import FileResponse from typing import Optional import logging @@ -16,6 +17,9 @@ from open_webui.utils.task import ( from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.constants import TASKS +from open_webui.routers.pipelines import process_pipeline_inlet_filter +from open_webui.utils.task import get_task_model_id + from open_webui.config import ( DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, @@ -121,9 +125,7 @@ async def update_task_config( async def generate_title( request: Request, form_data: dict, user=Depends(get_verified_user) ): - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -191,7 +193,7 @@ Artificial Intelligence in Healthcare # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -220,8 +222,7 @@ async def generate_chat_tags( content={"detail": "Tags generation is disabled"}, ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -281,7 +282,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -318,8 +319,7 @@ async def generate_queries( detail=f"Query generation is disabled", ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -363,7 +363,7 @@ async def generate_queries( # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -405,8 +405,7 @@ async def generate_autocompletion( detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", ) - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -450,7 +449,7 @@ async def generate_autocompletion( # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -473,8 +472,7 @@ async def generate_emoji( request: Request, form_data: dict, user=Depends(get_verified_user) ): - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -525,7 +523,7 @@ Message: """{{prompt}}""" # Handle pipeline filters try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -548,10 +546,9 @@ async def generate_moa_response( request: Request, form_data: dict, user=Depends(get_verified_user) ): - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - + models = request.app.state.MODELS model_id = form_data["model"] + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -593,7 +590,7 @@ Responses from models: {{responses}}""" } try: - payload = filter_pipeline(payload, user, models) + payload = process_pipeline_inlet_filter(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 604161a31..ebb7483ba 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -16,6 +16,22 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +def get_task_model_id( + default_model_id: str, task_model: str, task_model_external: str, models +) -> str: + # Set the task model + task_model_id = default_model_id + # Check if the user has a custom task model and use that model + if models[task_model_id]["owned_by"] == "ollama": + if task_model and task_model in models: + task_model_id = task_model + else: + if task_model_external and task_model_external in models: + task_model_id = task_model_external + + return task_model_id + + def prompt_template( template: str, user_name: Optional[str] = None, user_location: Optional[str] = None ) -> str: