mirror of
https://github.com/open-webui/open-webui
synced 2025-02-20 03:48:48 +00:00
feat: pipe function
This commit is contained in:
parent
de26a78a16
commit
d6e4aef607
@ -15,6 +15,9 @@ from apps.webui.routers import (
|
||||
files,
|
||||
functions,
|
||||
)
|
||||
from apps.webui.models.functions import Functions
|
||||
from apps.webui.utils import load_function_module_by_id
|
||||
|
||||
from config import (
|
||||
WEBUI_BUILD_HASH,
|
||||
SHOW_ADMIN_DETAILS,
|
||||
@ -97,3 +100,58 @@ async def get_status():
|
||||
"default_models": app.state.config.DEFAULT_MODELS,
|
||||
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||
}
|
||||
|
||||
|
||||
async def get_pipe_models():
|
||||
pipes = Functions.get_functions_by_type("pipe")
|
||||
pipe_models = []
|
||||
|
||||
for pipe in pipes:
|
||||
# Check if function is already loaded
|
||||
if pipe.id not in app.state.FUNCTIONS:
|
||||
function_module, function_type = load_function_module_by_id(pipe.id)
|
||||
app.state.FUNCTIONS[pipe.id] = function_module
|
||||
else:
|
||||
function_module = app.state.FUNCTIONS[pipe.id]
|
||||
|
||||
# Check if function is a manifold
|
||||
if hasattr(function_module, "type"):
|
||||
if function_module.type == "manifold":
|
||||
manifold_pipes = []
|
||||
|
||||
# Check if pipes is a function or a list
|
||||
if callable(pipe.pipes):
|
||||
manifold_pipes = pipe.pipes()
|
||||
else:
|
||||
manifold_pipes = pipe.pipes
|
||||
|
||||
for p in manifold_pipes:
|
||||
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
|
||||
manifold_pipe_name = p["name"]
|
||||
|
||||
if hasattr(pipe, "name"):
|
||||
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": manifold_pipe_id,
|
||||
"name": manifold_pipe_name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": {"type": pipe.type},
|
||||
}
|
||||
)
|
||||
else:
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": pipe.id,
|
||||
"name": pipe.name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": {"type": "pipe"},
|
||||
}
|
||||
)
|
||||
|
||||
return pipe_models
|
||||
|
221
backend/main.py
221
backend/main.py
@ -15,6 +15,7 @@ import uuid
|
||||
import inspect
|
||||
import asyncio
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import JSONResponse
|
||||
@ -46,7 +47,7 @@ from apps.webui.main import app as webui_app, get_pipe_models
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Iterator, Generator, Union
|
||||
|
||||
from apps.webui.models.models import Models, ModelModel
|
||||
from apps.webui.models.tools import Tools
|
||||
@ -66,7 +67,11 @@ from utils.task import (
|
||||
search_query_generation_template,
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from utils.misc import get_last_user_message, add_or_update_system_message
|
||||
from utils.misc import (
|
||||
get_last_user_message,
|
||||
add_or_update_system_message,
|
||||
stream_message_template,
|
||||
)
|
||||
|
||||
from apps.rag.utils import get_rag_context, rag_template
|
||||
|
||||
@ -347,38 +352,39 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
# Check if the model has any filters
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(
|
||||
filter_id
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if getattr(function_module, "file_handler"):
|
||||
skip_files = True
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "inlet"):
|
||||
data = function_module.inlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(
|
||||
filter_id
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if getattr(function_module, "file_handler"):
|
||||
skip_files = True
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "inlet"):
|
||||
data = function_module.inlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
# Set the task model
|
||||
task_model_id = data["model"]
|
||||
@ -794,13 +800,97 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
model = app.state.MODELS[model_id]
|
||||
print(model)
|
||||
|
||||
|
||||
pipe = model.get("pipe")
|
||||
if pipe:
|
||||
|
||||
if model.get('pipe') == True:
|
||||
print('hi')
|
||||
|
||||
|
||||
|
||||
def job():
|
||||
pipe_id = form_data["model"]
|
||||
if "." in pipe_id:
|
||||
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
|
||||
print(pipe_id)
|
||||
|
||||
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
|
||||
if form_data["stream"]:
|
||||
|
||||
def stream_content():
|
||||
res = pipe(body=form_data)
|
||||
|
||||
if isinstance(res, str):
|
||||
message = stream_message_template(form_data["model"], res)
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
|
||||
if isinstance(res, Iterator):
|
||||
for line in res:
|
||||
if isinstance(line, BaseModel):
|
||||
line = line.model_dump_json()
|
||||
line = f"data: {line}"
|
||||
try:
|
||||
line = line.decode("utf-8")
|
||||
except:
|
||||
pass
|
||||
|
||||
if line.startswith("data:"):
|
||||
yield f"{line}\n\n"
|
||||
else:
|
||||
line = stream_message_template(form_data["model"], line)
|
||||
yield f"data: {json.dumps(line)}\n\n"
|
||||
|
||||
if isinstance(res, str) or isinstance(res, Generator):
|
||||
finish_message = {
|
||||
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": form_data["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
yield f"data: {json.dumps(finish_message)}\n\n"
|
||||
yield f"data: [DONE]"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(), media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
res = pipe(body=form_data)
|
||||
|
||||
if isinstance(res, dict):
|
||||
return res
|
||||
elif isinstance(res, BaseModel):
|
||||
return res.model_dump()
|
||||
else:
|
||||
message = ""
|
||||
if isinstance(res, str):
|
||||
message = res
|
||||
if isinstance(res, Generator):
|
||||
for stream in res:
|
||||
message = f"{message}{stream}"
|
||||
|
||||
return {
|
||||
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": form_data["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": message,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
return await run_in_threadpool(job)
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(form_data, user=user)
|
||||
else:
|
||||
@ -877,32 +967,35 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
pass
|
||||
|
||||
# Check if the model has any filters
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(filter_id)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "outlet"):
|
||||
data = function_module.outlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(
|
||||
filter_id
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "outlet"):
|
||||
data = function_module.outlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
@ -4,6 +4,8 @@ import json
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Optional, List, Tuple
|
||||
import uuid
|
||||
import time
|
||||
|
||||
|
||||
def get_last_user_message(messages: List[dict]) -> str:
|
||||
@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
|
||||
return messages
|
||||
|
||||
|
||||
def stream_message_template(model: str, message: str):
|
||||
return {
|
||||
"id": f"{model}-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": message},
|
||||
"logprobs": None,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_gravatar_url(email):
|
||||
# Trim leading and trailing whitespace from
|
||||
# an email address and force all characters
|
||||
|
Loading…
Reference in New Issue
Block a user