refac: backend/main.py

This commit is contained in:
Michael Poluektov 2024-07-09 11:51:43 +01:00
parent f9e3c47d4a
commit e3e02e04e8

View File

@ -1,13 +1,10 @@
import base64
import uuid
import subprocess
from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup
import json
import markdown
import time
import os
import sys
@ -19,14 +16,11 @@ import shutil
import os
import uuid
import inspect
import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
@ -38,7 +32,6 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import sio, app as socket_app
from apps.ollama.main import (
app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion,
)
@ -56,14 +49,14 @@ from apps.webui.main import (
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import Session, SessionLocal
from apps.webui.internal.db import Session
from pydantic import BaseModel
from typing import List, Optional, Iterator, Generator, Union
from typing import List, Optional
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.models import Models
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users
@ -86,14 +79,12 @@ from utils.task import (
from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
parse_duration,
)
from apps.rag.utils import get_rag_context, rag_template
from config import (
CONFIG_DATA,
WEBUI_NAME,
WEBUI_URL,
WEBUI_AUTH,
@ -101,7 +92,6 @@ from config import (
VERSION,
CHANGELOG,
FRONTEND_BUILD_DIR,
UPLOAD_DIR,
CACHE_DIR,
STATIC_DIR,
DEFAULT_LOCALE,
@ -128,9 +118,8 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
BACKEND_DIR,
DATABASE_URL,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook
@ -355,19 +344,24 @@ async def get_function_call_response(
else:
content = response["choices"][0]["message"]["content"]
if content is None:
return None, None, False
# Parse the function response
if content is not None:
print(f"content: {content}")
result = json.loads(content)
print(result)
citation = None
if "name" not in result:
return None, None, False
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
@ -376,13 +370,9 @@ async def get_function_call_response(
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(
toolkit_module, "Valves"
):
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(
**(valves if valves else {})
)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
function = getattr(toolkit_module, result["name"])
function_result = None
@ -391,6 +381,21 @@ async def get_function_call_response(
sig = inspect.signature(function)
params = result["parameters"]
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
@ -403,55 +408,12 @@ async def get_function_call_response(
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(
tool_id, user.id
)
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__messages__" in sig.parameters:
# Call the function with the '__messages__' parameter included
params = {
**params,
"__messages__": messages,
}
if "__files__" in sig.parameters:
# Call the function with the '__files__' parameter included
params = {
**params,
"__files__": files,
}
if "__model__" in sig.parameters:
# Call the function with the '__model__' parameter included
params = {
**params,
"__model__": model,
}
if "__id__" in sig.parameters:
# Call the function with the '__id__' parameter included
params = {
**params,
"__id__": tool_id,
}
if "__event_emitter__" in sig.parameters:
# Call the function with the '__event_emitter__' parameter included
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
# Call the function with the '__event_call__' parameter included
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
@ -484,22 +446,20 @@ async def chat_completion_functions_handler(
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
@ -513,6 +473,19 @@ async def chat_completion_functions_handler(
sig = inspect.signature(inlet)
params = {"body": body}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
@ -533,30 +506,6 @@ async def chat_completion_functions_handler(
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
@ -1220,18 +1169,16 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
@ -1245,6 +1192,19 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
sig = inspect.signature(outlet)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
@ -1265,30 +1225,6 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
@ -1387,7 +1323,6 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
@ -1456,7 +1391,6 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
@ -1513,7 +1447,6 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@ -1583,7 +1516,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context, citation, file_handler = await get_function_call_response(
context, _, _ = await get_function_call_response(
form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
@ -1647,6 +1580,7 @@ async def upload_pipeline(
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
r = None
try:
# Save the uploaded file
with open(file_path, "wb") as buffer:
@ -1670,7 +1604,9 @@ async def upload_pipeline(
print(f"Connection error: {e}")
detail = "Pipeline not found"
status_code = status.HTTP_404_NOT_FOUND
if r is not None:
status_code = r.status_code
try:
res = r.json()
if "detail" in res:
@ -1679,7 +1615,7 @@ async def upload_pipeline(
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
status_code=status_code,
detail=detail,
)
finally:
@ -1778,8 +1714,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
r = None
try:
urlIdx
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]