mirror of
https://github.com/open-webui/open-webui
synced 2025-06-22 18:07:17 +00:00
Merge fc6bded532
into 7513dc7e34
This commit is contained in:
commit
be1f7ee7ec
@ -1,29 +1,25 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import time
|
||||
import re
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.models.tools import (
|
||||
ToolForm,
|
||||
ToolModel,
|
||||
ToolResponse,
|
||||
ToolUserResponse,
|
||||
Tools,
|
||||
ToolUserResponse,
|
||||
)
|
||||
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.tools import get_tool_specs
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.utils.tools import get_tool_servers_data
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
||||
from open_webui.utils.tools import get_tool_servers_data, get_tool_specs
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
@ -53,21 +49,25 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||
tools.append(
|
||||
ToolUserResponse(
|
||||
**{
|
||||
"id": f"server:{server['idx']}",
|
||||
"user_id": f"server:{server['idx']}",
|
||||
"name": server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("title", "Tool Server"),
|
||||
"meta": {
|
||||
"description": server.get("openapi", {})
|
||||
"id": f"server:{server['hash']}",
|
||||
"user_id": f"server:{server['hash']}",
|
||||
"name": (
|
||||
server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("description", ""),
|
||||
.get("title", "Tool Server")
|
||||
),
|
||||
"meta": {
|
||||
"description": (
|
||||
server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("description", "")
|
||||
),
|
||||
},
|
||||
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
|
||||
server["idx"]
|
||||
]
|
||||
.get("config", {})
|
||||
.get("access_control", None),
|
||||
"access_control": (
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS[server["idx"]]
|
||||
.get("config", {})
|
||||
.get("access_control", None)
|
||||
),
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
|
@ -1,48 +1,26 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import inspect
|
||||
from functools import partial, update_wrapper
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, get_type_hints
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import yaml
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
get_type_hints,
|
||||
get_args,
|
||||
get_origin,
|
||||
Dict,
|
||||
List,
|
||||
Tuple,
|
||||
Union,
|
||||
Optional,
|
||||
Type,
|
||||
)
|
||||
from functools import update_wrapper, partial
|
||||
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function as convert_pydantic_model_to_openai_function_spec,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
from open_webui.models.tools import Tools
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.utils.plugin import load_tool_module_by_id
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||
)
|
||||
|
||||
import copy
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@ -76,16 +54,18 @@ def get_tools(
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
if tool is None:
|
||||
if tool_id.startswith("server:"):
|
||||
server_idx = int(tool_id.split(":")[1])
|
||||
tool_server_connection = (
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx]
|
||||
)
|
||||
server_hash = tool_id.split(":", 1)[1] # Get hash part after "server:"
|
||||
tool_server_data = None
|
||||
for server in request.app.state.TOOL_SERVERS:
|
||||
if server["idx"] == server_idx:
|
||||
if server["hash"] == server_hash:
|
||||
tool_server_data = server
|
||||
break
|
||||
assert tool_server_data is not None
|
||||
|
||||
# Get connection config directly from server data
|
||||
tool_server_connection = (
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS[server["idx"]]
|
||||
)
|
||||
specs = tool_server_data.get("specs", [])
|
||||
|
||||
for spec in specs:
|
||||
@ -102,7 +82,8 @@ def get_tools(
|
||||
def make_tool_function(function_name, token, tool_server_data):
|
||||
async def tool_function(**kwargs):
|
||||
print(
|
||||
f"Executing tool function {function_name} with params: {kwargs}"
|
||||
f"Executing tool function {function_name} with params:"
|
||||
f" {kwargs}"
|
||||
)
|
||||
return await execute_tool_server(
|
||||
token=token,
|
||||
@ -190,8 +171,9 @@ def get_tools(
|
||||
"spec": spec,
|
||||
# Misc info
|
||||
"metadata": {
|
||||
"file_handler": hasattr(module, "file_handler")
|
||||
and module.file_handler,
|
||||
"file_handler": (
|
||||
hasattr(module, "file_handler") and module.file_handler
|
||||
),
|
||||
"citation": hasattr(module, "citation") and module.citation,
|
||||
},
|
||||
}
|
||||
@ -440,6 +422,26 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
||||
return tool_payload
|
||||
|
||||
|
||||
def compute_server_hash(server_url: str, server_path: str) -> str:
|
||||
"""
|
||||
Compute a hash for a tool server based on its URL and path.
|
||||
|
||||
Creates a consistent string representation by combining the normalized URL and path,
|
||||
then uses Python's built-in hash() function for performance. The result is converted
|
||||
to a positive integer and formatted as an 8-digit hexadecimal string for consistent
|
||||
length and use as a server identifier.
|
||||
|
||||
Args:
|
||||
server_url: The base URL of the server
|
||||
server_path: The path to the OpenAPI spec
|
||||
|
||||
Returns:
|
||||
A hash string to use as the server identifier
|
||||
"""
|
||||
combined = f"{server_url.rstrip('/')}/{server_path.lstrip('/')}"
|
||||
return f"{hash(combined) & 0x7FFFFFFF:08x}"
|
||||
|
||||
|
||||
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
@ -486,7 +488,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||
async def get_tool_servers_data(
|
||||
servers: List[Dict[str, Any]], session_token: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
# Prepare list of enabled servers along with their original index
|
||||
# Prepare list of enabled servers along with their hash-based identifier
|
||||
server_entries = []
|
||||
for idx, server in enumerate(servers):
|
||||
if server.get("config", {}).get("enable"):
|
||||
@ -504,6 +506,8 @@ async def get_tool_servers_data(
|
||||
|
||||
info = server.get("info", {})
|
||||
|
||||
server_hash = compute_server_hash(server.get("url", ""), openapi_path)
|
||||
|
||||
auth_type = server.get("auth_type", "bearer")
|
||||
token = None
|
||||
|
||||
@ -511,19 +515,21 @@ async def get_tool_servers_data(
|
||||
token = server.get("key", "")
|
||||
elif auth_type == "session":
|
||||
token = session_token
|
||||
server_entries.append((idx, server, full_url, info, token))
|
||||
server_entries.append((idx, server_hash, server, full_url, info, token))
|
||||
|
||||
# Create async tasks to fetch data
|
||||
tasks = [
|
||||
get_tool_server_data(token, url) for (_, _, url, _, token) in server_entries
|
||||
get_tool_server_data(token, url) for (_, _, _, url, _, token) in server_entries
|
||||
]
|
||||
|
||||
# Execute tasks concurrently
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Build final results with index and server metadata
|
||||
# Build final results with ID, hash, url, openapi, info, and specs
|
||||
results = []
|
||||
for (idx, server, url, info, _), response in zip(server_entries, responses):
|
||||
for (idx, server_hash, server, url, info, _), response in zip(
|
||||
server_entries, responses
|
||||
):
|
||||
if isinstance(response, Exception):
|
||||
log.error(f"Failed to connect to {url} OpenAPI tool server")
|
||||
continue
|
||||
@ -540,6 +546,7 @@ async def get_tool_servers_data(
|
||||
results.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"hash": server_hash, # Hash-based ID
|
||||
"url": server.get("url"),
|
||||
"openapi": openapi_data,
|
||||
"info": response.get("info"),
|
||||
|
Loading…
Reference in New Issue
Block a user