mirror of
https://github.com/open-webui/open-webui
synced 2025-01-01 08:42:14 +00:00
wip
This commit is contained in:
parent
403262d764
commit
4311bb7b99
@ -376,7 +376,7 @@ else:
|
|||||||
AIOHTTP_CLIENT_TIMEOUT = 300
|
AIOHTTP_CLIENT_TIMEOUT = 300
|
||||||
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
|
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
|
||||||
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "5"
|
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
|
||||||
)
|
)
|
||||||
|
|
||||||
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
|
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
|
||||||
|
315
backend/open_webui/functions.py
Normal file
315
backend/open_webui/functions.py
Normal file
@ -0,0 +1,315 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import AsyncGenerator, Generator, Iterator
|
||||||
|
from fastapi import (
|
||||||
|
Depends,
|
||||||
|
FastAPI,
|
||||||
|
File,
|
||||||
|
Form,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from starlette.responses import Response, StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.socket.main import (
|
||||||
|
get_event_call,
|
||||||
|
get_event_emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.models.functions import Functions
|
||||||
|
from open_webui.models.models import Models
|
||||||
|
|
||||||
|
from open_webui.utils.plugin import load_function_module_by_id
|
||||||
|
from open_webui.utils.tools import get_tools
|
||||||
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
|
||||||
|
from open_webui.utils.misc import (
|
||||||
|
add_or_update_system_message,
|
||||||
|
get_last_user_message,
|
||||||
|
prepend_to_first_user_message_content,
|
||||||
|
openai_chat_chunk_message_template,
|
||||||
|
openai_chat_completion_message_template,
|
||||||
|
)
|
||||||
|
from open_webui.utils.payload import (
|
||||||
|
apply_model_params_to_body_openai,
|
||||||
|
apply_model_system_prompt_to_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||||
|
# Check if function is already loaded
|
||||||
|
if pipe_id not in request.app.state.FUNCTIONS:
|
||||||
|
function_module, _, _ = load_function_module_by_id(pipe_id)
|
||||||
|
request.app.state.FUNCTIONS[pipe_id] = function_module
|
||||||
|
else:
|
||||||
|
function_module = request.app.state.FUNCTIONS[pipe_id]
|
||||||
|
|
||||||
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
|
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||||
|
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||||
|
return function_module
|
||||||
|
|
||||||
|
|
||||||
|
async def get_function_models():
|
||||||
|
pipes = Functions.get_functions_by_type("pipe", active_only=True)
|
||||||
|
pipe_models = []
|
||||||
|
|
||||||
|
for pipe in pipes:
|
||||||
|
function_module = get_function_module_by_id(pipe.id)
|
||||||
|
|
||||||
|
# Check if function is a manifold
|
||||||
|
if hasattr(function_module, "pipes"):
|
||||||
|
sub_pipes = []
|
||||||
|
|
||||||
|
# Check if pipes is a function or a list
|
||||||
|
|
||||||
|
try:
|
||||||
|
if callable(function_module.pipes):
|
||||||
|
sub_pipes = function_module.pipes()
|
||||||
|
else:
|
||||||
|
sub_pipes = function_module.pipes
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
sub_pipes = []
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for p in sub_pipes:
|
||||||
|
sub_pipe_id = f'{pipe.id}.{p["id"]}'
|
||||||
|
sub_pipe_name = p["name"]
|
||||||
|
|
||||||
|
if hasattr(function_module, "name"):
|
||||||
|
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
|
||||||
|
|
||||||
|
pipe_flag = {"type": pipe.type}
|
||||||
|
|
||||||
|
pipe_models.append(
|
||||||
|
{
|
||||||
|
"id": sub_pipe_id,
|
||||||
|
"name": sub_pipe_name,
|
||||||
|
"object": "model",
|
||||||
|
"created": pipe.created_at,
|
||||||
|
"owned_by": "openai",
|
||||||
|
"pipe": pipe_flag,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pipe_flag = {"type": "pipe"}
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe_models.append(
|
||||||
|
{
|
||||||
|
"id": pipe.id,
|
||||||
|
"name": pipe.name,
|
||||||
|
"object": "model",
|
||||||
|
"created": pipe.created_at,
|
||||||
|
"owned_by": "openai",
|
||||||
|
"pipe": pipe_flag,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipe_models
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_function_chat_completion(
|
||||||
|
request, form_data, user, models: dict = {}
|
||||||
|
):
|
||||||
|
async def execute_pipe(pipe, params):
|
||||||
|
if inspect.iscoroutinefunction(pipe):
|
||||||
|
return await pipe(**params)
|
||||||
|
else:
|
||||||
|
return pipe(**params)
|
||||||
|
|
||||||
|
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
|
||||||
|
if isinstance(res, str):
|
||||||
|
return res
|
||||||
|
if isinstance(res, Generator):
|
||||||
|
return "".join(map(str, res))
|
||||||
|
if isinstance(res, AsyncGenerator):
|
||||||
|
return "".join([str(stream) async for stream in res])
|
||||||
|
|
||||||
|
def process_line(form_data: dict, line):
|
||||||
|
if isinstance(line, BaseModel):
|
||||||
|
line = line.model_dump_json()
|
||||||
|
line = f"data: {line}"
|
||||||
|
if isinstance(line, dict):
|
||||||
|
line = f"data: {json.dumps(line)}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
line = line.decode("utf-8")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if line.startswith("data:"):
|
||||||
|
return f"{line}\n\n"
|
||||||
|
else:
|
||||||
|
line = openai_chat_chunk_message_template(form_data["model"], line)
|
||||||
|
return f"data: {json.dumps(line)}\n\n"
|
||||||
|
|
||||||
|
def get_pipe_id(form_data: dict) -> str:
|
||||||
|
pipe_id = form_data["model"]
|
||||||
|
if "." in pipe_id:
|
||||||
|
pipe_id, _ = pipe_id.split(".", 1)
|
||||||
|
return pipe_id
|
||||||
|
|
||||||
|
def get_function_params(function_module, form_data, user, extra_params=None):
|
||||||
|
if extra_params is None:
|
||||||
|
extra_params = {}
|
||||||
|
|
||||||
|
pipe_id = get_pipe_id(form_data)
|
||||||
|
|
||||||
|
# Get the signature of the function
|
||||||
|
sig = inspect.signature(function_module.pipe)
|
||||||
|
params = {"body": form_data} | {
|
||||||
|
k: v for k, v in extra_params.items() if k in sig.parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||||
|
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
|
||||||
|
try:
|
||||||
|
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
params["__user__"]["valves"] = function_module.UserValves()
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
model_id = form_data.get("model")
|
||||||
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
|
metadata = form_data.pop("metadata", {})
|
||||||
|
|
||||||
|
files = metadata.get("files", [])
|
||||||
|
tool_ids = metadata.get("tool_ids", [])
|
||||||
|
# Check if tool_ids is None
|
||||||
|
if tool_ids is None:
|
||||||
|
tool_ids = []
|
||||||
|
|
||||||
|
__event_emitter__ = None
|
||||||
|
__event_call__ = None
|
||||||
|
__task__ = None
|
||||||
|
__task_body__ = None
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
|
||||||
|
__event_emitter__ = get_event_emitter(metadata)
|
||||||
|
__event_call__ = get_event_call(metadata)
|
||||||
|
__task__ = metadata.get("task", None)
|
||||||
|
__task_body__ = metadata.get("task_body", None)
|
||||||
|
|
||||||
|
extra_params = {
|
||||||
|
"__event_emitter__": __event_emitter__,
|
||||||
|
"__event_call__": __event_call__,
|
||||||
|
"__task__": __task__,
|
||||||
|
"__task_body__": __task_body__,
|
||||||
|
"__files__": files,
|
||||||
|
"__user__": {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"name": user.name,
|
||||||
|
"role": user.role,
|
||||||
|
},
|
||||||
|
"__metadata__": metadata,
|
||||||
|
}
|
||||||
|
extra_params["__tools__"] = get_tools(
|
||||||
|
request,
|
||||||
|
tool_ids,
|
||||||
|
user,
|
||||||
|
{
|
||||||
|
**extra_params,
|
||||||
|
"__model__": models.get(form_data["model"], None),
|
||||||
|
"__messages__": form_data["messages"],
|
||||||
|
"__files__": files,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
if model_info.base_model_id:
|
||||||
|
form_data["model"] = model_info.base_model_id
|
||||||
|
|
||||||
|
params = model_info.params.model_dump()
|
||||||
|
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||||
|
form_data = apply_model_system_prompt_to_body(params, form_data, user)
|
||||||
|
|
||||||
|
pipe_id = get_pipe_id(form_data)
|
||||||
|
function_module = get_function_module_by_id(pipe_id)
|
||||||
|
|
||||||
|
pipe = function_module.pipe
|
||||||
|
params = get_function_params(function_module, form_data, user, extra_params)
|
||||||
|
|
||||||
|
if form_data.get("stream", False):
|
||||||
|
|
||||||
|
async def stream_content():
|
||||||
|
try:
|
||||||
|
res = await execute_pipe(pipe, params)
|
||||||
|
|
||||||
|
# Directly return if the response is a StreamingResponse
|
||||||
|
if isinstance(res, StreamingResponse):
|
||||||
|
async for data in res.body_iterator:
|
||||||
|
yield data
|
||||||
|
return
|
||||||
|
if isinstance(res, dict):
|
||||||
|
yield f"data: {json.dumps(res)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error: {e}")
|
||||||
|
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(res, str):
|
||||||
|
message = openai_chat_chunk_message_template(form_data["model"], res)
|
||||||
|
yield f"data: {json.dumps(message)}\n\n"
|
||||||
|
|
||||||
|
if isinstance(res, Iterator):
|
||||||
|
for line in res:
|
||||||
|
yield process_line(form_data, line)
|
||||||
|
|
||||||
|
if isinstance(res, AsyncGenerator):
|
||||||
|
async for line in res:
|
||||||
|
yield process_line(form_data, line)
|
||||||
|
|
||||||
|
if isinstance(res, str) or isinstance(res, Generator):
|
||||||
|
finish_message = openai_chat_chunk_message_template(
|
||||||
|
form_data["model"], ""
|
||||||
|
)
|
||||||
|
finish_message["choices"][0]["finish_reason"] = "stop"
|
||||||
|
yield f"data: {json.dumps(finish_message)}\n\n"
|
||||||
|
yield "data: [DONE]"
|
||||||
|
|
||||||
|
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
res = await execute_pipe(pipe, params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error: {e}")
|
||||||
|
return {"error": {"detail": str(e)}}
|
||||||
|
|
||||||
|
if isinstance(res, StreamingResponse) or isinstance(res, dict):
|
||||||
|
return res
|
||||||
|
if isinstance(res, BaseModel):
|
||||||
|
return res.model_dump()
|
||||||
|
|
||||||
|
message = await get_message_content(res)
|
||||||
|
return openai_chat_completion_message_template(form_data["model"], message)
|
File diff suppressed because it is too large
Load Diff
@ -41,7 +41,7 @@ router = APIRouter()
|
|||||||
@router.get("/config")
|
@router.get("/config")
|
||||||
async def get_config(request: Request, user=Depends(get_admin_user)):
|
async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"enabled": request.app.state.config.ENABLED,
|
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
||||||
"engine": request.app.state.config.ENGINE,
|
"engine": request.app.state.config.ENGINE,
|
||||||
"openai": {
|
"openai": {
|
||||||
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
|
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
|
||||||
@ -94,7 +94,7 @@ async def update_config(
|
|||||||
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
|
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
request.app.state.config.ENGINE = form_data.engine
|
request.app.state.config.ENGINE = form_data.engine
|
||||||
request.app.state.config.ENABLED = form_data.enabled
|
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
|
||||||
|
|
||||||
request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
|
request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
|
||||||
request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
||||||
@ -131,7 +131,7 @@ async def update_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"enabled": request.app.state.config.ENABLED,
|
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
||||||
"engine": request.app.state.config.ENGINE,
|
"engine": request.app.state.config.ENGINE,
|
||||||
"openai": {
|
"openai": {
|
||||||
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
|
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
|
||||||
@ -175,7 +175,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
|||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
request.app.state.config.ENABLED = False
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||||
elif request.app.state.config.ENGINE == "comfyui":
|
elif request.app.state.config.ENGINE == "comfyui":
|
||||||
try:
|
try:
|
||||||
@ -185,7 +185,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
|
|||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
request.app.state.config.ENABLED = False
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
@ -232,7 +232,7 @@ def get_image_model():
|
|||||||
options = r.json()
|
options = r.json()
|
||||||
return options["sd_model_checkpoint"]
|
return options["sd_model_checkpoint"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
request.app.state.config.ENABLED = False
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||||
|
|
||||||
|
|
||||||
@ -351,7 +351,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
request.app.state.config.ENABLED = False
|
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||||
|
|
||||||
|
|
||||||
|
@ -195,12 +195,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
|||||||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
"openai_config": {
|
"openai_config": {
|
||||||
"url": request.app.state.config.OPENAI_API_BASE_URL,
|
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||||||
"key": request.app.state.config.OPENAI_API_KEY,
|
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||||||
},
|
},
|
||||||
"ollama_config": {
|
"ollama_config": {
|
||||||
"url": request.app.state.config.OLLAMA_BASE_URL,
|
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||||
"key": request.app.state.config.OLLAMA_API_KEY,
|
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,14 +244,20 @@ async def update_embedding_config(
|
|||||||
|
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||||
if form_data.openai_config is not None:
|
if form_data.openai_config is not None:
|
||||||
request.app.state.config.OPENAI_API_BASE_URL = (
|
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
||||||
form_data.openai_config.url
|
form_data.openai_config.url
|
||||||
)
|
)
|
||||||
request.app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
request.app.state.config.RAG_OPENAI_API_KEY = (
|
||||||
|
form_data.openai_config.key
|
||||||
|
)
|
||||||
|
|
||||||
if form_data.ollama_config is not None:
|
if form_data.ollama_config is not None:
|
||||||
request.app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
|
request.app.state.config.RAG_OLLAMA_BASE_URL = (
|
||||||
request.app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
|
form_data.ollama_config.url
|
||||||
|
)
|
||||||
|
request.app.state.config.RAG_OLLAMA_API_KEY = (
|
||||||
|
form_data.ollama_config.key
|
||||||
|
)
|
||||||
|
|
||||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
||||||
form_data.embedding_batch_size
|
form_data.embedding_batch_size
|
||||||
@ -267,14 +273,14 @@ async def update_embedding_config(
|
|||||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
request.app.state.ef,
|
request.app.state.ef,
|
||||||
(
|
(
|
||||||
request.app.state.config.OPENAI_API_BASE_URL
|
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else request.app.state.config.OLLAMA_BASE_URL
|
else request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
request.app.state.config.OPENAI_API_KEY
|
request.app.state.config.RAG_OPENAI_API_KEY
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else request.app.state.config.OLLAMA_API_KEY
|
else request.app.state.config.RAG_OLLAMA_API_KEY
|
||||||
),
|
),
|
||||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
@ -285,12 +291,12 @@ async def update_embedding_config(
|
|||||||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
"openai_config": {
|
"openai_config": {
|
||||||
"url": request.app.state.config.OPENAI_API_BASE_URL,
|
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||||||
"key": request.app.state.config.OPENAI_API_KEY,
|
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||||||
},
|
},
|
||||||
"ollama_config": {
|
"ollama_config": {
|
||||||
"url": request.app.state.config.OLLAMA_BASE_URL,
|
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||||
"key": request.app.state.config.OLLAMA_API_KEY,
|
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -747,14 +753,14 @@ def save_docs_to_vector_db(
|
|||||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
request.app.state.ef,
|
request.app.state.ef,
|
||||||
(
|
(
|
||||||
request.app.state.config.OPENAI_API_BASE_URL
|
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else request.app.state.config.OLLAMA_BASE_URL
|
else request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
request.app.state.config.OPENAI_API_KEY
|
request.app.state.config.RAG_OPENAI_API_KEY
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else request.app.state.config.OLLAMA_API_KEY
|
else request.app.state.config.RAG_OLLAMA_API_KEY
|
||||||
),
|
),
|
||||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,7 @@ from pydantic import BaseModel
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from open_webui.utils.chat import generate_chat_completion
|
||||||
from open_webui.utils.task import (
|
from open_webui.utils.task import (
|
||||||
title_generation_template,
|
title_generation_template,
|
||||||
query_generation_template,
|
query_generation_template,
|
||||||
@ -193,7 +193,7 @@ Artificial Intelligence in Healthcare
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = process_pipeline_inlet_filter(payload, user, models)
|
payload = process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -208,7 +208,7 @@ Artificial Intelligence in Healthcare
|
|||||||
if "chat_id" in payload:
|
if "chat_id" in payload:
|
||||||
del payload["chat_id"]
|
del payload["chat_id"]
|
||||||
|
|
||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tags/completions")
|
@router.post("/tags/completions")
|
||||||
@ -282,7 +282,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = process_pipeline_inlet_filter(payload, user, models)
|
payload = process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -297,7 +297,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
|||||||
if "chat_id" in payload:
|
if "chat_id" in payload:
|
||||||
del payload["chat_id"]
|
del payload["chat_id"]
|
||||||
|
|
||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/queries/completions")
|
@router.post("/queries/completions")
|
||||||
@ -363,7 +363,7 @@ async def generate_queries(
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = process_pipeline_inlet_filter(payload, user, models)
|
payload = process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -378,7 +378,7 @@ async def generate_queries(
|
|||||||
if "chat_id" in payload:
|
if "chat_id" in payload:
|
||||||
del payload["chat_id"]
|
del payload["chat_id"]
|
||||||
|
|
||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/auto/completions")
|
@router.post("/auto/completions")
|
||||||
@ -449,7 +449,7 @@ async def generate_autocompletion(
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = process_pipeline_inlet_filter(payload, user, models)
|
payload = process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -464,7 +464,7 @@ async def generate_autocompletion(
|
|||||||
if "chat_id" in payload:
|
if "chat_id" in payload:
|
||||||
del payload["chat_id"]
|
del payload["chat_id"]
|
||||||
|
|
||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/emoji/completions")
|
@router.post("/emoji/completions")
|
||||||
@ -523,7 +523,7 @@ Message: """{{prompt}}"""
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = process_pipeline_inlet_filter(payload, user, models)
|
payload = process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -538,7 +538,7 @@ Message: """{{prompt}}"""
|
|||||||
if "chat_id" in payload:
|
if "chat_id" in payload:
|
||||||
del payload["chat_id"]
|
del payload["chat_id"]
|
||||||
|
|
||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/moa/completions")
|
@router.post("/moa/completions")
|
||||||
@ -590,7 +590,7 @@ Responses from models: {{responses}}"""
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = process_pipeline_inlet_filter(payload, user, models)
|
payload = process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -605,4 +605,4 @@ Responses from models: {{responses}}"""
|
|||||||
if "chat_id" in payload:
|
if "chat_id" in payload:
|
||||||
del payload["chat_id"]
|
del payload["chat_id"]
|
||||||
|
|
||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||||
|
380
backend/open_webui/utils/chat.py
Normal file
380
backend/open_webui/utils/chat.py
Normal file
@ -0,0 +1,380 @@
|
|||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from aiocache import cached
|
||||||
|
from typing import Any
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.responses import Response, StreamingResponse
|
||||||
|
|
||||||
|
from open_webui.socket.main import (
|
||||||
|
get_event_call,
|
||||||
|
get_event_emitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
from open_webui.functions import generate_function_chat_completion
|
||||||
|
|
||||||
|
from open_webui.routers.openai import (
|
||||||
|
generate_chat_completion as generate_openai_chat_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
from open_webui.routers.ollama import (
|
||||||
|
generate_chat_completion as generate_ollama_chat_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
from open_webui.routers.pipelines import (
|
||||||
|
process_pipeline_outlet_filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.models.functions import Functions
|
||||||
|
from open_webui.models.models import Models
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.plugin import load_function_module_by_id
|
||||||
|
from open_webui.utils.access_control import has_access
|
||||||
|
from open_webui.utils.models import get_all_models
|
||||||
|
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||||
|
from open_webui.utils.response import (
|
||||||
|
convert_response_ollama_to_openai,
|
||||||
|
convert_streaming_response_ollama_to_openai,
|
||||||
|
)
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_chat_completion(
|
||||||
|
request: Request,
|
||||||
|
form_data: dict,
|
||||||
|
user: Any,
|
||||||
|
bypass_filter: bool = False,
|
||||||
|
):
|
||||||
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
|
bypass_filter = True
|
||||||
|
|
||||||
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
|
model_id = form_data["model"]
|
||||||
|
if model_id not in models:
|
||||||
|
raise Exception("Model not found")
|
||||||
|
|
||||||
|
model = models[model_id]
|
||||||
|
|
||||||
|
# Check if user has access to the model
|
||||||
|
if not bypass_filter and user.role == "user":
|
||||||
|
if model.get("arena"):
|
||||||
|
if not has_access(
|
||||||
|
user.id,
|
||||||
|
type="read",
|
||||||
|
access_control=model.get("info", {})
|
||||||
|
.get("meta", {})
|
||||||
|
.get("access_control", {}),
|
||||||
|
):
|
||||||
|
raise Exception("Model not found")
|
||||||
|
else:
|
||||||
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
if not model_info:
|
||||||
|
raise Exception("Model not found")
|
||||||
|
elif not (
|
||||||
|
user.id == model_info.user_id
|
||||||
|
or has_access(
|
||||||
|
user.id, type="read", access_control=model_info.access_control
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise Exception("Model not found")
|
||||||
|
|
||||||
|
if model["owned_by"] == "arena":
|
||||||
|
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||||
|
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||||
|
if model_ids and filter_mode == "exclude":
|
||||||
|
model_ids = [
|
||||||
|
model["id"]
|
||||||
|
for model in await get_all_models(request)
|
||||||
|
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
selected_model_id = None
|
||||||
|
if isinstance(model_ids, list) and model_ids:
|
||||||
|
selected_model_id = random.choice(model_ids)
|
||||||
|
else:
|
||||||
|
model_ids = [
|
||||||
|
model["id"]
|
||||||
|
for model in await get_all_models(request)
|
||||||
|
if model.get("owned_by") != "arena"
|
||||||
|
]
|
||||||
|
selected_model_id = random.choice(model_ids)
|
||||||
|
|
||||||
|
form_data["model"] = selected_model_id
|
||||||
|
|
||||||
|
if form_data.get("stream") == True:
|
||||||
|
|
||||||
|
async def stream_wrapper(stream):
|
||||||
|
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
||||||
|
async for chunk in stream:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
response = await generate_chat_completion(
|
||||||
|
form_data, user, bypass_filter=True
|
||||||
|
)
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_wrapper(response.body_iterator), media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
**(await generate_chat_completion(form_data, user, bypass_filter=True)),
|
||||||
|
"selected_model_id": selected_model_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model.get("pipe"):
|
||||||
|
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||||
|
return await generate_function_chat_completion(
|
||||||
|
form_data, user=user, models=models
|
||||||
|
)
|
||||||
|
if model["owned_by"] == "ollama":
|
||||||
|
# Using /ollama/api/chat endpoint
|
||||||
|
form_data = convert_payload_openai_to_ollama(form_data)
|
||||||
|
response = await generate_ollama_chat_completion(
|
||||||
|
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||||
|
)
|
||||||
|
if form_data.stream:
|
||||||
|
response.headers["content-type"] = "text/event-stream"
|
||||||
|
return StreamingResponse(
|
||||||
|
convert_streaming_response_ollama_to_openai(response),
|
||||||
|
headers=dict(response.headers),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return convert_response_ollama_to_openai(response)
|
||||||
|
else:
|
||||||
|
return await generate_openai_chat_completion(
|
||||||
|
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||||
|
await get_all_models(request)
|
||||||
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
|
data = form_data
|
||||||
|
model_id = data["model"]
|
||||||
|
if model_id not in models:
|
||||||
|
raise Exception("Model not found")
|
||||||
|
|
||||||
|
model = models[model_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = process_pipeline_outlet_filter(request, data, user, models)
|
||||||
|
except Exception as e:
|
||||||
|
return Exception(f"Error: {e}")
|
||||||
|
|
||||||
|
__event_emitter__ = get_event_emitter(
|
||||||
|
{
|
||||||
|
"chat_id": data["chat_id"],
|
||||||
|
"message_id": data["id"],
|
||||||
|
"session_id": data["session_id"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
__event_call__ = get_event_call(
|
||||||
|
{
|
||||||
|
"chat_id": data["chat_id"],
|
||||||
|
"message_id": data["id"],
|
||||||
|
"session_id": data["session_id"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_priority(function_id):
|
||||||
|
function = Functions.get_function_by_id(function_id)
|
||||||
|
if function is not None and hasattr(function, "valves"):
|
||||||
|
# TODO: Fix FunctionModel to include vavles
|
||||||
|
return (function.valves if function.valves else {}).get("priority", 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||||
|
if "info" in model and "meta" in model["info"]:
|
||||||
|
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||||
|
filter_ids = list(set(filter_ids))
|
||||||
|
|
||||||
|
enabled_filter_ids = [
|
||||||
|
function.id
|
||||||
|
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||||
|
]
|
||||||
|
filter_ids = [
|
||||||
|
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
# Sort filter_ids by priority, using the get_priority function
|
||||||
|
filter_ids.sort(key=get_priority)
|
||||||
|
|
||||||
|
for filter_id in filter_ids:
|
||||||
|
filter = Functions.get_function_by_id(filter_id)
|
||||||
|
if not filter:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if filter_id in request.app.state.FUNCTIONS:
|
||||||
|
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||||
|
else:
|
||||||
|
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||||
|
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||||
|
|
||||||
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
|
valves = Functions.get_function_valves_by_id(filter_id)
|
||||||
|
function_module.valves = function_module.Valves(
|
||||||
|
**(valves if valves else {})
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hasattr(function_module, "outlet"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
outlet = function_module.outlet
|
||||||
|
|
||||||
|
# Get the signature of the function
|
||||||
|
sig = inspect.signature(outlet)
|
||||||
|
params = {"body": data}
|
||||||
|
|
||||||
|
# Extra parameters to be passed to the function
|
||||||
|
extra_params = {
|
||||||
|
"__model__": model,
|
||||||
|
"__id__": filter_id,
|
||||||
|
"__event_emitter__": __event_emitter__,
|
||||||
|
"__event_call__": __event_call__,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add extra params in contained in function signature
|
||||||
|
for key, value in extra_params.items():
|
||||||
|
if key in sig.parameters:
|
||||||
|
params[key] = value
|
||||||
|
|
||||||
|
if "__user__" in sig.parameters:
|
||||||
|
__user__ = {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"name": user.name,
|
||||||
|
"role": user.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(function_module, "UserValves"):
|
||||||
|
__user__["valves"] = function_module.UserValves(
|
||||||
|
**Functions.get_user_valves_by_id_and_user_id(
|
||||||
|
filter_id, user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
params = {**params, "__user__": __user__}
|
||||||
|
|
||||||
|
if inspect.iscoroutinefunction(outlet):
|
||||||
|
data = await outlet(**params)
|
||||||
|
else:
|
||||||
|
data = outlet(**params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return Exception(f"Error: {e}")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
||||||
|
if "." in action_id:
|
||||||
|
action_id, sub_action_id = action_id.split(".")
|
||||||
|
else:
|
||||||
|
sub_action_id = None
|
||||||
|
|
||||||
|
action = Functions.get_function_by_id(action_id)
|
||||||
|
if not action:
|
||||||
|
raise Exception(f"Action not found: {action_id}")
|
||||||
|
|
||||||
|
await get_all_models(request)
|
||||||
|
models = request.app.state.MODELS
|
||||||
|
|
||||||
|
data = form_data
|
||||||
|
model_id = data["model"]
|
||||||
|
|
||||||
|
if model_id not in models:
|
||||||
|
raise Exception("Model not found")
|
||||||
|
model = models[model_id]
|
||||||
|
|
||||||
|
__event_emitter__ = get_event_emitter(
|
||||||
|
{
|
||||||
|
"chat_id": data["chat_id"],
|
||||||
|
"message_id": data["id"],
|
||||||
|
"session_id": data["session_id"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
__event_call__ = get_event_call(
|
||||||
|
{
|
||||||
|
"chat_id": data["chat_id"],
|
||||||
|
"message_id": data["id"],
|
||||||
|
"session_id": data["session_id"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_id in request.app.state.FUNCTIONS:
|
||||||
|
function_module = request.app.state.FUNCTIONS[action_id]
|
||||||
|
else:
|
||||||
|
function_module, _, _ = load_function_module_by_id(action_id)
|
||||||
|
request.app.state.FUNCTIONS[action_id] = function_module
|
||||||
|
|
||||||
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
|
valves = Functions.get_function_valves_by_id(action_id)
|
||||||
|
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||||
|
|
||||||
|
if hasattr(function_module, "action"):
|
||||||
|
try:
|
||||||
|
action = function_module.action
|
||||||
|
|
||||||
|
# Get the signature of the function
|
||||||
|
sig = inspect.signature(action)
|
||||||
|
params = {"body": data}
|
||||||
|
|
||||||
|
# Extra parameters to be passed to the function
|
||||||
|
extra_params = {
|
||||||
|
"__model__": model,
|
||||||
|
"__id__": sub_action_id if sub_action_id is not None else action_id,
|
||||||
|
"__event_emitter__": __event_emitter__,
|
||||||
|
"__event_call__": __event_call__,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add extra params in contained in function signature
|
||||||
|
for key, value in extra_params.items():
|
||||||
|
if key in sig.parameters:
|
||||||
|
params[key] = value
|
||||||
|
|
||||||
|
if "__user__" in sig.parameters:
|
||||||
|
__user__ = {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"name": user.name,
|
||||||
|
"role": user.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(function_module, "UserValves"):
|
||||||
|
__user__["valves"] = function_module.UserValves(
|
||||||
|
**Functions.get_user_valves_by_id_and_user_id(
|
||||||
|
action_id, user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
params = {**params, "__user__": __user__}
|
||||||
|
|
||||||
|
if inspect.iscoroutinefunction(action):
|
||||||
|
data = await action(**params)
|
||||||
|
else:
|
||||||
|
data = action(**params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return Exception(f"Error: {e}")
|
||||||
|
|
||||||
|
return data
|
222
backend/open_webui/utils/models.py
Normal file
222
backend/open_webui/utils/models.py
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from aiocache import cached
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from open_webui.routers import openai, ollama
|
||||||
|
from open_webui.functions import get_function_models
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.models.functions import Functions
|
||||||
|
from open_webui.models.models import Models
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.plugin import load_function_module_by_id
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.config import (
|
||||||
|
DEFAULT_ARENA_MODEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all_base_models(request: Request):
|
||||||
|
function_models = []
|
||||||
|
openai_models = []
|
||||||
|
ollama_models = []
|
||||||
|
|
||||||
|
if request.app.state.config.ENABLE_OPENAI_API:
|
||||||
|
openai_models = await openai.get_all_models(request)
|
||||||
|
openai_models = openai_models["data"]
|
||||||
|
|
||||||
|
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||||
|
ollama_models = await ollama.get_all_models(request)
|
||||||
|
ollama_models = [
|
||||||
|
{
|
||||||
|
"id": model["model"],
|
||||||
|
"name": model["name"],
|
||||||
|
"object": "model",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"owned_by": "ollama",
|
||||||
|
"ollama": model,
|
||||||
|
}
|
||||||
|
for model in ollama_models["models"]
|
||||||
|
]
|
||||||
|
|
||||||
|
function_models = await get_function_models()
|
||||||
|
models = function_models + openai_models + ollama_models
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl=3)
|
||||||
|
async def get_all_models(request):
|
||||||
|
models = await get_all_base_models(request)
|
||||||
|
|
||||||
|
# If there are no models, return an empty list
|
||||||
|
if len(models) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Add arena models
|
||||||
|
if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
|
||||||
|
arena_models = []
|
||||||
|
if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
|
||||||
|
arena_models = [
|
||||||
|
{
|
||||||
|
"id": model["id"],
|
||||||
|
"name": model["name"],
|
||||||
|
"info": {
|
||||||
|
"meta": model["meta"],
|
||||||
|
},
|
||||||
|
"object": "model",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"owned_by": "arena",
|
||||||
|
"arena": True,
|
||||||
|
}
|
||||||
|
for model in request.app.state.config.EVALUATION_ARENA_MODELS
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Add default arena model
|
||||||
|
arena_models = [
|
||||||
|
{
|
||||||
|
"id": DEFAULT_ARENA_MODEL["id"],
|
||||||
|
"name": DEFAULT_ARENA_MODEL["name"],
|
||||||
|
"info": {
|
||||||
|
"meta": DEFAULT_ARENA_MODEL["meta"],
|
||||||
|
},
|
||||||
|
"object": "model",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"owned_by": "arena",
|
||||||
|
"arena": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
models = models + arena_models
|
||||||
|
|
||||||
|
global_action_ids = [
|
||||||
|
function.id for function in Functions.get_global_action_functions()
|
||||||
|
]
|
||||||
|
enabled_action_ids = [
|
||||||
|
function.id
|
||||||
|
for function in Functions.get_functions_by_type("action", active_only=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
custom_models = Models.get_all_models()
|
||||||
|
for custom_model in custom_models:
|
||||||
|
if custom_model.base_model_id is None:
|
||||||
|
for model in models:
|
||||||
|
if (
|
||||||
|
custom_model.id == model["id"]
|
||||||
|
or custom_model.id == model["id"].split(":")[0]
|
||||||
|
):
|
||||||
|
if custom_model.is_active:
|
||||||
|
model["name"] = custom_model.name
|
||||||
|
model["info"] = custom_model.model_dump()
|
||||||
|
|
||||||
|
action_ids = []
|
||||||
|
if "info" in model and "meta" in model["info"]:
|
||||||
|
action_ids.extend(
|
||||||
|
model["info"]["meta"].get("actionIds", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
model["action_ids"] = action_ids
|
||||||
|
else:
|
||||||
|
models.remove(model)
|
||||||
|
|
||||||
|
elif custom_model.is_active and (
|
||||||
|
custom_model.id not in [model["id"] for model in models]
|
||||||
|
):
|
||||||
|
owned_by = "openai"
|
||||||
|
pipe = None
|
||||||
|
action_ids = []
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
if (
|
||||||
|
custom_model.base_model_id == model["id"]
|
||||||
|
or custom_model.base_model_id == model["id"].split(":")[0]
|
||||||
|
):
|
||||||
|
owned_by = model["owned_by"]
|
||||||
|
if "pipe" in model:
|
||||||
|
pipe = model["pipe"]
|
||||||
|
break
|
||||||
|
|
||||||
|
if custom_model.meta:
|
||||||
|
meta = custom_model.meta.model_dump()
|
||||||
|
if "actionIds" in meta:
|
||||||
|
action_ids.extend(meta["actionIds"])
|
||||||
|
|
||||||
|
models.append(
|
||||||
|
{
|
||||||
|
"id": f"{custom_model.id}",
|
||||||
|
"name": custom_model.name,
|
||||||
|
"object": "model",
|
||||||
|
"created": custom_model.created_at,
|
||||||
|
"owned_by": owned_by,
|
||||||
|
"info": custom_model.model_dump(),
|
||||||
|
"preset": True,
|
||||||
|
**({"pipe": pipe} if pipe is not None else {}),
|
||||||
|
"action_ids": action_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process action_ids to get the actions
|
||||||
|
def get_action_items_from_module(function, module):
|
||||||
|
actions = []
|
||||||
|
if hasattr(module, "actions"):
|
||||||
|
actions = module.actions
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": f"{function.id}.{action['id']}",
|
||||||
|
"name": action.get("name", f"{function.name} ({action['id']})"),
|
||||||
|
"description": function.meta.description,
|
||||||
|
"icon_url": action.get(
|
||||||
|
"icon_url", function.meta.manifest.get("icon_url", None)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for action in actions
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": function.id,
|
||||||
|
"name": function.name,
|
||||||
|
"description": function.meta.description,
|
||||||
|
"icon_url": function.meta.manifest.get("icon_url", None),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_function_module_by_id(function_id):
|
||||||
|
if function_id in request.app.state.FUNCTIONS:
|
||||||
|
function_module = request.app.state.FUNCTIONS[function_id]
|
||||||
|
else:
|
||||||
|
function_module, _, _ = load_function_module_by_id(function_id)
|
||||||
|
request.app.state.FUNCTIONS[function_id] = function_module
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
action_ids = [
|
||||||
|
action_id
|
||||||
|
for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
|
||||||
|
if action_id in enabled_action_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
model["actions"] = []
|
||||||
|
for action_id in action_ids:
|
||||||
|
action_function = Functions.get_function_by_id(action_id)
|
||||||
|
if action_function is None:
|
||||||
|
raise Exception(f"Action not found: {action_id}")
|
||||||
|
|
||||||
|
function_module = get_function_module_by_id(action_id)
|
||||||
|
model["actions"].extend(
|
||||||
|
get_action_items_from_module(action_function, function_module)
|
||||||
|
)
|
||||||
|
log.debug(f"get_all_models() returned {len(models)} models")
|
||||||
|
|
||||||
|
request.app.state.MODELS = {model["id"]: model for model in models}
|
||||||
|
return models
|
@ -4,11 +4,15 @@ import re
|
|||||||
from typing import Any, Awaitable, Callable, get_type_hints
|
from typing import Any, Awaitable, Callable, get_type_hints
|
||||||
from functools import update_wrapper, partial
|
from functools import update_wrapper, partial
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.tools import Tools
|
from open_webui.models.tools import Tools
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
from open_webui.utils.plugin import load_tools_module_by_id
|
from open_webui.utils.plugin import load_tools_module_by_id
|
||||||
from pydantic import BaseModel, Field, create_model
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -32,7 +36,7 @@ def apply_extra_params_to_tool_function(
|
|||||||
|
|
||||||
# Mutation on extra_params
|
# Mutation on extra_params
|
||||||
def get_tools(
|
def get_tools(
|
||||||
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
|
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict]:
|
||||||
tools_dict = {}
|
tools_dict = {}
|
||||||
|
|
||||||
@ -41,10 +45,10 @@ def get_tools(
|
|||||||
if tools is None:
|
if tools is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
module = webui_app.state.TOOLS.get(tool_id, None)
|
module = request.app.state.TOOLS.get(tool_id, None)
|
||||||
if module is None:
|
if module is None:
|
||||||
module, _ = load_tools_module_by_id(tool_id)
|
module, _ = load_tools_module_by_id(tool_id)
|
||||||
webui_app.state.TOOLS[tool_id] = module
|
request.app.state.TOOLS[tool_id] = module
|
||||||
|
|
||||||
extra_params["__id__"] = tool_id
|
extra_params["__id__"] = tool_id
|
||||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||||
|
@ -110,7 +110,7 @@ export const chatAction = async (token: string, action_id: string, body: ChatAct
|
|||||||
export const getTaskConfig = async (token: string = '') => {
|
export const getTaskConfig = async (token: string = '') => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/config`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config`, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -138,7 +138,7 @@ export const getTaskConfig = async (token: string = '') => {
|
|||||||
export const updateTaskConfig = async (token: string, config: object) => {
|
export const updateTaskConfig = async (token: string, config: object) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/config/update`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config/update`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -176,7 +176,7 @@ export const generateTitle = async (
|
|||||||
) => {
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/title/completions`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/title/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -216,7 +216,7 @@ export const generateTags = async (
|
|||||||
) => {
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/tags/completions`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/tags/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -288,7 +288,7 @@ export const generateEmoji = async (
|
|||||||
) => {
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/emoji/completions`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/emoji/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -337,7 +337,7 @@ export const generateQueries = async (
|
|||||||
) => {
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/queries/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -407,7 +407,7 @@ export const generateAutoCompletion = async (
|
|||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/auto/completions`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/auto/completions`, {
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
@ -477,7 +477,7 @@ export const generateMoACompletion = async (
|
|||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/moa/completions`, {
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
@ -507,7 +507,7 @@ export const generateMoACompletion = async (
|
|||||||
export const getPipelinesList = async (token: string = '') => {
|
export const getPipelinesList = async (token: string = '') => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/list`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/list`, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -541,7 +541,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string)
|
|||||||
formData.append('file', file);
|
formData.append('file', file);
|
||||||
formData.append('urlIdx', urlIdx);
|
formData.append('urlIdx', urlIdx);
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/upload`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/upload`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
...(token && { authorization: `Bearer ${token}` })
|
...(token && { authorization: `Bearer ${token}` })
|
||||||
@ -573,7 +573,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string)
|
|||||||
export const downloadPipeline = async (token: string, url: string, urlIdx: string) => {
|
export const downloadPipeline = async (token: string, url: string, urlIdx: string) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/add`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/add`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -609,7 +609,7 @@ export const downloadPipeline = async (token: string, url: string, urlIdx: strin
|
|||||||
export const deletePipeline = async (token: string, id: string, urlIdx: string) => {
|
export const deletePipeline = async (token: string, id: string, urlIdx: string) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/delete`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/delete`, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -650,7 +650,7 @@ export const getPipelines = async (token: string, urlIdx?: string) => {
|
|||||||
searchParams.append('urlIdx', urlIdx);
|
searchParams.append('urlIdx', urlIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines?${searchParams.toString()}`, {
|
const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines?${searchParams.toString()}`, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
@ -685,7 +685,7 @@ export const getPipelineValves = async (token: string, pipeline_id: string, urlI
|
|||||||
}
|
}
|
||||||
|
|
||||||
const res = await fetch(
|
const res = await fetch(
|
||||||
`${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves?${searchParams.toString()}`,
|
`${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves?${searchParams.toString()}`,
|
||||||
{
|
{
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
@ -721,7 +721,7 @@ export const getPipelineValvesSpec = async (token: string, pipeline_id: string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const res = await fetch(
|
const res = await fetch(
|
||||||
`${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`,
|
`${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`,
|
||||||
{
|
{
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
@ -762,7 +762,7 @@ export const updatePipelineValves = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
const res = await fetch(
|
const res = await fetch(
|
||||||
`${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`,
|
`${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`,
|
||||||
{
|
{
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
|
Loading…
Reference in New Issue
Block a user