Merge pull request #4602 from michaelpoluektov/tools-refac-1

refactor, perf: Tools refactor (progress PR 1)
This commit is contained in:
Timothy Jaeryang Baek 2024-08-17 15:50:40 +02:00 committed by GitHub
commit cbb0940ff8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 358 additions and 455 deletions

View File

@ -1,12 +1,10 @@
from pydantic import BaseModel, ConfigDict, parse_obj_as from pydantic import BaseModel, ConfigDict
from typing import Union, Optional from typing import Optional
import time import time
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
class UsersTable: class UsersTable:
def insert_new_user( def insert_new_user(
self, self,
id: str, id: str,
@ -122,7 +119,6 @@ class UsersTable:
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first() user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
@ -131,7 +127,6 @@ class UsersTable:
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(email=email).first() user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
@ -140,7 +135,6 @@ class UsersTable:
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first() user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
@ -195,7 +189,6 @@ class UsersTable:
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
db.query(User).filter_by(id=id).update( db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())} {"last_active_at": int(time.time())}
) )

View File

@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional, Callable, Awaitable
from apps.webui.models.auths import Auths from apps.webui.models.auths import Auths
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.models.users import Users from apps.webui.models.users import Users, UserModel
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
@ -72,7 +72,7 @@ from utils.utils import (
from utils.task import ( from utils.task import (
title_generation_template, title_generation_template,
search_query_generation_template, search_query_generation_template,
tools_function_calling_generation_template, tool_calling_generation_template,
) )
from utils.misc import ( from utils.misc import (
get_last_user_message, get_last_user_message,
@ -261,6 +261,7 @@ def get_filter_function_ids(model):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"): if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel
return (function.valves if function.valves else {}).get("priority", 0) return (function.valves if function.valves else {}).get("priority", 0)
return 0 return 0
@ -282,164 +283,42 @@ def get_filter_function_ids(model):
return filter_ids return filter_ids
async def get_function_call_response( async def get_content_from_response(response) -> Optional[str]:
messages, content = None
files, if hasattr(response, "body_iterator"):
tool_id, async for chunk in response.body_iterator:
template, data = json.loads(chunk.decode("utf-8"))
task_model_id, content = data["choices"][0]["message"]["content"]
user,
__event_emitter__=None,
__event_call__=None,
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
return content
def get_tool_call_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages) user_message = get_last_user_message(messages)
prompt = ( history = "\n".join(
"History:\n" f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
+ "\n".join( for message in messages[::-1][:4]
[
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
]
)
+ f"\nQuery: {user_message}"
) )
print(prompt) prompt = f"History:\n{history}\nQuery: {user_message}"
payload = { return {
"model": task_model_id, "model": task_model_id,
"messages": [ "messages": [
{"role": "system", "content": content}, {"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"}, {"role": "user", "content": f"Query: {prompt}"},
], ],
"stream": False, "stream": False,
"task": str(TASKS.FUNCTION_CALLING), "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
} }
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
model = app.state.MODELS[task_model_id] async def chat_completion_inlets_handler(body, model, extra_params):
response = None
try:
response = await generate_chat_completions(form_data=payload, user=user)
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
if content is None:
return None, None, False
# Parse the function response
print(f"content: {content}")
result = json.loads(content)
print(result)
citation = None
if "name" not in result:
return None, None, False
# Call the function
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
params = result["parameters"]
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
except Exception as e:
print(f"Error: {e}")
return None, None, False
async def chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None skip_files = None
filter_ids = get_filter_function_ids(model) filter_ids = get_filter_function_ids(model)
@ -475,37 +354,20 @@ async def chat_completion_functions_handler(
params = {"body": body} params = {"body": body}
# Extra parameters to be passed to the function # Extra parameters to be passed to the function
extra_params = { custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
"__model__": model, if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try: try:
if hasattr(function_module, "UserValves"): uid = custom_params["__user__"]["id"]
__user__["valves"] = function_module.UserValves( custom_params["__user__"]["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id( **Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
filter_id, user.id )
)
)
except Exception as e: except Exception as e:
print(e) print(e)
params = {**params, "__user__": __user__} # Add extra params in contained in function signature
for key, value in custom_params.items():
if key in sig.parameters:
params[key] = value
if inspect.iscoroutinefunction(inlet): if inspect.iscoroutinefunction(inlet):
body = await inlet(**params) body = await inlet(**params)
@ -516,74 +378,171 @@ async def chat_completion_functions_handler(
print(f"Error: {e}") print(f"Error: {e}")
raise e raise e
if skip_files: if skip_files and "files" in body:
if "files" in body: del body["files"]
del body["files"]
return body, {} return body, {}
async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__): def get_tool_with_custom_params(
skip_files = None tool: Callable, custom_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(tool)
extra_params = {
key: value for key, value in custom_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(tool)
async def new_tool(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await tool(**extra_kwargs)
return tool(**extra_kwargs)
return new_tool
# Mutation on extra_params
def get_configured_tools(
tool_ids: list[str], extra_params: dict, user: UserModel
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
has_citation = hasattr(module, "citation") and module.citation
handles_files = hasattr(module, "file_handler") and module.file_handler
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
name = spec["name"]
callable = getattr(module, name)
# convert to function that takes only model params and inserts custom params
custom_callable = get_tool_with_custom_params(callable, extra_params)
# TODO: This needs to be a pydantic model
tool_dict = {
"spec": spec,
"citation": has_citation,
"file_handler": handles_files,
"toolkit_id": tool_id,
"callable": custom_callable,
}
# TODO: if collision, prepend toolkit name
if name in tools:
log.warning(f"Tool {name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{name}")
else:
tools[name] = tool_dict
return tools
async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]:
skip_files = False
contexts = [] contexts = []
citations = None citations = []
task_model_id = get_task_model_id(body["model"]) task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
if "tool_ids" in body: tool_ids = body.pop("tool_ids", None)
print(body["tool_ids"]) if not tool_ids:
for tool_id in body["tool_ids"]: return body, {}
print(tool_id)
try:
response, citation, file_handler = await get_function_call_response(
messages=body["messages"],
files=body.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
__event_emitter__=__event_emitter__,
__event_call__=__event_call__,
)
print(file_handler) log.debug(f"{tool_ids=}")
if isinstance(response, str): custom_params = {
contexts.append(response) **extra_params,
"__model__": app.state.MODELS[task_model_id],
if citation: "__messages__": body["messages"],
if citations is None: "__files__": body.get("files", []),
citations = [citation]
else:
citations.append(citation)
if file_handler:
skip_files = True
except Exception as e:
print(f"Error: {e}")
del body["tool_ids"]
print(f"tool_contexts: {contexts}")
if skip_files:
if "files" in body:
del body["files"]
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
} }
configured_tools = get_configured_tools(tool_ids, custom_params, user)
log.info(f"{configured_tools=}")
async def chat_completion_files_handler(body): specs = [tool["spec"] for tool in configured_tools.values()]
contexts = [] tools_specs = json.dumps(specs)
citations = None template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
content = tool_calling_generation_template(template, tools_specs)
payload = get_tool_call_payload(body["messages"], task_model_id, content)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
if "files" in body: try:
files = body["files"] response = await generate_chat_completions(form_data=payload, user=user)
log.debug(f"{response=}")
content = await get_content_from_response(response)
log.debug(f"{content=}")
if content is None:
return body, {}
result = json.loads(content)
tool_name = result.get("name", None)
if tool_name not in configured_tools:
return body, {}
tool_params = result.get("parameters", {})
toolkit_id = configured_tools[tool_name]["toolkit_id"]
try:
tool_output = await configured_tools[tool_name]["callable"](**tool_params)
except Exception as e:
tool_output = str(e)
if configured_tools[tool_name]["citation"]:
citations.append(
{
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
"document": [tool_output],
"metadata": [{"source": tool_name}],
}
)
if configured_tools[tool_name]["file_handler"]:
skip_files = True
if isinstance(tool_output, str):
contexts.append(tool_output)
except Exception as e:
print(f"Error: {e}")
content = None
log.debug(f"tool_contexts: {contexts}")
if skip_files and "files" in body:
del body["files"] del body["files"]
return body, {"contexts": contexts, "citations": citations}
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts = []
citations = []
if files := body.pop("files", None):
contexts, citations = get_rag_context( contexts, citations = get_rag_context(
files=files, files=files,
messages=body["messages"], messages=body["messages"],
@ -596,134 +555,130 @@ async def chat_completion_files_handler(body):
log.debug(f"rag_contexts: {contexts}, citations: {citations}") log.debug(f"rag_contexts: {contexts}, citations: {citations}")
return body, { return body, {"contexts": contexts, "citations": citations}
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
} def is_chat_completion_request(request):
return request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
)
class ChatCompletionMiddleware(BaseHTTPMiddleware): class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if request.method == "POST" and any( if not is_chat_completion_request(request):
endpoint in request.url.path return await call_next(request)
for endpoint in ["/ollama/api/chat", "/chat/completions"] log.debug(f"request.url.path: {request.url.path}")
):
log.debug(f"request.url.path: {request.url.path}")
try: try:
body, model, user = await get_body_and_model_and_user(request) body, model, user = await get_body_and_model_and_user(request)
except Exception as e: except Exception as e:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)}, content={"detail": str(e)},
)
metadata = {
"chat_id": body.pop("chat_id", None),
"message_id": body.pop("id", None),
"session_id": body.pop("session_id", None),
"valves": body.pop("valves", None),
}
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
extra_params = {
"__user__": __user__,
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
}
# Initialize data_items to store additional data to be sent to the client
# Initalize contexts and citation
data_items = []
contexts = []
citations = []
try:
body, flags = await chat_completion_inlets_handler(
body, model, extra_params
)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
try:
body, flags = await chat_completion_tools_handler(body, user, extra_params)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
log.exception(e)
try:
body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
log.exception(e)
# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
if prompt is None:
raise Exception("No user message found")
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
) )
metadata = {
"chat_id": body.pop("chat_id", None),
"message_id": body.pop("id", None),
"session_id": body.pop("session_id", None),
"valves": body.pop("valves", None),
}
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
# Initialize data_items to store additional data to be sent to the client
data_items = []
# Initialize context, and citations
contexts = []
citations = []
try:
body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
try:
body, flags = await chat_completion_tools_handler(
body, user, __event_emitter__, __event_call__
)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
try:
body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
else:
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
body["metadata"] = metadata
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
]
response = await call_next(request)
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, data_items),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, data_items),
)
return response
else: else:
return response body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
body["metadata"] = metadata
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
]
# If it's not a chat completion request, just pass it through
response = await call_next(request) response = await call_next(request)
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers["Content-Type"]
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, data_items),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, data_items),
)
return response return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
@ -790,19 +745,21 @@ def filter_pipeline(payload, user):
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "": if key == "":
headers = {"Authorization": f"Bearer {key}"} continue
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status() headers = {"Authorization": f"Bearer {key}"}
payload = r.json() r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
payload = r.json()
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") print(f"Connection error: {e}")
@ -817,44 +774,39 @@ def filter_pipeline(payload, user):
class PipelineMiddleware(BaseHTTPMiddleware): class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if request.method == "POST" and ( if not is_chat_completion_request(request):
"/ollama/api/chat" in request.url.path return await call_next(request)
or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body log.debug(f"request.url.path: {request.url.path}")
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user( # Read the original request body
request, body = await request.body()
get_http_authorization_cred(request.headers.get("Authorization")), # Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers["Authorization"]),
)
try:
data = filter_pipeline(data, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
) )
try: modified_body_bytes = json.dumps(data).encode("utf-8")
data = filter_pipeline(data, user) # Replace the request body with the modified one
except Exception as e: request._body = modified_body_bytes
return JSONResponse( # Set custom header to ensure content-length matches new body length
status_code=e.args[0], request.headers.__dict__["_list"] = [
content={"detail": e.args[1]}, (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
) *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
]
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
]
response = await call_next(request) response = await call_next(request)
return response return response
@ -1019,6 +971,8 @@ async def get_all_models():
model["actions"] = [] model["actions"] = []
for action_id in action_ids: for action_id in action_ids:
action = Functions.get_function_by_id(action_id) action = Functions.get_function_by_id(action_id)
if action is None:
raise Exception(f"Action not found: {action_id}")
if action_id in webui_app.state.FUNCTIONS: if action_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id] function_module = webui_app.state.FUNCTIONS[action_id]
@ -1099,22 +1053,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
task = None
if "task" in form_data:
task = form_data["task"]
del form_data["task"]
if task:
if "metadata" in form_data:
form_data["metadata"]["task"] = task
else:
form_data["metadata"] = {"task": task}
if model.get("pipe"): if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
print("generate_ollama_chat_completion")
return await generate_ollama_chat_completion(form_data, user=user) return await generate_ollama_chat_completion(form_data, user=user)
else: else:
return await generate_openai_chat_completion(form_data, user=user) return await generate_openai_chat_completion(form_data, user=user)
@ -1198,6 +1139,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"): if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0) return (function.valves if function.valves else {}).get("priority", 0)
return 0 return 0
@ -1487,7 +1429,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False, "stream": False,
"max_tokens": 50, "max_tokens": 50,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": str(TASKS.TITLE_GENERATION), "metadata": {"task": str(TASKS.TITLE_GENERATION)},
} }
log.debug(payload) log.debug(payload)
@ -1540,7 +1482,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"max_tokens": 30, "max_tokens": 30,
"task": str(TASKS.QUERY_GENERATION), "metadata": {"task": str(TASKS.QUERY_GENERATION)},
} }
print(payload) print(payload)
@ -1597,7 +1539,7 @@ Message: """{{prompt}}"""
"stream": False, "stream": False,
"max_tokens": 4, "max_tokens": 4,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": str(TASKS.EMOJI_GENERATION), "metadata": {"task": str(TASKS.EMOJI_GENERATION)},
} }
log.debug(payload) log.debug(payload)
@ -1616,41 +1558,6 @@ Message: """{{prompt}}"""
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
model_id = get_task_model_id(model_id)
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context, _, _ = await get_function_call_response(
form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
template,
model_id,
user,
)
return context
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
################################## ##################################
# #
# Pipelines Endpoints # Pipelines Endpoints
@ -1689,7 +1596,7 @@ async def upload_pipeline(
): ):
print("upload_pipeline", urlIdx, file.filename) print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file # Check if the uploaded file is a python file
if not file.filename.endswith(".py"): if not (file.filename and file.filename.endswith(".py")):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.", detail="Only Python (.py) files are allowed.",
@ -2138,7 +2045,10 @@ async def oauth_login(provider: str, request: Request):
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider "oauth_callback", provider=provider
) )
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) client = oauth.create_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
# OAuth login logic is as follows: # OAuth login logic is as follows:

View File

@ -121,6 +121,6 @@ def search_query_generation_template(
return template return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: def tool_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs) template = template.replace("{{TOOLS}}", tools_specs)
return template return template