refac: Refactor functions

This commit is contained in:
Michael Poluektov 2024-07-31 13:35:02 +01:00
parent 9d58bb1c66
commit 3978efd710
4 changed files with 215 additions and 292 deletions

View File

@ -52,7 +52,6 @@ async def user_join(sid, data):
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user: if user:
SESSION_POOL[sid] = user.id SESSION_POOL[sid] = user.id
if user.id in USER_POOL: if user.id in USER_POOL:
USER_POOL[user.id].append(sid) USER_POOL[user.id].append(sid)
@ -80,7 +79,6 @@ def get_models_in_use():
@sio.on("usage") @sio.on("usage")
async def usage(sid, data): async def usage(sid, data):
model_id = data["model"] model_id = data["model"]
# Cancel previous callback if there is one # Cancel previous callback if there is one
@ -139,7 +137,7 @@ async def disconnect(sid):
print(f"Unknown session ID {sid} disconnected") print(f"Unknown session ID {sid} disconnected")
async def get_event_emitter(request_info): def get_event_emitter(request_info):
async def __event_emitter__(event_data): async def __event_emitter__(event_data):
await sio.emit( await sio.emit(
"chat-events", "chat-events",
@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
return __event_emitter__ return __event_emitter__
async def get_event_call(request_info): def get_event_call(request_info):
async def __event_call__(event_data): async def __event_call__(event_data):
response = await sio.call( response = await sio.call(
"chat-events", "chat-events",

View File

@ -1,9 +1,6 @@
from fastapi import FastAPI, Depends from fastapi import FastAPI
from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from sqlalchemy.orm import Session
from apps.webui.routers import ( from apps.webui.routers import (
auths, auths,
users, users,
@ -27,7 +24,6 @@ from utils.task import prompt_template
from config import ( from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS, SHOW_ADMIN_DETAILS,
ADMIN_EMAIL, ADMIN_EMAIL,
WEBUI_AUTH, WEBUI_AUTH,
@ -55,7 +51,7 @@ import uuid
import time import time
import json import json
from typing import Iterator, Generator, AsyncGenerator, Optional from typing import Iterator, Generator, AsyncGenerator
from pydantic import BaseModel from pydantic import BaseModel
app = FastAPI() app = FastAPI()
@ -127,60 +123,60 @@ async def get_status():
} }
def get_function_module(pipe_id: str):
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_pipe_models(): async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe", active_only=True) pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = [] pipe_models = []
for pipe in pipes: for pipe in pipes:
# Check if function is already loaded function_module = get_function_module(pipe.id)
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe.id
)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe.id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
# Check if function is a manifold # Check if function is a manifold
if hasattr(function_module, "type"): if hasattr(function_module, "type"):
if function_module.type == "manifold": if not function_module.type == "manifold":
manifold_pipes = [] continue
manifold_pipes = []
# Check if pipes is a function or a list # Check if pipes is a function or a list
if callable(function_module.pipes): if callable(function_module.pipes):
manifold_pipes = function_module.pipes() manifold_pipes = function_module.pipes()
else: else:
manifold_pipes = function_module.pipes manifold_pipes = function_module.pipes
for p in manifold_pipes: for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}' manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"] manifold_pipe_name = p["name"]
if hasattr(function_module, "name"): if hasattr(function_module, "name"):
manifold_pipe_name = ( manifold_pipe_name = f"{function_module.name}{manifold_pipe_name}"
f"{function_module.name}{manifold_pipe_name}"
)
pipe_flag = {"type": pipe.type} pipe_flag = {"type": pipe.type}
if hasattr(function_module, "ChatValves"): if hasattr(function_module, "ChatValves"):
pipe_flag["valves_spec"] = function_module.ChatValves.schema() pipe_flag["valves_spec"] = function_module.ChatValves.schema()
pipe_models.append( pipe_models.append(
{ {
"id": manifold_pipe_id, "id": manifold_pipe_id,
"name": manifold_pipe_name, "name": manifold_pipe_name,
"object": "model", "object": "model",
"created": pipe.created_at, "created": pipe.created_at,
"owned_by": "openai", "owned_by": "openai",
"pipe": pipe_flag, "pipe": pipe_flag,
} }
) )
else: else:
pipe_flag = {"type": "pipe"} pipe_flag = {"type": "pipe"}
if hasattr(function_module, "ChatValves"): if hasattr(function_module, "ChatValves"):
@ -200,162 +196,179 @@ async def get_pipe_models():
return pipe_models return pipe_models
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def get_message(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
def get_final_message(form_data: dict, message: str | None = None) -> dict:
choice = {
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
# If message is None, we're dealing with a chunk
if not message:
choice["delta"] = {}
else:
choice["message"] = {"role": "assistant", "content": message}
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"created": int(time.time()),
"model": form_data["model"],
"object": "chat.completion" if message is not None else "chat.completion.chunk",
"choices": [choice],
}
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
print(pipe_id)
return pipe_id
def get_params_dict(pipe, form_data, user, extra_params, function_module):
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params["__user__"] = __user__
return params
async def generate_function_chat_completion(form_data, user): async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
metadata = None metadata = form_data.pop("metadata", None)
if "metadata" in form_data:
metadata = form_data["metadata"]
del form_data["metadata"]
__event_emitter__ = None __event_emitter__ = __event_call__ = __task__ = None
__event_call__ = None
__task__ = None
if metadata: if metadata:
if ( if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
metadata.get("session_id") __event_emitter__ = get_event_emitter(metadata)
and metadata.get("chat_id") __event_call__ = get_event_call(metadata)
and metadata.get("message_id") __task__ = metadata.get("task", None)
):
__event_emitter__ = await get_event_emitter(metadata)
__event_call__ = await get_event_call(metadata)
if metadata.get("task"): if not model_info:
__task__ = metadata.get("task") return
if model_info: if model_info.base_model_id:
if model_info.base_model_id: form_data["model"] = model_info.base_model_id
form_data["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump() params = model_info.params.model_dump()
if model_info.params: if params:
if model_info.params.get("temperature", None) is not None: mappings = {
form_data["temperature"] = float(model_info.params.get("temperature")) "temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
if model_info.params.get("top_p", None): for key, cast_func in mappings.items():
form_data["top_p"] = int(model_info.params.get("top_p", None)) if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
if model_info.params.get("max_tokens", None): system = params.get("system", None)
form_data["max_tokens"] = int(model_info.params.get("max_tokens", None)) if not system:
return
if model_info.params.get("frequency_penalty", None):
form_data["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
form_data["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
form_data["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if form_data.get("messages"):
for message in form_data["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
form_data["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else: else:
pass template_params = {}
system = prompt_template(system, **template_params)
# Check if the payload already has a system message
# If not, add a system message to the payload
for message in form_data.get("messages", []):
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
if form_data.get("messages"):
form_data["messages"].insert(0, {"role": "system", "content": system})
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
}
async def job(): async def job():
pipe_id = form_data["model"] pipe_id = get_pipe_id(form_data)
if "." in pipe_id: function_module = get_function_module(pipe_id)
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe_id
)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe pipe = function_module.pipe
params = get_params_dict(pipe, form_data, user, extra_params, function_module)
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__event_emitter__" in sig.parameters:
params = {**params, "__event_emitter__": __event_emitter__}
if "__event_call__" in sig.parameters:
params = {**params, "__event_call__": __event_call__}
if "__task__" in sig.parameters:
params = {**params, "__task__": __task__}
if form_data["stream"]: if form_data["stream"]:
async def stream_content(): async def stream_content():
try: try:
if inspect.iscoroutinefunction(pipe): res = await execute_pipe(pipe, params)
res = await pipe(**params)
else:
res = pipe(**params)
# Directly return if the response is a StreamingResponse # Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse): if isinstance(res, StreamingResponse):
@ -377,107 +390,32 @@ async def generate_function_chat_completion(form_data, user):
if isinstance(res, Iterator): if isinstance(res, Iterator):
for line in res: for line in res:
if isinstance(line, BaseModel): yield process_line(form_data, line)
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
if isinstance(res, AsyncGenerator): if isinstance(res, AsyncGenerator):
async for line in res: async for line in res:
if isinstance(line, BaseModel): yield process_line(form_data, line)
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try: if isinstance(res, str) or isinstance(res, Generator):
line = line.decode("utf-8") finish_message = get_final_message(form_data)
except: yield f"data: {json.dumps(finish_message)}\n\n"
pass yield "data: [DONE]"
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
try: try:
if inspect.iscoroutinefunction(pipe): res = await execute_pipe(pipe, params)
res = await pipe(**params)
else:
res = pipe(**params)
if isinstance(res, StreamingResponse):
return res
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return {"error": {"detail": str(e)}} return {"error": {"detail": str(e)}}
if isinstance(res, dict): if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res return res
elif isinstance(res, BaseModel): if isinstance(res, BaseModel):
return res.model_dump() return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
elif isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
elif isinstance(res, AsyncGenerator):
async for stream in res:
message = f"{message}{stream}"
return { message = await get_message(res)
"id": f"{form_data['model']}-{str(uuid.uuid4())}", return get_final_message(form_data, message)
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await job() return await job()

View File

@ -1,13 +1,11 @@
import json
import logging import logging
from typing import Optional from typing import Optional, List
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
import time import time
@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def insert_new_model( def insert_new_model(
self, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]: ) -> Optional[ModelModel]:
@ -126,9 +123,7 @@ class ModelsTable:
} }
) )
try: try:
with get_db() as db: with get_db() as db:
result = Model(**model.model_dump()) result = Model(**model.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
@ -144,13 +139,11 @@ class ModelsTable:
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
with get_db() as db: with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()] return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with get_db() as db: with get_db() as db:
model = db.get(Model, id) model = db.get(Model, id)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except: except:
@ -178,7 +171,6 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
db.query(Model).filter_by(id=id).delete() db.query(Model).filter_by(id=id).delete()
db.commit() db.commit()

View File

@ -13,8 +13,6 @@ import aiohttp
import requests import requests
import mimetypes import mimetypes
import shutil import shutil
import os
import uuid
import inspect import inspect
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call from apps.socket.main import app as socket_app, get_event_emitter, get_event_call
from apps.ollama.main import ( from apps.ollama.main import (
app as ollama_app, app as ollama_app,
get_all_models as get_ollama_models, get_all_models as get_ollama_models,
@ -639,10 +637,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id = body["id"] message_id = body["id"]
del body["id"] del body["id"]
__event_emitter__ = await get_event_emitter( __event_emitter__ = get_event_emitter(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id} {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
) )
__event_call__ = await get_event_call( __event_call__ = get_event_call(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id} {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
) )
@ -1191,13 +1189,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
status_code=r.status_code, status_code=r.status_code,
content=res, content=res,
) )
except: except Exception:
pass pass
else: else:
pass pass
__event_emitter__ = await get_event_emitter( __event_emitter__ = get_event_emitter(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
@ -1205,7 +1203,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
} }
) )
__event_call__ = await get_event_call( __event_call__ = get_event_call(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
@ -1334,14 +1332,14 @@ async def chat_completed(
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
__event_emitter__ = await get_event_emitter( __event_emitter__ = get_event_emitter(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
"session_id": data["session_id"], "session_id": data["session_id"],
} }
) )
__event_call__ = await get_event_call( __event_call__ = get_event_call(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
@ -1770,7 +1768,6 @@ class AddPipelineForm(BaseModel):
@app.post("/api/pipelines/add") @app.post("/api/pipelines/add")
async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
r = None r = None
try: try:
urlIdx = form_data.urlIdx urlIdx = form_data.urlIdx
@ -1813,7 +1810,6 @@ class DeletePipelineForm(BaseModel):
@app.delete("/api/pipelines/delete") @app.delete("/api/pipelines/delete")
async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
r = None r = None
try: try:
urlIdx = form_data.urlIdx urlIdx = form_data.urlIdx
@ -1891,7 +1887,6 @@ async def get_pipeline_valves(
models = await get_all_models() models = await get_all_models()
r = None r = None
try: try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]