Merge pull request #4602 from michaelpoluektov/tools-refac-1

refactor, perf: Tools refactor (progress PR 1)
This commit is contained in:
Timothy Jaeryang Baek 2024-08-17 15:50:40 +02:00 committed by GitHub
commit cbb0940ff8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 358 additions and 455 deletions

View File

@ -1,12 +1,10 @@
from pydantic import BaseModel, ConfigDict, parse_obj_as
from typing import Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import Optional
import time
from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url
from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.chats import Chats
####################
@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel):
class UsersTable:
def insert_new_user(
self,
id: str,
@ -122,7 +119,6 @@ class UsersTable:
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except Exception:
@ -131,7 +127,6 @@ class UsersTable:
def get_user_by_email(self, email: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except Exception:
@ -140,7 +135,6 @@ class UsersTable:
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except Exception:
@ -195,7 +189,6 @@ class UsersTable:
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)

View File

@ -51,13 +51,13 @@ from apps.webui.internal.db import Session
from pydantic import BaseModel
from typing import Optional
from typing import Optional, Callable, Awaitable
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users
from apps.webui.models.users import Users, UserModel
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
@ -72,7 +72,7 @@ from utils.utils import (
from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
tool_calling_generation_template,
)
from utils.misc import (
get_last_user_message,
@ -261,6 +261,7 @@ def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel
return (function.valves if function.valves else {}).get("priority", 0)
return 0
@ -282,164 +283,42 @@ def get_filter_function_ids(model):
return filter_ids
async def get_function_call_response(
messages,
files,
tool_id,
template,
task_model_id,
user,
__event_emitter__=None,
__event_call__=None,
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
async def get_content_from_response(response) -> Optional[str]:
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
return content
def get_tool_call_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages)
prompt = (
"History:\n"
+ "\n".join(
[
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
]
)
+ f"\nQuery: {user_message}"
history = "\n".join(
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
)
print(prompt)
prompt = f"History:\n{history}\nQuery: {user_message}"
payload = {
return {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"task": str(TASKS.FUNCTION_CALLING),
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
model = app.state.MODELS[task_model_id]
response = None
try:
response = await generate_chat_completions(form_data=payload, user=user)
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
if content is None:
return None, None, False
# Parse the function response
print(f"content: {content}")
result = json.loads(content)
print(result)
citation = None
if "name" not in result:
return None, None, False
# Call the function
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
params = result["parameters"]
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
except Exception as e:
print(f"Error: {e}")
return None, None, False
async def chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
):
async def chat_completion_inlets_handler(body, model, extra_params):
skip_files = None
filter_ids = get_filter_function_ids(model)
@ -475,37 +354,20 @@ async def chat_completion_functions_handler(
params = {"body": body}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
uid = custom_params["__user__"]["id"]
custom_params["__user__"]["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
# Add extra params in contained in function signature
for key, value in custom_params.items():
if key in sig.parameters:
params[key] = value
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
@ -516,74 +378,171 @@ async def chat_completion_functions_handler(
print(f"Error: {e}")
raise e
if skip_files:
if "files" in body:
del body["files"]
if skip_files and "files" in body:
del body["files"]
return body, {}
async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
skip_files = None
def get_tool_with_custom_params(
tool: Callable, custom_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(tool)
extra_params = {
key: value for key, value in custom_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(tool)
async def new_tool(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await tool(**extra_kwargs)
return tool(**extra_kwargs)
return new_tool
# Mutation on extra_params
def get_configured_tools(
tool_ids: list[str], extra_params: dict, user: UserModel
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
has_citation = hasattr(module, "citation") and module.citation
handles_files = hasattr(module, "file_handler") and module.file_handler
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
name = spec["name"]
callable = getattr(module, name)
# convert to function that takes only model params and inserts custom params
custom_callable = get_tool_with_custom_params(callable, extra_params)
# TODO: This needs to be a pydantic model
tool_dict = {
"spec": spec,
"citation": has_citation,
"file_handler": handles_files,
"toolkit_id": tool_id,
"callable": custom_callable,
}
# TODO: if collision, prepend toolkit name
if name in tools:
log.warning(f"Tool {name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{name}")
else:
tools[name] = tool_dict
return tools
async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]:
skip_files = False
contexts = []
citations = None
citations = []
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions
if "tool_ids" in body:
print(body["tool_ids"])
for tool_id in body["tool_ids"]:
print(tool_id)
try:
response, citation, file_handler = await get_function_call_response(
messages=body["messages"],
files=body.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
__event_emitter__=__event_emitter__,
__event_call__=__event_call__,
)
tool_ids = body.pop("tool_ids", None)
if not tool_ids:
return body, {}
print(file_handler)
if isinstance(response, str):
contexts.append(response)
if citation:
if citations is None:
citations = [citation]
else:
citations.append(citation)
if file_handler:
skip_files = True
except Exception as e:
print(f"Error: {e}")
del body["tool_ids"]
print(f"tool_contexts: {contexts}")
if skip_files:
if "files" in body:
del body["files"]
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
log.debug(f"{tool_ids=}")
custom_params = {
**extra_params,
"__model__": app.state.MODELS[task_model_id],
"__messages__": body["messages"],
"__files__": body.get("files", []),
}
configured_tools = get_configured_tools(tool_ids, custom_params, user)
log.info(f"{configured_tools=}")
async def chat_completion_files_handler(body):
contexts = []
citations = None
specs = [tool["spec"] for tool in configured_tools.values()]
tools_specs = json.dumps(specs)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
content = tool_calling_generation_template(template, tools_specs)
payload = get_tool_call_payload(body["messages"], task_model_id, content)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
if "files" in body:
files = body["files"]
try:
response = await generate_chat_completions(form_data=payload, user=user)
log.debug(f"{response=}")
content = await get_content_from_response(response)
log.debug(f"{content=}")
if content is None:
return body, {}
result = json.loads(content)
tool_name = result.get("name", None)
if tool_name not in configured_tools:
return body, {}
tool_params = result.get("parameters", {})
toolkit_id = configured_tools[tool_name]["toolkit_id"]
try:
tool_output = await configured_tools[tool_name]["callable"](**tool_params)
except Exception as e:
tool_output = str(e)
if configured_tools[tool_name]["citation"]:
citations.append(
{
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
"document": [tool_output],
"metadata": [{"source": tool_name}],
}
)
if configured_tools[tool_name]["file_handler"]:
skip_files = True
if isinstance(tool_output, str):
contexts.append(tool_output)
except Exception as e:
print(f"Error: {e}")
content = None
log.debug(f"tool_contexts: {contexts}")
if skip_files and "files" in body:
del body["files"]
return body, {"contexts": contexts, "citations": citations}
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts = []
citations = []
if files := body.pop("files", None):
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
@ -596,134 +555,130 @@ async def chat_completion_files_handler(body):
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
return body, {"contexts": contexts, "citations": citations}
def is_chat_completion_request(request):
return request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
)
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
if not is_chat_completion_request(request):
return await call_next(request)
log.debug(f"request.url.path: {request.url.path}")
try:
body, model, user = await get_body_and_model_and_user(request)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
try:
body, model, user = await get_body_and_model_and_user(request)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
metadata = {
"chat_id": body.pop("chat_id", None),
"message_id": body.pop("id", None),
"session_id": body.pop("session_id", None),
"valves": body.pop("valves", None),
}
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
extra_params = {
"__user__": __user__,
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
}
# Initialize data_items to store additional data to be sent to the client
# Initalize contexts and citation
data_items = []
contexts = []
citations = []
try:
body, flags = await chat_completion_inlets_handler(
body, model, extra_params
)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
try:
body, flags = await chat_completion_tools_handler(body, user, extra_params)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
log.exception(e)
try:
body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
log.exception(e)
# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
if prompt is None:
raise Exception("No user message found")
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
metadata = {
"chat_id": body.pop("chat_id", None),
"message_id": body.pop("id", None),
"session_id": body.pop("session_id", None),
"valves": body.pop("valves", None),
}
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
# Initialize data_items to store additional data to be sent to the client
data_items = []
# Initialize context, and citations
contexts = []
citations = []
try:
body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
try:
body, flags = await chat_completion_tools_handler(
body, user, __event_emitter__, __event_call__
)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
try:
body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
else:
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
body["metadata"] = metadata
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
]
response = await call_next(request)
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, data_items),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, data_items),
)
return response
else:
return response
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
body["metadata"] = metadata
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
]
# If it's not a chat completion request, just pass it through
response = await call_next(request)
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers["Content-Type"]
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, data_items),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, data_items),
)
return response
async def _receive(self, body: bytes):
@ -790,19 +745,21 @@ def filter_pipeline(payload, user):
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
if key == "":
continue
r.raise_for_status()
payload = r.json()
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
@ -817,44 +774,39 @@ def filter_pipeline(payload, user):
class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
"/ollama/api/chat" in request.url.path
or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
if not is_chat_completion_request(request):
return await call_next(request)
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
log.debug(f"request.url.path: {request.url.path}")
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers["Authorization"]),
)
try:
data = filter_pipeline(data, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
try:
data = filter_pipeline(data, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
]
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
]
response = await call_next(request)
return response
@ -1019,6 +971,8 @@ async def get_all_models():
model["actions"] = []
for action_id in action_ids:
action = Functions.get_function_by_id(action_id)
if action is None:
raise Exception(f"Action not found: {action_id}")
if action_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id]
@ -1099,22 +1053,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
)
model = app.state.MODELS[model_id]
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
task = None
if "task" in form_data:
task = form_data["task"]
del form_data["task"]
if task:
if "metadata" in form_data:
form_data["metadata"]["task"] = task
else:
form_data["metadata"] = {"task": task}
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":
print("generate_ollama_chat_completion")
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
@ -1198,6 +1139,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
return 0
@ -1487,7 +1429,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False,
"max_tokens": 50,
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.TITLE_GENERATION),
"metadata": {"task": str(TASKS.TITLE_GENERATION)},
}
log.debug(payload)
@ -1540,7 +1482,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
"task": str(TASKS.QUERY_GENERATION),
"metadata": {"task": str(TASKS.QUERY_GENERATION)},
}
print(payload)
@ -1597,7 +1539,7 @@ Message: """{{prompt}}"""
"stream": False,
"max_tokens": 4,
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.EMOJI_GENERATION),
"metadata": {"task": str(TASKS.EMOJI_GENERATION)},
}
log.debug(payload)
@ -1616,41 +1558,6 @@ Message: """{{prompt}}"""
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
model_id = get_task_model_id(model_id)
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context, _, _ = await get_function_call_response(
form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
template,
model_id,
user,
)
return context
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
##################################
#
# Pipelines Endpoints
@ -1689,7 +1596,7 @@ async def upload_pipeline(
):
print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file
if not file.filename.endswith(".py"):
if not (file.filename and file.filename.endswith(".py")):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.",
@ -2138,7 +2045,10 @@ async def oauth_login(provider: str, request: Request):
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
client = oauth.create_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
# OAuth login logic is as follows:

View File

@ -121,6 +121,6 @@ def search_query_generation_template(
return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
def tool_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
return template