mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'dev' into pyodide-files
This commit is contained in:
@@ -13,6 +13,8 @@ import pytz
|
||||
from pytz import UTC
|
||||
from typing import Optional, Union, List, Dict
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from open_webui.models.users import Users
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
@@ -194,7 +196,17 @@ def get_current_user(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||
)
|
||||
|
||||
return get_current_user_by_api_key(token)
|
||||
user = get_current_user_by_api_key(token)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
|
||||
return user
|
||||
|
||||
# auth by jwt token
|
||||
try:
|
||||
@@ -213,6 +225,14 @@ def get_current_user(
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "jwt")
|
||||
|
||||
# Refresh the user's last active timestamp asynchronously
|
||||
# to prevent blocking the request
|
||||
if background_tasks:
|
||||
@@ -234,6 +254,14 @@ def get_current_user_by_api_key(api_key: str):
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
|
||||
return user
|
||||
|
||||
@@ -309,6 +309,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
metadata = {
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"filter_ids": data.get("filter_ids", []),
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
@@ -330,7 +331,9 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
try:
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(model)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
||||
result, _ = await process_filter_functions(
|
||||
@@ -389,11 +392,8 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
}
|
||||
)
|
||||
|
||||
if action_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[action_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(action_id)
|
||||
request.app.state.FUNCTIONS[action_id] = function_module
|
||||
function_module, _, _ = load_function_module_by_id(action_id)
|
||||
request.app.state.FUNCTIONS[action_id] = function_module
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(action_id)
|
||||
|
||||
@@ -9,7 +9,18 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_sorted_filter_ids(model: dict):
|
||||
def get_function_module(request, function_id):
|
||||
"""
|
||||
Get the function module by its ID.
|
||||
"""
|
||||
|
||||
function_module, _, _ = load_function_module_by_id(function_id)
|
||||
request.app.state.FUNCTIONS[function_id] = function_module
|
||||
|
||||
return function_module
|
||||
|
||||
|
||||
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None:
|
||||
@@ -21,14 +32,23 @@ def get_sorted_filter_ids(model: dict):
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
active_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
|
||||
for filter_id in active_filter_ids:
|
||||
function_module = get_function_module(request, filter_id)
|
||||
|
||||
if getattr(function_module, "toggle", None) and (
|
||||
filter_id not in enabled_filter_ids
|
||||
):
|
||||
active_filter_ids.remove(filter_id)
|
||||
continue
|
||||
|
||||
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
|
||||
filter_ids.sort(key=get_priority)
|
||||
|
||||
return filter_ids
|
||||
|
||||
|
||||
@@ -43,12 +63,7 @@ async def process_filter_functions(
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
function_module = get_function_module(request, filter_id)
|
||||
# Prepare handler function
|
||||
handler = getattr(function_module, filter_type, None)
|
||||
if not handler:
|
||||
|
||||
@@ -43,6 +43,7 @@ from open_webui.routers.pipelines import (
|
||||
process_pipeline_outlet_filter,
|
||||
)
|
||||
from open_webui.routers.files import upload_file
|
||||
from open_webui.routers.memories import query_memory, QueryMemoryForm
|
||||
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
@@ -253,7 +254,12 @@ async def chat_completion_tools_handler(
|
||||
"name": (f"TOOL:{tool_name}"),
|
||||
},
|
||||
"document": [tool_result],
|
||||
"metadata": [{"source": (f"TOOL:{tool_name}")}],
|
||||
"metadata": [
|
||||
{
|
||||
"source": (f"TOOL:{tool_name}"),
|
||||
"parameters": tool_function_params,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -292,6 +298,38 @@ async def chat_completion_tools_handler(
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
async def chat_memory_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
results = await query_memory(
|
||||
request,
|
||||
QueryMemoryForm(
|
||||
**{"content": get_last_user_message(form_data["messages"]), "k": 3}
|
||||
),
|
||||
user,
|
||||
)
|
||||
|
||||
user_context = ""
|
||||
if results and hasattr(results, "documents"):
|
||||
if results.documents and len(results.documents) > 0:
|
||||
for doc_idx, doc in enumerate(results.documents[0]):
|
||||
created_at_date = "Unknown Date"
|
||||
|
||||
if results.metadatas[0][doc_idx].get("created_at"):
|
||||
created_at_timestamp = results.metadatas[0][doc_idx]["created_at"]
|
||||
created_at_date = time.strftime(
|
||||
"%Y-%m-%d", time.localtime(created_at_timestamp)
|
||||
)
|
||||
|
||||
user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n"
|
||||
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
f"User Context:\n{user_context}\n", form_data["messages"], append=True
|
||||
)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
async def chat_web_search_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
@@ -342,6 +380,11 @@ async def chat_web_search_handler(
|
||||
log.exception(e)
|
||||
queries = [user_message]
|
||||
|
||||
# Check if generated queries are empty
|
||||
if len(queries) == 1 and queries[0].strip() == "":
|
||||
queries = [user_message]
|
||||
|
||||
# Check if queries are not found
|
||||
if len(queries) == 0:
|
||||
await event_emitter(
|
||||
{
|
||||
@@ -653,7 +696,7 @@ def apply_params_to_form_data(form_data, model):
|
||||
convert_logit_bias_input_to_json(params["logit_bias"])
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error parsing logit_bias: {e}")
|
||||
log.exception(f"Error parsing logit_bias: {e}")
|
||||
|
||||
return form_data
|
||||
|
||||
@@ -751,9 +794,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
raise e
|
||||
|
||||
try:
|
||||
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(model)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
||||
form_data, flags = await process_filter_functions(
|
||||
@@ -768,6 +814,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
|
||||
features = form_data.pop("features", None)
|
||||
if features:
|
||||
if "memory" in features and features["memory"]:
|
||||
form_data = await chat_memory_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
if "web_search" in features and features["web_search"]:
|
||||
form_data = await chat_web_search_handler(
|
||||
request, form_data, extra_params, user
|
||||
@@ -870,6 +921,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
for doc_context, doc_meta in zip(
|
||||
source["document"], source["metadata"]
|
||||
):
|
||||
source_name = source.get("source", {}).get("name", None)
|
||||
citation_id = (
|
||||
doc_meta.get("source", None)
|
||||
or source.get("source", {}).get("id", None)
|
||||
@@ -877,7 +929,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
)
|
||||
if citation_id not in citation_idx:
|
||||
citation_idx[citation_id] = len(citation_idx) + 1
|
||||
context_string += f'<source id="{citation_idx[citation_id]}">{doc_context}</source>\n'
|
||||
context_string += (
|
||||
f'<source id="{citation_idx[citation_id]}"'
|
||||
+ (f' name="{source_name}"' if source_name else "")
|
||||
+ f">{doc_context}</source>\n"
|
||||
)
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
@@ -944,21 +1000,36 @@ async def process_chat_response(
|
||||
message = message_map.get(metadata["message_id"]) if message_map else None
|
||||
|
||||
if message:
|
||||
messages = get_message_list(message_map, message.get("id"))
|
||||
message_list = get_message_list(message_map, message.get("id"))
|
||||
|
||||
# Remove reasoning details and files from the messages.
|
||||
# Remove details tags and files from the messages.
|
||||
# as get_message_list creates a new list, it does not affect
|
||||
# the original messages outside of this handler
|
||||
for message in messages:
|
||||
message["content"] = re.sub(
|
||||
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
|
||||
"",
|
||||
message["content"],
|
||||
flags=re.S,
|
||||
).strip()
|
||||
|
||||
if message.get("files"):
|
||||
message["files"] = []
|
||||
messages = []
|
||||
for message in message_list:
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
content = item["text"]
|
||||
break
|
||||
|
||||
if isinstance(content, str):
|
||||
content = re.sub(
|
||||
r"<details\b[^>]*>.*?<\/details>|!\[.*?\]\(.*?\)",
|
||||
"",
|
||||
content,
|
||||
flags=re.S | re.I,
|
||||
).strip()
|
||||
|
||||
messages.append(
|
||||
{
|
||||
**message,
|
||||
"role": message["role"],
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
if tasks and messages:
|
||||
if TASKS.TITLE_GENERATION in tasks:
|
||||
@@ -1171,7 +1242,9 @@ async def process_chat_response(
|
||||
}
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(model)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
||||
# Streaming response
|
||||
|
||||
@@ -130,7 +130,9 @@ def prepend_to_first_user_message_content(
|
||||
return messages
|
||||
|
||||
|
||||
def add_or_update_system_message(content: str, messages: list[dict]):
|
||||
def add_or_update_system_message(
|
||||
content: str, messages: list[dict], append: bool = False
|
||||
):
|
||||
"""
|
||||
Adds a new system message at the beginning of the messages list
|
||||
or updates the existing system message at the beginning.
|
||||
@@ -141,7 +143,10 @@ def add_or_update_system_message(content: str, messages: list[dict]):
|
||||
"""
|
||||
|
||||
if messages and messages[0].get("role") == "system":
|
||||
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
|
||||
if append:
|
||||
messages[0]["content"] = f"{messages[0]['content']}\n{content}"
|
||||
else:
|
||||
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
|
||||
else:
|
||||
# Insert at the beginning
|
||||
messages.insert(0, {"role": "system", "content": content})
|
||||
|
||||
@@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None):
|
||||
"created": int(time.time()),
|
||||
"owned_by": "ollama",
|
||||
"ollama": model,
|
||||
"connection_type": model.get("connection_type", "local"),
|
||||
"tags": model.get("tags", []),
|
||||
}
|
||||
for model in ollama_models["models"]
|
||||
@@ -110,6 +111,14 @@ async def get_all_models(request, user: UserModel = None):
|
||||
for function in Functions.get_functions_by_type("action", active_only=True)
|
||||
]
|
||||
|
||||
global_filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
]
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
custom_models = Models.get_all_models()
|
||||
for custom_model in custom_models:
|
||||
if custom_model.base_model_id is None:
|
||||
@@ -125,13 +134,20 @@ async def get_all_models(request, user: UserModel = None):
|
||||
model["name"] = custom_model.name
|
||||
model["info"] = custom_model.model_dump()
|
||||
|
||||
# Set action_ids and filter_ids
|
||||
action_ids = []
|
||||
filter_ids = []
|
||||
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
action_ids.extend(
|
||||
model["info"]["meta"].get("actionIds", [])
|
||||
)
|
||||
filter_ids.extend(
|
||||
model["info"]["meta"].get("filterIds", [])
|
||||
)
|
||||
|
||||
model["action_ids"] = action_ids
|
||||
model["filter_ids"] = filter_ids
|
||||
else:
|
||||
models.remove(model)
|
||||
|
||||
@@ -140,7 +156,9 @@ async def get_all_models(request, user: UserModel = None):
|
||||
):
|
||||
owned_by = "openai"
|
||||
pipe = None
|
||||
|
||||
action_ids = []
|
||||
filter_ids = []
|
||||
|
||||
for model in models:
|
||||
if (
|
||||
@@ -154,9 +172,13 @@ async def get_all_models(request, user: UserModel = None):
|
||||
|
||||
if custom_model.meta:
|
||||
meta = custom_model.meta.model_dump()
|
||||
|
||||
if "actionIds" in meta:
|
||||
action_ids.extend(meta["actionIds"])
|
||||
|
||||
if "filterIds" in meta:
|
||||
filter_ids.extend(meta["filterIds"])
|
||||
|
||||
models.append(
|
||||
{
|
||||
"id": f"{custom_model.id}",
|
||||
@@ -168,6 +190,7 @@ async def get_all_models(request, user: UserModel = None):
|
||||
"preset": True,
|
||||
**({"pipe": pipe} if pipe is not None else {}),
|
||||
"action_ids": action_ids,
|
||||
"filter_ids": filter_ids,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -181,8 +204,11 @@ async def get_all_models(request, user: UserModel = None):
|
||||
"id": f"{function.id}.{action['id']}",
|
||||
"name": action.get("name", f"{function.name} ({action['id']})"),
|
||||
"description": function.meta.description,
|
||||
"icon_url": action.get(
|
||||
"icon_url", function.meta.manifest.get("icon_url", None)
|
||||
"icon": action.get(
|
||||
"icon_url",
|
||||
function.meta.manifest.get("icon_url", None)
|
||||
or getattr(module, "icon_url", None)
|
||||
or getattr(module, "icon", None),
|
||||
),
|
||||
}
|
||||
for action in actions
|
||||
@@ -193,16 +219,28 @@ async def get_all_models(request, user: UserModel = None):
|
||||
"id": function.id,
|
||||
"name": function.name,
|
||||
"description": function.meta.description,
|
||||
"icon_url": function.meta.manifest.get("icon_url", None),
|
||||
"icon": function.meta.manifest.get("icon_url", None)
|
||||
or getattr(module, "icon_url", None)
|
||||
or getattr(module, "icon", None),
|
||||
}
|
||||
]
|
||||
|
||||
# Process filter_ids to get the filters
|
||||
def get_filter_items_from_module(function, module):
|
||||
return [
|
||||
{
|
||||
"id": function.id,
|
||||
"name": function.name,
|
||||
"description": function.meta.description,
|
||||
"icon": function.meta.manifest.get("icon_url", None)
|
||||
or getattr(module, "icon_url", None)
|
||||
or getattr(module, "icon", None),
|
||||
}
|
||||
]
|
||||
|
||||
def get_function_module_by_id(function_id):
|
||||
if function_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[function_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(function_id)
|
||||
request.app.state.FUNCTIONS[function_id] = function_module
|
||||
function_module, _, _ = load_function_module_by_id(function_id)
|
||||
request.app.state.FUNCTIONS[function_id] = function_module
|
||||
return function_module
|
||||
|
||||
for model in models:
|
||||
@@ -211,6 +249,11 @@ async def get_all_models(request, user: UserModel = None):
|
||||
for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
|
||||
if action_id in enabled_action_ids
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id
|
||||
for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
|
||||
if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
model["actions"] = []
|
||||
for action_id in action_ids:
|
||||
@@ -222,6 +265,20 @@ async def get_all_models(request, user: UserModel = None):
|
||||
model["actions"].extend(
|
||||
get_action_items_from_module(action_function, function_module)
|
||||
)
|
||||
|
||||
model["filters"] = []
|
||||
for filter_id in filter_ids:
|
||||
filter_function = Functions.get_function_by_id(filter_id)
|
||||
if filter_function is None:
|
||||
raise Exception(f"Filter not found: {filter_id}")
|
||||
|
||||
function_module = get_function_module_by_id(filter_id)
|
||||
|
||||
if getattr(function_module, "toggle", None):
|
||||
model["filters"].extend(
|
||||
get_filter_items_from_module(filter_function, function_module)
|
||||
)
|
||||
|
||||
log.debug(f"get_all_models() returned {len(models)} models")
|
||||
|
||||
request.app.state.MODELS = {model["id"]: model for model in models}
|
||||
|
||||
@@ -41,6 +41,7 @@ from open_webui.config import (
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
WEBUI_NAME,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
@@ -305,8 +306,10 @@ class OAuthManager:
|
||||
get_kwargs["headers"] = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(picture_url, **get_kwargs) as resp:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||
) as resp:
|
||||
if resp.ok:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(picture).decode(
|
||||
@@ -371,7 +374,9 @@ class OAuthManager:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/user/emails", headers=headers
|
||||
"https://api.github.com/user/emails",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as resp:
|
||||
if resp.ok:
|
||||
emails = await resp.json()
|
||||
@@ -531,5 +536,10 @@ class OAuthManager:
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
# Redirect back to the frontend with the JWT token
|
||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||
|
||||
redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url
|
||||
if redirect_base_url.endswith("/"):
|
||||
redirect_base_url = redirect_base_url[:-1]
|
||||
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
|
||||
|
||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||
|
||||
@@ -57,6 +57,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": float,
|
||||
"min_p": float,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": float,
|
||||
"presence_penalty": float,
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_task_model_id(
|
||||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
# Check if the user has a custom task model and use that model
|
||||
if models[task_model_id].get("owned_by") == "ollama":
|
||||
if models[task_model_id].get("connection_type") == "local":
|
||||
if task_model and task_model in models:
|
||||
task_model_id = task_model
|
||||
else:
|
||||
|
||||
@@ -37,6 +37,7 @@ from open_webui.models.tools import Tools
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.utils.plugin import load_tool_module_by_id
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||
)
|
||||
@@ -44,6 +45,7 @@ from open_webui.env import (
|
||||
import copy
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
def get_async_tool_function_and_apply_extra_params(
|
||||
@@ -158,7 +160,7 @@ def get_tools(
|
||||
# TODO: Fix hack for OpenAI API
|
||||
# Some times breaks OpenAI but others don't. Leaving the comment
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val["type"] == "str":
|
||||
if val.get("type") == "str":
|
||||
val["type"] = "string"
|
||||
|
||||
# Remove internal reserved parameters (e.g. __id__, __user__)
|
||||
@@ -477,7 +479,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||
"specs": convert_openapi_to_tool_payload(res),
|
||||
}
|
||||
|
||||
print("Fetched data:", data)
|
||||
log.info("Fetched data:", data)
|
||||
return data
|
||||
|
||||
|
||||
@@ -510,7 +512,7 @@ async def get_tool_servers_data(
|
||||
results = []
|
||||
for (idx, server, url, _), response in zip(server_entries, responses):
|
||||
if isinstance(response, Exception):
|
||||
print(f"Failed to connect to {url} OpenAPI tool server")
|
||||
log.error(f"Failed to connect to {url} OpenAPI tool server")
|
||||
continue
|
||||
|
||||
results.append(
|
||||
@@ -620,5 +622,5 @@ async def execute_tool_server(
|
||||
|
||||
except Exception as err:
|
||||
error = str(err)
|
||||
print("API Request Error:", error)
|
||||
log.exception("API Request Error:", error)
|
||||
return {"error": error}
|
||||
|
||||
Reference in New Issue
Block a user