open-webui/backend/apps/webui/main.py
2024-07-11 16:24:59 -07:00

460 lines
15 KiB
Python

from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from sqlalchemy.orm import Session
from apps.webui.routers import (
auths,
users,
chats,
documents,
tools,
models,
prompts,
configs,
memories,
utils,
files,
functions,
)
from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template
from utils.task import prompt_template
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
ADMIN_EMAIL,
WEBUI_AUTH,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_USER_ROLE,
ENABLE_SIGNUP,
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
ENABLE_COMMUNITY_SHARING,
AppConfig,
OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM,
)
from apps.socket.main import get_event_call, get_event_emitter
import inspect
import uuid
import time
import json
from typing import Iterator, Generator, Optional
from pydantic import BaseModel
app = FastAPI()
origins = ["*"]
app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.MODELS = {}
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
@app.get("/")
async def get_status():
return {
"status": True,
"auth": WEBUI_AUTH,
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
}
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe.id
)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe.id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
# Check if pipes is a function or a list
if callable(function_module.pipes):
manifold_pipes = function_module.pipes()
else:
manifold_pipes = function_module.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(function_module, "name"):
manifold_pipe_name = (
f"{function_module.name}{manifold_pipe_name}"
)
pipe_flag = {"type": pipe.type}
if hasattr(function_module, "ChatValves"):
pipe_flag["valves_spec"] = function_module.ChatValves.schema()
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
else:
pipe_flag = {"type": "pipe"}
if hasattr(function_module, "ChatValves"):
pipe_flag["valves_spec"] = function_module.ChatValves.schema()
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
return pipe_models
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = None
if "metadata" in form_data:
metadata = form_data["metadata"]
del form_data["metadata"]
__event_emitter__ = None
__event_call__ = None
__task__ = None
if metadata:
if (
metadata.get("session_id")
and metadata.get("chat_id")
and metadata.get("message_id")
):
__event_emitter__ = await get_event_emitter(metadata)
__event_call__ = await get_event_call(metadata)
if metadata.get("task"):
__task__ = metadata.get("task")
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if model_info.params.get("temperature", None) is not None:
form_data["temperature"] = float(model_info.params.get("temperature"))
if model_info.params.get("top_p", None):
form_data["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None):
form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.params.get("frequency_penalty", None):
form_data["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
form_data["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
form_data["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if form_data.get("messages"):
for message in form_data["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
form_data["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
else:
pass
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe_id
)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__event_emitter__" in sig.parameters:
params = {**params, "__event_emitter__": __event_emitter__}
if "__event_call__" in sig.parameters:
params = {**params, "__event_call__": __event_call__}
if "__task__" in sig.parameters:
params = {**params, "__task__": __task__}
if form_data["stream"]:
async def stream_content():
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
async for data in res.body_iterator:
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
return
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
if isinstance(res, StreamingResponse):
return res
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await job()