Merge pull request #3734 from michaelpoluektov/refactor-main

refactor: Refactor backend/main.py
This commit is contained in:
Timothy Jaeryang Baek 2024-07-09 12:48:33 -07:00 committed by GitHub
commit 0444497eea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,10 @@
import base64
import uuid
import subprocess
from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup
import json
import markdown
import time
import os
import sys
@ -19,14 +16,11 @@ import shutil
import os
import uuid
import inspect
import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
@ -38,7 +32,6 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import sio, app as socket_app
from apps.ollama.main import (
app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion,
)
@ -56,14 +49,14 @@ from apps.webui.main import (
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import Session, SessionLocal
from apps.webui.internal.db import Session
from pydantic import BaseModel
from typing import List, Optional, Iterator, Generator, Union
from typing import List, Optional
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.models import Models
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users
@ -86,14 +79,12 @@ from utils.task import (
from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
parse_duration,
)
from apps.rag.utils import get_rag_context, rag_template
from config import (
CONFIG_DATA,
WEBUI_NAME,
WEBUI_URL,
WEBUI_AUTH,
@ -101,7 +92,6 @@ from config import (
VERSION,
CHANGELOG,
FRONTEND_BUILD_DIR,
UPLOAD_DIR,
CACHE_DIR,
STATIC_DIR,
DEFAULT_LOCALE,
@ -128,9 +118,8 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
BACKEND_DIR,
DATABASE_URL,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook
@ -237,7 +226,7 @@ async def get_body_and_model_and_user(request):
model_id = body["model"]
if model_id not in app.state.MODELS:
raise "Model not found"
raise Exception("Model not found")
model = app.state.MODELS[model_id]
user = get_current_user(
@ -300,7 +289,6 @@ async def get_function_call_response(
template,
task_model_id,
user,
model,
__event_emitter__=None,
__event_call__=None,
):
@ -355,121 +343,94 @@ async def get_function_call_response(
else:
content = response["choices"][0]["message"]["content"]
if content is None:
return None, None, False
# Parse the function response
if content is not None:
print(f"content: {content}")
result = json.loads(content)
print(result)
print(f"content: {content}")
result = json.loads(content)
print(result)
citation = None
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
citation = None
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if "name" not in result:
return None, None, False
if hasattr(toolkit_module, "valves") and hasattr(
toolkit_module, "Valves"
):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(
**(valves if valves else {})
)
# Call the function
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
params = result["parameters"]
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
params = result["parameters"]
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(
tool_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__messages__" in sig.parameters:
# Call the function with the '__messages__' parameter included
params = {
**params,
"__messages__": messages,
}
if "__files__" in sig.parameters:
# Call the function with the '__files__' parameter included
params = {
**params,
"__files__": files,
}
if "__model__" in sig.parameters:
# Call the function with the '__model__' parameter included
params = {
**params,
"__model__": model,
}
if "__id__" in sig.parameters:
# Call the function with the '__id__' parameter included
params = {
**params,
"__id__": tool_id,
}
if "__event_emitter__" in sig.parameters:
# Call the function with the '__event_emitter__' parameter included
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
# Call the function with the '__event_call__' parameter included
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
except Exception as e:
print(f"Error: {e}")
@ -484,87 +445,76 @@ async def chat_completion_functions_handler(
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "inlet"):
continue
try:
inlet = function_module.inlet
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": body}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
body = inlet(**params)
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try:
if hasattr(function_module, "inlet"):
inlet = function_module.inlet
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": body}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
raise e
except Exception as e:
print(f"Error: {e}")
raise e
if skip_files:
if "files" in body:
@ -573,9 +523,7 @@ async def chat_completion_functions_handler(
return body, {}
async def chat_completion_tools_handler(
body, model, user, __event_emitter__, __event_call__
):
async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
skip_files = None
contexts = []
@ -596,7 +544,6 @@ async def chat_completion_tools_handler(
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
model=model,
__event_emitter__=__event_emitter__,
__event_call__=__event_call__,
)
@ -723,7 +670,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try:
body, flags = await chat_completion_tools_handler(
body, model, user, __event_emitter__, __event_call__
body, user, __event_emitter__, __event_call__
)
contexts.extend(flags.get("contexts", []))
@ -817,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware)
##################################
def filter_pipeline(payload, user):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
def get_sorted_filters(model_id):
filters = [
model
for model in app.state.MODELS.values()
@ -835,6 +780,13 @@ def filter_pipeline(payload, user):
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
return sorted_filters
def filter_pipeline(payload, user):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id)
model = app.state.MODELS[model_id]
@ -867,19 +819,12 @@ def filter_pipeline(payload, user):
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
except:
pass
res = r.json()
if "detail" in res:
raise Exception(r.status_code, res["detail"])
else:
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "task" in payload:
del payload["task"]
if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
del payload["task"]
return payload
@ -1114,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
)
model = app.state.MODELS[model_id]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
sorted_filters = get_sorted_filters(model_id)
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
@ -1177,21 +1107,21 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else:
pass
async def __event_emitter__(data):
async def __event_emitter__(event_data):
await sio.emit(
"chat-events",
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"data": data,
"data": event_data,
},
to=data["session_id"],
)
async def __event_call__(data):
async def __event_call__(event_data):
response = await sio.call(
"chat-events",
{"chat_id": data["chat_id"], "message_id": data["id"], "data": data},
{"chat_id": data["chat_id"], "message_id": data["id"], "data": event_data},
to=data["session_id"],
)
return response
@ -1220,86 +1150,74 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "outlet"):
continue
try:
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
data = outlet(**params)
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try:
if hasattr(function_module, "outlet"):
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
@ -1375,19 +1293,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
model_id = get_task_model_id(model_id)
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
@ -1444,19 +1352,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
model_id = get_task_model_id(model_id)
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
@ -1501,19 +1399,9 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
model_id = get_task_model_id(model_id)
print(model_id)
model = app.state.MODELS[model_id]
template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@ -1568,22 +1456,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
model_id = get_task_model_id(model_id)
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context, citation, file_handler = await get_function_call_response(
context, _, _ = await get_function_call_response(
form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
@ -1647,6 +1526,7 @@ async def upload_pipeline(
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
r = None
try:
# Save the uploaded file
with open(file_path, "wb") as buffer:
@ -1670,7 +1550,9 @@ async def upload_pipeline(
print(f"Connection error: {e}")
detail = "Pipeline not found"
status_code = status.HTTP_404_NOT_FOUND
if r is not None:
status_code = r.status_code
try:
res = r.json()
if "detail" in res:
@ -1679,7 +1561,7 @@ async def upload_pipeline(
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
status_code=status_code,
detail=detail,
)
finally:
@ -1778,8 +1660,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
r = None
try:
urlIdx
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]