feat: pipe function

This commit is contained in:
Timothy J. Baek 2024-06-20 04:38:59 -07:00
parent de26a78a16
commit d6e4aef607
3 changed files with 234 additions and 64 deletions

View File

@ -15,6 +15,9 @@ from apps.webui.routers import (
files, files,
functions, functions,
) )
from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id
from config import ( from config import (
WEBUI_BUILD_HASH, WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS, SHOW_ADMIN_DETAILS,
@ -97,3 +100,58 @@ async def get_status():
"default_models": app.state.config.DEFAULT_MODELS, "default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
} }
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe")
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type = load_function_module_by_id(pipe.id)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
# Check if pipes is a function or a list
if callable(pipe.pipes):
manifold_pipes = pipe.pipes()
else:
manifold_pipes = pipe.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(pipe, "name"):
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": pipe.type},
}
)
else:
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": "pipe"},
}
)
return pipe_models

View File

@ -15,6 +15,7 @@ import uuid
import inspect import inspect
import asyncio import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -46,7 +47,7 @@ from apps.webui.main import app as webui_app, get_pipe_models
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional, Iterator, Generator, Union
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
@ -66,7 +67,11 @@ from utils.task import (
search_query_generation_template, search_query_generation_template,
tools_function_calling_generation_template, tools_function_calling_generation_template,
) )
from utils.misc import get_last_user_message, add_or_update_system_message from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
)
from apps.rag.utils import get_rag_context, rag_template from apps.rag.utils import get_rag_context, rag_template
@ -347,38 +352,39 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
# Check if the model has any filters # Check if the model has any filters
for filter_id in model["info"]["meta"].get("filterIds", []): if "info" in model and "meta" in model["info"]:
filter = Functions.get_function_by_id(filter_id) for filter_id in model["info"]["meta"].get("filterIds", []):
if filter: filter = Functions.get_function_by_id(filter_id)
if filter_id in webui_app.state.FUNCTIONS: if filter:
function_module = webui_app.state.FUNCTIONS[filter_id] if filter_id in webui_app.state.FUNCTIONS:
else: function_module = webui_app.state.FUNCTIONS[filter_id]
function_module, function_type = load_function_module_by_id( else:
filter_id function_module, function_type = load_function_module_by_id(
) filter_id
webui_app.state.FUNCTIONS[filter_id] = function_module )
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if getattr(function_module, "file_handler"): # Check if the function has a file_handler variable
skip_files = True if getattr(function_module, "file_handler"):
skip_files = True
try:
if hasattr(function_module, "inlet"): try:
data = function_module.inlet( if hasattr(function_module, "inlet"):
data, data = function_module.inlet(
{ data,
"id": user.id, {
"email": user.email, "id": user.id,
"name": user.name, "email": user.email,
"role": user.role, "name": user.name,
}, "role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
) )
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Set the task model # Set the task model
task_model_id = data["model"] task_model_id = data["model"]
@ -794,13 +800,97 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
print(model) print(model)
pipe = model.get("pipe")
if pipe:
def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
if model.get('pipe') == True: pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
print('hi') if form_data["stream"]:
def stream_content():
res = pipe(body=form_data)
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(
stream_content(), media_type="text/event-stream"
)
else:
res = pipe(body=form_data)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user) return await generate_ollama_chat_completion(form_data, user=user)
else: else:
@ -877,32 +967,35 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
pass pass
# Check if the model has any filters # Check if the model has any filters
for filter_id in model["info"]["meta"].get("filterIds", []): if "info" in model and "meta" in model["info"]:
filter = Functions.get_function_by_id(filter_id) for filter_id in model["info"]["meta"].get("filterIds", []):
if filter: filter = Functions.get_function_by_id(filter_id)
if filter_id in webui_app.state.FUNCTIONS: if filter:
function_module = webui_app.state.FUNCTIONS[filter_id] if filter_id in webui_app.state.FUNCTIONS:
else: function_module = webui_app.state.FUNCTIONS[filter_id]
function_module, function_type = load_function_module_by_id(filter_id) else:
webui_app.state.FUNCTIONS[filter_id] = function_module function_module, function_type = load_function_module_by_id(
filter_id
try: )
if hasattr(function_module, "outlet"): webui_app.state.FUNCTIONS[filter_id] = function_module
data = function_module.outlet(
data, try:
{ if hasattr(function_module, "outlet"):
"id": user.id, data = function_module.outlet(
"email": user.email, data,
"name": user.name, {
"role": user.role, "id": user.id,
}, "email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
) )
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data return data

View File

@ -4,6 +4,8 @@ import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
import uuid
import time
def get_last_user_message(messages: List[dict]) -> str: def get_last_user_message(messages: List[dict]) -> str:
@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages return messages
def stream_message_template(model: str, message: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_gravatar_url(email): def get_gravatar_url(email):
# Trim leading and trailing whitespace from # Trim leading and trailing whitespace from
# an email address and force all characters # an email address and force all characters