feat: pipe function

This commit is contained in:
Timothy J. Baek 2024-06-20 04:38:59 -07:00
parent de26a78a16
commit d6e4aef607
3 changed files with 234 additions and 64 deletions

View File

@ -15,6 +15,9 @@ from apps.webui.routers import (
files,
functions,
)
from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
@ -97,3 +100,58 @@ async def get_status():
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
}
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe")
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type = load_function_module_by_id(pipe.id)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
# Check if pipes is a function or a list
if callable(pipe.pipes):
manifold_pipes = pipe.pipes()
else:
manifold_pipes = pipe.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(pipe, "name"):
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": pipe.type},
}
)
else:
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": "pipe"},
}
)
return pipe_models

View File

@ -15,6 +15,7 @@ import uuid
import inspect
import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
@ -46,7 +47,7 @@ from apps.webui.main import app as webui_app, get_pipe_models
from pydantic import BaseModel
from typing import List, Optional
from typing import List, Optional, Iterator, Generator, Union
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
@ -66,7 +67,11 @@ from utils.task import (
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
)
from apps.rag.utils import get_rag_context, rag_template
@ -347,6 +352,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
model = app.state.MODELS[model_id]
# Check if the model has any filters
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
@ -794,13 +800,97 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
model = app.state.MODELS[model_id]
print(model)
pipe = model.get("pipe")
if pipe:
def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
if model.get('pipe') == True:
print('hi')
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
if form_data["stream"]:
def stream_content():
res = pipe(body=form_data)
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(
stream_content(), media_type="text/event-stream"
)
else:
res = pipe(body=form_data)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
@ -877,13 +967,16 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
pass
# Check if the model has any filters
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(filter_id)
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
try:

View File

@ -4,6 +4,8 @@ import json
import re
from datetime import timedelta
from typing import Optional, List, Tuple
import uuid
import time
def get_last_user_message(messages: List[dict]) -> str:
@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages
def stream_message_template(model: str, message: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters