mirror of
https://github.com/open-webui/open-webui
synced 2024-11-07 00:59:52 +00:00
typing and tweaks
This commit is contained in:
parent
790bdcf9fc
commit
d598d4bb93
@ -1,12 +1,10 @@
|
||||
from pydantic import BaseModel, ConfigDict, parse_obj_as
|
||||
from typing import Union, Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional
|
||||
import time
|
||||
|
||||
from sqlalchemy import String, Column, BigInteger, Text
|
||||
|
||||
from utils.misc import get_gravatar_url
|
||||
|
||||
from apps.webui.internal.db import Base, JSONField, Session, get_db
|
||||
from apps.webui.internal.db import Base, JSONField, get_db
|
||||
from apps.webui.models.chats import Chats
|
||||
|
||||
####################
|
||||
@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
|
||||
|
||||
|
||||
class UsersTable:
|
||||
|
||||
def insert_new_user(
|
||||
self,
|
||||
id: str,
|
||||
@ -122,7 +119,6 @@ class UsersTable:
|
||||
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
user = db.query(User).filter_by(api_key=api_key).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
@ -131,7 +127,6 @@ class UsersTable:
|
||||
def get_user_by_email(self, email: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
user = db.query(User).filter_by(email=email).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
@ -140,7 +135,6 @@ class UsersTable:
|
||||
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
user = db.query(User).filter_by(oauth_sub=sub).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
@ -195,7 +189,6 @@ class UsersTable:
|
||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
db.query(User).filter_by(id=id).update(
|
||||
{"last_active_at": int(time.time())}
|
||||
)
|
||||
|
@ -57,7 +57,7 @@ from apps.webui.models.auths import Auths
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.tools import Tools
|
||||
from apps.webui.models.functions import Functions
|
||||
from apps.webui.models.users import Users
|
||||
from apps.webui.models.users import Users, User
|
||||
|
||||
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
||||
|
||||
@ -322,7 +322,7 @@ async def call_tool_from_completion(
|
||||
return None
|
||||
|
||||
|
||||
def get_tool_calling_payload(messages, task_model_id, content):
|
||||
def get_tool_call_payload(messages, task_model_id, content):
|
||||
user_message = get_last_user_message(messages)
|
||||
history = "\n".join(
|
||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
||||
@ -345,13 +345,19 @@ def get_tool_calling_payload(messages, task_model_id, content):
|
||||
async def get_tool_call_response(
|
||||
messages, files, tool_id, template, task_model_id, user, extra_params
|
||||
) -> tuple[Optional[str], Optional[dict], bool]:
|
||||
"""
|
||||
return: tuple of (function_result, citation, file_handler) where
|
||||
- function_result: Optional[str] is the result of the tool call if successful
|
||||
- citation: Optional[dict] is the citation object if the tool has citation
|
||||
- file_handler: bool, True if tool handles files
|
||||
"""
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
if tool is None:
|
||||
return None, None, False
|
||||
|
||||
tools_specs = json.dumps(tool.specs, indent=2)
|
||||
content = tool_calling_generation_template(template, tools_specs)
|
||||
payload = get_tool_calling_payload(messages, task_model_id, content)
|
||||
payload = get_tool_call_payload(messages, task_model_id, content)
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
@ -486,7 +492,9 @@ async def chat_completion_inlets_handler(body, model, extra_params):
|
||||
return body, {}
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(body, user, extra_params):
|
||||
async def chat_completion_tools_handler(
|
||||
body: dict, user: User, extra_params: dict
|
||||
) -> tuple[dict, dict]:
|
||||
skip_files = None
|
||||
|
||||
contexts = []
|
||||
@ -498,21 +506,22 @@ async def chat_completion_tools_handler(body, user, extra_params):
|
||||
if "tool_ids" not in body:
|
||||
return body, {}
|
||||
|
||||
print(body["tool_ids"])
|
||||
log.debug(f"tool_ids: {body['tool_ids']}")
|
||||
kwargs = {
|
||||
"messages": body["messages"],
|
||||
"files": body.get("files", []),
|
||||
"template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
"task_model_id": task_model_id,
|
||||
"user": user,
|
||||
"extra_params": extra_params,
|
||||
}
|
||||
for tool_id in body["tool_ids"]:
|
||||
print(tool_id)
|
||||
log.debug(f"{tool_id=}")
|
||||
try:
|
||||
response, citation, file_handler = await get_tool_call_response(
|
||||
messages=body["messages"],
|
||||
files=body.get("files", []),
|
||||
tool_id=tool_id,
|
||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
task_model_id=task_model_id,
|
||||
user=user,
|
||||
extra_params=extra_params,
|
||||
tool_id=tool_id, **kwargs
|
||||
)
|
||||
|
||||
print(file_handler)
|
||||
if isinstance(response, str):
|
||||
contexts.append(response)
|
||||
|
||||
@ -526,10 +535,10 @@ async def chat_completion_tools_handler(body, user, extra_params):
|
||||
skip_files = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
log.exception(f"Error: {e}")
|
||||
|
||||
del body["tool_ids"]
|
||||
print(f"tool_contexts: {contexts}")
|
||||
log.debug(f"tool_contexts: {contexts}")
|
||||
|
||||
if skip_files:
|
||||
if "files" in body:
|
||||
|
Loading…
Reference in New Issue
Block a user