mirror of
https://github.com/open-webui/open-webui
synced 2025-05-19 12:51:35 +00:00
Merge pull request #4602 from michaelpoluektov/tools-refac-1
refactor, perf: Tools refactor (progress PR 1)
This commit is contained in:
commit
cbb0940ff8
@ -1,12 +1,10 @@
|
|||||||
from pydantic import BaseModel, ConfigDict, parse_obj_as
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing import Union, Optional
|
from typing import Optional
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from sqlalchemy import String, Column, BigInteger, Text
|
from sqlalchemy import String, Column, BigInteger, Text
|
||||||
|
|
||||||
from utils.misc import get_gravatar_url
|
from apps.webui.internal.db import Base, JSONField, get_db
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, JSONField, Session, get_db
|
|
||||||
from apps.webui.models.chats import Chats
|
from apps.webui.models.chats import Chats
|
||||||
|
|
||||||
####################
|
####################
|
||||||
@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class UsersTable:
|
class UsersTable:
|
||||||
|
|
||||||
def insert_new_user(
|
def insert_new_user(
|
||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
@ -122,7 +119,6 @@ class UsersTable:
|
|||||||
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
user = db.query(User).filter_by(api_key=api_key).first()
|
user = db.query(User).filter_by(api_key=api_key).first()
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -131,7 +127,6 @@ class UsersTable:
|
|||||||
def get_user_by_email(self, email: str) -> Optional[UserModel]:
|
def get_user_by_email(self, email: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
user = db.query(User).filter_by(email=email).first()
|
user = db.query(User).filter_by(email=email).first()
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -140,7 +135,6 @@ class UsersTable:
|
|||||||
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
user = db.query(User).filter_by(oauth_sub=sub).first()
|
user = db.query(User).filter_by(oauth_sub=sub).first()
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -195,7 +189,6 @@ class UsersTable:
|
|||||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
db.query(User).filter_by(id=id).update(
|
db.query(User).filter_by(id=id).update(
|
||||||
{"last_active_at": int(time.time())}
|
{"last_active_at": int(time.time())}
|
||||||
)
|
)
|
||||||
|
798
backend/main.py
798
backend/main.py
@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
|
|||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional, Callable, Awaitable
|
||||||
|
|
||||||
from apps.webui.models.auths import Auths
|
from apps.webui.models.auths import Auths
|
||||||
from apps.webui.models.models import Models
|
from apps.webui.models.models import Models
|
||||||
from apps.webui.models.tools import Tools
|
from apps.webui.models.tools import Tools
|
||||||
from apps.webui.models.functions import Functions
|
from apps.webui.models.functions import Functions
|
||||||
from apps.webui.models.users import Users
|
from apps.webui.models.users import Users, UserModel
|
||||||
|
|
||||||
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ from utils.utils import (
|
|||||||
from utils.task import (
|
from utils.task import (
|
||||||
title_generation_template,
|
title_generation_template,
|
||||||
search_query_generation_template,
|
search_query_generation_template,
|
||||||
tools_function_calling_generation_template,
|
tool_calling_generation_template,
|
||||||
)
|
)
|
||||||
from utils.misc import (
|
from utils.misc import (
|
||||||
get_last_user_message,
|
get_last_user_message,
|
||||||
@ -261,6 +261,7 @@ def get_filter_function_ids(model):
|
|||||||
def get_priority(function_id):
|
def get_priority(function_id):
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = Functions.get_function_by_id(function_id)
|
||||||
if function is not None and hasattr(function, "valves"):
|
if function is not None and hasattr(function, "valves"):
|
||||||
|
# TODO: Fix FunctionModel
|
||||||
return (function.valves if function.valves else {}).get("priority", 0)
|
return (function.valves if function.valves else {}).get("priority", 0)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -282,164 +283,42 @@ def get_filter_function_ids(model):
|
|||||||
return filter_ids
|
return filter_ids
|
||||||
|
|
||||||
|
|
||||||
async def get_function_call_response(
|
async def get_content_from_response(response) -> Optional[str]:
|
||||||
messages,
|
content = None
|
||||||
files,
|
if hasattr(response, "body_iterator"):
|
||||||
tool_id,
|
async for chunk in response.body_iterator:
|
||||||
template,
|
data = json.loads(chunk.decode("utf-8"))
|
||||||
task_model_id,
|
content = data["choices"][0]["message"]["content"]
|
||||||
user,
|
|
||||||
__event_emitter__=None,
|
|
||||||
__event_call__=None,
|
|
||||||
):
|
|
||||||
tool = Tools.get_tool_by_id(tool_id)
|
|
||||||
tools_specs = json.dumps(tool.specs, indent=2)
|
|
||||||
content = tools_function_calling_generation_template(template, tools_specs)
|
|
||||||
|
|
||||||
|
# Cleanup any remaining background tasks if necessary
|
||||||
|
if response.background is not None:
|
||||||
|
await response.background()
|
||||||
|
else:
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_call_payload(messages, task_model_id, content):
|
||||||
user_message = get_last_user_message(messages)
|
user_message = get_last_user_message(messages)
|
||||||
prompt = (
|
history = "\n".join(
|
||||||
"History:\n"
|
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
||||||
+ "\n".join(
|
for message in messages[::-1][:4]
|
||||||
[
|
|
||||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
|
||||||
for message in messages[::-1][:4]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
+ f"\nQuery: {user_message}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(prompt)
|
prompt = f"History:\n{history}\nQuery: {user_message}"
|
||||||
|
|
||||||
payload = {
|
return {
|
||||||
"model": task_model_id,
|
"model": task_model_id,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": content},
|
{"role": "system", "content": content},
|
||||||
{"role": "user", "content": f"Query: {prompt}"},
|
{"role": "user", "content": f"Query: {prompt}"},
|
||||||
],
|
],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"task": str(TASKS.FUNCTION_CALLING),
|
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
|
||||||
payload = filter_pipeline(payload, user)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
model = app.state.MODELS[task_model_id]
|
async def chat_completion_inlets_handler(body, model, extra_params):
|
||||||
|
|
||||||
response = None
|
|
||||||
try:
|
|
||||||
response = await generate_chat_completions(form_data=payload, user=user)
|
|
||||||
content = None
|
|
||||||
|
|
||||||
if hasattr(response, "body_iterator"):
|
|
||||||
async for chunk in response.body_iterator:
|
|
||||||
data = json.loads(chunk.decode("utf-8"))
|
|
||||||
content = data["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
# Cleanup any remaining background tasks if necessary
|
|
||||||
if response.background is not None:
|
|
||||||
await response.background()
|
|
||||||
else:
|
|
||||||
content = response["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
if content is None:
|
|
||||||
return None, None, False
|
|
||||||
|
|
||||||
# Parse the function response
|
|
||||||
print(f"content: {content}")
|
|
||||||
result = json.loads(content)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
citation = None
|
|
||||||
|
|
||||||
if "name" not in result:
|
|
||||||
return None, None, False
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
if tool_id in webui_app.state.TOOLS:
|
|
||||||
toolkit_module = webui_app.state.TOOLS[tool_id]
|
|
||||||
else:
|
|
||||||
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
|
|
||||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
|
||||||
|
|
||||||
file_handler = False
|
|
||||||
# check if toolkit_module has file_handler self variable
|
|
||||||
if hasattr(toolkit_module, "file_handler"):
|
|
||||||
file_handler = True
|
|
||||||
print("file_handler: ", file_handler)
|
|
||||||
|
|
||||||
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
|
|
||||||
valves = Tools.get_tool_valves_by_id(tool_id)
|
|
||||||
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
|
|
||||||
|
|
||||||
function = getattr(toolkit_module, result["name"])
|
|
||||||
function_result = None
|
|
||||||
try:
|
|
||||||
# Get the signature of the function
|
|
||||||
sig = inspect.signature(function)
|
|
||||||
params = result["parameters"]
|
|
||||||
|
|
||||||
# Extra parameters to be passed to the function
|
|
||||||
extra_params = {
|
|
||||||
"__model__": model,
|
|
||||||
"__id__": tool_id,
|
|
||||||
"__messages__": messages,
|
|
||||||
"__files__": files,
|
|
||||||
"__event_emitter__": __event_emitter__,
|
|
||||||
"__event_call__": __event_call__,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add extra params in contained in function signature
|
|
||||||
for key, value in extra_params.items():
|
|
||||||
if key in sig.parameters:
|
|
||||||
params[key] = value
|
|
||||||
|
|
||||||
if "__user__" in sig.parameters:
|
|
||||||
# Call the function with the '__user__' parameter included
|
|
||||||
__user__ = {
|
|
||||||
"id": user.id,
|
|
||||||
"email": user.email,
|
|
||||||
"name": user.name,
|
|
||||||
"role": user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if hasattr(toolkit_module, "UserValves"):
|
|
||||||
__user__["valves"] = toolkit_module.UserValves(
|
|
||||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
params = {**params, "__user__": __user__}
|
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(function):
|
|
||||||
function_result = await function(**params)
|
|
||||||
else:
|
|
||||||
function_result = function(**params)
|
|
||||||
|
|
||||||
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
|
||||||
citation = {
|
|
||||||
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
|
|
||||||
"document": [function_result],
|
|
||||||
"metadata": [{"source": result["name"]}],
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
# Add the function result to the system prompt
|
|
||||||
if function_result is not None:
|
|
||||||
return function_result, citation, file_handler
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
|
|
||||||
return None, None, False
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_functions_handler(
|
|
||||||
body, model, user, __event_emitter__, __event_call__
|
|
||||||
):
|
|
||||||
skip_files = None
|
skip_files = None
|
||||||
|
|
||||||
filter_ids = get_filter_function_ids(model)
|
filter_ids = get_filter_function_ids(model)
|
||||||
@ -475,37 +354,20 @@ async def chat_completion_functions_handler(
|
|||||||
params = {"body": body}
|
params = {"body": body}
|
||||||
|
|
||||||
# Extra parameters to be passed to the function
|
# Extra parameters to be passed to the function
|
||||||
extra_params = {
|
custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
|
||||||
"__model__": model,
|
if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
|
||||||
"__id__": filter_id,
|
|
||||||
"__event_emitter__": __event_emitter__,
|
|
||||||
"__event_call__": __event_call__,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add extra params in contained in function signature
|
|
||||||
for key, value in extra_params.items():
|
|
||||||
if key in sig.parameters:
|
|
||||||
params[key] = value
|
|
||||||
|
|
||||||
if "__user__" in sig.parameters:
|
|
||||||
__user__ = {
|
|
||||||
"id": user.id,
|
|
||||||
"email": user.email,
|
|
||||||
"name": user.name,
|
|
||||||
"role": user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(function_module, "UserValves"):
|
uid = custom_params["__user__"]["id"]
|
||||||
__user__["valves"] = function_module.UserValves(
|
custom_params["__user__"]["valves"] = function_module.UserValves(
|
||||||
**Functions.get_user_valves_by_id_and_user_id(
|
**Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
|
||||||
filter_id, user.id
|
)
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
params = {**params, "__user__": __user__}
|
# Add extra params in contained in function signature
|
||||||
|
for key, value in custom_params.items():
|
||||||
|
if key in sig.parameters:
|
||||||
|
params[key] = value
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(inlet):
|
if inspect.iscoroutinefunction(inlet):
|
||||||
body = await inlet(**params)
|
body = await inlet(**params)
|
||||||
@ -516,74 +378,171 @@ async def chat_completion_functions_handler(
|
|||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if skip_files:
|
if skip_files and "files" in body:
|
||||||
if "files" in body:
|
del body["files"]
|
||||||
del body["files"]
|
|
||||||
|
|
||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
|
def get_tool_with_custom_params(
|
||||||
skip_files = None
|
tool: Callable, custom_params: dict
|
||||||
|
) -> Callable[..., Awaitable]:
|
||||||
|
sig = inspect.signature(tool)
|
||||||
|
extra_params = {
|
||||||
|
key: value for key, value in custom_params.items() if key in sig.parameters
|
||||||
|
}
|
||||||
|
is_coroutine = inspect.iscoroutinefunction(tool)
|
||||||
|
|
||||||
|
async def new_tool(**kwargs):
|
||||||
|
extra_kwargs = kwargs | extra_params
|
||||||
|
if is_coroutine:
|
||||||
|
return await tool(**extra_kwargs)
|
||||||
|
return tool(**extra_kwargs)
|
||||||
|
|
||||||
|
return new_tool
|
||||||
|
|
||||||
|
|
||||||
|
# Mutation on extra_params
|
||||||
|
def get_configured_tools(
|
||||||
|
tool_ids: list[str], extra_params: dict, user: UserModel
|
||||||
|
) -> dict[str, dict]:
|
||||||
|
tools = {}
|
||||||
|
for tool_id in tool_ids:
|
||||||
|
toolkit = Tools.get_tool_by_id(tool_id)
|
||||||
|
if toolkit is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module = webui_app.state.TOOLS.get(tool_id, None)
|
||||||
|
if module is None:
|
||||||
|
module, _ = load_toolkit_module_by_id(tool_id)
|
||||||
|
webui_app.state.TOOLS[tool_id] = module
|
||||||
|
|
||||||
|
extra_params["__id__"] = tool_id
|
||||||
|
has_citation = hasattr(module, "citation") and module.citation
|
||||||
|
handles_files = hasattr(module, "file_handler") and module.file_handler
|
||||||
|
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||||
|
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||||
|
module.valves = module.Valves(**valves)
|
||||||
|
|
||||||
|
if hasattr(module, "UserValves"):
|
||||||
|
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
||||||
|
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
for spec in toolkit.specs:
|
||||||
|
# TODO: Fix hack for OpenAI API
|
||||||
|
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||||
|
if val["type"] == "str":
|
||||||
|
val["type"] = "string"
|
||||||
|
name = spec["name"]
|
||||||
|
callable = getattr(module, name)
|
||||||
|
|
||||||
|
# convert to function that takes only model params and inserts custom params
|
||||||
|
custom_callable = get_tool_with_custom_params(callable, extra_params)
|
||||||
|
|
||||||
|
# TODO: This needs to be a pydantic model
|
||||||
|
tool_dict = {
|
||||||
|
"spec": spec,
|
||||||
|
"citation": has_citation,
|
||||||
|
"file_handler": handles_files,
|
||||||
|
"toolkit_id": tool_id,
|
||||||
|
"callable": custom_callable,
|
||||||
|
}
|
||||||
|
# TODO: if collision, prepend toolkit name
|
||||||
|
if name in tools:
|
||||||
|
log.warning(f"Tool {name} already exists in another toolkit!")
|
||||||
|
log.warning(f"Collision between {toolkit} and {tool_id}.")
|
||||||
|
log.warning(f"Discarding {toolkit}.{name}")
|
||||||
|
else:
|
||||||
|
tools[name] = tool_dict
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion_tools_handler(
|
||||||
|
body: dict, user: UserModel, extra_params: dict
|
||||||
|
) -> tuple[dict, dict]:
|
||||||
|
skip_files = False
|
||||||
contexts = []
|
contexts = []
|
||||||
citations = None
|
citations = []
|
||||||
|
|
||||||
task_model_id = get_task_model_id(body["model"])
|
task_model_id = get_task_model_id(body["model"])
|
||||||
|
|
||||||
# If tool_ids field is present, call the functions
|
# If tool_ids field is present, call the functions
|
||||||
if "tool_ids" in body:
|
tool_ids = body.pop("tool_ids", None)
|
||||||
print(body["tool_ids"])
|
if not tool_ids:
|
||||||
for tool_id in body["tool_ids"]:
|
return body, {}
|
||||||
print(tool_id)
|
|
||||||
try:
|
|
||||||
response, citation, file_handler = await get_function_call_response(
|
|
||||||
messages=body["messages"],
|
|
||||||
files=body.get("files", []),
|
|
||||||
tool_id=tool_id,
|
|
||||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
||||||
task_model_id=task_model_id,
|
|
||||||
user=user,
|
|
||||||
__event_emitter__=__event_emitter__,
|
|
||||||
__event_call__=__event_call__,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(file_handler)
|
log.debug(f"{tool_ids=}")
|
||||||
if isinstance(response, str):
|
custom_params = {
|
||||||
contexts.append(response)
|
**extra_params,
|
||||||
|
"__model__": app.state.MODELS[task_model_id],
|
||||||
if citation:
|
"__messages__": body["messages"],
|
||||||
if citations is None:
|
"__files__": body.get("files", []),
|
||||||
citations = [citation]
|
|
||||||
else:
|
|
||||||
citations.append(citation)
|
|
||||||
|
|
||||||
if file_handler:
|
|
||||||
skip_files = True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
del body["tool_ids"]
|
|
||||||
print(f"tool_contexts: {contexts}")
|
|
||||||
|
|
||||||
if skip_files:
|
|
||||||
if "files" in body:
|
|
||||||
del body["files"]
|
|
||||||
|
|
||||||
return body, {
|
|
||||||
**({"contexts": contexts} if contexts is not None else {}),
|
|
||||||
**({"citations": citations} if citations is not None else {}),
|
|
||||||
}
|
}
|
||||||
|
configured_tools = get_configured_tools(tool_ids, custom_params, user)
|
||||||
|
|
||||||
|
log.info(f"{configured_tools=}")
|
||||||
|
|
||||||
async def chat_completion_files_handler(body):
|
specs = [tool["spec"] for tool in configured_tools.values()]
|
||||||
contexts = []
|
tools_specs = json.dumps(specs)
|
||||||
citations = None
|
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
|
content = tool_calling_generation_template(template, tools_specs)
|
||||||
|
payload = get_tool_call_payload(body["messages"], task_model_id, content)
|
||||||
|
try:
|
||||||
|
payload = filter_pipeline(payload, user)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
if "files" in body:
|
try:
|
||||||
files = body["files"]
|
response = await generate_chat_completions(form_data=payload, user=user)
|
||||||
|
log.debug(f"{response=}")
|
||||||
|
content = await get_content_from_response(response)
|
||||||
|
log.debug(f"{content=}")
|
||||||
|
if content is None:
|
||||||
|
return body, {}
|
||||||
|
|
||||||
|
result = json.loads(content)
|
||||||
|
tool_name = result.get("name", None)
|
||||||
|
if tool_name not in configured_tools:
|
||||||
|
return body, {}
|
||||||
|
|
||||||
|
tool_params = result.get("parameters", {})
|
||||||
|
toolkit_id = configured_tools[tool_name]["toolkit_id"]
|
||||||
|
try:
|
||||||
|
tool_output = await configured_tools[tool_name]["callable"](**tool_params)
|
||||||
|
except Exception as e:
|
||||||
|
tool_output = str(e)
|
||||||
|
if configured_tools[tool_name]["citation"]:
|
||||||
|
citations.append(
|
||||||
|
{
|
||||||
|
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
|
||||||
|
"document": [tool_output],
|
||||||
|
"metadata": [{"source": tool_name}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if configured_tools[tool_name]["file_handler"]:
|
||||||
|
skip_files = True
|
||||||
|
|
||||||
|
if isinstance(tool_output, str):
|
||||||
|
contexts.append(tool_output)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
content = None
|
||||||
|
|
||||||
|
log.debug(f"tool_contexts: {contexts}")
|
||||||
|
|
||||||
|
if skip_files and "files" in body:
|
||||||
del body["files"]
|
del body["files"]
|
||||||
|
|
||||||
|
return body, {"contexts": contexts, "citations": citations}
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
||||||
|
contexts = []
|
||||||
|
citations = []
|
||||||
|
|
||||||
|
if files := body.pop("files", None):
|
||||||
contexts, citations = get_rag_context(
|
contexts, citations = get_rag_context(
|
||||||
files=files,
|
files=files,
|
||||||
messages=body["messages"],
|
messages=body["messages"],
|
||||||
@ -596,134 +555,130 @@ async def chat_completion_files_handler(body):
|
|||||||
|
|
||||||
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
||||||
|
|
||||||
return body, {
|
return body, {"contexts": contexts, "citations": citations}
|
||||||
**({"contexts": contexts} if contexts is not None else {}),
|
|
||||||
**({"citations": citations} if citations is not None else {}),
|
|
||||||
}
|
def is_chat_completion_request(request):
|
||||||
|
return request.method == "POST" and any(
|
||||||
|
endpoint in request.url.path
|
||||||
|
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if request.method == "POST" and any(
|
if not is_chat_completion_request(request):
|
||||||
endpoint in request.url.path
|
return await call_next(request)
|
||||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
log.debug(f"request.url.path: {request.url.path}")
|
||||||
):
|
|
||||||
log.debug(f"request.url.path: {request.url.path}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, model, user = await get_body_and_model_and_user(request)
|
body, model, user = await get_body_and_model_and_user(request)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
content={"detail": str(e)},
|
content={"detail": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"chat_id": body.pop("chat_id", None),
|
||||||
|
"message_id": body.pop("id", None),
|
||||||
|
"session_id": body.pop("session_id", None),
|
||||||
|
"valves": body.pop("valves", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
__user__ = {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"name": user.name,
|
||||||
|
"role": user.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
extra_params = {
|
||||||
|
"__user__": __user__,
|
||||||
|
"__event_emitter__": get_event_emitter(metadata),
|
||||||
|
"__event_call__": get_event_call(metadata),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize data_items to store additional data to be sent to the client
|
||||||
|
# Initalize contexts and citation
|
||||||
|
data_items = []
|
||||||
|
contexts = []
|
||||||
|
citations = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
body, flags = await chat_completion_inlets_handler(
|
||||||
|
body, model, extra_params
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content={"detail": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
body, flags = await chat_completion_tools_handler(body, user, extra_params)
|
||||||
|
contexts.extend(flags.get("contexts", []))
|
||||||
|
citations.extend(flags.get("citations", []))
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
body, flags = await chat_completion_files_handler(body)
|
||||||
|
contexts.extend(flags.get("contexts", []))
|
||||||
|
citations.extend(flags.get("citations", []))
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
|
||||||
|
# If context is not empty, insert it into the messages
|
||||||
|
if len(contexts) > 0:
|
||||||
|
context_string = "/n".join(contexts).strip()
|
||||||
|
prompt = get_last_user_message(body["messages"])
|
||||||
|
if prompt is None:
|
||||||
|
raise Exception("No user message found")
|
||||||
|
# Workaround for Ollama 2.0+ system prompt issue
|
||||||
|
# TODO: replace with add_or_update_system_message
|
||||||
|
if model["owned_by"] == "ollama":
|
||||||
|
body["messages"] = prepend_to_first_user_message_content(
|
||||||
|
rag_template(
|
||||||
|
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||||
|
),
|
||||||
|
body["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"chat_id": body.pop("chat_id", None),
|
|
||||||
"message_id": body.pop("id", None),
|
|
||||||
"session_id": body.pop("session_id", None),
|
|
||||||
"valves": body.pop("valves", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
__event_emitter__ = get_event_emitter(metadata)
|
|
||||||
__event_call__ = get_event_call(metadata)
|
|
||||||
|
|
||||||
# Initialize data_items to store additional data to be sent to the client
|
|
||||||
data_items = []
|
|
||||||
|
|
||||||
# Initialize context, and citations
|
|
||||||
contexts = []
|
|
||||||
citations = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
body, flags = await chat_completion_functions_handler(
|
|
||||||
body, model, user, __event_emitter__, __event_call__
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
content={"detail": str(e)},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
body, flags = await chat_completion_tools_handler(
|
|
||||||
body, user, __event_emitter__, __event_call__
|
|
||||||
)
|
|
||||||
|
|
||||||
contexts.extend(flags.get("contexts", []))
|
|
||||||
citations.extend(flags.get("citations", []))
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
body, flags = await chat_completion_files_handler(body)
|
|
||||||
|
|
||||||
contexts.extend(flags.get("contexts", []))
|
|
||||||
citations.extend(flags.get("citations", []))
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
|
|
||||||
# If context is not empty, insert it into the messages
|
|
||||||
if len(contexts) > 0:
|
|
||||||
context_string = "/n".join(contexts).strip()
|
|
||||||
prompt = get_last_user_message(body["messages"])
|
|
||||||
|
|
||||||
# Workaround for Ollama 2.0+ system prompt issue
|
|
||||||
# TODO: replace with add_or_update_system_message
|
|
||||||
if model["owned_by"] == "ollama":
|
|
||||||
body["messages"] = prepend_to_first_user_message_content(
|
|
||||||
rag_template(
|
|
||||||
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
||||||
),
|
|
||||||
body["messages"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
body["messages"] = add_or_update_system_message(
|
|
||||||
rag_template(
|
|
||||||
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
||||||
),
|
|
||||||
body["messages"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# If there are citations, add them to the data_items
|
|
||||||
if len(citations) > 0:
|
|
||||||
data_items.append({"citations": citations})
|
|
||||||
|
|
||||||
body["metadata"] = metadata
|
|
||||||
modified_body_bytes = json.dumps(body).encode("utf-8")
|
|
||||||
# Replace the request body with the modified one
|
|
||||||
request._body = modified_body_bytes
|
|
||||||
# Set custom header to ensure content-length matches new body length
|
|
||||||
request.headers.__dict__["_list"] = [
|
|
||||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
||||||
*[
|
|
||||||
(k, v)
|
|
||||||
for k, v in request.headers.raw
|
|
||||||
if k.lower() != b"content-length"
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|
||||||
response = await call_next(request)
|
|
||||||
if isinstance(response, StreamingResponse):
|
|
||||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
|
||||||
content_type = response.headers.get("Content-Type")
|
|
||||||
if "text/event-stream" in content_type:
|
|
||||||
return StreamingResponse(
|
|
||||||
self.openai_stream_wrapper(response.body_iterator, data_items),
|
|
||||||
)
|
|
||||||
if "application/x-ndjson" in content_type:
|
|
||||||
return StreamingResponse(
|
|
||||||
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
else:
|
else:
|
||||||
return response
|
body["messages"] = add_or_update_system_message(
|
||||||
|
rag_template(
|
||||||
|
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||||
|
),
|
||||||
|
body["messages"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# If there are citations, add them to the data_items
|
||||||
|
if len(citations) > 0:
|
||||||
|
data_items.append({"citations": citations})
|
||||||
|
|
||||||
|
body["metadata"] = metadata
|
||||||
|
modified_body_bytes = json.dumps(body).encode("utf-8")
|
||||||
|
# Replace the request body with the modified one
|
||||||
|
request._body = modified_body_bytes
|
||||||
|
# Set custom header to ensure content-length matches new body length
|
||||||
|
request.headers.__dict__["_list"] = [
|
||||||
|
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||||
|
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
||||||
|
]
|
||||||
|
|
||||||
# If it's not a chat completion request, just pass it through
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
if isinstance(response, StreamingResponse):
|
||||||
|
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||||
|
content_type = response.headers["Content-Type"]
|
||||||
|
if "text/event-stream" in content_type:
|
||||||
|
return StreamingResponse(
|
||||||
|
self.openai_stream_wrapper(response.body_iterator, data_items),
|
||||||
|
)
|
||||||
|
if "application/x-ndjson" in content_type:
|
||||||
|
return StreamingResponse(
|
||||||
|
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _receive(self, body: bytes):
|
async def _receive(self, body: bytes):
|
||||||
@ -790,19 +745,21 @@ def filter_pipeline(payload, user):
|
|||||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
if key != "":
|
if key == "":
|
||||||
headers = {"Authorization": f"Bearer {key}"}
|
continue
|
||||||
r = requests.post(
|
|
||||||
f"{url}/{filter['id']}/filter/inlet",
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"user": user,
|
|
||||||
"body": payload,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
r.raise_for_status()
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
payload = r.json()
|
r = requests.post(
|
||||||
|
f"{url}/{filter['id']}/filter/inlet",
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"user": user,
|
||||||
|
"body": payload,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
payload = r.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle connection error here
|
# Handle connection error here
|
||||||
print(f"Connection error: {e}")
|
print(f"Connection error: {e}")
|
||||||
@ -817,44 +774,39 @@ def filter_pipeline(payload, user):
|
|||||||
|
|
||||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if request.method == "POST" and (
|
if not is_chat_completion_request(request):
|
||||||
"/ollama/api/chat" in request.url.path
|
return await call_next(request)
|
||||||
or "/chat/completions" in request.url.path
|
|
||||||
):
|
|
||||||
log.debug(f"request.url.path: {request.url.path}")
|
|
||||||
|
|
||||||
# Read the original request body
|
log.debug(f"request.url.path: {request.url.path}")
|
||||||
body = await request.body()
|
|
||||||
# Decode body to string
|
|
||||||
body_str = body.decode("utf-8")
|
|
||||||
# Parse string to JSON
|
|
||||||
data = json.loads(body_str) if body_str else {}
|
|
||||||
|
|
||||||
user = get_current_user(
|
# Read the original request body
|
||||||
request,
|
body = await request.body()
|
||||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
# Decode body to string
|
||||||
|
body_str = body.decode("utf-8")
|
||||||
|
# Parse string to JSON
|
||||||
|
data = json.loads(body_str) if body_str else {}
|
||||||
|
|
||||||
|
user = get_current_user(
|
||||||
|
request,
|
||||||
|
get_http_authorization_cred(request.headers["Authorization"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = filter_pipeline(data, user)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=e.args[0],
|
||||||
|
content={"detail": e.args[1]},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||||
data = filter_pipeline(data, user)
|
# Replace the request body with the modified one
|
||||||
except Exception as e:
|
request._body = modified_body_bytes
|
||||||
return JSONResponse(
|
# Set custom header to ensure content-length matches new body length
|
||||||
status_code=e.args[0],
|
request.headers.__dict__["_list"] = [
|
||||||
content={"detail": e.args[1]},
|
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||||
)
|
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
|
||||||
|
]
|
||||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
||||||
# Replace the request body with the modified one
|
|
||||||
request._body = modified_body_bytes
|
|
||||||
# Set custom header to ensure content-length matches new body length
|
|
||||||
request.headers.__dict__["_list"] = [
|
|
||||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
||||||
*[
|
|
||||||
(k, v)
|
|
||||||
for k, v in request.headers.raw
|
|
||||||
if k.lower() != b"content-length"
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
@ -1019,6 +971,8 @@ async def get_all_models():
|
|||||||
model["actions"] = []
|
model["actions"] = []
|
||||||
for action_id in action_ids:
|
for action_id in action_ids:
|
||||||
action = Functions.get_function_by_id(action_id)
|
action = Functions.get_function_by_id(action_id)
|
||||||
|
if action is None:
|
||||||
|
raise Exception(f"Action not found: {action_id}")
|
||||||
|
|
||||||
if action_id in webui_app.state.FUNCTIONS:
|
if action_id in webui_app.state.FUNCTIONS:
|
||||||
function_module = webui_app.state.FUNCTIONS[action_id]
|
function_module = webui_app.state.FUNCTIONS[action_id]
|
||||||
@ -1099,22 +1053,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|||||||
)
|
)
|
||||||
model = app.state.MODELS[model_id]
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
|
|
||||||
task = None
|
|
||||||
if "task" in form_data:
|
|
||||||
task = form_data["task"]
|
|
||||||
del form_data["task"]
|
|
||||||
|
|
||||||
if task:
|
|
||||||
if "metadata" in form_data:
|
|
||||||
form_data["metadata"]["task"] = task
|
|
||||||
else:
|
|
||||||
form_data["metadata"] = {"task": task}
|
|
||||||
|
|
||||||
if model.get("pipe"):
|
if model.get("pipe"):
|
||||||
return await generate_function_chat_completion(form_data, user=user)
|
return await generate_function_chat_completion(form_data, user=user)
|
||||||
if model["owned_by"] == "ollama":
|
if model["owned_by"] == "ollama":
|
||||||
print("generate_ollama_chat_completion")
|
|
||||||
return await generate_ollama_chat_completion(form_data, user=user)
|
return await generate_ollama_chat_completion(form_data, user=user)
|
||||||
else:
|
else:
|
||||||
return await generate_openai_chat_completion(form_data, user=user)
|
return await generate_openai_chat_completion(form_data, user=user)
|
||||||
@ -1198,6 +1139,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
def get_priority(function_id):
|
def get_priority(function_id):
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = Functions.get_function_by_id(function_id)
|
||||||
if function is not None and hasattr(function, "valves"):
|
if function is not None and hasattr(function, "valves"):
|
||||||
|
# TODO: Fix FunctionModel to include vavles
|
||||||
return (function.valves if function.valves else {}).get("priority", 0)
|
return (function.valves if function.valves else {}).get("priority", 0)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -1487,7 +1429,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
"max_tokens": 50,
|
"max_tokens": 50,
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
"task": str(TASKS.TITLE_GENERATION),
|
"metadata": {"task": str(TASKS.TITLE_GENERATION)},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug(payload)
|
log.debug(payload)
|
||||||
@ -1540,7 +1482,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"max_tokens": 30,
|
"max_tokens": 30,
|
||||||
"task": str(TASKS.QUERY_GENERATION),
|
"metadata": {"task": str(TASKS.QUERY_GENERATION)},
|
||||||
}
|
}
|
||||||
|
|
||||||
print(payload)
|
print(payload)
|
||||||
@ -1597,7 +1539,7 @@ Message: """{{prompt}}"""
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
"max_tokens": 4,
|
"max_tokens": 4,
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
"task": str(TASKS.EMOJI_GENERATION),
|
"metadata": {"task": str(TASKS.EMOJI_GENERATION)},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug(payload)
|
log.debug(payload)
|
||||||
@ -1616,41 +1558,6 @@ Message: """{{prompt}}"""
|
|||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completions(form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/task/tools/completions")
|
|
||||||
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
|
|
||||||
print("get_tools_function_calling")
|
|
||||||
|
|
||||||
model_id = form_data["model"]
|
|
||||||
if model_id not in app.state.MODELS:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Model not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the user has a custom task model
|
|
||||||
# If the user has a custom task model, use that model
|
|
||||||
model_id = get_task_model_id(model_id)
|
|
||||||
|
|
||||||
print(model_id)
|
|
||||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
||||||
|
|
||||||
try:
|
|
||||||
context, _, _ = await get_function_call_response(
|
|
||||||
form_data["messages"],
|
|
||||||
form_data.get("files", []),
|
|
||||||
form_data["tool_id"],
|
|
||||||
template,
|
|
||||||
model_id,
|
|
||||||
user,
|
|
||||||
)
|
|
||||||
return context
|
|
||||||
except Exception as e:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=e.args[0],
|
|
||||||
content={"detail": e.args[1]},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
##################################
|
##################################
|
||||||
#
|
#
|
||||||
# Pipelines Endpoints
|
# Pipelines Endpoints
|
||||||
@ -1689,7 +1596,7 @@ async def upload_pipeline(
|
|||||||
):
|
):
|
||||||
print("upload_pipeline", urlIdx, file.filename)
|
print("upload_pipeline", urlIdx, file.filename)
|
||||||
# Check if the uploaded file is a python file
|
# Check if the uploaded file is a python file
|
||||||
if not file.filename.endswith(".py"):
|
if not (file.filename and file.filename.endswith(".py")):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Only Python (.py) files are allowed.",
|
detail="Only Python (.py) files are allowed.",
|
||||||
@ -2138,7 +2045,10 @@ async def oauth_login(provider: str, request: Request):
|
|||||||
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
||||||
"oauth_callback", provider=provider
|
"oauth_callback", provider=provider
|
||||||
)
|
)
|
||||||
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
|
client = oauth.create_client(provider)
|
||||||
|
if client is None:
|
||||||
|
raise HTTPException(404)
|
||||||
|
return await client.authorize_redirect(request, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
# OAuth login logic is as follows:
|
# OAuth login logic is as follows:
|
||||||
|
@ -121,6 +121,6 @@ def search_query_generation_template(
|
|||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
|
def tool_calling_generation_template(template: str, tools_specs: str) -> str:
|
||||||
template = template.replace("{{TOOLS}}", tools_specs)
|
template = template.replace("{{TOOLS}}", tools_specs)
|
||||||
return template
|
return template
|
||||||
|
Loading…
Reference in New Issue
Block a user