mirror of
https://github.com/open-webui/open-webui
synced 2025-06-22 18:07:17 +00:00
Merge fc6bded532
into 7513dc7e34
This commit is contained in:
commit
be1f7ee7ec
@ -1,29 +1,25 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
import time
|
|
||||||
import re
|
import re
|
||||||
import aiohttp
|
import time
|
||||||
from pydantic import BaseModel, HttpUrl
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from open_webui.config import CACHE_DIR
|
||||||
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.models.tools import (
|
from open_webui.models.tools import (
|
||||||
ToolForm,
|
ToolForm,
|
||||||
ToolModel,
|
ToolModel,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolUserResponse,
|
|
||||||
Tools,
|
Tools,
|
||||||
|
ToolUserResponse,
|
||||||
)
|
)
|
||||||
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
|
||||||
from open_webui.config import CACHE_DIR
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
||||||
from open_webui.utils.tools import get_tool_specs
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
||||||
from open_webui.utils.tools import get_tool_servers_data
|
from open_webui.utils.tools import get_tool_servers_data, get_tool_specs
|
||||||
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
@ -53,21 +49,25 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
|||||||
tools.append(
|
tools.append(
|
||||||
ToolUserResponse(
|
ToolUserResponse(
|
||||||
**{
|
**{
|
||||||
"id": f"server:{server['idx']}",
|
"id": f"server:{server['hash']}",
|
||||||
"user_id": f"server:{server['idx']}",
|
"user_id": f"server:{server['hash']}",
|
||||||
"name": server.get("openapi", {})
|
"name": (
|
||||||
.get("info", {})
|
server.get("openapi", {})
|
||||||
.get("title", "Tool Server"),
|
|
||||||
"meta": {
|
|
||||||
"description": server.get("openapi", {})
|
|
||||||
.get("info", {})
|
.get("info", {})
|
||||||
.get("description", ""),
|
.get("title", "Tool Server")
|
||||||
|
),
|
||||||
|
"meta": {
|
||||||
|
"description": (
|
||||||
|
server.get("openapi", {})
|
||||||
|
.get("info", {})
|
||||||
|
.get("description", "")
|
||||||
|
),
|
||||||
},
|
},
|
||||||
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
|
"access_control": (
|
||||||
server["idx"]
|
request.app.state.config.TOOL_SERVER_CONNECTIONS[server["idx"]]
|
||||||
]
|
.get("config", {})
|
||||||
.get("config", {})
|
.get("access_control", None)
|
||||||
.get("access_control", None),
|
),
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
}
|
}
|
||||||
|
@ -1,48 +1,26 @@
|
|||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import inspect
|
from functools import partial, update_wrapper
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, get_type_hints
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic.fields import FieldInfo
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
get_type_hints,
|
|
||||||
get_args,
|
|
||||||
get_origin,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
)
|
|
||||||
from functools import update_wrapper, partial
|
|
||||||
|
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from pydantic import BaseModel, Field, create_model
|
|
||||||
|
|
||||||
from langchain_core.utils.function_calling import (
|
from langchain_core.utils.function_calling import (
|
||||||
convert_to_openai_function as convert_pydantic_model_to_openai_function_spec,
|
convert_to_openai_function as convert_pydantic_model_to_openai_function_spec,
|
||||||
)
|
)
|
||||||
|
from open_webui.env import (
|
||||||
|
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||||
|
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
|
)
|
||||||
from open_webui.models.tools import Tools
|
from open_webui.models.tools import Tools
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
from open_webui.utils.plugin import load_tool_module_by_id
|
from open_webui.utils.plugin import load_tool_module_by_id
|
||||||
from open_webui.env import (
|
from pydantic import BaseModel, Field, create_model
|
||||||
SRC_LOG_LEVELS,
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
|
|
||||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
|
||||||
)
|
|
||||||
|
|
||||||
import copy
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -76,16 +54,18 @@ def get_tools(
|
|||||||
tool = Tools.get_tool_by_id(tool_id)
|
tool = Tools.get_tool_by_id(tool_id)
|
||||||
if tool is None:
|
if tool is None:
|
||||||
if tool_id.startswith("server:"):
|
if tool_id.startswith("server:"):
|
||||||
server_idx = int(tool_id.split(":")[1])
|
server_hash = tool_id.split(":", 1)[1] # Get hash part after "server:"
|
||||||
tool_server_connection = (
|
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx]
|
|
||||||
)
|
|
||||||
tool_server_data = None
|
tool_server_data = None
|
||||||
for server in request.app.state.TOOL_SERVERS:
|
for server in request.app.state.TOOL_SERVERS:
|
||||||
if server["idx"] == server_idx:
|
if server["hash"] == server_hash:
|
||||||
tool_server_data = server
|
tool_server_data = server
|
||||||
break
|
break
|
||||||
assert tool_server_data is not None
|
assert tool_server_data is not None
|
||||||
|
|
||||||
|
# Get connection config directly from server data
|
||||||
|
tool_server_connection = (
|
||||||
|
request.app.state.config.TOOL_SERVER_CONNECTIONS[server["idx"]]
|
||||||
|
)
|
||||||
specs = tool_server_data.get("specs", [])
|
specs = tool_server_data.get("specs", [])
|
||||||
|
|
||||||
for spec in specs:
|
for spec in specs:
|
||||||
@ -102,7 +82,8 @@ def get_tools(
|
|||||||
def make_tool_function(function_name, token, tool_server_data):
|
def make_tool_function(function_name, token, tool_server_data):
|
||||||
async def tool_function(**kwargs):
|
async def tool_function(**kwargs):
|
||||||
print(
|
print(
|
||||||
f"Executing tool function {function_name} with params: {kwargs}"
|
f"Executing tool function {function_name} with params:"
|
||||||
|
f" {kwargs}"
|
||||||
)
|
)
|
||||||
return await execute_tool_server(
|
return await execute_tool_server(
|
||||||
token=token,
|
token=token,
|
||||||
@ -190,8 +171,9 @@ def get_tools(
|
|||||||
"spec": spec,
|
"spec": spec,
|
||||||
# Misc info
|
# Misc info
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"file_handler": hasattr(module, "file_handler")
|
"file_handler": (
|
||||||
and module.file_handler,
|
hasattr(module, "file_handler") and module.file_handler
|
||||||
|
),
|
||||||
"citation": hasattr(module, "citation") and module.citation,
|
"citation": hasattr(module, "citation") and module.citation,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -440,6 +422,26 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
|||||||
return tool_payload
|
return tool_payload
|
||||||
|
|
||||||
|
|
||||||
|
def compute_server_hash(server_url: str, server_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Compute a hash for a tool server based on its URL and path.
|
||||||
|
|
||||||
|
Creates a consistent string representation by combining the normalized URL and path,
|
||||||
|
then uses Python's built-in hash() function for performance. The result is converted
|
||||||
|
to a positive integer and formatted as an 8-digit hexadecimal string for consistent
|
||||||
|
length and use as a server identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_url: The base URL of the server
|
||||||
|
server_path: The path to the OpenAPI spec
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A hash string to use as the server identifier
|
||||||
|
"""
|
||||||
|
combined = f"{server_url.rstrip('/')}/{server_path.lstrip('/')}"
|
||||||
|
return f"{hash(combined) & 0x7FFFFFFF:08x}"
|
||||||
|
|
||||||
|
|
||||||
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
@ -486,7 +488,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
|||||||
async def get_tool_servers_data(
|
async def get_tool_servers_data(
|
||||||
servers: List[Dict[str, Any]], session_token: Optional[str] = None
|
servers: List[Dict[str, Any]], session_token: Optional[str] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
# Prepare list of enabled servers along with their original index
|
# Prepare list of enabled servers along with their hash-based identifier
|
||||||
server_entries = []
|
server_entries = []
|
||||||
for idx, server in enumerate(servers):
|
for idx, server in enumerate(servers):
|
||||||
if server.get("config", {}).get("enable"):
|
if server.get("config", {}).get("enable"):
|
||||||
@ -504,6 +506,8 @@ async def get_tool_servers_data(
|
|||||||
|
|
||||||
info = server.get("info", {})
|
info = server.get("info", {})
|
||||||
|
|
||||||
|
server_hash = compute_server_hash(server.get("url", ""), openapi_path)
|
||||||
|
|
||||||
auth_type = server.get("auth_type", "bearer")
|
auth_type = server.get("auth_type", "bearer")
|
||||||
token = None
|
token = None
|
||||||
|
|
||||||
@ -511,19 +515,21 @@ async def get_tool_servers_data(
|
|||||||
token = server.get("key", "")
|
token = server.get("key", "")
|
||||||
elif auth_type == "session":
|
elif auth_type == "session":
|
||||||
token = session_token
|
token = session_token
|
||||||
server_entries.append((idx, server, full_url, info, token))
|
server_entries.append((idx, server_hash, server, full_url, info, token))
|
||||||
|
|
||||||
# Create async tasks to fetch data
|
# Create async tasks to fetch data
|
||||||
tasks = [
|
tasks = [
|
||||||
get_tool_server_data(token, url) for (_, _, url, _, token) in server_entries
|
get_tool_server_data(token, url) for (_, _, _, url, _, token) in server_entries
|
||||||
]
|
]
|
||||||
|
|
||||||
# Execute tasks concurrently
|
# Execute tasks concurrently
|
||||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
# Build final results with index and server metadata
|
# Build final results with ID, hash, url, openapi, info, and specs
|
||||||
results = []
|
results = []
|
||||||
for (idx, server, url, info, _), response in zip(server_entries, responses):
|
for (idx, server_hash, server, url, info, _), response in zip(
|
||||||
|
server_entries, responses
|
||||||
|
):
|
||||||
if isinstance(response, Exception):
|
if isinstance(response, Exception):
|
||||||
log.error(f"Failed to connect to {url} OpenAPI tool server")
|
log.error(f"Failed to connect to {url} OpenAPI tool server")
|
||||||
continue
|
continue
|
||||||
@ -540,6 +546,7 @@ async def get_tool_servers_data(
|
|||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"idx": idx,
|
"idx": idx,
|
||||||
|
"hash": server_hash, # Hash-based ID
|
||||||
"url": server.get("url"),
|
"url": server.get("url"),
|
||||||
"openapi": openapi_data,
|
"openapi": openapi_data,
|
||||||
"info": response.get("info"),
|
"info": response.get("info"),
|
||||||
|
Loading…
Reference in New Issue
Block a user