This commit is contained in:
Timothy Jaeryang Baek
2024-12-12 20:22:17 -08:00
parent 403262d764
commit 4311bb7b99
10 changed files with 1102 additions and 966 deletions

View File

@@ -0,0 +1,380 @@
import time
import logging
import sys
from aiocache import cached
from typing import Any
import random
import json
import inspect
from fastapi import Request
from starlette.responses import Response, StreamingResponse
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.functions import generate_function_chat_completion
from open_webui.routers.openai import (
generate_chat_completion as generate_openai_chat_completion,
)
from open_webui.routers.ollama import (
generate_chat_completion as generate_ollama_chat_completion,
)
from open_webui.routers.pipelines import (
process_pipeline_outlet_filter,
)
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.access_control import has_access
from open_webui.utils.models import get_all_models
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def generate_chat_completion(
request: Request,
form_data: dict,
user: Any,
bypass_filter: bool = False,
):
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if model.get("arena"):
if not has_access(
user.id,
type="read",
access_control=model.get("info", {})
.get("meta", {})
.get("access_control", {}),
):
raise Exception("Model not found")
else:
model_info = Models.get_model_by_id(model_id)
if not model_info:
raise Exception("Model not found")
elif not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise Exception("Model not found")
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in await get_all_models(request)
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in await get_all_models(request)
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completion(
form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
)
else:
return {
**(await generate_chat_completion(form_data, user, bypass_filter=True)),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
)
if form_data.stream:
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
)
async def chat_completed(request: Request, form_data: dict, user: Any):
await get_all_models(request)
models = request.app.state.MODELS
data = form_data
model_id = data["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
try:
data = process_pipeline_outlet_filter(request, data, user, models)
except Exception as e:
return Exception(f"Error: {e}")
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "outlet"):
continue
try:
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
return Exception(f"Error: {e}")
return data
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
if "." in action_id:
action_id, sub_action_id = action_id.split(".")
else:
sub_action_id = None
action = Functions.get_function_by_id(action_id)
if not action:
raise Exception(f"Action not found: {action_id}")
await get_all_models(request)
models = request.app.state.MODELS
data = form_data
model_id = data["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
if action_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
request.app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
if hasattr(function_module, "action"):
try:
action = function_module.action
# Get the signature of the function
sig = inspect.signature(action)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": sub_action_id if sub_action_id is not None else action_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(action):
data = await action(**params)
else:
data = action(**params)
except Exception as e:
return Exception(f"Error: {e}")
return data

View File

@@ -0,0 +1,222 @@
import time
import logging
import sys
from aiocache import cached
from fastapi import Request
from open_webui.routers import openai, ollama
from open_webui.functions import get_function_models
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.config import (
DEFAULT_ARENA_MODEL,
)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def get_all_base_models(request: Request):
function_models = []
openai_models = []
ollama_models = []
if request.app.state.config.ENABLE_OPENAI_API:
openai_models = await openai.get_all_models(request)
openai_models = openai_models["data"]
if request.app.state.config.ENABLE_OLLAMA_API:
ollama_models = await ollama.get_all_models(request)
ollama_models = [
{
"id": model["model"],
"name": model["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "ollama",
"ollama": model,
}
for model in ollama_models["models"]
]
function_models = await get_function_models()
models = function_models + openai_models + ollama_models
return models
@cached(ttl=3)
async def get_all_models(request):
models = await get_all_base_models(request)
# If there are no models, return an empty list
if len(models) == 0:
return []
# Add arena models
if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
arena_models = []
if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
arena_models = [
{
"id": model["id"],
"name": model["name"],
"info": {
"meta": model["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
for model in request.app.state.config.EVALUATION_ARENA_MODELS
]
else:
# Add default arena model
arena_models = [
{
"id": DEFAULT_ARENA_MODEL["id"],
"name": DEFAULT_ARENA_MODEL["name"],
"info": {
"meta": DEFAULT_ARENA_MODEL["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
]
models = models + arena_models
global_action_ids = [
function.id for function in Functions.get_global_action_functions()
]
enabled_action_ids = [
function.id
for function in Functions.get_functions_by_type("action", active_only=True)
]
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id is None:
for model in models:
if (
custom_model.id == model["id"]
or custom_model.id == model["id"].split(":")[0]
):
if custom_model.is_active:
model["name"] = custom_model.name
model["info"] = custom_model.model_dump()
action_ids = []
if "info" in model and "meta" in model["info"]:
action_ids.extend(
model["info"]["meta"].get("actionIds", [])
)
model["action_ids"] = action_ids
else:
models.remove(model)
elif custom_model.is_active and (
custom_model.id not in [model["id"] for model in models]
):
owned_by = "openai"
pipe = None
action_ids = []
for model in models:
if (
custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0]
):
owned_by = model["owned_by"]
if "pipe" in model:
pipe = model["pipe"]
break
if custom_model.meta:
meta = custom_model.meta.model_dump()
if "actionIds" in meta:
action_ids.extend(meta["actionIds"])
models.append(
{
"id": f"{custom_model.id}",
"name": custom_model.name,
"object": "model",
"created": custom_model.created_at,
"owned_by": owned_by,
"info": custom_model.model_dump(),
"preset": True,
**({"pipe": pipe} if pipe is not None else {}),
"action_ids": action_ids,
}
)
# Process action_ids to get the actions
def get_action_items_from_module(function, module):
actions = []
if hasattr(module, "actions"):
actions = module.actions
return [
{
"id": f"{function.id}.{action['id']}",
"name": action.get("name", f"{function.name} ({action['id']})"),
"description": function.meta.description,
"icon_url": action.get(
"icon_url", function.meta.manifest.get("icon_url", None)
),
}
for action in actions
]
else:
return [
{
"id": function.id,
"name": function.name,
"description": function.meta.description,
"icon_url": function.meta.manifest.get("icon_url", None),
}
]
def get_function_module_by_id(function_id):
if function_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[function_id]
else:
function_module, _, _ = load_function_module_by_id(function_id)
request.app.state.FUNCTIONS[function_id] = function_module
for model in models:
action_ids = [
action_id
for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
if action_id in enabled_action_ids
]
model["actions"] = []
for action_id in action_ids:
action_function = Functions.get_function_by_id(action_id)
if action_function is None:
raise Exception(f"Action not found: {action_id}")
function_module = get_function_module_by_id(action_id)
model["actions"].extend(
get_action_items_from_module(action_function, function_module)
)
log.debug(f"get_all_models() returned {len(models)} models")
request.app.state.MODELS = {model["id"]: model for model in models}
return models

View File

@@ -4,11 +4,15 @@ import re
from typing import Any, Awaitable, Callable, get_type_hints
from functools import update_wrapper, partial
from fastapi import Request
from pydantic import BaseModel, Field, create_model
from langchain_core.utils.function_calling import convert_to_openai_function
from open_webui.models.tools import Tools
from open_webui.models.users import UserModel
from open_webui.utils.plugin import load_tools_module_by_id
from pydantic import BaseModel, Field, create_model
log = logging.getLogger(__name__)
@@ -32,7 +36,7 @@ def apply_extra_params_to_tool_function(
# Mutation on extra_params
def get_tools(
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools_dict = {}
@@ -41,10 +45,10 @@ def get_tools(
if tools is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
module = request.app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_tools_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
request.app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
if hasattr(module, "valves") and hasattr(module, "Valves"):