From 932de8f1e2d38c93d0c5a9298aa3927800eb77a1 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 16 Nov 2024 04:41:07 -0800 Subject: [PATCH] refac --- backend/open_webui/apps/ollama/main.py | 83 +++--- backend/open_webui/apps/openai/main.py | 119 ++++---- backend/open_webui/apps/retrieval/utils.py | 23 +- backend/open_webui/apps/webui/main.py | 5 +- backend/open_webui/main.py | 282 +++++++++++------- backend/open_webui/utils/utils.py | 7 +- backend/requirements.txt | 1 + pyproject.toml | 1 + .../components/admin/Settings/Models.svelte | 2 +- 9 files changed, 309 insertions(+), 214 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index fe619b49a..07bf43510 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -68,25 +68,12 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS -app.state.MODELS = {} - # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # least connections, or least response time for better resource utilization and performance optimization. -@app.middleware("http") -async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - else: - pass - - response = await call_next(request) - return response - - @app.head("/") @app.get("/") async def get_status(): @@ -321,8 +308,6 @@ async def get_all_models(): else: models = {"models": []} - app.state.MODELS = {model["model"]: model for model in models["models"]} - return models @@ -470,8 +455,11 @@ async def push_model( user=Depends(get_admin_user), ): if url_idx is None: - if form_data.name in app.state.MODELS: - url_idx = app.state.MODELS[form_data.name]["urls"][0] + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.name in models: + url_idx = models[form_data.name]["urls"][0] else: raise HTTPException( status_code=400, @@ -520,8 +508,11 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - if form_data.source in app.state.MODELS: - url_idx = app.state.MODELS[form_data.source]["urls"][0] + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.source in models: + url_idx = models[form_data.source]["urls"][0] else: raise HTTPException( status_code=400, @@ -576,8 +567,11 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - if form_data.name in app.state.MODELS: - url_idx = app.state.MODELS[form_data.name]["urls"][0] + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.name in models: + url_idx = models[form_data.name]["urls"][0] else: raise HTTPException( status_code=400, @@ -625,13 +619,16 @@ async def delete_model( @app.post("/api/show") async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): - if form_data.name not in app.state.MODELS: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if form_data.name not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) + url_idx = random.choice(models[form_data.name]["urls"]) url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") @@ -701,23 +698,26 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) + return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) -def generate_ollama_embeddings( +async def generate_ollama_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, ): log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + model = form_data.model if ":" not in model: model = f"{model}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if model in models: + url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, @@ -768,20 +768,23 @@ def generate_ollama_embeddings( ) -def generate_ollama_batch_embeddings( +async def generate_ollama_batch_embeddings( form_data: GenerateEmbedForm, url_idx: Optional[int] = None, ): log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + model = form_data.model if ":" not in model: model = f"{model}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if model in models: + url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, @@ -851,13 +854,16 @@ async def generate_completion( user=Depends(get_verified_user), ): if url_idx is None: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + model = form_data.model if ":" not in model: model = f"{model}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if model in models: + url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, @@ -892,14 +898,17 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -def get_ollama_url(url_idx: Optional[int], model: str): +async def get_ollama_url(url_idx: Optional[int], model: str): if url_idx is None: - if model not in app.state.MODELS: + model_list = await get_all_models() + models = {model["model"]: model for model in model_list["models"]} + + if model not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) - url_idx = random.choice(app.state.MODELS[model]["urls"]) + url_idx = random.choice(models[model]["urls"]) url = app.state.config.OLLAMA_BASE_URLS[url_idx] return url @@ -948,7 +957,7 @@ async def generate_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = get_ollama_url(url_idx, payload["model"]) + url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") @@ -1030,7 +1039,7 @@ async def generate_openai_chat_completion( if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" - url = get_ollama_url(url_idx, payload["model"]) + url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index d704238b4..86cecead0 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -36,7 +36,7 @@ from open_webui.utils.payload import ( apply_model_system_prompt_to_body, ) -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.utils import get_admin_user, get_verified_user, has_access log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) @@ -64,17 +64,6 @@ app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS -app.state.MODELS = {} - - -@app.middleware("http") -async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - - response = await call_next(request) - return response - @app.get("/config") async def get_config(user=Depends(get_admin_user)): @@ -259,7 +248,7 @@ def merge_models_lists(model_lists): return merged_list -async def get_all_models_raw() -> list: +async def get_all_models_responses() -> list: if not app.state.config.ENABLE_OPENAI_API: return [] @@ -330,22 +319,13 @@ async def get_all_models_raw() -> list: return responses -@overload -async def get_all_models(raw: Literal[True]) -> list: ... - - -@overload -async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ... - - -async def get_all_models(raw=False) -> dict[str, list] | list: +async def get_all_models() -> dict[str, list]: log.info("get_all_models()") - if not app.state.config.ENABLE_OPENAI_API: - return [] if raw else {"data": []} - responses = await get_all_models_raw() - if raw: - return responses + if not app.state.config.ENABLE_OPENAI_API: + return {"data": []} + + responses = await get_all_models_responses() def extract_data(response): if response and "data" in response: @@ -355,9 +335,7 @@ async def get_all_models(raw=False) -> dict[str, list] | list: return None models = {"data": merge_models_lists(map(extract_data, responses))} - log.debug(f"models: {models}") - app.state.MODELS = {model["id"]: model for model in models["data"]} return models @@ -365,21 +343,12 @@ async def get_all_models(raw=False) -> dict[str, list] | list: @app.get("/models") @app.get("/models/{url_idx}") async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): + models = { + "data": [], + } + if url_idx is None: models = await get_all_models() - - # TODO: Check User Group and Filter Models - # if app.state.config.ENABLE_MODEL_FILTER: - # if user.role == "user": - # models["data"] = list( - # filter( - # lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, - # models["data"], - # ) - # ) - # return models - - return models else: url = app.state.config.OPENAI_API_BASE_URLS[url_idx] key = app.state.config.OPENAI_API_KEYS[url_idx] @@ -387,6 +356,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us headers = {} headers["Authorization"] = f"Bearer {key}" headers["Content-Type"] = "application/json" + if ENABLE_FORWARD_USER_INFO_HEADERS: headers["X-OpenWebUI-User-Name"] = user.name headers["X-OpenWebUI-User-Id"] = user.id @@ -428,8 +398,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us ) ] - return response_data - + models = response_data except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") @@ -443,6 +412,22 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) + if user.role == "user": + # Filter models based on user access control + filtered_models = [] + for model in models.get("data", []): + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + else: + filtered_models.append(model) + models["data"] = filtered_models + + return models + class ConnectionVerificationForm(BaseModel): url: str @@ -502,18 +487,9 @@ async def generate_chat_completion( del payload["metadata"] model_id = form_data.get("model") - - # TODO: Check User Group and Filter Models - # if not bypass_filter: - # if app.state.config.ENABLE_MODEL_FILTER: - # if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - # raise HTTPException( - # status_code=403, - # detail="Model not found", - # ) - model_info = Models.get_model_by_id(model_id) + # Check model info and override the payload if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -522,12 +498,33 @@ async def generate_chat_completion( payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) - try: - model = app.state.MODELS[payload.get("model")] - idx = model["urlIdx"] - except Exception as e: - raise HTTPException(status_code=404, detail="Model not found") + # Check if user has access to the model + if user.role == "user" and not has_access( + user.id, type="read", access_control=model_info.access_control + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + # Attemp to get urlIdx from the model + models = await get_all_models() + + # Find the model from the list + model = next( + (model for model in models["data"] if model["id"] == payload.get("model")), + None, + ) + + if model: + idx = model["urlIdx"] + else: + raise HTTPException( + status_code=404, + detail="Model not found", + ) + + # Get the API config for the model api_config = app.state.config.OPENAI_API_CONFIGS.get( app.state.config.OPENAI_API_BASE_URLS[idx], {} ) @@ -536,6 +533,7 @@ async def generate_chat_completion( if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + # Add user info to the payload if the model is a pipeline if "pipeline" in model and model.get("pipeline"): payload["user"] = { "name": user.name, @@ -546,8 +544,9 @@ async def generate_chat_completion( url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] - is_o1 = payload["model"].lower().startswith("o1-") + # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" + is_o1 = payload["model"].lower().startswith("o1-") # Change max_completion_tokens to max_tokens (Backward compatible) if "api.openai.com" not in url and not is_o1: if "max_completion_tokens" in payload: diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 8f855473e..7d92b7350 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -3,6 +3,7 @@ import os import uuid from typing import Optional, Union +import asyncio import requests from huggingface_hub import snapshot_download @@ -291,7 +292,13 @@ def get_embedding_function( if embedding_engine == "": return lambda query: embedding_function.encode(query).tolist() elif embedding_engine in ["ollama", "openai"]: - func = lambda query: generate_embeddings( + + # Wrapper to run the async generate_embeddings synchronously. + def sync_generate_embeddings(*args, **kwargs): + return asyncio.run(generate_embeddings(*args, **kwargs)) + + # Semantic expectation from the original version (using sync wrapper). + func = lambda query: sync_generate_embeddings( engine=embedding_engine, model=embedding_model, text=query, @@ -469,7 +476,7 @@ def get_model_path(model: str, update_model: bool = False): return model -def generate_openai_batch_embeddings( +async def generate_openai_batch_embeddings( model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" ) -> Optional[list[list[float]]]: try: @@ -492,14 +499,16 @@ def generate_openai_batch_embeddings( return None -def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): +async def generate_embeddings( + engine: str, model: str, text: Union[str, list[str]], **kwargs +): if engine == "ollama": if isinstance(text, list): - embeddings = generate_ollama_batch_embeddings( + embeddings = await generate_ollama_batch_embeddings( GenerateEmbedForm(**{"model": model, "input": text}) ) else: - embeddings = generate_ollama_batch_embeddings( + embeddings = await generate_ollama_batch_embeddings( GenerateEmbedForm(**{"model": model, "input": [text]}) ) return ( @@ -512,9 +521,9 @@ def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], ** url = kwargs.get("url", "https://api.openai.com/v1") if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, key, url) + embeddings = await generate_openai_batch_embeddings(model, text, key, url) else: - embeddings = generate_openai_batch_embeddings(model, [text], key, url) + embeddings = await generate_openai_batch_embeddings(model, [text], key, url) return embeddings[0] if isinstance(text, str) else embeddings diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index c7f338066..4535501fd 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -142,7 +142,6 @@ app.state.config.LDAP_USE_TLS = LDAP_USE_TLS app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE app.state.config.LDAP_CIPHERS = LDAP_CIPHERS -app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} @@ -369,7 +368,7 @@ def get_function_params(function_module, form_data, user, extra_params=None): return params -async def generate_function_chat_completion(form_data, user): +async def generate_function_chat_completion(form_data, user, models: dict = {}): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) @@ -412,7 +411,7 @@ async def generate_function_chat_completion(form_data, user): user, { **extra_params, - "__model__": app.state.MODELS[form_data["model"]], + "__model__": models.get(form_data["model"], None), "__messages__": form_data["messages"], "__files__": files, }, diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 9deb3a5cb..c523fd4c8 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -11,6 +11,7 @@ import random from contextlib import asynccontextmanager from typing import Optional +from aiocache import cached import aiohttp import requests from fastapi import ( @@ -45,6 +46,7 @@ from open_webui.apps.openai.main import ( app as openai_app, generate_chat_completion as generate_openai_chat_completion, get_all_models as get_openai_models, + get_all_models_responses as get_openai_models_responses, ) from open_webui.apps.retrieval.main import app as retrieval_app from open_webui.apps.retrieval.utils import get_rag_context, rag_template @@ -132,6 +134,7 @@ from open_webui.utils.utils import ( get_current_user, get_http_authorization_cred, get_verified_user, + has_access, ) if SAFE_MODE: @@ -196,20 +199,22 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL + app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE -app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE + app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION +app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE + + +app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE ) -app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) -app.state.MODELS = {} - - ################################## # # ChatCompletion Middleware @@ -217,26 +222,6 @@ app.state.MODELS = {} ################################## -def get_task_model_id(default_model_id): - # Set the task model - task_model_id = default_model_id - # Check if the user has a custom task model and use that model - if app.state.MODELS[task_model_id]["owned_by"] == "ollama": - if ( - app.state.config.TASK_MODEL - and app.state.config.TASK_MODEL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL - else: - if ( - app.state.config.TASK_MODEL_EXTERNAL - and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - - return task_model_id - - def get_filter_function_ids(model): def get_priority(function_id): function = Functions.get_function_by_id(function_id) @@ -366,8 +351,24 @@ async def get_content_from_response(response) -> Optional[str]: return content +def get_task_model_id( + default_model_id: str, task_model: str, task_model_external: str, models +) -> str: + # Set the task model + task_model_id = default_model_id + # Check if the user has a custom task model and use that model + if models[task_model_id]["owned_by"] == "ollama": + if task_model and task_model in models: + task_model_id = task_model + else: + if task_model_external and task_model_external in models: + task_model_id = task_model_external + + return task_model_id + + async def chat_completion_tools_handler( - body: dict, user: UserModel, extra_params: dict + body: dict, user: UserModel, models, extra_params: dict ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions metadata = body.get("metadata", {}) @@ -381,14 +382,19 @@ async def chat_completion_tools_handler( contexts = [] citations = [] - task_model_id = get_task_model_id(body["model"]) + task_model_id = get_task_model_id( + body["model"], + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) tools = get_tools( webui_app, tool_ids, user, { **extra_params, - "__model__": app.state.MODELS[task_model_id], + "__model__": models[task_model_id], "__messages__": body["messages"], "__files__": metadata.get("files", []), }, @@ -412,7 +418,7 @@ async def chat_completion_tools_handler( ) try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: raise e @@ -513,16 +519,16 @@ def is_chat_completion_request(request): ) -async def get_body_and_model_and_user(request): +async def get_body_and_model_and_user(request, models): # Read the original request body body = await request.body() body_str = body.decode("utf-8") body = json.loads(body_str) if body_str else {} model_id = body["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise Exception("Model not found") - model = app.state.MODELS[model_id] + model = models[model_id] user = get_current_user( request, @@ -538,14 +544,27 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): return await call_next(request) log.debug(f"request.url.path: {request.url.path}") + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + try: - body, model, user = await get_body_and_model_and_user(request) + body, model, user = await get_body_and_model_and_user(request, models) except Exception as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(e)}, ) + model_info = Models.get_model_by_id(model["id"]) + if user.role == "user": + if model_info and not has_access( + user.id, type="read", access_control=model_info.access_control + ): + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"detail": "User does not have access to the model"}, + ) + metadata = { "chat_id": body.pop("chat_id", None), "message_id": body.pop("id", None), @@ -582,15 +601,20 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): content={"detail": str(e)}, ) + tool_ids = body.pop("tool_ids", None) + files = body.pop("files", None) + metadata = { **metadata, - "tool_ids": body.pop("tool_ids", None), - "files": body.pop("files", None), + "tool_ids": tool_ids, + "files": files, } body["metadata"] = metadata try: - body, flags = await chat_completion_tools_handler(body, user, extra_params) + body, flags = await chat_completion_tools_handler( + body, user, models, extra_params + ) contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: @@ -687,10 +711,10 @@ app.add_middleware(ChatCompletionMiddleware) ################################## -def get_sorted_filters(model_id): +def get_sorted_filters(model_id, models): filters = [ model - for model in app.state.MODELS.values() + for model in models.values() if "pipeline" in model and "type" in model["pipeline"] and model["pipeline"]["type"] == "filter" @@ -706,12 +730,12 @@ def get_sorted_filters(model_id): return sorted_filters -def filter_pipeline(payload, user): +def filter_pipeline(payload, user, models): user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} model_id = payload["model"] - sorted_filters = get_sorted_filters(model_id) - model = app.state.MODELS[model_id] + sorted_filters = get_sorted_filters(model_id, models) + model = models[model_id] if "pipeline" in model: sorted_filters.append(model) @@ -782,8 +806,11 @@ class PipelineMiddleware(BaseHTTPMiddleware): content={"detail": "Not authenticated"}, ) + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + try: - data = filter_pipeline(data, user) + data = filter_pipeline(data, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -862,16 +889,10 @@ async def commit_session_after_request(request: Request, call_next): @app.middleware("http") async def check_url(request: Request, call_next): - if len(app.state.MODELS) == 0: - await get_all_models() - else: - pass - start_time = int(time.time()) response = await call_next(request) process_time = int(time.time()) - start_time response.headers["X-Process-Time"] = str(process_time) - return response @@ -911,10 +932,10 @@ app.mount("/retrieval/api/v1", retrieval_app) app.mount("/api/v1", webui_app) - webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION +@cached(ttl=1) async def get_all_base_models(): open_webui_models = [] openai_models = [] @@ -944,6 +965,7 @@ async def get_all_base_models(): return models +@cached(ttl=1) async def get_all_models(): models = await get_all_base_models() @@ -1065,9 +1087,6 @@ async def get_all_models(): function_module = get_function_module_by_id(action_id) model["actions"].extend(get_action_items_from_module(function_module)) - - app.state.MODELS = {model["id"]: model for model in models} - webui_app.state.MODELS = app.state.MODELS return models @@ -1082,16 +1101,19 @@ async def get_models(user=Depends(get_verified_user)): if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" ] - # TODO: Check User Group and Filter Models - # if app.state.config.ENABLE_MODEL_FILTER: - # if user.role == "user": - # models = list( - # filter( - # lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, - # models, - # ) - # ) - # return {"data": models} + # Filter out models that the user does not have access to + if user.role == "user": + filtered_models = [] + for model in models: + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + else: + filtered_models.append(model) + models = filtered_models return {"data": models} @@ -1106,24 +1128,27 @@ async def get_base_models(user=Depends(get_admin_user)): async def generate_chat_completions( form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False ): - model_id = form_data["model"] + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} - if model_id not in app.state.MODELS: + model_id = form_data["model"] + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) - # TODO: Check User Group and Filter Models - # if not bypass_filter: - # if app.state.config.ENABLE_MODEL_FILTER: - # if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: - # raise HTTPException( - # status_code=status.HTTP_403_FORBIDDEN, - # detail="Model not found", - # ) - - model = app.state.MODELS[model_id] + model = models[model_id] + # Check if user has access to the model + if user.role == "user": + model_info = Models.get_model_by_id(model_id) + if not has_access( + user.id, type="read", access_control=model_info.access_control + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) if model["owned_by"] == "arena": model_ids = model.get("info", {}).get("meta", {}).get("model_ids") @@ -1174,7 +1199,9 @@ async def generate_chat_completions( if model.get("pipe"): # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion(form_data, user=user) + return await generate_function_chat_completion( + form_data, user=user, models=models + ) if model["owned_by"] == "ollama": # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) @@ -1198,16 +1225,20 @@ async def generate_chat_completions( @app.post("/api/chat/completed") async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + data = form_data model_id = data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) - model = app.state.MODELS[model_id] - sorted_filters = get_sorted_filters(model_id) + model = models[model_id] + sorted_filters = get_sorted_filters(model_id, models) if "pipeline" in model: sorted_filters = [model] + sorted_filters @@ -1382,14 +1413,18 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified detail="Action not found", ) + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + data = form_data model_id = data["model"] - if model_id not in app.state.MODELS: + + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) - model = app.state.MODELS[model_id] + model = models[model_id] __event_emitter__ = get_event_emitter( { @@ -1543,8 +1578,11 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u async def generate_title(form_data: dict, user=Depends(get_verified_user)): print("generate_title") + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1552,10 +1590,16 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + print(task_model_id) - model = app.state.MODELS[task_model_id] + model = models[task_model_id] if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE @@ -1589,7 +1633,7 @@ Artificial Intelligence in Healthcare "stream": False, **( {"max_tokens": 50} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" + if models[task_model_id]["owned_by"] == "ollama" else { "max_completion_tokens": 50, } @@ -1601,7 +1645,7 @@ Artificial Intelligence in Healthcare # Handle pipeline filters try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1628,8 +1672,11 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): content={"detail": "Tags generation is disabled"}, ) + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1637,7 +1684,12 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) print(task_model_id) if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": @@ -1675,7 +1727,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } # Handle pipeline filters try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1702,8 +1754,11 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) detail=f"Search query generation is disabled", ) + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1711,10 +1766,15 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) print(task_model_id) - model = app.state.MODELS[task_model_id] + model = models[task_model_id] if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE @@ -1741,7 +1801,7 @@ Search Query:""" "stream": False, **( {"max_tokens": 30} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" + if models[task_model_id]["owned_by"] == "ollama" else { "max_completion_tokens": 30, } @@ -1752,7 +1812,7 @@ Search Query:""" # Handle pipeline filters try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1774,8 +1834,11 @@ Search Query:""" async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): print("generate_emoji") + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1783,10 +1846,15 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) print(task_model_id) - model = app.state.MODELS[task_model_id] + model = models[task_model_id] template = ''' Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). @@ -1808,7 +1876,7 @@ Message: """{{prompt}}""" "stream": False, **( {"max_tokens": 4} - if app.state.MODELS[task_model_id]["owned_by"] == "ollama" + if models[task_model_id]["owned_by"] == "ollama" else { "max_completion_tokens": 4, } @@ -1820,7 +1888,7 @@ Message: """{{prompt}}""" # Handle pipeline filters try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1842,8 +1910,11 @@ Message: """{{prompt}}""" async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): print("generate_moa_response") + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + model_id = form_data["model"] - if model_id not in app.state.MODELS: + if model_id not in models: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", @@ -1851,10 +1922,15 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user) # Check if the user has a custom task model # If the user has a custom task model, use that model - task_model_id = get_task_model_id(model_id) + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) print(task_model_id) - model = app.state.MODELS[task_model_id] + model = models[task_model_id] template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" @@ -1881,7 +1957,7 @@ Responses from models: {{responses}}""" log.debug(payload) try: - payload = filter_pipeline(payload, user) + payload = filter_pipeline(payload, user, models) except Exception as e: if len(e.args) > 1: return JSONResponse( @@ -1911,7 +1987,7 @@ Responses from models: {{responses}}""" @app.get("/api/pipelines/list") async def get_pipelines_list(user=Depends(get_admin_user)): - responses = await get_openai_models(raw=True) + responses = await get_openai_models_responses() print(responses) urlIdxs = [ diff --git a/backend/open_webui/utils/utils.py b/backend/open_webui/utils/utils.py index 370a30d6f..fd6aa8bca 100644 --- a/backend/open_webui/utils/utils.py +++ b/backend/open_webui/utils/utils.py @@ -192,15 +192,16 @@ def has_permission( def has_access( user_id: str, - action: str = "write", + type: str = "write", access_control: Optional[dict] = None, ) -> bool: + print("user_id", user_id, "type", type, "access_control", access_control) if access_control is None: - return action == "read" + return type == "read" user_groups = Groups.get_groups_by_member_id(user_id) user_group_ids = [group.id for group in user_groups] - permission_access = access_control.get(action, {}) + permission_access = access_control.get(type, {}) permitted_group_ids = permission_access.get("group_ids", []) permitted_user_ids = permission_access.get("user_ids", []) diff --git a/backend/requirements.txt b/backend/requirements.txt index 44838dd36..a5bfae585 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -13,6 +13,7 @@ passlib[bcrypt]==1.7.4 requests==2.32.3 aiohttp==3.10.8 async-timeout +aiocache sqlalchemy==2.0.32 alembic==1.13.2 diff --git a/pyproject.toml b/pyproject.toml index 305ced6eb..fa16381f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "requests==2.32.3", "aiohttp==3.10.8", "async-timeout", + "aiocache", "sqlalchemy==2.0.32", "alembic==1.13.2", diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index 0abcd3ec5..42ffa238c 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -71,7 +71,7 @@ const upsertModelHandler = async (model) => { model.base_model_id = null; - if (models.find((m) => m.id === model.id)) { + if (workspaceModels.find((m) => m.id === model.id)) { await updateModelById(localStorage.token, model.id, model).catch((error) => { return null; });