mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	Merge pull request #4295 from michaelpoluektov/refactor-tools
refactor: Refactor OpenAI API to use helper functions, silence LSP/linter warnings
This commit is contained in:
		
						commit
						91851114e4
					
				@ -1,6 +1,6 @@
 | 
			
		||||
from fastapi import FastAPI, Request, Response, HTTPException, Depends
 | 
			
		||||
from fastapi import FastAPI, Request, HTTPException, Depends
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
 | 
			
		||||
from fastapi.responses import StreamingResponse, FileResponse
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import aiohttp
 | 
			
		||||
@ -12,16 +12,12 @@ from pydantic import BaseModel
 | 
			
		||||
from starlette.background import BackgroundTask
 | 
			
		||||
 | 
			
		||||
from apps.webui.models.models import Models
 | 
			
		||||
from apps.webui.models.users import Users
 | 
			
		||||
from constants import ERROR_MESSAGES
 | 
			
		||||
from utils.utils import (
 | 
			
		||||
    decode_token,
 | 
			
		||||
    get_verified_user,
 | 
			
		||||
    get_verified_user,
 | 
			
		||||
    get_admin_user,
 | 
			
		||||
)
 | 
			
		||||
from utils.task import prompt_template
 | 
			
		||||
from utils.misc import add_or_update_system_message
 | 
			
		||||
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
 | 
			
		||||
 | 
			
		||||
