typing and tweaks

This commit is contained in:
Michael Poluektov 2024-08-11 15:16:57 +01:00
parent 790bdcf9fc
commit d598d4bb93
2 changed files with 28 additions and 26 deletions

View File

@ -1,12 +1,10 @@
from pydantic import BaseModel, ConfigDict, parse_obj_as
from typing import Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import Optional
import time
from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url
from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.chats import Chats
####################
@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
class UsersTable:
def insert_new_user(
self,
id: str,
@ -122,7 +119,6 @@ class UsersTable:
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except Exception:
@ -131,7 +127,6 @@ class UsersTable:
def get_user_by_email(self, email: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except Exception:
@ -140,7 +135,6 @@ class UsersTable:
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except Exception:
@ -195,7 +189,6 @@ class UsersTable:
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)

View File

@ -57,7 +57,7 @@ from apps.webui.models.auths import Auths
from apps.webui.models.models import Models
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users
from apps.webui.models.users import Users, User
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
@ -322,7 +322,7 @@ async def call_tool_from_completion(
return None
def get_tool_calling_payload(messages, task_model_id, content):
def get_tool_call_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages)
history = "\n".join(
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
@ -345,13 +345,19 @@ def get_tool_calling_payload(messages, task_model_id, content):
async def get_tool_call_response(
messages, files, tool_id, template, task_model_id, user, extra_params
) -> tuple[Optional[str], Optional[dict], bool]:
"""
return: tuple of (function_result, citation, file_handler) where
- function_result: Optional[str] is the result of the tool call if successful
- citation: Optional[dict] is the citation object if the tool has citation
- file_handler: bool, True if tool handles files
"""
tool = Tools.get_tool_by_id(tool_id)
if tool is None:
return None, None, False
tools_specs = json.dumps(tool.specs, indent=2)
content = tool_calling_generation_template(template, tools_specs)
payload = get_tool_calling_payload(messages, task_model_id, content)
payload = get_tool_call_payload(messages, task_model_id, content)
try:
payload = filter_pipeline(payload, user)
@ -486,7 +492,9 @@ async def chat_completion_inlets_handler(body, model, extra_params):
return body, {}
async def chat_completion_tools_handler(body, user, extra_params):
async def chat_completion_tools_handler(
body: dict, user: User, extra_params: dict
) -> tuple[dict, dict]:
skip_files = None
contexts = []
@ -498,21 +506,22 @@ async def chat_completion_tools_handler(body, user, extra_params):
if "tool_ids" not in body:
return body, {}
print(body["tool_ids"])
log.debug(f"tool_ids: {body['tool_ids']}")
kwargs = {
"messages": body["messages"],
"files": body.get("files", []),
"template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
"task_model_id": task_model_id,
"user": user,
"extra_params": extra_params,
}
for tool_id in body["tool_ids"]:
print(tool_id)
log.debug(f"{tool_id=}")
try:
response, citation, file_handler = await get_tool_call_response(
messages=body["messages"],
files=body.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
extra_params=extra_params,
tool_id=tool_id, **kwargs
)
print(file_handler)
if isinstance(response, str):
contexts.append(response)
@ -526,10 +535,10 @@ async def chat_completion_tools_handler(body, user, extra_params):
skip_files = True
except Exception as e:
print(f"Error: {e}")
log.exception(f"Error: {e}")
del body["tool_ids"]
print(f"tool_contexts: {contexts}")
log.debug(f"tool_contexts: {contexts}")
if skip_files:
if "files" in body: