From d6e4aef607350ec2d54f7c46b5417e26bb17fc55 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 04:38:59 -0700 Subject: [PATCH] feat: pipe function --- backend/apps/webui/main.py | 58 ++++++++++ backend/main.py | 221 ++++++++++++++++++++++++++----------- backend/utils/misc.py | 19 ++++ 3 files changed, 234 insertions(+), 64 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 4a53b15bf..5ccb8ae58 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -15,6 +15,9 @@ from apps.webui.routers import ( files, functions, ) +from apps.webui.models.functions import Functions +from apps.webui.utils import load_function_module_by_id + from config import ( WEBUI_BUILD_HASH, SHOW_ADMIN_DETAILS, @@ -97,3 +100,58 @@ async def get_status(): "default_models": app.state.config.DEFAULT_MODELS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, } + + +async def get_pipe_models(): + pipes = Functions.get_functions_by_type("pipe") + pipe_models = [] + + for pipe in pipes: + # Check if function is already loaded + if pipe.id not in app.state.FUNCTIONS: + function_module, function_type = load_function_module_by_id(pipe.id) + app.state.FUNCTIONS[pipe.id] = function_module + else: + function_module = app.state.FUNCTIONS[pipe.id] + + # Check if function is a manifold + if hasattr(function_module, "type"): + if function_module.type == "manifold": + manifold_pipes = [] + + # Check if pipes is a function or a list + if callable(pipe.pipes): + manifold_pipes = pipe.pipes() + else: + manifold_pipes = pipe.pipes + + for p in manifold_pipes: + manifold_pipe_id = f'{pipe.id}.{p["id"]}' + manifold_pipe_name = p["name"] + + if hasattr(pipe, "name"): + manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}" + + pipe_models.append( + { + "id": manifold_pipe_id, + "name": manifold_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": {"type": pipe.type}, + } + ) + else: + pipe_models.append( + { + "id": pipe.id, + "name": pipe.name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": {"type": "pipe"}, + } + ) + + return pipe_models diff --git a/backend/main.py b/backend/main.py index 3d95d1913..d6a8c8831 100644 --- a/backend/main.py +++ b/backend/main.py @@ -15,6 +15,7 @@ import uuid import inspect import asyncio +from fastapi.concurrency import run_in_threadpool from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse @@ -46,7 +47,7 @@ from apps.webui.main import app as webui_app, get_pipe_models from pydantic import BaseModel -from typing import List, Optional +from typing import List, Optional, Iterator, Generator, Union from apps.webui.models.models import Models, ModelModel from apps.webui.models.tools import Tools @@ -66,7 +67,11 @@ from utils.task import ( search_query_generation_template, tools_function_calling_generation_template, ) -from utils.misc import get_last_user_message, add_or_update_system_message +from utils.misc import ( + get_last_user_message, + add_or_update_system_message, + stream_message_template, +) from apps.rag.utils import get_rag_context, rag_template @@ -347,38 +352,39 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): model = app.state.MODELS[model_id] # Check if the model has any filters - for filter_id in model["info"]["meta"].get("filterIds", []): - filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type = load_function_module_by_id( - filter_id - ) - webui_app.state.FUNCTIONS[filter_id] = function_module - - # Check if the function has a file_handler variable - if getattr(function_module, "file_handler"): - skip_files = True - - try: - if hasattr(function_module, "inlet"): - data = function_module.inlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, + if "info" in model and "meta" in model["info"]: + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if getattr(function_module, "file_handler"): + skip_files = True + + try: + if hasattr(function_module, "inlet"): + data = function_module.inlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, ) - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) # Set the task model task_model_id = data["model"] @@ -794,13 +800,97 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u model = app.state.MODELS[model_id] print(model) - + pipe = model.get("pipe") + if pipe: - if model.get('pipe') == True: - print('hi') - - - + def job(): + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, sub_pipe_id = pipe_id.split(".", 1) + print(pipe_id) + + pipe = webui_app.state.FUNCTIONS[pipe_id].pipe + if form_data["stream"]: + + def stream_content(): + res = pipe(body=form_data) + + if isinstance(res, str): + message = stream_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" + + if isinstance(res, Iterator): + for line in res: + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + try: + line = line.decode("utf-8") + except: + pass + + if line.startswith("data:"): + yield f"{line}\n\n" + else: + line = stream_message_template(form_data["model"], line) + yield f"data: {json.dumps(line)}\n\n" + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": form_data["model"], + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + yield f"data: {json.dumps(finish_message)}\n\n" + yield f"data: [DONE]" + + return StreamingResponse( + stream_content(), media_type="text/event-stream" + ) + else: + res = pipe(body=form_data) + + if isinstance(res, dict): + return res + elif isinstance(res, BaseModel): + return res.model_dump() + else: + message = "" + if isinstance(res, str): + message = res + if isinstance(res, Generator): + for stream in res: + message = f"{message}{stream}" + + return { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "object": "chat.completion", + "created": int(time.time()), + "model": form_data["model"], + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + return await run_in_threadpool(job) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user) else: @@ -877,32 +967,35 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): pass # Check if the model has any filters - for filter_id in model["info"]["meta"].get("filterIds", []): - filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - try: - if hasattr(function_module, "outlet"): - data = function_module.outlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, + if "info" in model and "meta" in model["info"]: + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module + + try: + if hasattr(function_module, "outlet"): + data = function_module.outlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, ) - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) return data diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 41fbdcc75..b4e499df8 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -4,6 +4,8 @@ import json import re from datetime import timedelta from typing import Optional, List, Tuple +import uuid +import time def get_last_user_message(messages: List[dict]) -> str: @@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]): return messages +def stream_message_template(model: str, message: str): + return { + "id": f"{model}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": message}, + "logprobs": None, + "finish_reason": None, + } + ], + } + + def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters