from fastapi import FastAPI, Depends from fastapi.routing import APIRoute from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware from sqlalchemy.orm import Session from apps.webui.routers import ( auths, users, chats, documents, tools, models, prompts, configs, memories, utils, files, functions, ) from apps.webui.models.functions import Functions from apps.webui.models.models import Models from apps.webui.utils import load_function_module_by_id from utils.misc import stream_message_template from utils.task import prompt_template from config import ( WEBUI_BUILD_HASH, SHOW_ADMIN_DETAILS, ADMIN_EMAIL, WEBUI_AUTH, DEFAULT_MODELS, DEFAULT_PROMPT_SUGGESTIONS, DEFAULT_USER_ROLE, ENABLE_SIGNUP, ENABLE_LOGIN_FORM, USER_PERMISSIONS, WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, JWT_EXPIRES_IN, WEBUI_BANNERS, ENABLE_COMMUNITY_SHARING, AppConfig, OAUTH_USERNAME_CLAIM, OAUTH_PICTURE_CLAIM, ) from apps.socket.main import get_event_call, get_event_emitter import inspect import uuid import time import json from typing import Iterator, Generator, AsyncGenerator, Optional from pydantic import BaseModel app = FastAPI() origins = ["*"] app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS app.state.config.ADMIN_EMAIL = ADMIN_EMAIL app.state.config.DEFAULT_MODELS = DEFAULT_MODELS app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(files.router, prefix="/files", tags=["files"]) app.include_router(tools.router, prefix="/tools", tags=["tools"]) app.include_router(functions.router, prefix="/functions", tags=["functions"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) @app.get("/") async def get_status(): return { "status": True, "auth": WEBUI_AUTH, "default_models": app.state.config.DEFAULT_MODELS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, } async def get_pipe_models(): pipes = Functions.get_functions_by_type("pipe", active_only=True) pipe_models = [] for pipe in pipes: # Check if function is already loaded if pipe.id not in app.state.FUNCTIONS: function_module, function_type, frontmatter = load_function_module_by_id( pipe.id ) app.state.FUNCTIONS[pipe.id] = function_module else: function_module = app.state.FUNCTIONS[pipe.id] if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(pipe.id) function_module.valves = function_module.Valves( **(valves if valves else {}) ) # Check if function is a manifold if hasattr(function_module, "type"): if function_module.type == "manifold": manifold_pipes = [] # Check if pipes is a function or a list if callable(function_module.pipes): manifold_pipes = function_module.pipes() else: manifold_pipes = function_module.pipes for p in manifold_pipes: manifold_pipe_id = f'{pipe.id}.{p["id"]}' manifold_pipe_name = p["name"] if hasattr(function_module, "name"): manifold_pipe_name = ( f"{function_module.name}{manifold_pipe_name}" ) pipe_flag = {"type": pipe.type} if hasattr(function_module, "ChatValves"): pipe_flag["valves_spec"] = function_module.ChatValves.schema() pipe_models.append( { "id": manifold_pipe_id, "name": manifold_pipe_name, "object": "model", "created": pipe.created_at, "owned_by": "openai", "pipe": pipe_flag, } ) else: pipe_flag = {"type": "pipe"} if hasattr(function_module, "ChatValves"): pipe_flag["valves_spec"] = function_module.ChatValves.schema() pipe_models.append( { "id": pipe.id, "name": pipe.name, "object": "model", "created": pipe.created_at, "owned_by": "openai", "pipe": pipe_flag, } ) return pipe_models async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) metadata = None if "metadata" in form_data: metadata = form_data["metadata"] del form_data["metadata"] __event_emitter__ = None __event_call__ = None __task__ = None if metadata: if ( metadata.get("session_id") and metadata.get("chat_id") and metadata.get("message_id") ): __event_emitter__ = await get_event_emitter(metadata) __event_call__ = await get_event_call(metadata) if metadata.get("task"): __task__ = metadata.get("task") if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id model_info.params = model_info.params.model_dump() if model_info.params: if model_info.params.get("temperature", None) is not None: form_data["temperature"] = float(model_info.params.get("temperature")) if model_info.params.get("top_p", None): form_data["top_p"] = int(model_info.params.get("top_p", None)) if model_info.params.get("max_tokens", None): form_data["max_tokens"] = int(model_info.params.get("max_tokens", None)) if model_info.params.get("frequency_penalty", None): form_data["frequency_penalty"] = int( model_info.params.get("frequency_penalty", None) ) if model_info.params.get("seed", None): form_data["seed"] = model_info.params.get("seed", None) if model_info.params.get("stop", None): form_data["stop"] = ( [ bytes(stop, "utf-8").decode("unicode_escape") for stop in model_info.params["stop"] ] if model_info.params.get("stop", None) else None ) system = model_info.params.get("system", None) if system: system = prompt_template( system, **( { "user_name": user.name, "user_location": ( user.info.get("location") if user.info else None ), } if user else {} ), ) # Check if the payload already has a system message # If not, add a system message to the payload if form_data.get("messages"): for message in form_data["messages"]: if message.get("role") == "system": message["content"] = system + message["content"] break else: form_data["messages"].insert( 0, { "role": "system", "content": system, }, ) else: pass async def job(): pipe_id = form_data["model"] if "." in pipe_id: pipe_id, sub_pipe_id = pipe_id.split(".", 1) print(pipe_id) # Check if function is already loaded if pipe_id not in app.state.FUNCTIONS: function_module, function_type, frontmatter = load_function_module_by_id( pipe_id ) app.state.FUNCTIONS[pipe_id] = function_module else: function_module = app.state.FUNCTIONS[pipe_id] if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(pipe_id) function_module.valves = function_module.Valves( **(valves if valves else {}) ) pipe = function_module.pipe # Get the signature of the function sig = inspect.signature(pipe) params = {"body": form_data} if "__user__" in sig.parameters: __user__ = { "id": user.id, "email": user.email, "name": user.name, "role": user.role, } try: if hasattr(function_module, "UserValves"): __user__["valves"] = function_module.UserValves( **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) ) except Exception as e: print(e) params = {**params, "__user__": __user__} if "__event_emitter__" in sig.parameters: params = {**params, "__event_emitter__": __event_emitter__} if "__event_call__" in sig.parameters: params = {**params, "__event_call__": __event_call__} if "__task__" in sig.parameters: params = {**params, "__task__": __task__} if form_data["stream"]: async def stream_content(): try: if inspect.iscoroutinefunction(pipe): res = await pipe(**params) else: res = pipe(**params) # Directly return if the response is a StreamingResponse if isinstance(res, StreamingResponse): async for data in res.body_iterator: yield data return if isinstance(res, dict): yield f"data: {json.dumps(res)}\n\n" return except Exception as e: print(f"Error: {e}") yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" return if isinstance(res, str): message = stream_message_template(form_data["model"], res) yield f"data: {json.dumps(message)}\n\n" if isinstance(res, Iterator): for line in res: if isinstance(line, BaseModel): line = line.model_dump_json() line = f"data: {line}" if isinstance(line, dict): line = f"data: {json.dumps(line)}" try: line = line.decode("utf-8") except: pass if line.startswith("data:"): yield f"{line}\n\n" else: line = stream_message_template(form_data["model"], line) yield f"data: {json.dumps(line)}\n\n" if isinstance(res, str) or isinstance(res, Generator): finish_message = { "id": f"{form_data['model']}-{str(uuid.uuid4())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": form_data["model"], "choices": [ { "index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop", } ], } yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: [DONE]" if isinstance(res, AsyncGenerator): async for line in res: if isinstance(line, BaseModel): line = line.model_dump_json() line = f"data: {line}" if isinstance(line, dict): line = f"data: {json.dumps(line)}" try: line = line.decode("utf-8") except: pass if line.startswith("data:"): yield f"{line}\n\n" else: line = stream_message_template(form_data["model"], line) yield f"data: {json.dumps(line)}\n\n" return StreamingResponse(stream_content(), media_type="text/event-stream") else: try: if inspect.iscoroutinefunction(pipe): res = await pipe(**params) else: res = pipe(**params) if isinstance(res, StreamingResponse): return res except Exception as e: print(f"Error: {e}") return {"error": {"detail": str(e)}} if isinstance(res, dict): return res elif isinstance(res, BaseModel): return res.model_dump() else: message = "" if isinstance(res, str): message = res elif isinstance(res, Generator): for stream in res: message = f"{message}{stream}" elif isinstance(res, AsyncGenerator): async for stream in res: message = f"{message}{stream}" return { "id": f"{form_data['model']}-{str(uuid.uuid4())}", "object": "chat.completion", "created": int(time.time()), "model": form_data["model"], "choices": [ { "index": 0, "message": { "role": "assistant", "content": message, }, "logprobs": None, "finish_reason": "stop", } ], } return await job()