typing and tweaks

This commit is contained in:
Michael Poluektov 2024-08-11 15:16:57 +01:00
parent 790bdcf9fc
commit d598d4bb93
2 changed files with 28 additions and 26 deletions

View File

@ -1,12 +1,10 @@
from pydantic import BaseModel, ConfigDict, parse_obj_as from pydantic import BaseModel, ConfigDict
from typing import Union, Optional from typing import Optional
import time import time
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
class UsersTable: class UsersTable:
def insert_new_user( def insert_new_user(
self, self,
id: str, id: str,
@ -122,7 +119,6 @@ class UsersTable:
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first() user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
@ -131,7 +127,6 @@ class UsersTable:
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(email=email).first() user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
@ -140,7 +135,6 @@ class UsersTable:
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first() user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
@ -195,7 +189,6 @@ class UsersTable:
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
db.query(User).filter_by(id=id).update( db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())} {"last_active_at": int(time.time())}
) )

View File

@ -57,7 +57,7 @@ from apps.webui.models.auths import Auths
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.models.users import Users from apps.webui.models.users import Users, User
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
@ -322,7 +322,7 @@ async def call_tool_from_completion(
return None return None
def get_tool_calling_payload(messages, task_model_id, content): def get_tool_call_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages) user_message = get_last_user_message(messages)
history = "\n".join( history = "\n".join(
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
@ -345,13 +345,19 @@ def get_tool_calling_payload(messages, task_model_id, content):
async def get_tool_call_response( async def get_tool_call_response(
messages, files, tool_id, template, task_model_id, user, extra_params messages, files, tool_id, template, task_model_id, user, extra_params
) -> tuple[Optional[str], Optional[dict], bool]: ) -> tuple[Optional[str], Optional[dict], bool]:
"""
return: tuple of (function_result, citation, file_handler) where
- function_result: Optional[str] is the result of the tool call if successful
- citation: Optional[dict] is the citation object if the tool has citation
- file_handler: bool, True if tool handles files
"""
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
if tool is None: if tool is None:
return None, None, False return None, None, False
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
content = tool_calling_generation_template(template, tools_specs) content = tool_calling_generation_template(template, tools_specs)
payload = get_tool_calling_payload(messages, task_model_id, content) payload = get_tool_call_payload(messages, task_model_id, content)
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user)
@ -486,7 +492,9 @@ async def chat_completion_inlets_handler(body, model, extra_params):
return body, {} return body, {}
async def chat_completion_tools_handler(body, user, extra_params): async def chat_completion_tools_handler(
body: dict, user: User, extra_params: dict
) -> tuple[dict, dict]:
skip_files = None skip_files = None
contexts = [] contexts = []
@ -498,21 +506,22 @@ async def chat_completion_tools_handler(body, user, extra_params):
if "tool_ids" not in body: if "tool_ids" not in body:
return body, {} return body, {}
print(body["tool_ids"]) log.debug(f"tool_ids: {body['tool_ids']}")
kwargs = {
"messages": body["messages"],
"files": body.get("files", []),
"template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
"task_model_id": task_model_id,
"user": user,
"extra_params": extra_params,
}
for tool_id in body["tool_ids"]: for tool_id in body["tool_ids"]:
print(tool_id) log.debug(f"{tool_id=}")
try: try:
response, citation, file_handler = await get_tool_call_response( response, citation, file_handler = await get_tool_call_response(
messages=body["messages"], tool_id=tool_id, **kwargs
files=body.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
extra_params=extra_params,
) )
print(file_handler)
if isinstance(response, str): if isinstance(response, str):
contexts.append(response) contexts.append(response)
@ -526,10 +535,10 @@ async def chat_completion_tools_handler(body, user, extra_params):
skip_files = True skip_files = True
except Exception as e: except Exception as e:
print(f"Error: {e}") log.exception(f"Error: {e}")
del body["tool_ids"] del body["tool_ids"]
print(f"tool_contexts: {contexts}") log.debug(f"tool_contexts: {contexts}")
if skip_files: if skip_files:
if "files" in body: if "files" in body: