From d598d4bb9397effab7df147ca3ef0913a787c028 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sun, 11 Aug 2024 15:16:57 +0100 Subject: [PATCH] typing and tweaks --- backend/apps/webui/models/users.py | 13 +++------- backend/main.py | 41 ++++++++++++++++++------------ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 36dfa4f85..b6e85e2ca 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -1,12 +1,10 @@ -from pydantic import BaseModel, ConfigDict, parse_obj_as -from typing import Union, Optional +from pydantic import BaseModel, ConfigDict +from typing import Optional import time from sqlalchemy import String, Column, BigInteger, Text -from utils.misc import get_gravatar_url - -from apps.webui.internal.db import Base, JSONField, Session, get_db +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.chats import Chats #################### @@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel): class UsersTable: - def insert_new_user( self, id: str, @@ -122,7 +119,6 @@ class UsersTable: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(api_key=api_key).first() return UserModel.model_validate(user) except Exception: @@ -131,7 +127,6 @@ class UsersTable: def get_user_by_email(self, email: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) except Exception: @@ -140,7 +135,6 @@ class UsersTable: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(oauth_sub=sub).first() return UserModel.model_validate(user) except Exception: @@ -195,7 +189,6 @@ class UsersTable: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: with get_db() as db: - db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) diff --git a/backend/main.py b/backend/main.py index c992f175f..0a1768fd6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -57,7 +57,7 @@ from apps.webui.models.auths import Auths from apps.webui.models.models import Models from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions -from apps.webui.models.users import Users +from apps.webui.models.users import Users, User from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id @@ -322,7 +322,7 @@ async def call_tool_from_completion( return None -def get_tool_calling_payload(messages, task_model_id, content): +def get_tool_call_payload(messages, task_model_id, content): user_message = get_last_user_message(messages) history = "\n".join( f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" @@ -345,13 +345,19 @@ def get_tool_calling_payload(messages, task_model_id, content): async def get_tool_call_response( messages, files, tool_id, template, task_model_id, user, extra_params ) -> tuple[Optional[str], Optional[dict], bool]: + """ + return: tuple of (function_result, citation, file_handler) where + - function_result: Optional[str] is the result of the tool call if successful + - citation: Optional[dict] is the citation object if the tool has citation + - file_handler: bool, True if tool handles files + """ tool = Tools.get_tool_by_id(tool_id) if tool is None: return None, None, False tools_specs = json.dumps(tool.specs, indent=2) content = tool_calling_generation_template(template, tools_specs) - payload = get_tool_calling_payload(messages, task_model_id, content) + payload = get_tool_call_payload(messages, task_model_id, content) try: payload = filter_pipeline(payload, user) @@ -486,7 +492,9 @@ async def chat_completion_inlets_handler(body, model, extra_params): return body, {} -async def chat_completion_tools_handler(body, user, extra_params): +async def chat_completion_tools_handler( + body: dict, user: User, extra_params: dict +) -> tuple[dict, dict]: skip_files = None contexts = [] @@ -498,21 +506,22 @@ async def chat_completion_tools_handler(body, user, extra_params): if "tool_ids" not in body: return body, {} - print(body["tool_ids"]) + log.debug(f"tool_ids: {body['tool_ids']}") + kwargs = { + "messages": body["messages"], + "files": body.get("files", []), + "template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + "task_model_id": task_model_id, + "user": user, + "extra_params": extra_params, + } for tool_id in body["tool_ids"]: - print(tool_id) + log.debug(f"{tool_id=}") try: response, citation, file_handler = await get_tool_call_response( - messages=body["messages"], - files=body.get("files", []), - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - extra_params=extra_params, + tool_id=tool_id, **kwargs ) - print(file_handler) if isinstance(response, str): contexts.append(response) @@ -526,10 +535,10 @@ async def chat_completion_tools_handler(body, user, extra_params): skip_files = True except Exception as e: - print(f"Error: {e}") + log.exception(f"Error: {e}") del body["tool_ids"] - print(f"tool_contexts: {contexts}") + log.debug(f"tool_contexts: {contexts}") if skip_files: if "files" in body: