diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index c712709a5..a0d8f3750 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -1,6 +1,6 @@ -from fastapi import FastAPI, Request, Response, HTTPException, Depends +from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse, JSONResponse, FileResponse +from fastapi.responses import StreamingResponse, FileResponse import requests import aiohttp @@ -12,16 +12,12 @@ from pydantic import BaseModel from starlette.background import BackgroundTask from apps.webui.models.models import Models -from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( - decode_token, - get_verified_user, get_verified_user, get_admin_user, ) -from utils.task import prompt_template -from utils.misc import add_or_update_system_message +from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body from config import ( SRC_LOG_LEVELS, @@ -34,7 +30,7 @@ from config import ( MODEL_FILTER_LIST, AppConfig, ) -from typing import List, Optional +from typing import List, Optional, Literal, overload import hashlib @@ -69,8 +65,6 @@ app.state.MODELS = {} async def check_url(request: Request, call_next): if len(app.state.MODELS) == 0: await get_all_models() - else: - pass response = await call_next(request) return response @@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): res = r.json() if "error" in res: error_detail = f"External: {res['error']}" - except: + except Exception: error_detail = f"External: {e}" raise HTTPException( @@ -234,64 +228,68 @@ def merge_models_lists(model_lists): return merged_list -async def get_all_models(raw: bool = False): +def is_openai_api_disabled(): + api_keys = app.state.config.OPENAI_API_KEYS + no_keys = len(api_keys) == 1 and api_keys[0] == "" + return no_keys or not app.state.config.ENABLE_OPENAI_API + + +async def get_all_models_raw() -> list: + if is_openai_api_disabled(): + return [] + + # Check if API KEYS length is same than API URLS length + num_urls = len(app.state.config.OPENAI_API_BASE_URLS) + num_keys = len(app.state.config.OPENAI_API_KEYS) + + if num_keys != num_urls: + # if there are more keys than urls, remove the extra keys + if num_keys > num_urls: + new_keys = app.state.config.OPENAI_API_KEYS[:num_urls] + app.state.config.OPENAI_API_KEYS = new_keys + # if there are more urls than keys, add empty keys + else: + app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) + + tasks = [ + fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS) + ] + + responses = await asyncio.gather(*tasks) + log.debug(f"get_all_models:responses() {responses}") + + return responses + + +@overload +async def get_all_models(raw: Literal[True]) -> list: ... + + +@overload +async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ... + + +async def get_all_models(raw=False) -> dict[str, list] | list: log.info("get_all_models()") + if is_openai_api_disabled(): + return [] if raw else {"data": []} - if ( - len(app.state.config.OPENAI_API_KEYS) == 1 - and app.state.config.OPENAI_API_KEYS[0] == "" - ) or not app.state.config.ENABLE_OPENAI_API: - models = {"data": []} - else: - # Check if API KEYS length is same than API URLS length - if len(app.state.config.OPENAI_API_KEYS) != len( - app.state.config.OPENAI_API_BASE_URLS - ): - # if there are more keys than urls, remove the extra keys - if len(app.state.config.OPENAI_API_KEYS) > len( - app.state.config.OPENAI_API_BASE_URLS - ): - app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[ - : len(app.state.config.OPENAI_API_BASE_URLS) - ] - # if there are more urls than keys, add empty keys - else: - app.state.config.OPENAI_API_KEYS += [ - "" - for _ in range( - len(app.state.config.OPENAI_API_BASE_URLS) - - len(app.state.config.OPENAI_API_KEYS) - ) - ] + responses = await get_all_models_raw() + if raw: + return responses - tasks = [ - fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS) - ] + def extract_data(response): + if response and "data" in response: + return response["data"] + if isinstance(response, list): + return response + return None - responses = await asyncio.gather(*tasks) - log.debug(f"get_all_models:responses() {responses}") + models = {"data": merge_models_lists(map(extract_data, responses))} - if raw: - return responses - - models = { - "data": merge_models_lists( - list( - map( - lambda response: ( - response["data"] - if (response and "data" in response) - else (response if isinstance(response, list) else None) - ), - responses, - ) - ) - ) - } - - log.debug(f"models: {models}") - app.state.MODELS = {model["id"]: model for model in models["data"]} + log.debug(f"models: {models}") + app.state.MODELS = {model["id"]: model for model in models["data"]} return models @@ -299,7 +297,7 @@ async def get_all_models(raw: bool = False): @app.get("/models") @app.get("/models/{url_idx}") async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): - if url_idx == None: + if url_idx is None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user": @@ -340,7 +338,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us res = r.json() if "error" in res: error_detail = f"External: {res['error']}" - except: + except Exception: error_detail = f"External: {e}" raise HTTPException( @@ -358,8 +356,7 @@ async def generate_chat_completion( ): idx = 0 payload = {**form_data} - if "metadata" in payload: - del payload["metadata"] + payload.pop("metadata") model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) @@ -368,70 +365,9 @@ async def generate_chat_completion( if model_info.base_model_id: payload["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() - - if model_info.params: - if ( - model_info.params.get("temperature", None) is not None - and payload.get("temperature") is None - ): - payload["temperature"] = float(model_info.params.get("temperature")) - - if model_info.params.get("top_p", None) and payload.get("top_p") is None: - payload["top_p"] = int(model_info.params.get("top_p", None)) - - if ( - model_info.params.get("max_tokens", None) - and payload.get("max_tokens") is None - ): - payload["max_tokens"] = int(model_info.params.get("max_tokens", None)) - - if ( - model_info.params.get("frequency_penalty", None) - and payload.get("frequency_penalty") is None - ): - payload["frequency_penalty"] = int( - model_info.params.get("frequency_penalty", None) - ) - - if ( - model_info.params.get("seed", None) is not None - and payload.get("seed") is None - ): - payload["seed"] = model_info.params.get("seed", None) - - if model_info.params.get("stop", None) and payload.get("stop") is None: - payload["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) - - system = model_info.params.get("system", None) - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), - ) - if payload.get("messages"): - payload["messages"] = add_or_update_system_message( - system, payload["messages"] - ) - - else: - pass + params = model_info.params.model_dump() + payload = apply_model_params_to_body(params, payload) + payload = apply_model_system_prompt_to_body(params, payload, user) model = app.state.MODELS[payload.get("model")] idx = model["urlIdx"] @@ -444,13 +380,6 @@ async def generate_chat_completion( "role": user.role, } - # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 - # This is a workaround until OpenAI fixes the issue with this model - if payload.get("model") == "gpt-4-vision-preview": - if "max_tokens" not in payload: - payload["max_tokens"] = 4000 - log.debug("Modified payload:", payload) - # Convert the modified body back to JSON payload = json.dumps(payload) @@ -506,7 +435,7 @@ async def generate_chat_completion( print(res) if "error" in res: error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" - except: + except Exception: error_detail = f"External: {e}" raise HTTPException(status_code=r.status if r else 500, detail=error_detail) finally: @@ -569,7 +498,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): print(res) if "error" in res: error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" - except: + except Exception: error_detail = f"External: {e}" raise HTTPException(status_code=r.status if r else 500, detail=error_detail) finally: diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index 1d98d37ff..fcffca420 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -44,23 +44,26 @@ async def user_join(sid, data): print("user-join", sid, data) auth = data["auth"] if "auth" in data else None + if not auth or "token" not in auth: + return - if auth and "token" in auth: - data = decode_token(auth["token"]) + data = decode_token(auth["token"]) + if data is None or "id" not in data: + return - if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) + user = Users.get_user_by_id(data["id"]) + if not user: + return - if user: - SESSION_POOL[sid] = user.id - if user.id in USER_POOL: - USER_POOL[user.id].append(sid) - else: - USER_POOL[user.id] = [sid] + SESSION_POOL[sid] = user.id + if user.id in USER_POOL: + USER_POOL[user.id].append(sid) + else: + USER_POOL[user.id] = [sid] - print(f"user {user.name}({user.id}) connected with session ID {sid}") + print(f"user {user.name}({user.id}) connected with session ID {sid}") - await sio.emit("user-count", {"count": len(set(USER_POOL))}) + await sio.emit("user-count", {"count": len(set(USER_POOL))}) @sio.on("user-count") diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 972562a04..a0b9f5008 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -22,9 +22,9 @@ from apps.webui.utils import load_function_module_by_id from utils.misc import ( openai_chat_chunk_message_template, openai_chat_completion_message_template, - add_or_update_system_message, + apply_model_params_to_body, + apply_model_system_prompt_to_body, ) -from utils.task import prompt_template from config import ( @@ -269,47 +269,6 @@ def get_function_params(function_module, form_data, user, extra_params={}): return params -# inplace function: form_data is modified -def apply_model_params_to_body(params: dict, form_data: dict) -> dict: - if not params: - return form_data - - mappings = { - "temperature": float, - "top_p": int, - "max_tokens": int, - "frequency_penalty": int, - "seed": lambda x: x, - "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], - } - - for key, cast_func in mappings.items(): - if (value := params.get(key)) is not None: - form_data[key] = cast_func(value) - - return form_data - - -# inplace function: form_data is modified -def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict: - system = params.get("system", None) - if not system: - return form_data - - if user: - template_params = { - "user_name": user.name, - "user_location": user.info.get("location") if user.info else None, - } - else: - template_params = {} - system = prompt_template(system, **template_params) - form_data["messages"] = add_or_update_system_message( - system, form_data.get("messages", []) - ) - return form_data - - async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index ea9db8180..7e60fe4d1 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -1,12 +1,8 @@ -from fastapi import Depends, FastAPI, HTTPException, status, Request -from datetime import datetime, timedelta -from typing import List, Union, Optional +from fastapi import Depends, HTTPException, status, Request +from typing import List, Optional from fastapi import APIRouter -from pydantic import BaseModel -import json -from apps.webui.models.users import Users from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.utils import load_toolkit_module_by_id @@ -14,7 +10,6 @@ from utils.utils import get_admin_user, get_verified_user from utils.tools import get_tools_specs from constants import ERROR_MESSAGES -from importlib import util import os from pathlib import Path @@ -69,7 +64,7 @@ async def create_new_toolkit( form_data.id = form_data.id.lower() toolkit = Tools.get_tool_by_id(form_data.id) - if toolkit == None: + if toolkit is None: toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py") try: with open(toolkit_path, "w") as tool_file: @@ -98,7 +93,7 @@ async def create_new_toolkit( print(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + detail=ERROR_MESSAGES.DEFAULT(str(e)), ) else: raise HTTPException( @@ -170,7 +165,7 @@ async def update_toolkit_by_id( except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + detail=ERROR_MESSAGES.DEFAULT(str(e)), ) @@ -210,7 +205,7 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + detail=ERROR_MESSAGES.DEFAULT(str(e)), ) else: raise HTTPException( @@ -233,7 +228,7 @@ async def get_toolkit_valves_spec_by_id( if id in request.app.state.TOOLS: toolkit_module = request.app.state.TOOLS[id] else: - toolkit_module, frontmatter = load_toolkit_module_by_id(id) + toolkit_module, _ = load_toolkit_module_by_id(id) request.app.state.TOOLS[id] = toolkit_module if hasattr(toolkit_module, "Valves"): @@ -261,7 +256,7 @@ async def update_toolkit_valves_by_id( if id in request.app.state.TOOLS: toolkit_module = request.app.state.TOOLS[id] else: - toolkit_module, frontmatter = load_toolkit_module_by_id(id) + toolkit_module, _ = load_toolkit_module_by_id(id) request.app.state.TOOLS[id] = toolkit_module if hasattr(toolkit_module, "Valves"): @@ -276,7 +271,7 @@ async def update_toolkit_valves_by_id( print(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + detail=ERROR_MESSAGES.DEFAULT(str(e)), ) else: raise HTTPException( @@ -306,7 +301,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + detail=ERROR_MESSAGES.DEFAULT(str(e)), ) else: raise HTTPException( @@ -324,7 +319,7 @@ async def get_toolkit_user_valves_spec_by_id( if id in request.app.state.TOOLS: toolkit_module = request.app.state.TOOLS[id] else: - toolkit_module, frontmatter = load_toolkit_module_by_id(id) + toolkit_module, _ = load_toolkit_module_by_id(id) request.app.state.TOOLS[id] = toolkit_module if hasattr(toolkit_module, "UserValves"): @@ -348,7 +343,7 @@ async def update_toolkit_user_valves_by_id( if id in request.app.state.TOOLS: toolkit_module = request.app.state.TOOLS[id] else: - toolkit_module, frontmatter = load_toolkit_module_by_id(id) + toolkit_module, _ = load_toolkit_module_by_id(id) request.app.state.TOOLS[id] = toolkit_module if hasattr(toolkit_module, "UserValves"): @@ -365,7 +360,7 @@ async def update_toolkit_user_valves_by_id( print(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + detail=ERROR_MESSAGES.DEFAULT(str(e)), ) else: raise HTTPException( diff --git a/backend/main.py b/backend/main.py index a7dd8bc23..181944606 100644 --- a/backend/main.py +++ b/backend/main.py @@ -957,7 +957,7 @@ async def get_all_models(): custom_models = Models.get_all_models() for custom_model in custom_models: - if custom_model.base_model_id == None: + if custom_model.base_model_id is None: for model in models: if ( custom_model.id == model["id"] @@ -1656,13 +1656,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ @app.get("/api/pipelines/list") async def get_pipelines_list(user=Depends(get_admin_user)): - responses = await get_openai_models(raw=True) + responses = await get_openai_models(raw = True) print(responses) urlIdxs = [ idx for idx, response in enumerate(responses) - if response != None and "pipelines" in response + if response is not None and "pipelines" in response ] return { @@ -1723,7 +1723,7 @@ async def upload_pipeline( res = r.json() if "detail" in res: detail = res["detail"] - except: + except Exception: pass raise HTTPException( @@ -1769,7 +1769,7 @@ async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)) res = r.json() if "detail" in res: detail = res["detail"] - except: + except Exception: pass raise HTTPException( @@ -1811,7 +1811,7 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_ res = r.json() if "detail" in res: detail = res["detail"] - except: + except Exception: pass raise HTTPException( @@ -1844,7 +1844,7 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use res = r.json() if "detail" in res: detail = res["detail"] - except: + except Exception: pass raise HTTPException( @@ -1859,7 +1859,6 @@ async def get_pipeline_valves( pipeline_id: str, user=Depends(get_admin_user), ): - models = await get_all_models() r = None try: url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] @@ -1898,8 +1897,6 @@ async def get_pipeline_valves_spec( pipeline_id: str, user=Depends(get_admin_user), ): - models = await get_all_models() - r = None try: url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] @@ -1922,7 +1919,7 @@ async def get_pipeline_valves_spec( res = r.json() if "detail" in res: detail = res["detail"] - except: + except Exception: pass raise HTTPException( @@ -1938,8 +1935,6 @@ async def update_pipeline_valves( form_data: dict, user=Depends(get_admin_user), ): - models = await get_all_models() - r = None try: url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] @@ -1967,7 +1962,7 @@ async def update_pipeline_valves( res = r.json() if "detail" in res: detail = res["detail"] - except: + except Exception: pass raise HTTPException( @@ -2068,7 +2063,7 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): @app.get("/api/version") -async def get_app_config(): +async def get_app_version(): return { "version": VERSION, } @@ -2091,7 +2086,7 @@ async def get_app_latest_release_version(): latest_version = data["tag_name"] return {"current": VERSION, "latest": latest_version[1:]} - except aiohttp.ClientError as e: + except aiohttp.ClientError: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED, diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 3aadd3fb9..25dd4dd5b 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -6,6 +6,8 @@ from typing import Optional, List, Tuple import uuid import time +from utils.task import prompt_template + def get_last_user_message_item(messages: List[dict]) -> Optional[dict]: for message in reversed(messages): @@ -112,6 +114,47 @@ def openai_chat_completion_message_template(model: str, message: str) -> dict: return template +# inplace function: form_data is modified +def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict: + system = params.get("system", None) + if not system: + return form_data + + if user: + template_params = { + "user_name": user.name, + "user_location": user.info.get("location") if user.info else None, + } + else: + template_params = {} + system = prompt_template(system, **template_params) + form_data["messages"] = add_or_update_system_message( + system, form_data.get("messages", []) + ) + return form_data + + +# inplace function: form_data is modified +def apply_model_params_to_body(params: dict, form_data: dict) -> dict: + if not params: + return form_data + + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + } + + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) + + return form_data + + def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters diff --git a/backend/utils/task.py b/backend/utils/task.py index 053a526a8..1b2276c9c 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -6,7 +6,7 @@ from typing import Optional def prompt_template( - template: str, user_name: str = None, user_location: str = None + template: str, user_name: Optional[str] = None, user_location: Optional[str] = None ) -> str: # Get the current date current_date = datetime.now() @@ -83,7 +83,6 @@ def title_generation_template( def search_query_generation_template( template: str, prompt: str, user: Optional[dict] = None ) -> str: - def replacement_function(match): full_match = match.group(0) start_length = match.group(1) diff --git a/backend/utils/utils.py b/backend/utils/utils.py index fbc539af5..288db1fb5 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -1,15 +1,12 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import HTTPException, status, Depends, Request -from sqlalchemy.orm import Session from apps.webui.models.users import Users -from pydantic import BaseModel from typing import Union, Optional from constants import ERROR_MESSAGES from passlib.context import CryptContext from datetime import datetime, timedelta -import requests import jwt import uuid import logging @@ -54,7 +51,7 @@ def decode_token(token: str) -> Optional[dict]: try: decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM]) return decoded - except Exception as e: + except Exception: return None @@ -71,7 +68,7 @@ def get_http_authorization_cred(auth_header: str): try: scheme, credentials = auth_header.split(" ") return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) - except: + except Exception: raise ValueError(ERROR_MESSAGES.INVALID_TOKEN) @@ -96,7 +93,7 @@ def get_current_user( # auth by jwt token data = decode_token(token) - if data != None and "id" in data: + if data is not None and "id" in data: user = Users.get_user_by_id(data["id"]) if user is None: raise HTTPException(