feat: pipe function

This commit is contained in:
Timothy J. Baek 2024-06-20 04:38:59 -07:00
parent de26a78a16
commit d6e4aef607
3 changed files with 234 additions and 64 deletions

View File

@ -15,6 +15,9 @@ from apps.webui.routers import (
files,
functions,
)
from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
@ -97,3 +100,58 @@ async def get_status():
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
}
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe")
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type = load_function_module_by_id(pipe.id)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
# Check if pipes is a function or a list
if callable(pipe.pipes):
manifold_pipes = pipe.pipes()
else:
manifold_pipes = pipe.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(pipe, "name"):
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": pipe.type},
}
)
else:
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": "pipe"},
}
)
return pipe_models

View File

@ -15,6 +15,7 @@ import uuid
import inspect
import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
@ -46,7 +47,7 @@ from apps.webui.main import app as webui_app, get_pipe_models
from pydantic import BaseModel
from typing import List, Optional
from typing import List, Optional, Iterator, Generator, Union
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
@ -66,7 +67,11 @@ from utils.task import (
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
)
from apps.rag.utils import get_rag_context, rag_template
@ -347,38 +352,39 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
model = app.state.MODELS[model_id]
# Check if the model has any filters
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if getattr(function_module, "file_handler"):
skip_files = True
try:
if hasattr(function_module, "inlet"):
data = function_module.inlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if getattr(function_module, "file_handler"):
skip_files = True
try:
if hasattr(function_module, "inlet"):
data = function_module.inlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Set the task model
task_model_id = data["model"]
@ -794,13 +800,97 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
model = app.state.MODELS[model_id]
print(model)
pipe = model.get("pipe")
if pipe:
if model.get('pipe') == True:
print('hi')
def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
if form_data["stream"]:
def stream_content():
res = pipe(body=form_data)
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(
stream_content(), media_type="text/event-stream"
)
else:
res = pipe(body=form_data)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
@ -877,32 +967,35 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
pass
# Check if the model has any filters
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
try:
if hasattr(function_module, "outlet"):
data = function_module.outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
try:
if hasattr(function_module, "outlet"):
data = function_module.outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data

View File

@ -4,6 +4,8 @@ import json
import re
from datetime import timedelta
from typing import Optional, List, Tuple
import uuid
import time
def get_last_user_message(messages: List[dict]) -> str:
@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages
def stream_message_template(model: str, message: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters