refac: backend/main.py

This commit is contained in:
Michael Poluektov 2024-07-09 11:51:43 +01:00
parent f9e3c47d4a
commit e3e02e04e8

View File

@ -1,13 +1,10 @@
import base64 import base64
import uuid import uuid
import subprocess
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup
import json import json
import markdown
import time import time
import os import os
import sys import sys
@ -19,14 +16,11 @@ import shutil
import os import os
import uuid import uuid
import inspect import inspect
import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
@ -38,7 +32,6 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import sio, app as socket_app from apps.socket.main import sio, app as socket_app
from apps.ollama.main import ( from apps.ollama.main import (
app as ollama_app, app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models, get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion, generate_openai_chat_completion as generate_ollama_chat_completion,
) )
@ -56,14 +49,14 @@ from apps.webui.main import (
get_pipe_models, get_pipe_models,
generate_function_chat_completion, generate_function_chat_completion,
) )
from apps.webui.internal.db import Session, SessionLocal from apps.webui.internal.db import Session
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional, Iterator, Generator, Union from typing import List, Optional
from apps.webui.models.auths import Auths from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.models.users import Users from apps.webui.models.users import Users
@ -86,14 +79,12 @@ from utils.task import (
from utils.misc import ( from utils.misc import (
get_last_user_message, get_last_user_message,
add_or_update_system_message, add_or_update_system_message,
stream_message_template,
parse_duration, parse_duration,
) )
from apps.rag.utils import get_rag_context, rag_template from apps.rag.utils import get_rag_context, rag_template
from config import ( from config import (
CONFIG_DATA,
WEBUI_NAME, WEBUI_NAME,
WEBUI_URL, WEBUI_URL,
WEBUI_AUTH, WEBUI_AUTH,
@ -101,7 +92,6 @@ from config import (
VERSION, VERSION,
CHANGELOG, CHANGELOG,
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
UPLOAD_DIR,
CACHE_DIR, CACHE_DIR,
STATIC_DIR, STATIC_DIR,
DEFAULT_LOCALE, DEFAULT_LOCALE,
@ -128,9 +118,8 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE, WEBUI_SESSION_COOKIE_SECURE,
AppConfig, AppConfig,
BACKEND_DIR,
DATABASE_URL,
) )
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook from utils.webhook import post_webhook
@ -355,121 +344,94 @@ async def get_function_call_response(
else: else:
content = response["choices"][0]["message"]["content"] content = response["choices"][0]["message"]["content"]
if content is None:
return None, None, False
# Parse the function response # Parse the function response
if content is not None: print(f"content: {content}")
print(f"content: {content}") result = json.loads(content)
result = json.loads(content) print(result)
print(result)
citation = None citation = None
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False if "name" not in result:
# check if toolkit_module has file_handler self variable return None, None, False
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr( # Call the function
toolkit_module, "Valves" if tool_id in webui_app.state.TOOLS:
): toolkit_module = webui_app.state.TOOLS[tool_id]
valves = Tools.get_tool_valves_by_id(tool_id) else:
toolkit_module.valves = toolkit_module.Valves( toolkit_module, _ = load_toolkit_module_by_id(tool_id)
**(valves if valves else {}) webui_app.state.TOOLS[tool_id] = toolkit_module
)
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
params = result["parameters"]
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
function = getattr(toolkit_module, result["name"])
function_result = None
try: try:
# Get the signature of the function if hasattr(toolkit_module, "UserValves"):
sig = inspect.signature(function) __user__["valves"] = toolkit_module.UserValves(
params = result["parameters"] **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(
tool_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__messages__" in sig.parameters:
# Call the function with the '__messages__' parameter included
params = {
**params,
"__messages__": messages,
}
if "__files__" in sig.parameters:
# Call the function with the '__files__' parameter included
params = {
**params,
"__files__": files,
}
if "__model__" in sig.parameters:
# Call the function with the '__model__' parameter included
params = {
**params,
"__model__": model,
}
if "__id__" in sig.parameters:
# Call the function with the '__id__' parameter included
params = {
**params,
"__id__": tool_id,
}
if "__event_emitter__" in sig.parameters:
# Call the function with the '__event_emitter__' parameter included
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
# Call the function with the '__event_call__' parameter included
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e: except Exception as e:
print(e) print(e)
# Add the function result to the system prompt params = {**params, "__user__": __user__}
if function_result is not None:
return function_result, citation, file_handler if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
@ -484,87 +446,74 @@ async def chat_completion_functions_handler(
filter_ids = get_filter_function_ids(model) filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids: for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
if filter: if not filter:
if filter_id in webui_app.state.FUNCTIONS: continue
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable if filter_id in webui_app.state.FUNCTIONS:
if hasattr(function_module, "file_handler"): function_module = webui_app.state.FUNCTIONS[filter_id]
skip_files = function_module.file_handler else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr( # Check if the function has a file_handler variable
function_module, "Valves" if hasattr(function_module, "file_handler"):
): skip_files = function_module.file_handler
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try: if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
if hasattr(function_module, "inlet"): valves = Functions.get_function_valves_by_id(filter_id)
inlet = function_module.inlet function_module.valves = function_module.Valves(
**(valves if valves else {})
)
# Get the signature of the function try:
sig = inspect.signature(inlet) if hasattr(function_module, "inlet"):
params = {"body": body} inlet = function_module.inlet
if "__user__" in sig.parameters: # Get the signature of the function
__user__ = { sig = inspect.signature(inlet)
"id": user.id, params = {"body": body}
"email": user.email,
"name": user.name,
"role": user.role,
}
try: # Extra parameters to be passed to the function
if hasattr(function_module, "UserValves"): extra_params = {
__user__["valves"] = function_module.UserValves( "__model__": model,
**Functions.get_user_valves_by_id_and_user_id( "__id__": filter_id,
filter_id, user.id "__event_emitter__": __event_emitter__,
) "__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
) )
except Exception as e: )
print(e) except Exception as e:
print(e)
params = {**params, "__user__": __user__} params = {**params, "__user__": __user__}
if "__id__" in sig.parameters: if inspect.iscoroutinefunction(inlet):
params = { body = await inlet(**params)
**params, else:
"__id__": filter_id, body = inlet(**params)
}
if "__model__" in sig.parameters: except Exception as e:
params = { print(f"Error: {e}")
**params, raise e
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
raise e
if skip_files: if skip_files:
if "files" in body: if "files" in body:
@ -1220,86 +1169,73 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
for filter_id in filter_ids: for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
if filter: if not filter:
if filter_id in webui_app.state.FUNCTIONS: continue
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr( if filter_id in webui_app.state.FUNCTIONS:
function_module, "Valves" function_module = webui_app.state.FUNCTIONS[filter_id]
): else:
valves = Functions.get_function_valves_by_id(filter_id) function_module, _, _ = load_function_module_by_id(filter_id)
function_module.valves = function_module.Valves( webui_app.state.FUNCTIONS[filter_id] = function_module
**(valves if valves else {})
)
try: if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
if hasattr(function_module, "outlet"): valves = Functions.get_function_valves_by_id(filter_id)
outlet = function_module.outlet function_module.valves = function_module.Valves(
**(valves if valves else {})
)
# Get the signature of the function try:
sig = inspect.signature(outlet) if hasattr(function_module, "outlet"):
params = {"body": data} outlet = function_module.outlet
if "__user__" in sig.parameters: # Get the signature of the function
__user__ = { sig = inspect.signature(outlet)
"id": user.id, params = {"body": data}
"email": user.email,
"name": user.name,
"role": user.role,
}
try: # Extra parameters to be passed to the function
if hasattr(function_module, "UserValves"): extra_params = {
__user__["valves"] = function_module.UserValves( "__model__": model,
**Functions.get_user_valves_by_id_and_user_id( "__id__": filter_id,
filter_id, user.id "__event_emitter__": __event_emitter__,
) "__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
) )
except Exception as e: )
print(e) except Exception as e:
print(e)
params = {**params, "__user__": __user__} params = {**params, "__user__": __user__}
if "__id__" in sig.parameters: if inspect.iscoroutinefunction(outlet):
params = { data = await outlet(**params)
**params, else:
"__id__": filter_id, data = outlet(**params)
}
if "__model__" in sig.parameters: except Exception as e:
params = { print(f"Error: {e}")
**params, return JSONResponse(
"__model__": model, status_code=status.HTTP_400_BAD_REQUEST,
} content={"detail": str(e)},
)
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data return data
@ -1387,7 +1323,6 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
model_id = task_model_id model_id = task_model_id
print(model_id) print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
@ -1456,7 +1391,6 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
model_id = task_model_id model_id = task_model_id
print(model_id) print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
@ -1513,7 +1447,6 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
model_id = task_model_id model_id = task_model_id
print(model_id) print(model_id)
model = app.state.MODELS[model_id]
template = ''' template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@ -1583,7 +1516,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try: try:
context, citation, file_handler = await get_function_call_response( context, _, _ = await get_function_call_response(
form_data["messages"], form_data["messages"],
form_data.get("files", []), form_data.get("files", []),
form_data["tool_id"], form_data["tool_id"],
@ -1647,6 +1580,7 @@ async def upload_pipeline(
os.makedirs(upload_folder, exist_ok=True) os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename) file_path = os.path.join(upload_folder, file.filename)
r = None
try: try:
# Save the uploaded file # Save the uploaded file
with open(file_path, "wb") as buffer: with open(file_path, "wb") as buffer:
@ -1670,7 +1604,9 @@ async def upload_pipeline(
print(f"Connection error: {e}") print(f"Connection error: {e}")
detail = "Pipeline not found" detail = "Pipeline not found"
status_code = status.HTTP_404_NOT_FOUND
if r is not None: if r is not None:
status_code = r.status_code
try: try:
res = r.json() res = r.json()
if "detail" in res: if "detail" in res:
@ -1679,7 +1615,7 @@ async def upload_pipeline(
pass pass
raise HTTPException( raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), status_code=status_code,
detail=detail, detail=detail,
) )
finally: finally:
@ -1778,8 +1714,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
r = None r = None
try: try:
urlIdx
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]