This commit is contained in:
Timothy Jaeryang Baek 2024-12-12 20:22:17 -08:00
parent 403262d764
commit 4311bb7b99
10 changed files with 1102 additions and 966 deletions

View File

@ -376,7 +376,7 @@ else:
AIOHTTP_CLIENT_TIMEOUT = 300 AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get( AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "5" "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
) )
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "": if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":

View File

@ -0,0 +1,315 @@
import logging
import sys
import inspect
import json
from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from starlette.responses import Response, StreamingResponse
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.tools import get_tools
from open_webui.utils.access_control import has_access
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
prepend_to_first_user_message_content,
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str):
# Check if function is already loaded
if pipe_id not in request.app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
request.app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = request.app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_function_models():
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
function_module = get_function_module_by_id(pipe.id)
# Check if function is a manifold
if hasattr(function_module, "pipes"):
sub_pipes = []
# Check if pipes is a function or a list
try:
if callable(function_module.pipes):
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
log.debug(
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)
for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"]
if hasattr(function_module, "name"):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
pipe_flag = {"type": pipe.type}
pipe_models.append(
{
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
else:
pipe_flag = {"type": "pipe"}
log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
)
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
return pipe_models
async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
else:
line = openai_chat_chunk_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(function_module.pipe)
params = {"body": form_data} | {
k: v for k, v in extra_params.items() if k in sig.parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
return params
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
# Check if tool_ids is None
if tool_ids is None:
tool_ids = []
__event_emitter__ = None
__event_call__ = None
__task__ = None
__task_body__ = None
if metadata:
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
__task_body__ = metadata.get("task_body", None)
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
}
extra_params["__tools__"] = get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models.get(form_data["model"], None),
"__messages__": form_data["messages"],
"__files__": files,
},
)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(pipe_id)
pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False):
async def stream_content():
try:
res = await execute_pipe(pipe, params)
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
async for data in res.body_iterator:
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
return
except Exception as e:
log.error(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = openai_chat_chunk_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
yield process_line(form_data, line)
if isinstance(res, AsyncGenerator):
async for line in res:
yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator):
finish_message = openai_chat_chunk_message_template(
form_data["model"], ""
)
finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
except Exception as e:
log.error(f"Error: {e}")
return {"error": {"detail": str(e)}}
if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res
if isinstance(res, BaseModel):
return res.model_dump()
message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)

File diff suppressed because it is too large Load Diff

View File

@ -41,7 +41,7 @@ router = APIRouter()
@router.get("/config") @router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)): async def get_config(request: Request, user=Depends(get_admin_user)):
return { return {
"enabled": request.app.state.config.ENABLED, "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.ENGINE, "engine": request.app.state.config.ENGINE,
"openai": { "openai": {
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
@ -94,7 +94,7 @@ async def update_config(
request: Request, form_data: ConfigForm, user=Depends(get_admin_user) request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
): ):
request.app.state.config.ENGINE = form_data.engine request.app.state.config.ENGINE = form_data.engine
request.app.state.config.ENABLED = form_data.enabled request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
@ -131,7 +131,7 @@ async def update_config(
) )
return { return {
"enabled": request.app.state.config.ENABLED, "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.ENGINE, "engine": request.app.state.config.ENGINE,
"openai": { "openai": {
"OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
@ -175,7 +175,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
r.raise_for_status() r.raise_for_status()
return True return True
except Exception: except Exception:
request.app.state.config.ENABLED = False request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif request.app.state.config.ENGINE == "comfyui": elif request.app.state.config.ENGINE == "comfyui":
try: try:
@ -185,7 +185,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
r.raise_for_status() r.raise_for_status()
return True return True
except Exception: except Exception:
request.app.state.config.ENABLED = False request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
else: else:
return True return True
@ -232,7 +232,7 @@ def get_image_model():
options = r.json() options = r.json()
return options["sd_model_checkpoint"] return options["sd_model_checkpoint"]
except Exception as e: except Exception as e:
request.app.state.config.ENABLED = False request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@ -351,7 +351,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
) )
) )
except Exception as e: except Exception as e:
request.app.state.config.ENABLED = False request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))

View File

@ -195,12 +195,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": { "openai_config": {
"url": request.app.state.config.OPENAI_API_BASE_URL, "url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
"key": request.app.state.config.OPENAI_API_KEY, "key": request.app.state.config.RAG_OPENAI_API_KEY,
}, },
"ollama_config": { "ollama_config": {
"url": request.app.state.config.OLLAMA_BASE_URL, "url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.OLLAMA_API_KEY, "key": request.app.state.config.RAG_OLLAMA_API_KEY,
}, },
} }
@ -244,14 +244,20 @@ async def update_embedding_config(
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config is not None: if form_data.openai_config is not None:
request.app.state.config.OPENAI_API_BASE_URL = ( request.app.state.config.RAG_OPENAI_API_BASE_URL = (
form_data.openai_config.url form_data.openai_config.url
) )
request.app.state.config.OPENAI_API_KEY = form_data.openai_config.key request.app.state.config.RAG_OPENAI_API_KEY = (
form_data.openai_config.key
)
if form_data.ollama_config is not None: if form_data.ollama_config is not None:
request.app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url request.app.state.config.RAG_OLLAMA_BASE_URL = (
request.app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key form_data.ollama_config.url
)
request.app.state.config.RAG_OLLAMA_API_KEY = (
form_data.ollama_config.key
)
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
form_data.embedding_batch_size form_data.embedding_batch_size
@ -267,14 +273,14 @@ async def update_embedding_config(
request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.ef, request.app.state.ef,
( (
request.app.state.config.OPENAI_API_BASE_URL request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.OLLAMA_BASE_URL else request.app.state.config.RAG_OLLAMA_BASE_URL
), ),
( (
request.app.state.config.OPENAI_API_KEY request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.OLLAMA_API_KEY else request.app.state.config.RAG_OLLAMA_API_KEY
), ),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
) )
@ -285,12 +291,12 @@ async def update_embedding_config(
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": { "openai_config": {
"url": request.app.state.config.OPENAI_API_BASE_URL, "url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
"key": request.app.state.config.OPENAI_API_KEY, "key": request.app.state.config.RAG_OPENAI_API_KEY,
}, },
"ollama_config": { "ollama_config": {
"url": request.app.state.config.OLLAMA_BASE_URL, "url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.OLLAMA_API_KEY, "key": request.app.state.config.RAG_OLLAMA_API_KEY,
}, },
} }
except Exception as e: except Exception as e:
@ -747,14 +753,14 @@ def save_docs_to_vector_db(
request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.ef, request.app.state.ef,
( (
request.app.state.config.OPENAI_API_BASE_URL request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.OLLAMA_BASE_URL else request.app.state.config.RAG_OLLAMA_BASE_URL
), ),
( (
request.app.state.config.OPENAI_API_KEY request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.OLLAMA_API_KEY else request.app.state.config.RAG_OLLAMA_API_KEY
), ),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
) )

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel
from typing import Optional from typing import Optional
import logging import logging
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import ( from open_webui.utils.task import (
title_generation_template, title_generation_template,
query_generation_template, query_generation_template,
@ -193,7 +193,7 @@ Artificial Intelligence in Healthcare
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = process_pipeline_inlet_filter(payload, user, models) payload = process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -208,7 +208,7 @@ Artificial Intelligence in Healthcare
if "chat_id" in payload: if "chat_id" in payload:
del payload["chat_id"] del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
@router.post("/tags/completions") @router.post("/tags/completions")
@ -282,7 +282,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = process_pipeline_inlet_filter(payload, user, models) payload = process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -297,7 +297,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
if "chat_id" in payload: if "chat_id" in payload:
del payload["chat_id"] del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
@router.post("/queries/completions") @router.post("/queries/completions")
@ -363,7 +363,7 @@ async def generate_queries(
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = process_pipeline_inlet_filter(payload, user, models) payload = process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -378,7 +378,7 @@ async def generate_queries(
if "chat_id" in payload: if "chat_id" in payload:
del payload["chat_id"] del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
@router.post("/auto/completions") @router.post("/auto/completions")
@ -449,7 +449,7 @@ async def generate_autocompletion(
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = process_pipeline_inlet_filter(payload, user, models) payload = process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -464,7 +464,7 @@ async def generate_autocompletion(
if "chat_id" in payload: if "chat_id" in payload:
del payload["chat_id"] del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
@router.post("/emoji/completions") @router.post("/emoji/completions")
@ -523,7 +523,7 @@ Message: """{{prompt}}"""
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = process_pipeline_inlet_filter(payload, user, models) payload = process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -538,7 +538,7 @@ Message: """{{prompt}}"""
if "chat_id" in payload: if "chat_id" in payload:
del payload["chat_id"] del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)
@router.post("/moa/completions") @router.post("/moa/completions")
@ -590,7 +590,7 @@ Responses from models: {{responses}}"""
} }
try: try:
payload = process_pipeline_inlet_filter(payload, user, models) payload = process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -605,4 +605,4 @@ Responses from models: {{responses}}"""
if "chat_id" in payload: if "chat_id" in payload:
del payload["chat_id"] del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completion(request, form_data=payload, user=user)

View File

@ -0,0 +1,380 @@
import time
import logging
import sys
from aiocache import cached
from typing import Any
import random
import json
import inspect
from fastapi import Request
from starlette.responses import Response, StreamingResponse
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.functions import generate_function_chat_completion
from open_webui.routers.openai import (
generate_chat_completion as generate_openai_chat_completion,
)
from open_webui.routers.ollama import (
generate_chat_completion as generate_ollama_chat_completion,
)
from open_webui.routers.pipelines import (
process_pipeline_outlet_filter,
)
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.access_control import has_access
from open_webui.utils.models import get_all_models
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def generate_chat_completion(
request: Request,
form_data: dict,
user: Any,
bypass_filter: bool = False,
):
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if model.get("arena"):
if not has_access(
user.id,
type="read",
access_control=model.get("info", {})
.get("meta", {})
.get("access_control", {}),
):
raise Exception("Model not found")
else:
model_info = Models.get_model_by_id(model_id)
if not model_info:
raise Exception("Model not found")
elif not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise Exception("Model not found")
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in await get_all_models(request)
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in await get_all_models(request)
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completion(
form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
)
else:
return {
**(await generate_chat_completion(form_data, user, bypass_filter=True)),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
)
if form_data.stream:
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
)
async def chat_completed(request: Request, form_data: dict, user: Any):
await get_all_models(request)
models = request.app.state.MODELS
data = form_data
model_id = data["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
try:
data = process_pipeline_outlet_filter(request, data, user, models)
except Exception as e:
return Exception(f"Error: {e}")
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "outlet"):
continue
try:
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
return Exception(f"Error: {e}")
return data
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
if "." in action_id:
action_id, sub_action_id = action_id.split(".")
else:
sub_action_id = None
action = Functions.get_function_by_id(action_id)
if not action:
raise Exception(f"Action not found: {action_id}")
await get_all_models(request)
models = request.app.state.MODELS
data = form_data
model_id = data["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
if action_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
request.app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
if hasattr(function_module, "action"):
try:
action = function_module.action
# Get the signature of the function
sig = inspect.signature(action)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": sub_action_id if sub_action_id is not None else action_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(action):
data = await action(**params)
else:
data = action(**params)
except Exception as e:
return Exception(f"Error: {e}")
return data

View File

@ -0,0 +1,222 @@
import time
import logging
import sys
from aiocache import cached
from fastapi import Request
from open_webui.routers import openai, ollama
from open_webui.functions import get_function_models
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.config import (
DEFAULT_ARENA_MODEL,
)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def get_all_base_models(request: Request):
function_models = []
openai_models = []
ollama_models = []
if request.app.state.config.ENABLE_OPENAI_API:
openai_models = await openai.get_all_models(request)
openai_models = openai_models["data"]
if request.app.state.config.ENABLE_OLLAMA_API:
ollama_models = await ollama.get_all_models(request)
ollama_models = [
{
"id": model["model"],
"name": model["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "ollama",
"ollama": model,
}
for model in ollama_models["models"]
]
function_models = await get_function_models()
models = function_models + openai_models + ollama_models
return models
@cached(ttl=3)
async def get_all_models(request):
models = await get_all_base_models(request)
# If there are no models, return an empty list
if len(models) == 0:
return []
# Add arena models
if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
arena_models = []
if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
arena_models = [
{
"id": model["id"],
"name": model["name"],
"info": {
"meta": model["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
for model in request.app.state.config.EVALUATION_ARENA_MODELS
]
else:
# Add default arena model
arena_models = [
{
"id": DEFAULT_ARENA_MODEL["id"],
"name": DEFAULT_ARENA_MODEL["name"],
"info": {
"meta": DEFAULT_ARENA_MODEL["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
]
models = models + arena_models
global_action_ids = [
function.id for function in Functions.get_global_action_functions()
]
enabled_action_ids = [
function.id
for function in Functions.get_functions_by_type("action", active_only=True)
]
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id is None:
for model in models:
if (
custom_model.id == model["id"]
or custom_model.id == model["id"].split(":")[0]
):
if custom_model.is_active:
model["name"] = custom_model.name
model["info"] = custom_model.model_dump()
action_ids = []
if "info" in model and "meta" in model["info"]:
action_ids.extend(
model["info"]["meta"].get("actionIds", [])
)
model["action_ids"] = action_ids
else:
models.remove(model)
elif custom_model.is_active and (
custom_model.id not in [model["id"] for model in models]
):
owned_by = "openai"
pipe = None
action_ids = []
for model in models:
if (
custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0]
):
owned_by = model["owned_by"]
if "pipe" in model:
pipe = model["pipe"]
break
if custom_model.meta:
meta = custom_model.meta.model_dump()
if "actionIds" in meta:
action_ids.extend(meta["actionIds"])
models.append(
{
"id": f"{custom_model.id}",
"name": custom_model.name,
"object": "model",
"created": custom_model.created_at,
"owned_by": owned_by,
"info": custom_model.model_dump(),
"preset": True,
**({"pipe": pipe} if pipe is not None else {}),
"action_ids": action_ids,
}
)
# Process action_ids to get the actions
def get_action_items_from_module(function, module):
actions = []
if hasattr(module, "actions"):
actions = module.actions
return [
{
"id": f"{function.id}.{action['id']}",
"name": action.get("name", f"{function.name} ({action['id']})"),
"description": function.meta.description,
"icon_url": action.get(
"icon_url", function.meta.manifest.get("icon_url", None)
),
}
for action in actions
]
else:
return [
{
"id": function.id,
"name": function.name,
"description": function.meta.description,
"icon_url": function.meta.manifest.get("icon_url", None),
}
]
def get_function_module_by_id(function_id):
if function_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[function_id]
else:
function_module, _, _ = load_function_module_by_id(function_id)
request.app.state.FUNCTIONS[function_id] = function_module
for model in models:
action_ids = [
action_id
for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
if action_id in enabled_action_ids
]
model["actions"] = []
for action_id in action_ids:
action_function = Functions.get_function_by_id(action_id)
if action_function is None:
raise Exception(f"Action not found: {action_id}")
function_module = get_function_module_by_id(action_id)
model["actions"].extend(
get_action_items_from_module(action_function, function_module)
)
log.debug(f"get_all_models() returned {len(models)} models")
request.app.state.MODELS = {model["id"]: model for model in models}
return models

View File

@ -4,11 +4,15 @@ import re
from typing import Any, Awaitable, Callable, get_type_hints from typing import Any, Awaitable, Callable, get_type_hints
from functools import update_wrapper, partial from functools import update_wrapper, partial
from fastapi import Request
from pydantic import BaseModel, Field, create_model
from langchain_core.utils.function_calling import convert_to_openai_function from langchain_core.utils.function_calling import convert_to_openai_function
from open_webui.models.tools import Tools from open_webui.models.tools import Tools
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.utils.plugin import load_tools_module_by_id from open_webui.utils.plugin import load_tools_module_by_id
from pydantic import BaseModel, Field, create_model
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -32,7 +36,7 @@ def apply_extra_params_to_tool_function(
# Mutation on extra_params # Mutation on extra_params
def get_tools( def get_tools(
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]: ) -> dict[str, dict]:
tools_dict = {} tools_dict = {}
@ -41,10 +45,10 @@ def get_tools(
if tools is None: if tools is None:
continue continue
module = webui_app.state.TOOLS.get(tool_id, None) module = request.app.state.TOOLS.get(tool_id, None)
if module is None: if module is None:
module, _ = load_tools_module_by_id(tool_id) module, _ = load_tools_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module request.app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id extra_params["__id__"] = tool_id
if hasattr(module, "valves") and hasattr(module, "Valves"): if hasattr(module, "valves") and hasattr(module, "Valves"):

View File

@ -110,7 +110,7 @@ export const chatAction = async (token: string, action_id: string, body: ChatAct
export const getTaskConfig = async (token: string = '') => { export const getTaskConfig = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/config`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -138,7 +138,7 @@ export const getTaskConfig = async (token: string = '') => {
export const updateTaskConfig = async (token: string, config: object) => { export const updateTaskConfig = async (token: string, config: object) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/config/update`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/config/update`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -176,7 +176,7 @@ export const generateTitle = async (
) => { ) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/title/completions`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/title/completions`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -216,7 +216,7 @@ export const generateTags = async (
) => { ) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/tags/completions`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/tags/completions`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -288,7 +288,7 @@ export const generateEmoji = async (
) => { ) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/emoji/completions`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/emoji/completions`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -337,7 +337,7 @@ export const generateQueries = async (
) => { ) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/queries/completions`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -407,7 +407,7 @@ export const generateAutoCompletion = async (
const controller = new AbortController(); const controller = new AbortController();
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/auto/completions`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/auto/completions`, {
signal: controller.signal, signal: controller.signal,
method: 'POST', method: 'POST',
headers: { headers: {
@ -477,7 +477,7 @@ export const generateMoACompletion = async (
const controller = new AbortController(); const controller = new AbortController();
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/moa/completions`, {
signal: controller.signal, signal: controller.signal,
method: 'POST', method: 'POST',
headers: { headers: {
@ -507,7 +507,7 @@ export const generateMoACompletion = async (
export const getPipelinesList = async (token: string = '') => { export const getPipelinesList = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/list`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/list`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -541,7 +541,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string)
formData.append('file', file); formData.append('file', file);
formData.append('urlIdx', urlIdx); formData.append('urlIdx', urlIdx);
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/upload`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/upload`, {
method: 'POST', method: 'POST',
headers: { headers: {
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
@ -573,7 +573,7 @@ export const uploadPipeline = async (token: string, file: File, urlIdx: string)
export const downloadPipeline = async (token: string, url: string, urlIdx: string) => { export const downloadPipeline = async (token: string, url: string, urlIdx: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/add`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/add`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -609,7 +609,7 @@ export const downloadPipeline = async (token: string, url: string, urlIdx: strin
export const deletePipeline = async (token: string, id: string, urlIdx: string) => { export const deletePipeline = async (token: string, id: string, urlIdx: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/delete`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines/delete`, {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -650,7 +650,7 @@ export const getPipelines = async (token: string, urlIdx?: string) => {
searchParams.append('urlIdx', urlIdx); searchParams.append('urlIdx', urlIdx);
} }
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines?${searchParams.toString()}`, { const res = await fetch(`${WEBUI_BASE_URL}/api/v1/pipelines?${searchParams.toString()}`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -685,7 +685,7 @@ export const getPipelineValves = async (token: string, pipeline_id: string, urlI
} }
const res = await fetch( const res = await fetch(
`${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves?${searchParams.toString()}`, `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves?${searchParams.toString()}`,
{ {
method: 'GET', method: 'GET',
headers: { headers: {
@ -721,7 +721,7 @@ export const getPipelineValvesSpec = async (token: string, pipeline_id: string,
} }
const res = await fetch( const res = await fetch(
`${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`, `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/spec?${searchParams.toString()}`,
{ {
method: 'GET', method: 'GET',
headers: { headers: {
@ -762,7 +762,7 @@ export const updatePipelineValves = async (
} }
const res = await fetch( const res = await fetch(
`${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`, `${WEBUI_BASE_URL}/api/v1/pipelines/${pipeline_id}/valves/update?${searchParams.toString()}`,
{ {
method: 'POST', method: 'POST',
headers: { headers: {