Merge branch 'dev' into k_reranker

This commit is contained in:
Timothy Jaeryang Baek
2025-03-26 20:50:31 -07:00
committed by GitHub
147 changed files with 6065 additions and 1350 deletions

View File

@@ -106,6 +106,7 @@ async def process_filter_functions(
# Handle file cleanup for inlet
if skip_files and "files" in form_data.get("metadata", {}):
del form_data["files"]
del form_data["metadata"]["files"]
return form_data, {}

View File

@@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, tools
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
@@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
event_caller = extra_params["__event_call__"]
metadata = extra_params["__metadata__"]
task_model_id = get_task_model_id(
body["model"],
request.app.state.config.TASK_MODEL,
@@ -189,19 +192,33 @@ async def chat_completion_tools_handler(
tool_function_params = tool_call.get("parameters", {})
try:
required_params = (
tools[tool_function_name]
.get("spec", {})
.get("parameters", {})
.get("required", [])
tool = tools[tool_function_name]
spec = tool.get("spec", {})
allowed_params = (
spec.get("parameters", {}).get("properties", {}).keys()
)
tool_function = tools[tool_function_name]["callable"]
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
if k in allowed_params
}
tool_output = await tool_function(**tool_function_params)
if tool.get("direct", False):
tool_output = await tool_function(**tool_function_params)
else:
tool_output = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get("session_id", None),
},
}
)
except Exception as e:
tool_output = str(e)
@@ -767,12 +784,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
}
form_data["metadata"] = metadata
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
tool_specs = form_data.get("tool_specs", None)
log.debug(f"{tool_ids=}")
log.debug(f"{tool_specs=}")
tools_dict = {}
if tool_ids:
# If tool_ids field is present, then get the tools
tools = get_tools(
tools_dict = get_tools(
request,
tool_ids,
user,
@@ -783,20 +806,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
log.info(f"{tools_dict=}")
if tool_specs:
for tool in tool_specs:
callable = tool.pop("callable", None)
tools_dict[tool["name"]] = {
"direct": True,
"callable": callable,
"spec": tool,
}
if tools_dict:
if metadata.get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools
metadata["tools"] = tools_dict
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools.values()
for tool in tools_dict.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, user, models, tools
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
@@ -815,7 +848,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
for source_idx, source in enumerate(sources):
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
context_string += f"<source><source_id>{source_idx + 1}</source_id><source_context>{doc_context}</source_context></source>\n"
context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
@@ -1082,8 +1115,6 @@ async def process_chat_response(
for filter_id in get_sorted_filter_ids(model)
]
print(f"{filter_functions=}")
# Streaming response
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
@@ -1563,7 +1594,9 @@ async def process_chat_response(
value = delta.get("content")
reasoning_content = delta.get("reasoning_content")
reasoning_content = delta.get(
"reasoning_content"
) or delta.get("reasoning")
if reasoning_content:
if (
not content_blocks
@@ -1766,18 +1799,36 @@ async def process_chat_response(
spec = tool.get("spec", {})
try:
required_params = spec.get("parameters", {}).get(
"required", []
allowed_params = (
spec.get("parameters", {})
.get("properties", {})
.keys()
)
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
if k in allowed_params
}
tool_result = await tool_function(
**tool_function_params
)
if tool.get("direct", False):
tool_result = await tool_function(
**tool_function_params
)
else:
tool_result = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get(
"session_id", None
),
},
}
)
except Exception as e:
tool_result = str(e)

View File

@@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None):
"created": int(time.time()),
"owned_by": "ollama",
"ollama": model,
"tags": model.get("tags", []),
}
for model in ollama_models["models"]
]

View File

@@ -94,7 +94,7 @@ class OAuthManager:
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
oauth_roles = None
oauth_roles = []
# Default/fallback role if no matching roles are found
role = auth_manager_config.DEFAULT_USER_ROLE
@@ -104,7 +104,7 @@ class OAuthManager:
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
oauth_roles = claim_data if isinstance(claim_data, list) else []
log.debug(f"Oauth Roles claim: {oauth_claim}")
log.debug(f"User roles from oauth: {oauth_roles}")
@@ -140,6 +140,7 @@ class OAuthManager:
log.debug("Running OAUTH Group management")
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
user_oauth_groups = []
# Nested claim search for groups claim
if oauth_claim:
claim_data = user_data
@@ -160,7 +161,7 @@ class OAuthManager:
# Remove groups that user is no longer a part of
for group_model in user_current_groups:
if group_model.name not in user_oauth_groups:
if user_oauth_groups and group_model.name not in user_oauth_groups:
# Remove group from user
log.debug(
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
@@ -186,8 +187,10 @@ class OAuthManager:
# Add user to new groups
for group_model in all_available_groups:
if group_model.name in user_oauth_groups and not any(
gm.name == group_model.name for gm in user_current_groups
if (
user_oauth_groups
and group_model.name in user_oauth_groups
and not any(gm.name == group_model.name for gm in user_current_groups)
):
# Add user to group
log.debug(

View File

@@ -110,6 +110,11 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
"num_thread": int,
}
# Extract keep_alive from options if it exists
if "options" in form_data and "keep_alive" in form_data["options"]:
form_data["keep_alive"] = form_data["options"]["keep_alive"]
del form_data["options"]["keep_alive"]
return apply_model_params_to_body(params, form_data, mappings)
@@ -231,6 +236,11 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
"system"
] # To prevent Ollama warning of invalid option provided
# Extract keep_alive from options if it exists
if "keep_alive" in ollama_options:
ollama_payload["keep_alive"] = ollama_options["keep_alive"]
del ollama_options["keep_alive"]
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
if "stop" in openai_payload:
ollama_options = ollama_payload.get("options", {})

View File

@@ -7,7 +7,7 @@ import types
import tempfile
import logging
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import SRC_LOG_LEVELS, PIP_OPTIONS, PIP_PACKAGE_INDEX_OPTIONS
from open_webui.models.functions import Functions
from open_webui.models.tools import Tools
@@ -165,15 +165,19 @@ def load_function_module_by_id(function_id, content=None):
os.unlink(temp_file.name)
def install_frontmatter_requirements(requirements):
def install_frontmatter_requirements(requirements: str):
if requirements:
try:
req_list = [req.strip() for req in requirements.split(",")]
for req in req_list:
log.info(f"Installing requirement: {req}")
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
log.info(f"Installing requirements: {' '.join(req_list)}")
subprocess.check_call(
[sys.executable, "-m", "pip", "install"]
+ PIP_OPTIONS
+ req_list
+ PIP_PACKAGE_INDEX_OPTIONS
)
except Exception as e:
log.error(f"Error installing package: {req}")
log.error(f"Error installing packages: {' '.join(req_list)}")
raise e
else:

View File

@@ -0,0 +1,26 @@
from opentelemetry.semconv.trace import SpanAttributes as _SpanAttributes
# Span Tags
SPAN_DB_TYPE = "mysql"
SPAN_REDIS_TYPE = "redis"
SPAN_DURATION = "duration"
SPAN_SQL_STR = "sql"
SPAN_SQL_EXPLAIN = "explain"
SPAN_ERROR_TYPE = "error"
class SpanAttributes(_SpanAttributes):
"""
Span Attributes
"""
DB_INSTANCE = "db.instance"
DB_TYPE = "db.type"
DB_IP = "db.ip"
DB_PORT = "db.port"
ERROR_KIND = "error.kind"
ERROR_OBJECT = "error.object"
ERROR_MESSAGE = "error.message"
RESULT_CODE = "result.code"
RESULT_MESSAGE = "result.message"
RESULT_ERRORS = "result.errors"

View File

@@ -0,0 +1,31 @@
import threading
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import BatchSpanProcessor
class LazyBatchSpanProcessor(BatchSpanProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.done = True
with self.condition:
self.condition.notify_all()
self.worker_thread.join()
self.done = False
self.worker_thread = None
def on_end(self, span: ReadableSpan) -> None:
if self.worker_thread is None:
self.worker_thread = threading.Thread(
name=self.__class__.__name__, target=self.worker, daemon=True
)
self.worker_thread.start()
super().on_end(span)
def shutdown(self) -> None:
self.done = True
with self.condition:
self.condition.notify_all()
if self.worker_thread:
self.worker_thread.join()
self.span_exporter.shutdown()

View File

@@ -0,0 +1,202 @@
import logging
import traceback
from typing import Collection, Union
from aiohttp import (
TraceRequestStartParams,
TraceRequestEndParams,
TraceRequestExceptionParams,
)
from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi
from fastapi import FastAPI
from opentelemetry.instrumentation.httpx import (
HTTPXClientInstrumentor,
RequestInfo,
ResponseInfo,
)
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
from opentelemetry.trace import Span, StatusCode
from redis import Redis
from requests import PreparedRequest, Response
from sqlalchemy import Engine
from fastapi import status
from open_webui.utils.telemetry.constants import SPAN_REDIS_TYPE, SpanAttributes
from open_webui.env import SRC_LOG_LEVELS
logger = logging.getLogger(__name__)
logger.setLevel(SRC_LOG_LEVELS["MAIN"])
def requests_hook(span: Span, request: PreparedRequest):
"""
Http Request Hook
"""
span.update_name(f"{request.method} {request.url}")
span.set_attributes(
attributes={
SpanAttributes.HTTP_URL: request.url,
SpanAttributes.HTTP_METHOD: request.method,
}
)
def response_hook(span: Span, request: PreparedRequest, response: Response):
"""
HTTP Response Hook
"""
span.set_attributes(
attributes={
SpanAttributes.HTTP_STATUS_CODE: response.status_code,
}
)
span.set_status(StatusCode.ERROR if response.status_code >= 400 else StatusCode.OK)
def redis_request_hook(span: Span, instance: Redis, args, kwargs):
"""
Redis Request Hook
"""
try:
connection_kwargs: dict = instance.connection_pool.connection_kwargs
host = connection_kwargs.get("host")
port = connection_kwargs.get("port")
db = connection_kwargs.get("db")
span.set_attributes(
{
SpanAttributes.DB_INSTANCE: f"{host}/{db}",
SpanAttributes.DB_NAME: f"{host}/{db}",
SpanAttributes.DB_TYPE: SPAN_REDIS_TYPE,
SpanAttributes.DB_PORT: port,
SpanAttributes.DB_IP: host,
SpanAttributes.DB_STATEMENT: " ".join([str(i) for i in args]),
SpanAttributes.DB_OPERATION: str(args[0]),
}
)
except Exception: # pylint: disable=W0718
logger.error(traceback.format_exc())
def httpx_request_hook(span: Span, request: RequestInfo):
"""
HTTPX Request Hook
"""
span.update_name(f"{request.method.decode()} {str(request.url)}")
span.set_attributes(
attributes={
SpanAttributes.HTTP_URL: str(request.url),
SpanAttributes.HTTP_METHOD: request.method.decode(),
}
)
def httpx_response_hook(span: Span, request: RequestInfo, response: ResponseInfo):
"""
HTTPX Response Hook
"""
span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.status_code)
span.set_status(
StatusCode.ERROR
if response.status_code >= status.HTTP_400_BAD_REQUEST
else StatusCode.OK
)
async def httpx_async_request_hook(span: Span, request: RequestInfo):
"""
Async Request Hook
"""
httpx_request_hook(span, request)
async def httpx_async_response_hook(
span: Span, request: RequestInfo, response: ResponseInfo
):
"""
Async Response Hook
"""
httpx_response_hook(span, request, response)
def aiohttp_request_hook(span: Span, request: TraceRequestStartParams):
"""
Aiohttp Request Hook
"""
span.update_name(f"{request.method} {str(request.url)}")
span.set_attributes(
attributes={
SpanAttributes.HTTP_URL: str(request.url),
SpanAttributes.HTTP_METHOD: request.method,
}
)
def aiohttp_response_hook(
span: Span, response: Union[TraceRequestExceptionParams, TraceRequestEndParams]
):
"""
Aiohttp Response Hook
"""
if isinstance(response, TraceRequestEndParams):
span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.response.status)
span.set_status(
StatusCode.ERROR
if response.response.status >= status.HTTP_400_BAD_REQUEST
else StatusCode.OK
)
elif isinstance(response, TraceRequestExceptionParams):
span.set_status(StatusCode.ERROR)
span.set_attribute(SpanAttributes.ERROR_MESSAGE, str(response.exception))
class Instrumentor(BaseInstrumentor):
"""
Instrument OT
"""
def __init__(self, app: FastAPI, db_engine: Engine):
self.app = app
self.db_engine = db_engine
def instrumentation_dependencies(self) -> Collection[str]:
return []
def _instrument(self, **kwargs):
instrument_fastapi(app=self.app)
SQLAlchemyInstrumentor().instrument(engine=self.db_engine)
RedisInstrumentor().instrument(request_hook=redis_request_hook)
RequestsInstrumentor().instrument(
request_hook=requests_hook, response_hook=response_hook
)
LoggingInstrumentor().instrument()
HTTPXClientInstrumentor().instrument(
request_hook=httpx_request_hook,
response_hook=httpx_response_hook,
async_request_hook=httpx_async_request_hook,
async_response_hook=httpx_async_response_hook,
)
AioHttpClientInstrumentor().instrument(
request_hook=aiohttp_request_hook,
response_hook=aiohttp_response_hook,
)
def _uninstrument(self, **kwargs):
if getattr(self, "instrumentors", None) is None:
return
for instrumentor in self.instrumentors:
instrumentor.uninstrument()

View File

@@ -0,0 +1,23 @@
from fastapi import FastAPI
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from sqlalchemy import Engine
from open_webui.utils.telemetry.exporters import LazyBatchSpanProcessor
from open_webui.utils.telemetry.instrumentors import Instrumentor
from open_webui.env import OTEL_SERVICE_NAME, OTEL_EXPORTER_OTLP_ENDPOINT
def setup(app: FastAPI, db_engine: Engine):
# set up trace
trace.set_tracer_provider(
TracerProvider(
resource=Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
)
)
# otlp export
exporter = OTLPSpanExporter(endpoint=OTEL_EXPORTER_OTLP_ENDPOINT)
trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter))
Instrumentor(app=app, db_engine=db_engine).instrument()

View File

@@ -1,6 +1,9 @@
import inspect
import logging
import re
import inspect
import uuid
from typing import Any, Awaitable, Callable, get_type_hints
from functools import update_wrapper, partial