mirror of
https://github.com/open-webui/open-webui
synced 2025-01-18 00:30:51 +00:00
Merge pull request #4602 from michaelpoluektov/tools-refac-1
refactor, perf: Tools refactor (progress PR 1)
This commit is contained in:
commit
cbb0940ff8
@ -1,12 +1,10 @@
|
||||
from pydantic import BaseModel, ConfigDict, parse_obj_as
|
||||
from typing import Union, Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional
|
||||
import time
|
||||
|
||||
from sqlalchemy import String, Column, BigInteger, Text
|
||||
|
||||
from utils.misc import get_gravatar_url
|
||||
|
||||
from apps.webui.internal.db import Base, JSONField, Session, get_db
|
||||
from apps.webui.internal.db import Base, JSONField, get_db
|
||||
from apps.webui.models.chats import Chats
|
||||
|
||||
####################
|
||||
@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
|
||||
|
||||
|
||||
class UsersTable:
|
||||
|
||||
def insert_new_user(
|
||||
self,
|
||||
id: str,
|
||||
@ -122,7 +119,6 @@ class UsersTable:
|
||||
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
user = db.query(User).filter_by(api_key=api_key).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
@ -131,7 +127,6 @@ class UsersTable:
|
||||
def get_user_by_email(self, email: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
user = db.query(User).filter_by(email=email).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
@ -140,7 +135,6 @@ class UsersTable:
|
||||
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
user = db.query(User).filter_by(oauth_sub=sub).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
@ -195,7 +189,6 @@ class UsersTable:
|
||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
db.query(User).filter_by(id=id).update(
|
||||
{"last_active_at": int(time.time())}
|
||||
)
|
||||
|
798
backend/main.py
798
backend/main.py
@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable, Awaitable
|
||||
|
||||
from apps.webui.models.auths import Auths
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.tools import Tools
|
||||
from apps.webui.models.functions import Functions
|
||||
from apps.webui.models.users import Users
|
||||
from apps.webui.models.users import Users, UserModel
|
||||
|
||||
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
||||
|
||||
@ -72,7 +72,7 @@ from utils.utils import (
|
||||
from utils.task import (
|
||||
title_generation_template,
|
||||
search_query_generation_template,
|
||||
tools_function_calling_generation_template,
|
||||
tool_calling_generation_template,
|
||||
)
|
||||
from utils.misc import (
|
||||
get_last_user_message,
|
||||
@ -261,6 +261,7 @@ def get_filter_function_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
@ -282,164 +283,42 @@ def get_filter_function_ids(model):
|
||||
return filter_ids
|
||||
|
||||
|
||||
async def get_function_call_response(
|
||||
messages,
|
||||
files,
|
||||
tool_id,
|
||||
template,
|
||||
task_model_id,
|
||||
user,
|
||||
__event_emitter__=None,
|
||||
__event_call__=None,
|
||||
):
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
tools_specs = json.dumps(tool.specs, indent=2)
|
||||
content = tools_function_calling_generation_template(template, tools_specs)
|
||||
async def get_content_from_response(response) -> Optional[str]:
|
||||
content = None
|
||||
if hasattr(response, "body_iterator"):
|
||||
async for chunk in response.body_iterator:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
# Cleanup any remaining background tasks if necessary
|
||||
if response.background is not None:
|
||||
await response.background()
|
||||
else:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
return content
|
||||
|
||||
|
||||
def get_tool_call_payload(messages, task_model_id, content):
|
||||
user_message = get_last_user_message(messages)
|
||||
prompt = (
|
||||
"History:\n"
|
||||
+ "\n".join(
|
||||
[
|
||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
||||
for message in messages[::-1][:4]
|
||||
]
|
||||
)
|
||||
+ f"\nQuery: {user_message}"
|
||||
history = "\n".join(
|
||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
||||
for message in messages[::-1][:4]
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
prompt = f"History:\n{history}\nQuery: {user_message}"
|
||||
|
||||
payload = {
|
||||
return {
|
||||
"model": task_model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": content},
|
||||
{"role": "user", "content": f"Query: {prompt}"},
|
||||
],
|
||||
"stream": False,
|
||||
"task": str(TASKS.FUNCTION_CALLING),
|
||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||
}
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
response = None
|
||||
try:
|
||||
response = await generate_chat_completions(form_data=payload, user=user)
|
||||
content = None
|
||||
|
||||
if hasattr(response, "body_iterator"):
|
||||
async for chunk in response.body_iterator:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
# Cleanup any remaining background tasks if necessary
|
||||
if response.background is not None:
|
||||
await response.background()
|
||||
else:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
|
||||
if content is None:
|
||||
return None, None, False
|
||||
|
||||
# Parse the function response
|
||||
print(f"content: {content}")
|
||||
result = json.loads(content)
|
||||
print(result)
|
||||
|
||||
citation = None
|
||||
|
||||
if "name" not in result:
|
||||
return None, None, False
|
||||
|
||||
# Call the function
|
||||
if tool_id in webui_app.state.TOOLS:
|
||||
toolkit_module = webui_app.state.TOOLS[tool_id]
|
||||
else:
|
||||
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||
|
||||
file_handler = False
|
||||
# check if toolkit_module has file_handler self variable
|
||||
if hasattr(toolkit_module, "file_handler"):
|
||||
file_handler = True
|
||||
print("file_handler: ", file_handler)
|
||||
|
||||
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id)
|
||||
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
|
||||
|
||||
function = getattr(toolkit_module, result["name"])
|
||||
function_result = None
|
||||
try:
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function)
|
||||
params = result["parameters"]
|
||||
|
||||
# Extra parameters to be passed to the function
|
||||
extra_params = {
|
||||
"__model__": model,
|
||||
"__id__": tool_id,
|
||||
"__messages__": messages,
|
||||
"__files__": files,
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
}
|
||||
|
||||
# Add extra params in contained in function signature
|
||||
for key, value in extra_params.items():
|
||||
if key in sig.parameters:
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
# Call the function with the '__user__' parameter included
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(toolkit_module, "UserValves"):
|
||||
__user__["valves"] = toolkit_module.UserValves(
|
||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if inspect.iscoroutinefunction(function):
|
||||
function_result = await function(**params)
|
||||
else:
|
||||
function_result = function(**params)
|
||||
|
||||
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
||||
citation = {
|
||||
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
|
||||
"document": [function_result],
|
||||
"metadata": [{"source": result["name"]}],
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# Add the function result to the system prompt
|
||||
if function_result is not None:
|
||||
return function_result, citation, file_handler
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
return None, None, False
|
||||
|
||||
|
||||
async def chat_completion_functions_handler(
|
||||
body, model, user, __event_emitter__, __event_call__
|
||||
):
|
||||
async def chat_completion_inlets_handler(body, model, extra_params):
|
||||
skip_files = None
|
||||
|
||||
filter_ids = get_filter_function_ids(model)
|
||||
@ -475,37 +354,20 @@ async def chat_completion_functions_handler(
|
||||
params = {"body": body}
|
||||
|
||||
# Extra parameters to be passed to the function
|
||||
extra_params = {
|
||||
"__model__": model,
|
||||
"__id__": filter_id,
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
}
|
||||
|
||||
# Add extra params in contained in function signature
|
||||
for key, value in extra_params.items():
|
||||
if key in sig.parameters:
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
|
||||
if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
uid = custom_params["__user__"]["id"]
|
||||
custom_params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
# Add extra params in contained in function signature
|
||||
for key, value in custom_params.items():
|
||||
if key in sig.parameters:
|
||||
params[key] = value
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
body = await inlet(**params)
|
||||
@ -516,74 +378,171 @@ async def chat_completion_functions_handler(
|
||||
print(f"Error: {e}")
|
||||
raise e
|
||||
|
||||
if skip_files:
|
||||
if "files" in body:
|
||||
del body["files"]
|
||||
if skip_files and "files" in body:
|
||||
del body["files"]
|
||||
|
||||
return body, {}
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
|
||||
skip_files = None
|
||||
def get_tool_with_custom_params(
|
||||
tool: Callable, custom_params: dict
|
||||
) -> Callable[..., Awaitable]:
|
||||
sig = inspect.signature(tool)
|
||||
extra_params = {
|
||||
key: value for key, value in custom_params.items() if key in sig.parameters
|
||||
}
|
||||
is_coroutine = inspect.iscoroutinefunction(tool)
|
||||
|
||||
async def new_tool(**kwargs):
|
||||
extra_kwargs = kwargs | extra_params
|
||||
if is_coroutine:
|
||||
return await tool(**extra_kwargs)
|
||||
return tool(**extra_kwargs)
|
||||
|
||||
return new_tool
|
||||
|
||||
|
||||
# Mutation on extra_params
|
||||
def get_configured_tools(
|
||||
tool_ids: list[str], extra_params: dict, user: UserModel
|
||||
) -> dict[str, dict]:
|
||||
tools = {}
|
||||
for tool_id in tool_ids:
|
||||
toolkit = Tools.get_tool_by_id(tool_id)
|
||||
if toolkit is None:
|
||||
continue
|
||||
|
||||
module = webui_app.state.TOOLS.get(tool_id, None)
|
||||
if module is None:
|
||||
module, _ = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = module
|
||||
|
||||
extra_params["__id__"] = tool_id
|
||||
has_citation = hasattr(module, "citation") and module.citation
|
||||
handles_files = hasattr(module, "file_handler") and module.file_handler
|
||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||
module.valves = module.Valves(**valves)
|
||||
|
||||
if hasattr(module, "UserValves"):
|
||||
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||
)
|
||||
|
||||
for spec in toolkit.specs:
|
||||
# TODO: Fix hack for OpenAI API
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val["type"] == "str":
|
||||
val["type"] = "string"
|
||||
name = spec["name"]
|
||||
callable = getattr(module, name)
|
||||
|
||||
# convert to function that takes only model params and inserts custom params
|
||||
custom_callable = get_tool_with_custom_params(callable, extra_params)
|
||||
|
||||
# TODO: This needs to be a pydantic model
|
||||
tool_dict = {
|
||||
"spec": spec,
|
||||
"citation": has_citation,
|
||||
"file_handler": handles_files,
|
||||
"toolkit_id": tool_id,
|
||||
"callable": custom_callable,
|
||||
}
|
||||
# TODO: if collision, prepend toolkit name
|
||||
if name in tools:
|
||||
log.warning(f"Tool {name} already exists in another toolkit!")
|
||||
log.warning(f"Collision between {toolkit} and {tool_id}.")
|
||||
log.warning(f"Discarding {toolkit}.{name}")
|
||||
else:
|
||||
tools[name] = tool_dict
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
body: dict, user: UserModel, extra_params: dict
|
||||
) -> tuple[dict, dict]:
|
||||
skip_files = False
|
||||
contexts = []
|
||||
citations = None
|
||||
|
||||
citations = []
|
||||
task_model_id = get_task_model_id(body["model"])
|
||||
|
||||
# If tool_ids field is present, call the functions
|
||||
if "tool_ids" in body:
|
||||
print(body["tool_ids"])
|
||||
for tool_id in body["tool_ids"]:
|
||||
print(tool_id)
|
||||
try:
|
||||
response, citation, file_handler = await get_function_call_response(
|
||||
messages=body["messages"],
|
||||
files=body.get("files", []),
|
||||
tool_id=tool_id,
|
||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
task_model_id=task_model_id,
|
||||
user=user,
|
||||
__event_emitter__=__event_emitter__,
|
||||
__event_call__=__event_call__,
|
||||
)
|
||||
tool_ids = body.pop("tool_ids", None)
|
||||
if not tool_ids:
|
||||
return body, {}
|
||||
|
||||
print(file_handler)
|
||||
if isinstance(response, str):
|
||||
contexts.append(response)
|
||||
|
||||
if citation:
|
||||
if citations is None:
|
||||
citations = [citation]
|
||||
else:
|
||||
citations.append(citation)
|
||||
|
||||
if file_handler:
|
||||
skip_files = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
del body["tool_ids"]
|
||||
print(f"tool_contexts: {contexts}")
|
||||
|
||||
if skip_files:
|
||||
if "files" in body:
|
||||
del body["files"]
|
||||
|
||||
return body, {
|
||||
**({"contexts": contexts} if contexts is not None else {}),
|
||||
**({"citations": citations} if citations is not None else {}),
|
||||
log.debug(f"{tool_ids=}")
|
||||
custom_params = {
|
||||
**extra_params,
|
||||
"__model__": app.state.MODELS[task_model_id],
|
||||
"__messages__": body["messages"],
|
||||
"__files__": body.get("files", []),
|
||||
}
|
||||
configured_tools = get_configured_tools(tool_ids, custom_params, user)
|
||||
|
||||
log.info(f"{configured_tools=}")
|
||||
|
||||
async def chat_completion_files_handler(body):
|
||||
contexts = []
|
||||
citations = None
|
||||
specs = [tool["spec"] for tool in configured_tools.values()]
|
||||
tools_specs = json.dumps(specs)
|
||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
content = tool_calling_generation_template(template, tools_specs)
|
||||
payload = get_tool_call_payload(body["messages"], task_model_id, content)
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if "files" in body:
|
||||
files = body["files"]
|
||||
try:
|
||||
response = await generate_chat_completions(form_data=payload, user=user)
|
||||
log.debug(f"{response=}")
|
||||
content = await get_content_from_response(response)
|
||||
log.debug(f"{content=}")
|
||||
if content is None:
|
||||
return body, {}
|
||||
|
||||
result = json.loads(content)
|
||||
tool_name = result.get("name", None)
|
||||
if tool_name not in configured_tools:
|
||||
return body, {}
|
||||
|
||||
tool_params = result.get("parameters", {})
|
||||
toolkit_id = configured_tools[tool_name]["toolkit_id"]
|
||||
try:
|
||||
tool_output = await configured_tools[tool_name]["callable"](**tool_params)
|
||||
except Exception as e:
|
||||
tool_output = str(e)
|
||||
if configured_tools[tool_name]["citation"]:
|
||||
citations.append(
|
||||
{
|
||||
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
|
||||
"document": [tool_output],
|
||||
"metadata": [{"source": tool_name}],
|
||||
}
|
||||
)
|
||||
if configured_tools[tool_name]["file_handler"]:
|
||||
skip_files = True
|
||||
|
||||
if isinstance(tool_output, str):
|
||||
contexts.append(tool_output)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
content = None
|
||||
|
||||
log.debug(f"tool_contexts: {contexts}")
|
||||
|
||||
if skip_files and "files" in body:
|
||||
del body["files"]
|
||||
|
||||
return body, {"contexts": contexts, "citations": citations}
|
||||
|
||||
|
||||
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
if files := body.pop("files", None):
|
||||
contexts, citations = get_rag_context(
|
||||
files=files,
|
||||
messages=body["messages"],
|
||||
@ -596,134 +555,130 @@ async def chat_completion_files_handler(body):
|
||||
|
||||
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
||||
|
||||
return body, {
|
||||
**({"contexts": contexts} if contexts is not None else {}),
|
||||
**({"citations": citations} if citations is not None else {}),
|
||||
}
|
||||
return body, {"contexts": contexts, "citations": citations}
|
||||
|
||||
|
||||
def is_chat_completion_request(request):
|
||||
return request.method == "POST" and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.method == "POST" and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
):
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
if not is_chat_completion_request(request):
|
||||
return await call_next(request)
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
try:
|
||||
body, model, user = await get_body_and_model_and_user(request)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
try:
|
||||
body, model, user = await get_body_and_model_and_user(request)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"chat_id": body.pop("chat_id", None),
|
||||
"message_id": body.pop("id", None),
|
||||
"session_id": body.pop("session_id", None),
|
||||
"valves": body.pop("valves", None),
|
||||
}
|
||||
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
extra_params = {
|
||||
"__user__": __user__,
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
}
|
||||
|
||||
# Initialize data_items to store additional data to be sent to the client
|
||||
# Initalize contexts and citation
|
||||
data_items = []
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_inlets_handler(
|
||||
body, model, extra_params
|
||||
)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_tools_handler(body, user, extra_params)
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_files_handler(body)
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
# If context is not empty, insert it into the messages
|
||||
if len(contexts) > 0:
|
||||
context_string = "/n".join(contexts).strip()
|
||||
prompt = get_last_user_message(body["messages"])
|
||||
if prompt is None:
|
||||
raise Exception("No user message found")
|
||||
# Workaround for Ollama 2.0+ system prompt issue
|
||||
# TODO: replace with add_or_update_system_message
|
||||
if model["owned_by"] == "ollama":
|
||||
body["messages"] = prepend_to_first_user_message_content(
|
||||
rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
),
|
||||
body["messages"],
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"chat_id": body.pop("chat_id", None),
|
||||
"message_id": body.pop("id", None),
|
||||
"session_id": body.pop("session_id", None),
|
||||
"valves": body.pop("valves", None),
|
||||
}
|
||||
|
||||
__event_emitter__ = get_event_emitter(metadata)
|
||||
__event_call__ = get_event_call(metadata)
|
||||
|
||||
# Initialize data_items to store additional data to be sent to the client
|
||||
data_items = []
|
||||
|
||||
# Initialize context, and citations
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_functions_handler(
|
||||
body, model, user, __event_emitter__, __event_call__
|
||||
)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_tools_handler(
|
||||
body, user, __event_emitter__, __event_call__
|
||||
)
|
||||
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_files_handler(body)
|
||||
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
# If context is not empty, insert it into the messages
|
||||
if len(contexts) > 0:
|
||||
context_string = "/n".join(contexts).strip()
|
||||
prompt = get_last_user_message(body["messages"])
|
||||
|
||||
# Workaround for Ollama 2.0+ system prompt issue
|
||||
# TODO: replace with add_or_update_system_message
|
||||
if model["owned_by"] == "ollama":
|
||||
body["messages"] = prepend_to_first_user_message_content(
|
||||
rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
),
|
||||
body["messages"],
|
||||
)
|
||||
else:
|
||||
body["messages"] = add_or_update_system_message(
|
||||
rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
),
|
||||
body["messages"],
|
||||
)
|
||||
|
||||
# If there are citations, add them to the data_items
|
||||
if len(citations) > 0:
|
||||
data_items.append({"citations": citations})
|
||||
|
||||
body["metadata"] = metadata
|
||||
modified_body_bytes = json.dumps(body).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
*[
|
||||
(k, v)
|
||||
for k, v in request.headers.raw
|
||||
if k.lower() != b"content-length"
|
||||
],
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
if isinstance(response, StreamingResponse):
|
||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if "text/event-stream" in content_type:
|
||||
return StreamingResponse(
|
||||
self.openai_stream_wrapper(response.body_iterator, data_items),
|
||||
)
|
||||
if "application/x-ndjson" in content_type:
|
||||
return StreamingResponse(
|
||||
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
return response
|
||||
body["messages"] = add_or_update_system_message(
|
||||
rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
),
|
||||
body["messages"],
|
||||
)
|
||||
|
||||
# If there are citations, add them to the data_items
|
||||
if len(citations) > 0:
|
||||
data_items.append({"citations": citations})
|
||||
|
||||
body["metadata"] = metadata
|
||||
modified_body_bytes = json.dumps(body).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
||||
]
|
||||
|
||||
# If it's not a chat completion request, just pass it through
|
||||
response = await call_next(request)
|
||||
if isinstance(response, StreamingResponse):
|
||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||
content_type = response.headers["Content-Type"]
|
||||
if "text/event-stream" in content_type:
|
||||
return StreamingResponse(
|
||||
self.openai_stream_wrapper(response.body_iterator, data_items),
|
||||
)
|
||||
if "application/x-ndjson" in content_type:
|
||||
return StreamingResponse(
|
||||
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def _receive(self, body: bytes):
|
||||
@ -790,19 +745,21 @@ def filter_pipeline(payload, user):
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
if key == "":
|
||||
continue
|
||||
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
@ -817,44 +774,39 @@ def filter_pipeline(payload, user):
|
||||
|
||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.method == "POST" and (
|
||||
"/ollama/api/chat" in request.url.path
|
||||
or "/chat/completions" in request.url.path
|
||||
):
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
if not is_chat_completion_request(request):
|
||||
return await call_next(request)
|
||||
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
# Decode body to string
|
||||
body_str = body.decode("utf-8")
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
# Decode body to string
|
||||
body_str = body.decode("utf-8")
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers["Authorization"]),
|
||||
)
|
||||
|
||||
try:
|
||||
data = filter_pipeline(data, user)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
|
||||
try:
|
||||
data = filter_pipeline(data, user)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
*[
|
||||
(k, v)
|
||||
for k, v in request.headers.raw
|
||||
if k.lower() != b"content-length"
|
||||
],
|
||||
]
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
@ -1019,6 +971,8 @@ async def get_all_models():
|
||||
model["actions"] = []
|
||||
for action_id in action_ids:
|
||||
action = Functions.get_function_by_id(action_id)
|
||||
if action is None:
|
||||
raise Exception(f"Action not found: {action_id}")
|
||||
|
||||
if action_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[action_id]
|
||||
@ -1099,22 +1053,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
|
||||
task = None
|
||||
if "task" in form_data:
|
||||
task = form_data["task"]
|
||||
del form_data["task"]
|
||||
|
||||
if task:
|
||||
if "metadata" in form_data:
|
||||
form_data["metadata"]["task"] = task
|
||||
else:
|
||||
form_data["metadata"] = {"task": task}
|
||||
|
||||
if model.get("pipe"):
|
||||
return await generate_function_chat_completion(form_data, user=user)
|
||||
if model["owned_by"] == "ollama":
|
||||
print("generate_ollama_chat_completion")
|
||||
return await generate_ollama_chat_completion(form_data, user=user)
|
||||
else:
|
||||
return await generate_openai_chat_completion(form_data, user=user)
|
||||
@ -1198,6 +1139,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel to include vavles
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
@ -1487,7 +1429,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
"stream": False,
|
||||
"max_tokens": 50,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"task": str(TASKS.TITLE_GENERATION),
|
||||
"metadata": {"task": str(TASKS.TITLE_GENERATION)},
|
||||
}
|
||||
|
||||
log.debug(payload)
|
||||
@ -1540,7 +1482,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": 30,
|
||||
"task": str(TASKS.QUERY_GENERATION),
|
||||
"metadata": {"task": str(TASKS.QUERY_GENERATION)},
|
||||
}
|
||||
|
||||
print(payload)
|
||||
@ -1597,7 +1539,7 @@ Message: """{{prompt}}"""
|
||||
"stream": False,
|
||||
"max_tokens": 4,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"task": str(TASKS.EMOJI_GENERATION),
|
||||
"metadata": {"task": str(TASKS.EMOJI_GENERATION)},
|
||||
}
|
||||
|
||||
log.debug(payload)
|
||||
@ -1616,41 +1558,6 @@ Message: """{{prompt}}"""
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/tools/completions")
|
||||
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("get_tools_function_calling")
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
model_id = get_task_model_id(model_id)
|
||||
|
||||
print(model_id)
|
||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
|
||||
try:
|
||||
context, _, _ = await get_function_call_response(
|
||||
form_data["messages"],
|
||||
form_data.get("files", []),
|
||||
form_data["tool_id"],
|
||||
template,
|
||||
model_id,
|
||||
user,
|
||||
)
|
||||
return context
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Pipelines Endpoints
|
||||
@ -1689,7 +1596,7 @@ async def upload_pipeline(
|
||||
):
|
||||
print("upload_pipeline", urlIdx, file.filename)
|
||||
# Check if the uploaded file is a python file
|
||||
if not file.filename.endswith(".py"):
|
||||
if not (file.filename and file.filename.endswith(".py")):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only Python (.py) files are allowed.",
|
||||
@ -2138,7 +2045,10 @@ async def oauth_login(provider: str, request: Request):
|
||||
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
||||
"oauth_callback", provider=provider
|
||||
)
|
||||
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
|
||||
client = oauth.create_client(provider)
|
||||
if client is None:
|
||||
raise HTTPException(404)
|
||||
return await client.authorize_redirect(request, redirect_uri)
|
||||
|
||||
|
||||
# OAuth login logic is as follows:
|
||||
|
@ -121,6 +121,6 @@ def search_query_generation_template(
|
||||
return template
|
||||
|
||||
|
||||
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
|
||||
def tool_calling_generation_template(template: str, tools_specs: str) -> str:
|
||||
template = template.replace("{{TOOLS}}", tools_specs)
|
||||
return template
|
||||
|
Loading…
Reference in New Issue
Block a user