from config import (
 | 
			
		||||
    SRC_LOG_LEVELS,
 | 
			
		||||
@ -34,7 +30,7 @@ from config import (
 | 
			
		||||
    MODEL_FILTER_LIST,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
from typing import List, Optional, Literal, overload
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import hashlib
 | 
			
		||||
@ -69,8 +65,6 @@ app.state.MODELS = {}
 | 
			
		||||
async def check_url(request: Request, call_next):
 | 
			
		||||
    if len(app.state.MODELS) == 0:
 | 
			
		||||
        await get_all_models()
 | 
			
		||||
    else:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    response = await call_next(request)
 | 
			
		||||
    return response
 | 
			
		||||
@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
                    res = r.json()
 | 
			
		||||
                    if "error" in res:
 | 
			
		||||
                        error_detail = f"External: {res['error']}"
 | 
			
		||||
                except:
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    error_detail = f"External: {e}"
 | 
			
		||||
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
@ -234,64 +228,68 @@ def merge_models_lists(model_lists):
 | 
			
		||||
    return merged_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_all_models(raw: bool = False):
 | 
			
		||||
def is_openai_api_disabled():
 | 
			
		||||
    api_keys = app.state.config.OPENAI_API_KEYS
 | 
			
		||||
    no_keys = len(api_keys) == 1 and api_keys[0] == ""
 | 
			
		||||
    return no_keys or not app.state.config.ENABLE_OPENAI_API
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_all_models_raw() -> list:
 | 
			
		||||
    if is_openai_api_disabled():
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    # Check if API KEYS length is same than API URLS length
 | 
			
		||||
    num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
 | 
			
		||||
    num_keys = len(app.state.config.OPENAI_API_KEYS)
 | 
			
		||||
 | 
			
		||||
    if num_keys != num_urls:
 | 
			
		||||
        # if there are more keys than urls, remove the extra keys
 | 
			
		||||
        if num_keys > num_urls:
 | 
			
		||||
            new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
 | 
			
		||||
            app.state.config.OPENAI_API_KEYS = new_keys
 | 
			
		||||
        # if there are more urls than keys, add empty keys
 | 
			
		||||
        else:
 | 
			
		||||
            app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
 | 
			
		||||
 | 
			
		||||
    tasks = [
 | 
			
		||||
        fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
 | 
			
		||||
        for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    responses = await asyncio.gather(*tasks)
 | 
			
		||||
    log.debug(f"get_all_models:responses() {responses}")
 | 
			
		||||
 | 
			
		||||
    return responses
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
async def get_all_models(raw: Literal[True]) -> list: ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_all_models(raw=False) -> dict[str, list] | list:
 | 
			
		||||
    log.info("get_all_models()")
 | 
			
		||||
    if is_openai_api_disabled():
 | 
			
		||||
        return [] if raw else {"data": []}
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
        len(app.state.config.OPENAI_API_KEYS) == 1
 | 
			
		||||
        and app.state.config.OPENAI_API_KEYS[0] == ""
 | 
			
		||||
    ) or not app.state.config.ENABLE_OPENAI_API:
 | 
			
		||||
        models = {"data": []}
 | 
			
		||||
    else:
 | 
			
		||||
        # Check if API KEYS length is same than API URLS length
 | 
			
		||||
        if len(app.state.config.OPENAI_API_KEYS) != len(
 | 
			
		||||
            app.state.config.OPENAI_API_BASE_URLS
 | 
			
		||||
        ):
 | 
			
		||||
            # if there are more keys than urls, remove the extra keys
 | 
			
		||||
            if len(app.state.config.OPENAI_API_KEYS) > len(
 | 
			
		||||
                app.state.config.OPENAI_API_BASE_URLS
 | 
			
		||||
            ):
 | 
			
		||||
                app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
 | 
			
		||||
                    : len(app.state.config.OPENAI_API_BASE_URLS)
 | 
			
		||||
                ]
 | 
			
		||||
            # if there are more urls than keys, add empty keys
 | 
			
		||||
            else:
 | 
			
		||||
                app.state.config.OPENAI_API_KEYS += [
 | 
			
		||||
                    ""
 | 
			
		||||
                    for _ in range(
 | 
			
		||||
                        len(app.state.config.OPENAI_API_BASE_URLS)
 | 
			
		||||
                        - len(app.state.config.OPENAI_API_KEYS)
 | 
			
		||||
                    )
 | 
			
		||||
                ]
 | 
			
		||||
    responses = await get_all_models_raw()
 | 
			
		||||
    if raw:
 | 
			
		||||
        return responses
 | 
			
		||||
 | 
			
		||||
        tasks = [
 | 
			
		||||
            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
 | 
			
		||||
            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
 | 
			
		||||
        ]
 | 
			
		||||
    def extract_data(response):
 | 
			
		||||
        if response and "data" in response:
 | 
			
		||||
            return response["data"]
 | 
			
		||||
        if isinstance(response, list):
 | 
			
		||||
            return response
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
        responses = await asyncio.gather(*tasks)
 | 
			
		||||
        log.debug(f"get_all_models:responses() {responses}")
 | 
			
		||||
    models = {"data": merge_models_lists(map(extract_data, responses))}
 | 
			
		||||
 | 
			
		||||
        if raw:
 | 
			
		||||
            return responses
 | 
			
		||||
 | 
			
		||||
        models = {
 | 
			
		||||
            "data": merge_models_lists(
 | 
			
		||||
                list(
 | 
			
		||||
                    map(
 | 
			
		||||
                        lambda response: (
 | 
			
		||||
                            response["data"]
 | 
			
		||||
                            if (response and "data" in response)
 | 
			
		||||
                            else (response if isinstance(response, list) else None)
 | 
			
		||||
                        ),
 | 
			
		||||
                        responses,
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        log.debug(f"models: {models}")
 | 
			
		||||
        app.state.MODELS = {model["id"]: model for model in models["data"]}
 | 
			
		||||
    log.debug(f"models: {models}")
 | 
			
		||||
    app.state.MODELS = {model["id"]: model for model in models["data"]}
 | 
			
		||||
 | 
			
		||||
    return models
 | 
			
		||||
 | 
			
		||||
@ -299,7 +297,7 @@ async def get_all_models(raw: bool = False):
 | 
			
		||||
@app.get("/models")
 | 
			
		||||
@app.get("/models/{url_idx}")
 | 
			
		||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
 | 
			
		||||
    if url_idx == None:
 | 
			
		||||
    if url_idx is None:
 | 
			
		||||
        models = await get_all_models()
 | 
			
		||||
        if app.state.config.ENABLE_MODEL_FILTER:
 | 
			
		||||
            if user.role == "user":
 | 
			
		||||
@ -340,7 +338,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
 | 
			
		||||
                    res = r.json()
 | 
			
		||||
                    if "error" in res:
 | 
			
		||||
                        error_detail = f"External: {res['error']}"
 | 
			
		||||
                except:
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    error_detail = f"External: {e}"
 | 
			
		||||
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
@ -358,8 +356,7 @@ async def generate_chat_completion(
 | 
			
		||||
):
 | 
			
		||||
    idx = 0
 | 
			
		||||
    payload = {**form_data}
 | 
			
		||||
    if "metadata" in payload:
 | 
			
		||||
        del payload["metadata"]
 | 
			
		||||
    payload.pop("metadata")
 | 
			
		||||
 | 
			
		||||
    model_id = form_data.get("model")
 | 
			
		||||
    model_info = Models.get_model_by_id(model_id)
 | 
			
		||||
@ -368,70 +365,9 @@ async def generate_chat_completion(
 | 
			
		||||
        if model_info.base_model_id:
 | 
			
		||||
            payload["model"] = model_info.base_model_id
 | 
			
		||||
 | 
			
		||||
        model_info.params = model_info.params.model_dump()
 | 
			
		||||
 | 
			
		||||
        if model_info.params:
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("temperature", None) is not None
 | 
			
		||||
                and payload.get("temperature") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["temperature"] = float(model_info.params.get("temperature"))
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("top_p", None) and payload.get("top_p") is None:
 | 
			
		||||
                payload["top_p"] = int(model_info.params.get("top_p", None))
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("max_tokens", None)
 | 
			
		||||
                and payload.get("max_tokens") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("frequency_penalty", None)
 | 
			
		||||
                and payload.get("frequency_penalty") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["frequency_penalty"] = int(
 | 
			
		||||
                    model_info.params.get("frequency_penalty", None)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("seed", None) is not None
 | 
			
		||||
                and payload.get("seed") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["seed"] = model_info.params.get("seed", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("stop", None) and payload.get("stop") is None:
 | 
			
		||||
                payload["stop"] = (
 | 
			
		||||
                    [
 | 
			
		||||
                        bytes(stop, "utf-8").decode("unicode_escape")
 | 
			
		||||
                        for stop in model_info.params["stop"]
 | 
			
		||||
                    ]
 | 
			
		||||
                    if model_info.params.get("stop", None)
 | 
			
		||||
                    else None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        system = model_info.params.get("system", None)
 | 
			
		||||
        if system:
 | 
			
		||||
            system = prompt_template(
 | 
			
		||||
                system,
 | 
			
		||||
                **(
 | 
			
		||||
                    {
 | 
			
		||||
                        "user_name": user.name,
 | 
			
		||||
                        "user_location": (
 | 
			
		||||
                            user.info.get("location") if user.info else None
 | 
			
		||||
                        ),
 | 
			
		||||
                    }
 | 
			
		||||
                    if user
 | 
			
		||||
                    else {}
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
            if payload.get("messages"):
 | 
			
		||||
                payload["messages"] = add_or_update_system_message(
 | 
			
		||||
                    system, payload["messages"]
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        pass
 | 
			
		||||
        params = model_info.params.model_dump()
 | 
			
		||||
        payload = apply_model_params_to_body(params, payload)
 | 
			
		||||
        payload = apply_model_system_prompt_to_body(params, payload, user)
 | 
			
		||||
 | 
			
		||||
    model = app.state.MODELS[payload.get("model")]
 | 
			
		||||
    idx = model["urlIdx"]
 | 
			
		||||
@ -444,13 +380,6 @@ async def generate_chat_completion(
 | 
			
		||||
            "role": user.role,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
 | 
			
		||||
    # This is a workaround until OpenAI fixes the issue with this model
 | 
			
		||||
    if payload.get("model") == "gpt-4-vision-preview":
 | 
			
		||||
        if "max_tokens" not in payload:
 | 
			
		||||
            payload["max_tokens"] = 4000
 | 
			
		||||
        log.debug("Modified payload:", payload)
 | 
			
		||||
 | 
			
		||||
    # Convert the modified body back to JSON
 | 
			
		||||
    payload = json.dumps(payload)
 | 
			
		||||
 | 
			
		||||
@ -506,7 +435,7 @@ async def generate_chat_completion(
 | 
			
		||||
                print(res)
 | 
			
		||||
                if "error" in res:
 | 
			
		||||
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
 | 
			
		||||
            except:
 | 
			
		||||
            except Exception:
 | 
			
		||||
                error_detail = f"External: {e}"
 | 
			
		||||
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
 | 
			
		||||
    finally:
 | 
			
		||||
@ -569,7 +498,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
                print(res)
 | 
			
		||||
                if "error" in res:
 | 
			
		||||
                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
 | 
			
		||||
            except:
 | 
			
		||||
            except Exception:
 | 
			
		||||
                error_detail = f"External: {e}"
 | 
			
		||||
        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
 | 
			
		||||
    finally:
 | 
			
		||||
 | 
			
		||||
@ -44,23 +44,26 @@ async def user_join(sid, data):
 | 
			
		||||
    print("user-join", sid, data)
 | 
			
		||||
 | 
			
		||||
    auth = data["auth"] if "auth" in data else None
 | 
			
		||||
    if not auth or "token" not in auth:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if auth and "token" in auth:
 | 
			
		||||
        data = decode_token(auth["token"])
 | 
			
		||||
    data = decode_token(auth["token"])
 | 
			
		||||
    if data is None or "id" not in data:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
        if data is not None and "id" in data:
 | 
			
		||||
            user = Users.get_user_by_id(data["id"])
 | 
			
		||||
    user = Users.get_user_by_id(data["id"])
 | 
			
		||||
    if not user:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
        if user:
 | 
			
		||||
            SESSION_POOL[sid] = user.id
 | 
			
		||||
            if user.id in USER_POOL:
 | 
			
		||||
                USER_POOL[user.id].append(sid)
 | 
			
		||||
            else:
 | 
			
		||||
                USER_POOL[user.id] = [sid]
 | 
			
		||||
    SESSION_POOL[sid] = user.id
 | 
			
		||||
    if user.id in USER_POOL:
 | 
			
		||||
        USER_POOL[user.id].append(sid)
 | 
			
		||||
    else:
 | 
			
		||||
        USER_POOL[user.id] = [sid]
 | 
			
		||||
 | 
			
		||||
            print(f"user {user.name}({user.id}) connected with session ID {sid}")
 | 
			
		||||
    print(f"user {user.name}({user.id}) connected with session ID {sid}")
 | 
			
		||||
 | 
			
		||||
            await sio.emit("user-count", {"count": len(set(USER_POOL))})
 | 
			
		||||
    await sio.emit("user-count", {"count": len(set(USER_POOL))})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@sio.on("user-count")
 | 
			
		||||
 | 
			
		||||
@ -22,9 +22,9 @@ from apps.webui.utils import load_function_module_by_id
 | 
			
		||||
from utils.misc import (
 | 
			
		||||
    openai_chat_chunk_message_template,
 | 
			
		||||
    openai_chat_completion_message_template,
 | 
			
		||||
    add_or_update_system_message,
 | 
			
		||||
    apply_model_params_to_body,
 | 
			
		||||
    apply_model_system_prompt_to_body,
 | 
			
		||||
)
 | 
			
		||||
from utils.task import prompt_template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from config import (
 | 
			
		||||
@ -269,47 +269,6 @@ def get_function_params(function_module, form_data, user, extra_params={}):
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# inplace function: form_data is modified
 | 
			
		||||
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
 | 
			
		||||
    if not params:
 | 
			
		||||
        return form_data
 | 
			
		||||
 | 
			
		||||
    mappings = {
 | 
			
		||||
        "temperature": float,
 | 
			
		||||
        "top_p": int,
 | 
			
		||||
        "max_tokens": int,
 | 
			
		||||
        "frequency_penalty": int,
 | 
			
		||||
        "seed": lambda x: x,
 | 
			
		||||
        "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for key, cast_func in mappings.items():
 | 
			
		||||
        if (value := params.get(key)) is not None:
 | 
			
		||||
            form_data[key] = cast_func(value)
 | 
			
		||||
 | 
			
		||||
    return form_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# inplace function: form_data is modified
 | 
			
		||||
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
 | 
			
		||||
    system = params.get("system", None)
 | 
			
		||||
    if not system:
 | 
			
		||||
        return form_data
 | 
			
		||||
 | 
			
		||||
    if user:
 | 
			
		||||
        template_params = {
 | 
			
		||||
            "user_name": user.name,
 | 
			
		||||
            "user_location": user.info.get("location") if user.info else None,
 | 
			
		||||
        }
 | 
			
		||||
    else:
 | 
			
		||||
        template_params = {}
 | 
			
		||||
    system = prompt_template(system, **template_params)
 | 
			
		||||
    form_data["messages"] = add_or_update_system_message(
 | 
			
		||||
        system, form_data.get("messages", [])
 | 
			
		||||
    )
 | 
			
		||||
    return form_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def generate_function_chat_completion(form_data, user):
 | 
			
		||||
    model_id = form_data.get("model")
 | 
			
		||||
    model_info = Models.get_model_by_id(model_id)
 | 
			
		||||
 | 
			
		||||
@ -1,12 +1,8 @@
 | 
			
		||||
from fastapi import Depends, FastAPI, HTTPException, status, Request
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from typing import List, Union, Optional
 | 
			
		||||
from fastapi import Depends, HTTPException, status, Request
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
 | 
			
		||||
from fastapi import APIRouter
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from apps.webui.models.users import Users
 | 
			
		||||
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
 | 
			
		||||
from apps.webui.utils import load_toolkit_module_by_id
 | 
			
		||||
 | 
			
		||||
@ -14,7 +10,6 @@ from utils.utils import get_admin_user, get_verified_user
 | 
			
		||||
from utils.tools import get_tools_specs
 | 
			
		||||
from constants import ERROR_MESSAGES
 | 
			
		||||
 | 
			
		||||
from importlib import util
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
@ -69,7 +64,7 @@ async def create_new_toolkit(
 | 
			
		||||
    form_data.id = form_data.id.lower()
 | 
			
		||||
 | 
			
		||||
    toolkit = Tools.get_tool_by_id(form_data.id)
 | 
			
		||||
    if toolkit == None:
 | 
			
		||||
    if toolkit is None:
 | 
			
		||||
        toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
 | 
			
		||||
        try:
 | 
			
		||||
            with open(toolkit_path, "w") as tool_file:
 | 
			
		||||
@ -98,7 +93,7 @@ async def create_new_toolkit(
 | 
			
		||||
            print(e)
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
                detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
                detail=ERROR_MESSAGES.DEFAULT(str(e)),
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -170,7 +165,7 @@ async def update_toolkit_by_id(
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
            detail=ERROR_MESSAGES.DEFAULT(str(e)),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -210,7 +205,7 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
                detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
                detail=ERROR_MESSAGES.DEFAULT(str(e)),
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -233,7 +228,7 @@ async def get_toolkit_valves_spec_by_id(
 | 
			
		||||
        if id in request.app.state.TOOLS:
 | 
			
		||||
            toolkit_module = request.app.state.TOOLS[id]
 | 
			
		||||
        else:
 | 
			
		||||
            toolkit_module, frontmatter = load_toolkit_module_by_id(id)
 | 
			
		||||
            toolkit_module, _ = load_toolkit_module_by_id(id)
 | 
			
		||||
            request.app.state.TOOLS[id] = toolkit_module
 | 
			
		||||
 | 
			
		||||
        if hasattr(toolkit_module, "Valves"):
 | 
			
		||||
@ -261,7 +256,7 @@ async def update_toolkit_valves_by_id(
 | 
			
		||||
        if id in request.app.state.TOOLS:
 | 
			
		||||
            toolkit_module = request.app.state.TOOLS[id]
 | 
			
		||||
        else:
 | 
			
		||||
            toolkit_module, frontmatter = load_toolkit_module_by_id(id)
 | 
			
		||||
            toolkit_module, _ = load_toolkit_module_by_id(id)
 | 
			
		||||
            request.app.state.TOOLS[id] = toolkit_module
 | 
			
		||||
 | 
			
		||||
        if hasattr(toolkit_module, "Valves"):
 | 
			
		||||
@ -276,7 +271,7 @@ async def update_toolkit_valves_by_id(
 | 
			
		||||
                print(e)
 | 
			
		||||
                raise HTTPException(
 | 
			
		||||
                    status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
                    detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
                    detail=ERROR_MESSAGES.DEFAULT(str(e)),
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
@ -306,7 +301,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
                detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
                detail=ERROR_MESSAGES.DEFAULT(str(e)),
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -324,7 +319,7 @@ async def get_toolkit_user_valves_spec_by_id(
 | 
			
		||||
        if id in request.app.state.TOOLS:
 | 
			
		||||
            toolkit_module = request.app.state.TOOLS[id]
 | 
			
		||||
        else:
 | 
			
		||||
            toolkit_module, frontmatter = load_toolkit_module_by_id(id)
 | 
			
		||||
            toolkit_module, _ = load_toolkit_module_by_id(id)
 | 
			
		||||
            request.app.state.TOOLS[id] = toolkit_module
 | 
			
		||||
 | 
			
		||||
        if hasattr(toolkit_module, "UserValves"):
 | 
			
		||||
@ -348,7 +343,7 @@ async def update_toolkit_user_valves_by_id(
 | 
			
		||||
        if id in request.app.state.TOOLS:
 | 
			
		||||
            toolkit_module = request.app.state.TOOLS[id]
 | 
			
		||||
        else:
 | 
			
		||||
            toolkit_module, frontmatter = load_toolkit_module_by_id(id)
 | 
			
		||||
            toolkit_module, _ = load_toolkit_module_by_id(id)
 | 
			
		||||
            request.app.state.TOOLS[id] = toolkit_module
 | 
			
		||||
 | 
			
		||||
        if hasattr(toolkit_module, "UserValves"):
 | 
			
		||||
@ -365,7 +360,7 @@ async def update_toolkit_user_valves_by_id(
 | 
			
		||||
                print(e)
 | 
			
		||||
                raise HTTPException(
 | 
			
		||||
                    status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
                    detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
                    detail=ERROR_MESSAGES.DEFAULT(str(e)),
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
 | 
			
		||||
@ -957,7 +957,7 @@ async def get_all_models():
 | 
			
		||||
 | 
			
		||||
    custom_models = Models.get_all_models()
 | 
			
		||||
    for custom_model in custom_models:
 | 
			
		||||
        if custom_model.base_model_id == None:
 | 
			
		||||
        if custom_model.base_model_id is None:
 | 
			
		||||
            for model in models:
 | 
			
		||||
                if (
 | 
			
		||||
                    custom_model.id == model["id"]
 | 
			
		||||
@ -1656,13 +1656,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
 | 
			
		||||
 | 
			
		||||
@app.get("/api/pipelines/list")
 | 
			
		||||
async def get_pipelines_list(user=Depends(get_admin_user)):
 | 
			
		||||
    responses = await get_openai_models(raw=True)
 | 
			
		||||
    responses = await get_openai_models(raw = True)
 | 
			
		||||
 | 
			
		||||
    print(responses)
 | 
			
		||||
    urlIdxs = [
 | 
			
		||||
        idx
 | 
			
		||||
        for idx, response in enumerate(responses)
 | 
			
		||||
        if response != None and "pipelines" in response
 | 
			
		||||
        if response is not None and "pipelines" in response
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
@ -1723,7 +1723,7 @@ async def upload_pipeline(
 | 
			
		||||
                res = r.json()
 | 
			
		||||
                if "detail" in res:
 | 
			
		||||
                    detail = res["detail"]
 | 
			
		||||
            except:
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -1769,7 +1769,7 @@ async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user))
 | 
			
		||||
                res = r.json()
 | 
			
		||||
                if "detail" in res:
 | 
			
		||||
                    detail = res["detail"]
 | 
			
		||||
            except:
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -1811,7 +1811,7 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
 | 
			
		||||
                res = r.json()
 | 
			
		||||
                if "detail" in res:
 | 
			
		||||
                    detail = res["detail"]
 | 
			
		||||
            except:
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -1844,7 +1844,7 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
 | 
			
		||||
                res = r.json()
 | 
			
		||||
                if "detail" in res:
 | 
			
		||||
                    detail = res["detail"]
 | 
			
		||||
            except:
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -1859,7 +1859,6 @@ async def get_pipeline_valves(
 | 
			
		||||
    pipeline_id: str,
 | 
			
		||||
    user=Depends(get_admin_user),
 | 
			
		||||
):
 | 
			
		||||
    models = await get_all_models()
 | 
			
		||||
    r = None
 | 
			
		||||
    try:
 | 
			
		||||
        url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
 | 
			
		||||
@ -1898,8 +1897,6 @@ async def get_pipeline_valves_spec(
 | 
			
		||||
    pipeline_id: str,
 | 
			
		||||
    user=Depends(get_admin_user),
 | 
			
		||||
):
 | 
			
		||||
    models = await get_all_models()
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
    try:
 | 
			
		||||
        url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
 | 
			
		||||
@ -1922,7 +1919,7 @@ async def get_pipeline_valves_spec(
 | 
			
		||||
                res = r.json()
 | 
			
		||||
                if "detail" in res:
 | 
			
		||||
                    detail = res["detail"]
 | 
			
		||||
            except:
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -1938,8 +1935,6 @@ async def update_pipeline_valves(
 | 
			
		||||
    form_data: dict,
 | 
			
		||||
    user=Depends(get_admin_user),
 | 
			
		||||
):
 | 
			
		||||
    models = await get_all_models()
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
    try:
 | 
			
		||||
        url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
 | 
			
		||||
@ -1967,7 +1962,7 @@ async def update_pipeline_valves(
 | 
			
		||||
                res = r.json()
 | 
			
		||||
                if "detail" in res:
 | 
			
		||||
                    detail = res["detail"]
 | 
			
		||||
            except:
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -2068,7 +2063,7 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/api/version")
 | 
			
		||||
async def get_app_config():
 | 
			
		||||
async def get_app_version():
 | 
			
		||||
    return {
 | 
			
		||||
        "version": VERSION,
 | 
			
		||||
    }
 | 
			
		||||
@ -2091,7 +2086,7 @@ async def get_app_latest_release_version():
 | 
			
		||||
                latest_version = data["tag_name"]
 | 
			
		||||
 | 
			
		||||
                return {"current": VERSION, "latest": latest_version[1:]}
 | 
			
		||||
    except aiohttp.ClientError as e:
 | 
			
		||||
    except aiohttp.ClientError:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
 | 
			
		||||
            detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,8 @@ from typing import Optional, List, Tuple
 | 
			
		||||
import uuid
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from utils.task import prompt_template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
 | 
			
		||||
    for message in reversed(messages):
 | 
			
		||||
@ -112,6 +114,47 @@ def openai_chat_completion_message_template(model: str, message: str) -> dict:
 | 
			
		||||
    return template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# inplace function: form_data is modified
 | 
			
		||||
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
 | 
			
		||||
    system = params.get("system", None)
 | 
			
		||||
    if not system:
 | 
			
		||||
        return form_data
 | 
			
		||||
 | 
			
		||||
    if user:
 | 
			
		||||
        template_params = {
 | 
			
		||||
            "user_name": user.name,
 | 
			
		||||
            "user_location": user.info.get("location") if user.info else None,
 | 
			
		||||
        }
 | 
			
		||||
    else:
 | 
			
		||||
        template_params = {}
 | 
			
		||||
    system = prompt_template(system, **template_params)
 | 
			
		||||
    form_data["messages"] = add_or_update_system_message(
 | 
			
		||||
        system, form_data.get("messages", [])
 | 
			
		||||
    )
 | 
			
		||||
    return form_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# inplace function: form_data is modified
 | 
			
		||||
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
 | 
			
		||||
    if not params:
 | 
			
		||||
        return form_data
 | 
			
		||||
 | 
			
		||||
    mappings = {
 | 
			
		||||
        "temperature": float,
 | 
			
		||||
        "top_p": int,
 | 
			
		||||
        "max_tokens": int,
 | 
			
		||||
        "frequency_penalty": int,
 | 
			
		||||
        "seed": lambda x: x,
 | 
			
		||||
        "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for key, cast_func in mappings.items():
 | 
			
		||||
        if (value := params.get(key)) is not None:
 | 
			
		||||
            form_data[key] = cast_func(value)
 | 
			
		||||
 | 
			
		||||
    return form_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_gravatar_url(email):
 | 
			
		||||
    # Trim leading and trailing whitespace from
 | 
			
		||||
    # an email address and force all characters
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ from typing import Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prompt_template(
 | 
			
		||||
    template: str, user_name: str = None, user_location: str = None
 | 
			
		||||
    template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
 | 
			
		||||
) -> str:
 | 
			
		||||
    # Get the current date
 | 
			
		||||
    current_date = datetime.now()
 | 
			
		||||
@ -83,7 +83,6 @@ def title_generation_template(
 | 
			
		||||
def search_query_generation_template(
 | 
			
		||||
    template: str, prompt: str, user: Optional[dict] = None
 | 
			
		||||
) -> str:
 | 
			
		||||
 | 
			
		||||
    def replacement_function(match):
 | 
			
		||||
        full_match = match.group(0)
 | 
			
		||||
        start_length = match.group(1)
 | 
			
		||||
 | 
			
		||||
@ -1,15 +1,12 @@
 | 
			
		||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 | 
			
		||||
from fastapi import HTTPException, status, Depends, Request
 | 
			
		||||
from sqlalchemy.orm import Session
 | 
			
		||||
 | 
			
		||||
from apps.webui.models.users import Users
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from typing import Union, Optional
 | 
			
		||||
from constants import ERROR_MESSAGES
 | 
			
		||||
from passlib.context import CryptContext
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
import requests
 | 
			
		||||
import jwt
 | 
			
		||||
import uuid
 | 
			
		||||
import logging
 | 
			
		||||
@ -54,7 +51,7 @@ def decode_token(token: str) -> Optional[dict]:
 | 
			
		||||
    try:
 | 
			
		||||
        decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
 | 
			
		||||
        return decoded
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
    except Exception:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -71,7 +68,7 @@ def get_http_authorization_cred(auth_header: str):
 | 
			
		||||
    try:
 | 
			
		||||
        scheme, credentials = auth_header.split(" ")
 | 
			
		||||
        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
 | 
			
		||||
    except:
 | 
			
		||||
    except Exception:
 | 
			
		||||
        raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -96,7 +93,7 @@ def get_current_user(
 | 
			
		||||
 | 
			
		||||
    # auth by jwt token
 | 
			
		||||
    data = decode_token(token)
 | 
			
		||||
    if data != None and "id" in data:
 | 
			
		||||
    if data is not None and "id" in data:
 | 
			
		||||
        user = Users.get_user_by_id(data["id"])
 | 
			
		||||
        if user is None:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user