Merge branch 'origin/dev' into sidebar-pagination [skip ci]

This commit is contained in:
Aryan Kothari
2024-08-03 09:56:23 -04:00
80 changed files with 3065 additions and 1802 deletions

View File

@@ -1,9 +1,6 @@
from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from sqlalchemy.orm import Session
from apps.webui.routers import (
auths,
users,
@@ -22,12 +19,15 @@ from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template
from utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
add_or_update_system_message,
)
from utils.task import prompt_template
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
ADMIN_EMAIL,
WEBUI_AUTH,
@@ -51,11 +51,9 @@ from config import (
from apps.socket.main import get_event_call, get_event_emitter
import inspect
import uuid
import time
import json
from typing import Iterator, Generator, AsyncGenerator, Optional
from typing import Iterator, Generator, AsyncGenerator
from pydantic import BaseModel
app = FastAPI()
@@ -127,60 +125,58 @@ async def get_status():
}
def get_function_module(pipe_id: str):
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe.id
)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe.id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
function_module = get_function_module(pipe.id)
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
if hasattr(function_module, "pipes"):
manifold_pipes = []
# Check if pipes is a function or a list
if callable(function_module.pipes):
manifold_pipes = function_module.pipes()
else:
manifold_pipes = function_module.pipes
# Check if pipes is a function or a list
if callable(function_module.pipes):
manifold_pipes = function_module.pipes()
else:
manifold_pipes = function_module.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(function_module, "name"):
manifold_pipe_name = (
f"{function_module.name}{manifold_pipe_name}"
)
if hasattr(function_module, "name"):
manifold_pipe_name = f"{function_module.name}{manifold_pipe_name}"
pipe_flag = {"type": pipe.type}
if hasattr(function_module, "ChatValves"):
pipe_flag["valves_spec"] = function_module.ChatValves.schema()
pipe_flag = {"type": pipe.type}
if hasattr(function_module, "ChatValves"):
pipe_flag["valves_spec"] = function_module.ChatValves.schema()
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
else:
pipe_flag = {"type": "pipe"}
if hasattr(function_module, "ChatValves"):
@@ -200,284 +196,211 @@ async def get_pipe_models():
return pipe_models
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
else:
line = openai_chat_chunk_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
print(pipe_id)
return pipe_id
def get_function_params(function_module, form_data, user, extra_params={}):
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(function_module.pipe)
params = {"body": form_data}
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params["__user__"] = __user__
return params
# inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
if not params:
return form_data
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = None
if "metadata" in form_data:
metadata = form_data["metadata"]
del form_data["metadata"]
metadata = form_data.pop("metadata", None)
__event_emitter__ = None
__event_call__ = None
__task__ = None
if metadata:
if (
metadata.get("session_id")
and metadata.get("chat_id")
and metadata.get("message_id")
):
__event_emitter__ = await get_event_emitter(metadata)
__event_call__ = await get_event_call(metadata)
if metadata.get("task"):
__task__ = metadata.get("task")
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
params = model_info.params.model_dump()
form_data = apply_model_params_to_body(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user)
if model_info.params:
if model_info.params.get("temperature", None) is not None:
form_data["temperature"] = float(model_info.params.get("temperature"))
pipe_id = get_pipe_id(form_data)
function_module = get_function_module(pipe_id)
if model_info.params.get("top_p", None):
form_data["top_p"] = int(model_info.params.get("top_p", None))
pipe = function_module.pipe
params = get_function_params(
function_module,
form_data,
user,
{
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
},
)
if model_info.params.get("max_tokens", None):
form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.params.get("frequency_penalty", None):
form_data["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
form_data["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
form_data["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if form_data.get("messages"):
for message in form_data["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
form_data["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
else:
pass
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe_id
)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
if form_data["stream"]:
async def stream_content():
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
res = await execute_pipe(pipe, params)
params = {**params, "__user__": __user__}
if "__event_emitter__" in sig.parameters:
params = {**params, "__event_emitter__": __event_emitter__}
if "__event_call__" in sig.parameters:
params = {**params, "__event_call__": __event_call__}
if "__task__" in sig.parameters:
params = {**params, "__task__": __task__}
if form_data["stream"]:
async def stream_content():
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
async for data in res.body_iterator:
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
return
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
async for data in res.body_iterator:
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
return
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
if isinstance(res, AsyncGenerator):
async for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
if isinstance(res, StreamingResponse):
return res
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
elif isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
elif isinstance(res, AsyncGenerator):
async for stream in res:
message = f"{message}{stream}"
if isinstance(res, str):
message = openai_chat_chunk_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
if isinstance(res, Iterator):
for line in res:
yield process_line(form_data, line)
return await job()
if isinstance(res, AsyncGenerator):
async for line in res:
yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator):
finish_message = openai_chat_chunk_message_template(
form_data["model"], ""
)
finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res
if isinstance(res, BaseModel):
return res.model_dump()
message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)

View File

@@ -1,13 +1,11 @@
import json
import logging
from typing import Optional
from typing import Optional, List
from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text
from sqlalchemy import Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS
import time
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class ModelsTable:
def insert_new_model(
self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]:
@@ -126,9 +123,7 @@ class ModelsTable:
}
)
try:
with get_db() as db:
result = Model(**model.model_dump())
db.add(result)
db.commit()
@@ -144,13 +139,11 @@ class ModelsTable:
def get_all_models(self) -> List[ModelModel]:
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with get_db() as db:
model = db.get(Model, id)
return ModelModel.model_validate(model)
except:
@@ -178,7 +171,6 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Model).filter_by(id=id).delete()
db.commit()

View File

@@ -1,3 +1,6 @@
from pathlib import Path
import site
from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status
from starlette.responses import StreamingResponse, FileResponse
@@ -64,8 +67,18 @@ async def download_chat_as_pdf(
pdf = FPDF()
pdf.add_page()
STATIC_DIR = "./static"
FONTS_DIR = f"{STATIC_DIR}/fonts"
# When running in docker, workdir is /app/backend, so fonts is in /app/backend/static/fonts
FONTS_DIR = Path("./static/fonts")
# Non Docker Installation
# When running using `pip install` the static directory is in the site packages.
if not FONTS_DIR.exists():
FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts"
# When running using `pip install -e .` the static directory is in the site packages.
# This path only works if `open-webui serve` is run from the root of this project.
if not FONTS_DIR.exists():
FONTS_DIR = Path("./backend/static/fonts")
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")