mirror of
https://github.com/open-webui/open-webui
synced 2025-05-25 07:14:43 +00:00
Merge pull request #4724 from michaelpoluektov/tools-refac-2.1
feat: Add `__tools__` optional param for function pipes
This commit is contained in:
commit
ee526b4b07
@ -731,11 +731,8 @@ async def generate_chat_completion(
|
|||||||
url_idx: Optional[int] = None,
|
url_idx: Optional[int] = None,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
|
payload = {**form_data.model_dump(exclude_none=True)}
|
||||||
|
log.debug(f"{payload = }")
|
||||||
payload = {
|
|
||||||
**form_data.model_dump(exclude_none=True, exclude=["metadata"]),
|
|
||||||
}
|
|
||||||
if "metadata" in payload:
|
if "metadata" in payload:
|
||||||
del payload["metadata"]
|
del payload["metadata"]
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from utils.misc import (
|
|||||||
apply_model_system_prompt_to_body,
|
apply_model_system_prompt_to_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from utils.tools import get_tools
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
SHOW_ADMIN_DETAILS,
|
SHOW_ADMIN_DETAILS,
|
||||||
@ -275,7 +276,9 @@ def get_function_params(function_module, form_data, user, extra_params={}):
|
|||||||
async def generate_function_chat_completion(form_data, user):
|
async def generate_function_chat_completion(form_data, user):
|
||||||
model_id = form_data.get("model")
|
model_id = form_data.get("model")
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = Models.get_model_by_id(model_id)
|
||||||
metadata = form_data.pop("metadata", None)
|
metadata = form_data.pop("metadata", {})
|
||||||
|
files = metadata.get("files", [])
|
||||||
|
tool_ids = metadata.get("tool_ids", [])
|
||||||
|
|
||||||
__event_emitter__ = None
|
__event_emitter__ = None
|
||||||
__event_call__ = None
|
__event_call__ = None
|
||||||
@ -287,6 +290,20 @@ async def generate_function_chat_completion(form_data, user):
|
|||||||
__event_call__ = get_event_call(metadata)
|
__event_call__ = get_event_call(metadata)
|
||||||
__task__ = metadata.get("task", None)
|
__task__ = metadata.get("task", None)
|
||||||
|
|
||||||
|
extra_params = {
|
||||||
|
"__event_emitter__": __event_emitter__,
|
||||||
|
"__event_call__": __event_call__,
|
||||||
|
"__task__": __task__,
|
||||||
|
}
|
||||||
|
tools_params = {
|
||||||
|
**extra_params,
|
||||||
|
"__model__": app.state.MODELS[form_data["model"]],
|
||||||
|
"__messages__": form_data["messages"],
|
||||||
|
"__files__": files,
|
||||||
|
}
|
||||||
|
configured_tools = get_tools(app, tool_ids, user, tools_params)
|
||||||
|
|
||||||
|
extra_params["__tools__"] = configured_tools
|
||||||
if model_info:
|
if model_info:
|
||||||
if model_info.base_model_id:
|
if model_info.base_model_id:
|
||||||
form_data["model"] = model_info.base_model_id
|
form_data["model"] = model_info.base_model_id
|
||||||
@ -299,16 +316,7 @@ async def generate_function_chat_completion(form_data, user):
|
|||||||
function_module = get_function_module(pipe_id)
|
function_module = get_function_module(pipe_id)
|
||||||
|
|
||||||
pipe = function_module.pipe
|
pipe = function_module.pipe
|
||||||
params = get_function_params(
|
params = get_function_params(function_module, form_data, user, extra_params)
|
||||||
function_module,
|
|
||||||
form_data,
|
|
||||||
user,
|
|
||||||
{
|
|
||||||
"__event_emitter__": __event_emitter__,
|
|
||||||
"__event_call__": __event_call__,
|
|
||||||
"__task__": __task__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if form_data["stream"]:
|
if form_data["stream"]:
|
||||||
|
|
||||||
|
@ -176,7 +176,6 @@ for version in soup.find_all("h2"):
|
|||||||
|
|
||||||
CHANGELOG = changelog_json
|
CHANGELOG = changelog_json
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# SAFE_MODE
|
# SAFE_MODE
|
||||||
####################################
|
####################################
|
||||||
|
157
backend/main.py
157
backend/main.py
@ -14,6 +14,7 @@ import requests
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import shutil
|
import shutil
|
||||||
import inspect
|
import inspect
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
@ -51,15 +52,13 @@ from apps.webui.internal.db import Session
|
|||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, Callable, Awaitable
|
|
||||||
|
|
||||||
from apps.webui.models.auths import Auths
|
from apps.webui.models.auths import Auths
|
||||||
from apps.webui.models.models import Models
|
from apps.webui.models.models import Models
|
||||||
from apps.webui.models.tools import Tools
|
|
||||||
from apps.webui.models.functions import Functions
|
from apps.webui.models.functions import Functions
|
||||||
from apps.webui.models.users import Users, UserModel
|
from apps.webui.models.users import Users, UserModel
|
||||||
|
|
||||||
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
from apps.webui.utils import load_function_module_by_id
|
||||||
|
|
||||||
from utils.utils import (
|
from utils.utils import (
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
@ -76,6 +75,8 @@ from utils.task import (
|
|||||||
tools_function_calling_generation_template,
|
tools_function_calling_generation_template,
|
||||||
moa_response_generation_template,
|
moa_response_generation_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from utils.tools import get_tools
|
||||||
from utils.misc import (
|
from utils.misc import (
|
||||||
get_last_user_message,
|
get_last_user_message,
|
||||||
add_or_update_system_message,
|
add_or_update_system_message,
|
||||||
@ -325,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if skip_files and "files" in body:
|
if skip_files and "files" in body.get("metadata", {}):
|
||||||
del body["files"]
|
del body["metadata"]["files"]
|
||||||
|
|
||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
@ -351,80 +352,6 @@ def get_tools_function_calling_payload(messages, task_model_id, content):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def apply_extra_params_to_tool_function(
|
|
||||||
function: Callable, extra_params: dict
|
|
||||||
) -> Callable[..., Awaitable]:
|
|
||||||
sig = inspect.signature(function)
|
|
||||||
extra_params = {
|
|
||||||
key: value for key, value in extra_params.items() if key in sig.parameters
|
|
||||||
}
|
|
||||||
is_coroutine = inspect.iscoroutinefunction(function)
|
|
||||||
|
|
||||||
async def new_function(**kwargs):
|
|
||||||
extra_kwargs = kwargs | extra_params
|
|
||||||
if is_coroutine:
|
|
||||||
return await function(**extra_kwargs)
|
|
||||||
return function(**extra_kwargs)
|
|
||||||
|
|
||||||
return new_function
|
|
||||||
|
|
||||||
|
|
||||||
# Mutation on extra_params
|
|
||||||
def get_tools(
|
|
||||||
tool_ids: list[str], user: UserModel, extra_params: dict
|
|
||||||
) -> dict[str, dict]:
|
|
||||||
tools = {}
|
|
||||||
for tool_id in tool_ids:
|
|
||||||
toolkit = Tools.get_tool_by_id(tool_id)
|
|
||||||
if toolkit is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
module = webui_app.state.TOOLS.get(tool_id, None)
|
|
||||||
if module is None:
|
|
||||||
module, _ = load_toolkit_module_by_id(tool_id)
|
|
||||||
webui_app.state.TOOLS[tool_id] = module
|
|
||||||
|
|
||||||
extra_params["__id__"] = tool_id
|
|
||||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
|
||||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
|
||||||
module.valves = module.Valves(**valves)
|
|
||||||
|
|
||||||
if hasattr(module, "UserValves"):
|
|
||||||
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
|
||||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
for spec in toolkit.specs:
|
|
||||||
# TODO: Fix hack for OpenAI API
|
|
||||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
|
||||||
if val["type"] == "str":
|
|
||||||
val["type"] = "string"
|
|
||||||
function_name = spec["name"]
|
|
||||||
|
|
||||||
# convert to function that takes only model params and inserts custom params
|
|
||||||
callable = apply_extra_params_to_tool_function(
|
|
||||||
getattr(module, function_name), extra_params
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: This needs to be a pydantic model
|
|
||||||
tool_dict = {
|
|
||||||
"toolkit_id": tool_id,
|
|
||||||
"callable": callable,
|
|
||||||
"spec": spec,
|
|
||||||
"file_handler": hasattr(module, "file_handler") and module.file_handler,
|
|
||||||
"citation": hasattr(module, "citation") and module.citation,
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: if collision, prepend toolkit name
|
|
||||||
if function_name in tools:
|
|
||||||
log.warning(f"Tool {function_name} already exists in another toolkit!")
|
|
||||||
log.warning(f"Collision between {toolkit} and {tool_id}.")
|
|
||||||
log.warning(f"Discarding {toolkit}.{function_name}")
|
|
||||||
else:
|
|
||||||
tools[function_name] = tool_dict
|
|
||||||
return tools
|
|
||||||
|
|
||||||
|
|
||||||
async def get_content_from_response(response) -> Optional[str]:
|
async def get_content_from_response(response) -> Optional[str]:
|
||||||
content = None
|
content = None
|
||||||
if hasattr(response, "body_iterator"):
|
if hasattr(response, "body_iterator"):
|
||||||
@ -443,15 +370,17 @@ async def get_content_from_response(response) -> Optional[str]:
|
|||||||
async def chat_completion_tools_handler(
|
async def chat_completion_tools_handler(
|
||||||
body: dict, user: UserModel, extra_params: dict
|
body: dict, user: UserModel, extra_params: dict
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
|
# If tool_ids field is present, call the functions
|
||||||
|
metadata = body.get("metadata", {})
|
||||||
|
tool_ids = metadata.get("tool_ids", None)
|
||||||
|
if not tool_ids:
|
||||||
|
return body, {}
|
||||||
|
|
||||||
skip_files = False
|
skip_files = False
|
||||||
contexts = []
|
contexts = []
|
||||||
citations = []
|
citations = []
|
||||||
|
|
||||||
task_model_id = get_task_model_id(body["model"])
|
task_model_id = get_task_model_id(body["model"])
|
||||||
# If tool_ids field is present, call the functions
|
|
||||||
tool_ids = body.pop("tool_ids", None)
|
|
||||||
if not tool_ids:
|
|
||||||
return body, {}
|
|
||||||
|
|
||||||
log.debug(f"{tool_ids=}")
|
log.debug(f"{tool_ids=}")
|
||||||
|
|
||||||
@ -459,9 +388,9 @@ async def chat_completion_tools_handler(
|
|||||||
**extra_params,
|
**extra_params,
|
||||||
"__model__": app.state.MODELS[task_model_id],
|
"__model__": app.state.MODELS[task_model_id],
|
||||||
"__messages__": body["messages"],
|
"__messages__": body["messages"],
|
||||||
"__files__": body.get("files", []),
|
"__files__": metadata.get("files", []),
|
||||||
}
|
}
|
||||||
tools = get_tools(tool_ids, user, custom_params)
|
tools = get_tools(webui_app, tool_ids, user, custom_params)
|
||||||
log.info(f"{tools=}")
|
log.info(f"{tools=}")
|
||||||
|
|
||||||
specs = [tool["spec"] for tool in tools.values()]
|
specs = [tool["spec"] for tool in tools.values()]
|
||||||
@ -486,7 +415,7 @@ async def chat_completion_tools_handler(
|
|||||||
content = await get_content_from_response(response)
|
content = await get_content_from_response(response)
|
||||||
log.debug(f"{content=}")
|
log.debug(f"{content=}")
|
||||||
|
|
||||||
if content is None:
|
if not content:
|
||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
result = json.loads(content)
|
result = json.loads(content)
|
||||||
@ -521,13 +450,13 @@ async def chat_completion_tools_handler(
|
|||||||
contexts.append(tool_output)
|
contexts.append(tool_output)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
log.exception(f"Error: {e}")
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
log.debug(f"tool_contexts: {contexts}")
|
log.debug(f"tool_contexts: {contexts}")
|
||||||
|
|
||||||
if skip_files and "files" in body:
|
if skip_files and "files" in body.get("metadata", {}):
|
||||||
del body["files"]
|
del body["metadata"]["files"]
|
||||||
|
|
||||||
return body, {"contexts": contexts, "citations": citations}
|
return body, {"contexts": contexts, "citations": citations}
|
||||||
|
|
||||||
@ -536,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
|||||||
contexts = []
|
contexts = []
|
||||||
citations = []
|
citations = []
|
||||||
|
|
||||||
if files := body.pop("files", None):
|
if files := body.get("metadata", {}).get("files", None):
|
||||||
contexts, citations = get_rag_context(
|
contexts, citations = get_rag_context(
|
||||||
files=files,
|
files=files,
|
||||||
messages=body["messages"],
|
messages=body["messages"],
|
||||||
@ -597,6 +526,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
"message_id": body.pop("id", None),
|
"message_id": body.pop("id", None),
|
||||||
"session_id": body.pop("session_id", None),
|
"session_id": body.pop("session_id", None),
|
||||||
"valves": body.pop("valves", None),
|
"valves": body.pop("valves", None),
|
||||||
|
"tool_ids": body.pop("tool_ids", None),
|
||||||
|
"files": body.pop("files", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
__user__ = {
|
__user__ = {
|
||||||
@ -680,37 +611,30 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
]
|
]
|
||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
if isinstance(response, StreamingResponse):
|
if not isinstance(response, StreamingResponse):
|
||||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
|
||||||
content_type = response.headers["Content-Type"]
|
|
||||||
if "text/event-stream" in content_type:
|
|
||||||
return StreamingResponse(
|
|
||||||
self.openai_stream_wrapper(response.body_iterator, data_items),
|
|
||||||
)
|
|
||||||
if "application/x-ndjson" in content_type:
|
|
||||||
return StreamingResponse(
|
|
||||||
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
content_type = response.headers["Content-Type"]
|
||||||
|
is_openai = "text/event-stream" in content_type
|
||||||
|
is_ollama = "application/x-ndjson" in content_type
|
||||||
|
if not is_openai and not is_ollama:
|
||||||
|
return response
|
||||||
|
|
||||||
|
def wrap_item(item):
|
||||||
|
return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
||||||
|
|
||||||
|
async def stream_wrapper(original_generator, data_items):
|
||||||
|
for item in data_items:
|
||||||
|
yield wrap_item(json.dumps(item))
|
||||||
|
|
||||||
|
async for data in original_generator:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
|
||||||
|
|
||||||
async def _receive(self, body: bytes):
|
async def _receive(self, body: bytes):
|
||||||
return {"type": "http.request", "body": body, "more_body": False}
|
return {"type": "http.request", "body": body, "more_body": False}
|
||||||
|
|
||||||
async def openai_stream_wrapper(self, original_generator, data_items):
|
|
||||||
for item in data_items:
|
|
||||||
yield f"data: {json.dumps(item)}\n\n"
|
|
||||||
|
|
||||||
async for data in original_generator:
|
|
||||||
yield data
|
|
||||||
|
|
||||||
async def ollama_stream_wrapper(self, original_generator, data_items):
|
|
||||||
for item in data_items:
|
|
||||||
yield f"{json.dumps(item)}\n"
|
|
||||||
|
|
||||||
async for data in original_generator:
|
|
||||||
yield data
|
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(ChatCompletionMiddleware)
|
app.add_middleware(ChatCompletionMiddleware)
|
||||||
|
|
||||||
@ -1065,7 +989,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|||||||
detail="Model not found",
|
detail="Model not found",
|
||||||
)
|
)
|
||||||
model = app.state.MODELS[model_id]
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
if model.get("pipe"):
|
if model.get("pipe"):
|
||||||
return await generate_function_chat_completion(form_data, user=user)
|
return await generate_function_chat_completion(form_data, user=user)
|
||||||
if model["owned_by"] == "ollama":
|
if model["owned_by"] == "ollama":
|
||||||
|
104
backend/utils/schemas.py
Normal file
104
backend/utils/schemas.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
|
|
||||||
|
def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
|
||||||
|
"""
|
||||||
|
Converts a JSON schema to a Pydantic BaseModel class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_schema: The JSON schema to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Pydantic BaseModel class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Extract the model name from the schema title.
|
||||||
|
model_name = tool_dict["name"]
|
||||||
|
schema = tool_dict["parameters"]
|
||||||
|
|
||||||
|
# Extract the field definitions from the schema properties.
|
||||||
|
field_definitions = {
|
||||||
|
name: json_schema_to_pydantic_field(name, prop, schema.get("required", []))
|
||||||
|
for name, prop in schema.get("properties", {}).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create the BaseModel class using create_model().
|
||||||
|
return create_model(model_name, **field_definitions)
|
||||||
|
|
||||||
|
|
||||||
|
def json_schema_to_pydantic_field(
|
||||||
|
name: str, json_schema: dict[str, Any], required: list[str]
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Converts a JSON schema property to a Pydantic field definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The field name.
|
||||||
|
json_schema: The JSON schema property.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Pydantic field definition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get the field type.
|
||||||
|
type_ = json_schema_to_pydantic_type(json_schema)
|
||||||
|
|
||||||
|
# Get the field description.
|
||||||
|
description = json_schema.get("description")
|
||||||
|
|
||||||
|
# Get the field examples.
|
||||||
|
examples = json_schema.get("examples")
|
||||||
|
|
||||||
|
# Create a Field object with the type, description, and examples.
|
||||||
|
# The 'required' flag will be set later when creating the model.
|
||||||
|
return (
|
||||||
|
type_,
|
||||||
|
Field(
|
||||||
|
description=description,
|
||||||
|
examples=examples,
|
||||||
|
default=... if name in required else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Converts a JSON schema type to a Pydantic type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_schema: The JSON schema to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Pydantic type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type_ = json_schema.get("type")
|
||||||
|
|
||||||
|
if type_ == "string" or type_ == "str":
|
||||||
|
return str
|
||||||
|
elif type_ == "integer" or type_ == "int":
|
||||||
|
return int
|
||||||
|
elif type_ == "number" or type_ == "float":
|
||||||
|
return float
|
||||||
|
elif type_ == "boolean" or type_ == "bool":
|
||||||
|
return bool
|
||||||
|
elif type_ == "array":
|
||||||
|
items_schema = json_schema.get("items")
|
||||||
|
if items_schema:
|
||||||
|
item_type = json_schema_to_pydantic_type(items_schema)
|
||||||
|
return list[item_type]
|
||||||
|
else:
|
||||||
|
return list
|
||||||
|
elif type_ == "object":
|
||||||
|
# Handle nested models.
|
||||||
|
properties = json_schema.get("properties")
|
||||||
|
if properties:
|
||||||
|
nested_model = json_schema_to_model(json_schema)
|
||||||
|
return nested_model
|
||||||
|
else:
|
||||||
|
return dict
|
||||||
|
elif type_ == "null":
|
||||||
|
return Optional[Any] # Use Optional[Any] for nullable fields
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported JSON schema type: {type_}")
|
@ -1,5 +1,90 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import get_type_hints
|
import logging
|
||||||
|
from typing import Awaitable, Callable, get_type_hints
|
||||||
|
|
||||||
|
from apps.webui.models.tools import Tools
|
||||||
|
from apps.webui.models.users import UserModel
|
||||||
|
from apps.webui.utils import load_toolkit_module_by_id
|
||||||
|
|
||||||
|
from utils.schemas import json_schema_to_model
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_extra_params_to_tool_function(
|
||||||
|
function: Callable, extra_params: dict
|
||||||
|
) -> Callable[..., Awaitable]:
|
||||||
|
sig = inspect.signature(function)
|
||||||
|
extra_params = {
|
||||||
|
key: value for key, value in extra_params.items() if key in sig.parameters
|
||||||
|
}
|
||||||
|
is_coroutine = inspect.iscoroutinefunction(function)
|
||||||
|
|
||||||
|
async def new_function(**kwargs):
|
||||||
|
extra_kwargs = kwargs | extra_params
|
||||||
|
if is_coroutine:
|
||||||
|
return await function(**extra_kwargs)
|
||||||
|
return function(**extra_kwargs)
|
||||||
|
|
||||||
|
return new_function
|
||||||
|
|
||||||
|
|
||||||
|
# Mutation on extra_params
|
||||||
|
def get_tools(
|
||||||
|
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||||
|
) -> dict[str, dict]:
|
||||||
|
tools = {}
|
||||||
|
for tool_id in tool_ids:
|
||||||
|
toolkit = Tools.get_tool_by_id(tool_id)
|
||||||
|
if toolkit is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module = webui_app.state.TOOLS.get(tool_id, None)
|
||||||
|
if module is None:
|
||||||
|
module, _ = load_toolkit_module_by_id(tool_id)
|
||||||
|
webui_app.state.TOOLS[tool_id] = module
|
||||||
|
|
||||||
|
extra_params["__id__"] = tool_id
|
||||||
|
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||||
|
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||||
|
module.valves = module.Valves(**valves)
|
||||||
|
|
||||||
|
if hasattr(module, "UserValves"):
|
||||||
|
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
||||||
|
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
for spec in toolkit.specs:
|
||||||
|
# TODO: Fix hack for OpenAI API
|
||||||
|
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||||
|
if val["type"] == "str":
|
||||||
|
val["type"] = "string"
|
||||||
|
function_name = spec["name"]
|
||||||
|
|
||||||
|
# convert to function that takes only model params and inserts custom params
|
||||||
|
original_func = getattr(module, function_name)
|
||||||
|
callable = apply_extra_params_to_tool_function(original_func, extra_params)
|
||||||
|
if hasattr(original_func, "__doc__"):
|
||||||
|
callable.__doc__ = original_func.__doc__
|
||||||
|
|
||||||
|
# TODO: This needs to be a pydantic model
|
||||||
|
tool_dict = {
|
||||||
|
"toolkit_id": tool_id,
|
||||||
|
"callable": callable,
|
||||||
|
"spec": spec,
|
||||||
|
"pydantic_model": json_schema_to_model(spec),
|
||||||
|
"file_handler": hasattr(module, "file_handler") and module.file_handler,
|
||||||
|
"citation": hasattr(module, "citation") and module.citation,
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: if collision, prepend toolkit name
|
||||||
|
if function_name in tools:
|
||||||
|
log.warning(f"Tool {function_name} already exists in another toolkit!")
|
||||||
|
log.warning(f"Collision between {toolkit} and {tool_id}.")
|
||||||
|
log.warning(f"Discarding {toolkit}.{function_name}")
|
||||||
|
else:
|
||||||
|
tools[function_name] = tool_dict
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def doc_to_dict(docstring):
|
def doc_to_dict(docstring):
|
||||||
|
Loading…
Reference in New Issue
Block a user