From afd74213cc3af8e9fe90ee7faa6c90853988184e Mon Sep 17 00:00:00 2001 From: Michael Poluektov <78477503+michaelpoluektov@users.noreply.github.com> Date: Tue, 2 Jul 2024 10:20:50 +0100 Subject: [PATCH 01/11] Update main.py --- backend/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/main.py b/backend/main.py index 0e3986f21..70aa52bba 100644 --- a/backend/main.py +++ b/backend/main.py @@ -311,6 +311,7 @@ async def get_function_call_response( {"role": "user", "content": f"Query: {prompt}"}, ], "stream": False, + "function": True, } try: From 09514751b553fd7713970af44033acf8f6e9f75a Mon Sep 17 00:00:00 2001 From: Michael Poluektov <78477503+michaelpoluektov@users.noreply.github.com> Date: Tue, 2 Jul 2024 10:57:56 +0100 Subject: [PATCH 02/11] Update main.py --- backend/main.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/backend/main.py b/backend/main.py index 70aa52bba..91b2495f6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -834,11 +834,9 @@ def filter_pipeline(payload, user): pass if "pipeline" not in app.state.MODELS[model_id]: - if "title" in payload: - del payload["title"] - - if "task" in payload: - del payload["task"] + for key in ["title", "task", "function"]: + if key in payload: + del payload[key] return payload From 02f242e9e81648ef2f4d6bf529e0821a87b4afb8 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 2 Jul 2024 11:52:46 +0100 Subject: [PATCH 03/11] keep title, task, function tags for pipelines --- backend/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backend/main.py b/backend/main.py index 91b2495f6..b2fe38245 100644 --- a/backend/main.py +++ b/backend/main.py @@ -833,7 +833,10 @@ def filter_pipeline(payload, user): else: pass - if "pipeline" not in app.state.MODELS[model_id]: + keep_extras = ( + "pipeline" in app.state.MODELS[model_id] or "pipe" in app.state.MODELS[model_id] + ) + if not keep_extras: for key in ["title", "task", "function"]: if key in payload: del payload[key] From 24c7990fd4440b5c4734d8713302f47cd2a64c67 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 3 Jul 2024 09:12:00 +0100 Subject: [PATCH 04/11] revert not delete on pipe --- backend/main.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/backend/main.py b/backend/main.py index b2fe38245..91b2495f6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -833,10 +833,7 @@ def filter_pipeline(payload, user): else: pass - keep_extras = ( - "pipeline" in app.state.MODELS[model_id] or "pipe" in app.state.MODELS[model_id] - ) - if not keep_extras: + if "pipeline" not in app.state.MODELS[model_id]: for key in ["title", "task", "function"]: if key in payload: del payload[key] From e3e02e04e87a39d858f42b588d8fa6aa5eca056c Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 9 Jul 2024 11:51:43 +0100 Subject: [PATCH 05/11] refac: backend/main.py --- backend/main.py | 474 +++++++++++++++++++++--------------------------- 1 file changed, 204 insertions(+), 270 deletions(-) diff --git a/backend/main.py b/backend/main.py index 6ded8e1d2..055bd28c9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,13 +1,10 @@ import base64 import uuid -import subprocess from contextlib import asynccontextmanager from authlib.integrations.starlette_client import OAuth from authlib.oidc.core import UserInfo -from bs4 import BeautifulSoup import json -import markdown import time import os import sys @@ -19,14 +16,11 @@ import shutil import os import uuid import inspect -import asyncio -from fastapi.concurrency import run_in_threadpool from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse from fastapi import HTTPException -from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from sqlalchemy import text from starlette.exceptions import HTTPException as StarletteHTTPException @@ -38,7 +32,6 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse from apps.socket.main import sio, app as socket_app from apps.ollama.main import ( app as ollama_app, - OpenAIChatCompletionForm, get_all_models as get_ollama_models, generate_openai_chat_completion as generate_ollama_chat_completion, ) @@ -56,14 +49,14 @@ from apps.webui.main import ( get_pipe_models, generate_function_chat_completion, ) -from apps.webui.internal.db import Session, SessionLocal +from apps.webui.internal.db import Session from pydantic import BaseModel -from typing import List, Optional, Iterator, Generator, Union +from typing import List, Optional from apps.webui.models.auths import Auths -from apps.webui.models.models import Models, ModelModel +from apps.webui.models.models import Models from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions from apps.webui.models.users import Users @@ -86,14 +79,12 @@ from utils.task import ( from utils.misc import ( get_last_user_message, add_or_update_system_message, - stream_message_template, parse_duration, ) from apps.rag.utils import get_rag_context, rag_template from config import ( - CONFIG_DATA, WEBUI_NAME, WEBUI_URL, WEBUI_AUTH, @@ -101,7 +92,6 @@ from config import ( VERSION, CHANGELOG, FRONTEND_BUILD_DIR, - UPLOAD_DIR, CACHE_DIR, STATIC_DIR, DEFAULT_LOCALE, @@ -128,9 +118,8 @@ from config import ( WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, AppConfig, - BACKEND_DIR, - DATABASE_URL, ) + from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS from utils.webhook import post_webhook @@ -355,121 +344,94 @@ async def get_function_call_response( else: content = response["choices"][0]["message"]["content"] + if content is None: + return None, None, False + # Parse the function response - if content is not None: - print(f"content: {content}") - result = json.loads(content) - print(result) + print(f"content: {content}") + result = json.loads(content) + print(result) - citation = None - # Call the function - if "name" in result: - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module + citation = None - file_handler = False - # check if toolkit_module has file_handler self variable - if hasattr(toolkit_module, "file_handler"): - file_handler = True - print("file_handler: ", file_handler) + if "name" not in result: + return None, None, False - if hasattr(toolkit_module, "valves") and hasattr( - toolkit_module, "Valves" - ): - valves = Tools.get_tool_valves_by_id(tool_id) - toolkit_module.valves = toolkit_module.Valves( - **(valves if valves else {}) - ) + # Call the function + if tool_id in webui_app.state.TOOLS: + toolkit_module = webui_app.state.TOOLS[tool_id] + else: + toolkit_module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = toolkit_module + + file_handler = False + # check if toolkit_module has file_handler self variable + if hasattr(toolkit_module, "file_handler"): + file_handler = True + print("file_handler: ", file_handler) + + if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) + toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) + + function = getattr(toolkit_module, result["name"]) + function_result = None + try: + # Get the signature of the function + sig = inspect.signature(function) + params = result["parameters"] + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": tool_id, + "__messages__": messages, + "__files__": files, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + # Call the function with the '__user__' parameter included + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } - function = getattr(toolkit_module, result["name"]) - function_result = None try: - # Get the signature of the function - sig = inspect.signature(function) - params = result["parameters"] - - if "__user__" in sig.parameters: - # Call the function with the '__user__' parameter included - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(toolkit_module, "UserValves"): - __user__["valves"] = toolkit_module.UserValves( - **Tools.get_user_valves_by_id_and_user_id( - tool_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - if "__messages__" in sig.parameters: - # Call the function with the '__messages__' parameter included - params = { - **params, - "__messages__": messages, - } - - if "__files__" in sig.parameters: - # Call the function with the '__files__' parameter included - params = { - **params, - "__files__": files, - } - - if "__model__" in sig.parameters: - # Call the function with the '__model__' parameter included - params = { - **params, - "__model__": model, - } - - if "__id__" in sig.parameters: - # Call the function with the '__id__' parameter included - params = { - **params, - "__id__": tool_id, - } - - if "__event_emitter__" in sig.parameters: - # Call the function with the '__event_emitter__' parameter included - params = { - **params, - "__event_emitter__": __event_emitter__, - } - - if "__event_call__" in sig.parameters: - # Call the function with the '__event_call__' parameter included - params = { - **params, - "__event_call__": __event_call__, - } - - if inspect.iscoroutinefunction(function): - function_result = await function(**params) - else: - function_result = function(**params) - - if hasattr(toolkit_module, "citation") and toolkit_module.citation: - citation = { - "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, - "document": [function_result], - "metadata": [{"source": result["name"]}], - } + if hasattr(toolkit_module, "UserValves"): + __user__["valves"] = toolkit_module.UserValves( + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) except Exception as e: print(e) - # Add the function result to the system prompt - if function_result is not None: - return function_result, citation, file_handler + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(function): + function_result = await function(**params) + else: + function_result = function(**params) + + if hasattr(toolkit_module, "citation") and toolkit_module.citation: + citation = { + "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, + "document": [function_result], + "metadata": [{"source": result["name"]}], + } + except Exception as e: + print(e) + + # Add the function result to the system prompt + if function_result is not None: + return function_result, citation, file_handler except Exception as e: print(f"Error: {e}") @@ -484,87 +446,74 @@ async def chat_completion_functions_handler( filter_ids = get_filter_function_ids(model) for filter_id in filter_ids: filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type, frontmatter = ( - load_function_module_by_id(filter_id) - ) - webui_app.state.FUNCTIONS[filter_id] = function_module + if not filter: + continue - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module - if hasattr(function_module, "valves") and hasattr( - function_module, "Valves" - ): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + # Check if the function has a file_handler variable + if hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler - try: - if hasattr(function_module, "inlet"): - inlet = function_module.inlet + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) - # Get the signature of the function - sig = inspect.signature(inlet) - params = {"body": body} + try: + if hasattr(function_module, "inlet"): + inlet = function_module.inlet - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } + # Get the signature of the function + sig = inspect.signature(inlet) + params = {"body": body} - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id ) - except Exception as e: - print(e) + ) + except Exception as e: + print(e) - params = {**params, "__user__": __user__} + params = {**params, "__user__": __user__} - if "__id__" in sig.parameters: - params = { - **params, - "__id__": filter_id, - } + if inspect.iscoroutinefunction(inlet): + body = await inlet(**params) + else: + body = inlet(**params) - if "__model__" in sig.parameters: - params = { - **params, - "__model__": model, - } - - if "__event_emitter__" in sig.parameters: - params = { - **params, - "__event_emitter__": __event_emitter__, - } - - if "__event_call__" in sig.parameters: - params = { - **params, - "__event_call__": __event_call__, - } - - if inspect.iscoroutinefunction(inlet): - body = await inlet(**params) - else: - body = inlet(**params) - - except Exception as e: - print(f"Error: {e}") - raise e + except Exception as e: + print(f"Error: {e}") + raise e if skip_files: if "files" in body: @@ -1220,86 +1169,73 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): for filter_id in filter_ids: filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type, frontmatter = ( - load_function_module_by_id(filter_id) - ) - webui_app.state.FUNCTIONS[filter_id] = function_module + if not filter: + continue - if hasattr(function_module, "valves") and hasattr( - function_module, "Valves" - ): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, _, _ = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module - try: - if hasattr(function_module, "outlet"): - outlet = function_module.outlet + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} + try: + if hasattr(function_module, "outlet"): + outlet = function_module.outlet - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id ) - except Exception as e: - print(e) + ) + except Exception as e: + print(e) - params = {**params, "__user__": __user__} + params = {**params, "__user__": __user__} - if "__id__" in sig.parameters: - params = { - **params, - "__id__": filter_id, - } + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) - if "__model__" in sig.parameters: - params = { - **params, - "__model__": model, - } - - if "__event_emitter__" in sig.parameters: - params = { - **params, - "__event_emitter__": __event_emitter__, - } - - if "__event_call__" in sig.parameters: - params = { - **params, - "__event_call__": __event_call__, - } - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) return data @@ -1387,7 +1323,6 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): model_id = task_model_id print(model_id) - model = app.state.MODELS[model_id] template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE @@ -1456,7 +1391,6 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) model_id = task_model_id print(model_id) - model = app.state.MODELS[model_id] template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE @@ -1513,7 +1447,6 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): model_id = task_model_id print(model_id) - model = app.state.MODELS[model_id] template = ''' Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). @@ -1583,7 +1516,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE try: - context, citation, file_handler = await get_function_call_response( + context, _, _ = await get_function_call_response( form_data["messages"], form_data.get("files", []), form_data["tool_id"], @@ -1647,6 +1580,7 @@ async def upload_pipeline( os.makedirs(upload_folder, exist_ok=True) file_path = os.path.join(upload_folder, file.filename) + r = None try: # Save the uploaded file with open(file_path, "wb") as buffer: @@ -1670,7 +1604,9 @@ async def upload_pipeline( print(f"Connection error: {e}") detail = "Pipeline not found" + status_code = status.HTTP_404_NOT_FOUND if r is not None: + status_code = r.status_code try: res = r.json() if "detail" in res: @@ -1679,7 +1615,7 @@ async def upload_pipeline( pass raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + status_code=status_code, detail=detail, ) finally: @@ -1778,8 +1714,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): r = None try: - urlIdx - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] From d7dd901f017568b0cd2a4ebfb1e3f0371b5f2fc6 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 9 Jul 2024 12:15:09 +0100 Subject: [PATCH 06/11] refac: remove nesting --- backend/main.py | 151 ++++++++++++++++++++++++------------------------ 1 file changed, 77 insertions(+), 74 deletions(-) diff --git a/backend/main.py b/backend/main.py index 055bd28c9..32a557e9e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -465,51 +465,53 @@ async def chat_completion_functions_handler( **(valves if valves else {}) ) + if not hasattr(function_module, "inlet"): + continue + try: - if hasattr(function_module, "inlet"): - inlet = function_module.inlet + inlet = function_module.inlet - # Get the signature of the function - sig = inspect.signature(inlet) - params = {"body": body} + # Get the signature of the function + sig = inspect.signature(inlet) + params = {"body": body} - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, } - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id ) - except Exception as e: - print(e) + ) + except Exception as e: + print(e) - params = {**params, "__user__": __user__} + params = {**params, "__user__": __user__} - if inspect.iscoroutinefunction(inlet): - body = await inlet(**params) - else: - body = inlet(**params) + if inspect.iscoroutinefunction(inlet): + body = await inlet(**params) + else: + body = inlet(**params) except Exception as e: print(f"Error: {e}") @@ -1184,51 +1186,52 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): **(valves if valves else {}) ) + if not hasattr(function_module, "outlet"): + continue try: - if hasattr(function_module, "outlet"): - outlet = function_module.outlet + outlet = function_module.outlet - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": filter_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, } - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id ) - except Exception as e: - print(e) + ) + except Exception as e: + print(e) - params = {**params, "__user__": __user__} + params = {**params, "__user__": __user__} - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) except Exception as e: print(f"Error: {e}") From ff474936f81c921922cbb8718d3215c5cba36d23 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 9 Jul 2024 12:20:28 +0100 Subject: [PATCH 07/11] refac: remove model param --- backend/main.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/backend/main.py b/backend/main.py index 32a557e9e..d785f302e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -289,7 +289,6 @@ async def get_function_call_response( template, task_model_id, user, - model, __event_emitter__=None, __event_call__=None, ): @@ -525,7 +524,7 @@ async def chat_completion_functions_handler( async def chat_completion_tools_handler( - body, model, user, __event_emitter__, __event_call__ + body, user, __event_emitter__, __event_call__ ): skip_files = None @@ -547,7 +546,6 @@ async def chat_completion_tools_handler( template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, task_model_id=task_model_id, user=user, - model=model, __event_emitter__=__event_emitter__, __event_call__=__event_call__, ) @@ -674,7 +672,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: body, flags = await chat_completion_tools_handler( - body, model, user, __event_emitter__, __event_call__ + body, user, __event_emitter__, __event_call__ ) contexts.extend(flags.get("contexts", [])) From 7ffd75b991782c9a5fe315abd80e1e0fd096d10b Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 9 Jul 2024 12:32:43 +0100 Subject: [PATCH 08/11] refac: black --- backend/main.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/backend/main.py b/backend/main.py index d785f302e..e7210bc0a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -523,9 +523,7 @@ async def chat_completion_functions_handler( return body, {} -async def chat_completion_tools_handler( - body, user, __event_emitter__, __event_call__ -): +async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__): skip_files = None contexts = [] From 144581a7df164b4521d4803c3d558f77a5cf42ac Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 9 Jul 2024 12:51:13 +0100 Subject: [PATCH 09/11] refac: get_sorted_pipelines() --- backend/main.py | 41 ++++++++++++----------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/backend/main.py b/backend/main.py index e7210bc0a..e1890b85f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -764,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware) ################################## -def filter_pipeline(payload, user): - user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} - model_id = payload["model"] +def get_sorted_filters(model_id): filters = [ model for model in app.state.MODELS.values() @@ -782,6 +780,13 @@ def filter_pipeline(payload, user): ) ] sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + return sorted_filters + + +def filter_pipeline(payload, user): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + sorted_filters = get_sorted_filters(model_id) model = app.state.MODELS[model_id] @@ -814,19 +819,12 @@ def filter_pipeline(payload, user): print(f"Connection error: {e}") if r is not None: - try: - res = r.json() - except: - pass + res = r.json() if "detail" in res: raise Exception(r.status_code, res["detail"]) - else: - pass - - if "pipeline" not in app.state.MODELS[model_id]: - if "task" in payload: - del payload["task"] + if "pipeline" not in app.state.MODELS[model_id] and "task" in payload: + del payload["task"] return payload @@ -1061,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ) model = app.state.MODELS[model_id] - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + sorted_filters = get_sorted_filters(model_id) if "pipeline" in model: sorted_filters = [model] + sorted_filters From 8f23df574919eaa3ac4bf503d02b790033de1fc0 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 9 Jul 2024 15:57:24 +0100 Subject: [PATCH 10/11] fix: outlet __event_emitter__ --- backend/main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/main.py b/backend/main.py index e1890b85f..eb1e3ffb8 100644 --- a/backend/main.py +++ b/backend/main.py @@ -226,7 +226,7 @@ async def get_body_and_model_and_user(request): model_id = body["model"] if model_id not in app.state.MODELS: - raise "Model not found" + raise Exception("Model not found") model = app.state.MODELS[model_id] user = get_current_user( @@ -1107,21 +1107,21 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): else: pass - async def __event_emitter__(data): + async def __event_emitter__(event_data): await sio.emit( "chat-events", { "chat_id": data["chat_id"], "message_id": data["id"], - "data": data, + "data": event_data, }, to=data["session_id"], ) - async def __event_call__(data): + async def __event_call__(event_data): response = await sio.call( "chat-events", - {"chat_id": data["chat_id"], "message_id": data["id"], "data": data}, + {"chat_id": data["chat_id"], "message_id": data["id"], "data": event_data}, to=data["session_id"], ) return response From 1d20c27553f019477f01d7233ebe40b11d31e479 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 9 Jul 2024 16:08:54 +0100 Subject: [PATCH 11/11] refac: use get_task_model_id() --- backend/main.py | 44 ++++---------------------------------------- 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/backend/main.py b/backend/main.py index eb1e3ffb8..89252e164 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1293,16 +1293,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - if app.state.MODELS[model_id]["owned_by"] == "ollama": - if app.state.config.TASK_MODEL: - task_model_id = app.state.config.TASK_MODEL - if task_model_id in app.state.MODELS: - model_id = task_model_id - else: - if app.state.config.TASK_MODEL_EXTERNAL: - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - if task_model_id in app.state.MODELS: - model_id = task_model_id + model_id = get_task_model_id(model_id) print(model_id) @@ -1361,16 +1352,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) # Check if the user has a custom task model # If the user has a custom task model, use that model - if app.state.MODELS[model_id]["owned_by"] == "ollama": - if app.state.config.TASK_MODEL: - task_model_id = app.state.config.TASK_MODEL - if task_model_id in app.state.MODELS: - model_id = task_model_id - else: - if app.state.config.TASK_MODEL_EXTERNAL: - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - if task_model_id in app.state.MODELS: - model_id = task_model_id + model_id = get_task_model_id(model_id) print(model_id) @@ -1417,16 +1399,7 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - if app.state.MODELS[model_id]["owned_by"] == "ollama": - if app.state.config.TASK_MODEL: - task_model_id = app.state.config.TASK_MODEL - if task_model_id in app.state.MODELS: - model_id = task_model_id - else: - if app.state.config.TASK_MODEL_EXTERNAL: - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - if task_model_id in app.state.MODELS: - model_id = task_model_id + model_id = get_task_model_id(model_id) print(model_id) @@ -1483,16 +1456,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ # Check if the user has a custom task model # If the user has a custom task model, use that model - if app.state.MODELS[model_id]["owned_by"] == "ollama": - if app.state.config.TASK_MODEL: - task_model_id = app.state.config.TASK_MODEL - if task_model_id in app.state.MODELS: - model_id = task_model_id - else: - if app.state.config.TASK_MODEL_EXTERNAL: - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - if task_model_id in app.state.MODELS: - model_id = task_model_id + model_id = get_task_model_id(model_id) print(model_id) template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE