From 1c4e7f03245f6994d66f0200fde1fdbaed2429d2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 24 Jun 2024 11:17:18 -0700 Subject: [PATCH] refac --- backend/apps/webui/main.py | 159 ++++++++++++++++++++++++++++++++ backend/main.py | 180 +++---------------------------------- 2 files changed, 169 insertions(+), 170 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 8255cd393..260d305f0 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -1,5 +1,6 @@ from fastapi import FastAPI, Depends from fastapi.routing import APIRoute +from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from apps.webui.routers import ( auths, @@ -17,6 +18,7 @@ from apps.webui.routers import ( ) from apps.webui.models.functions import Functions from apps.webui.utils import load_function_module_by_id +from utils.misc import stream_message_template from config import ( WEBUI_BUILD_HASH, @@ -37,6 +39,14 @@ from config import ( AppConfig, ) +import inspect +import uuid +import time +import json + +from typing import Iterator, Generator +from pydantic import BaseModel + app = FastAPI() origins = ["*"] @@ -166,3 +176,152 @@ async def get_pipe_models(): ) return pipe_models + + +async def generate_function_chat_completion(form_data, user): + async def job(): + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, sub_pipe_id = pipe_id.split(".", 1) + print(pipe_id) + + # Check if function is already loaded + if pipe_id not in app.state.FUNCTIONS: + function_module, function_type, frontmatter = load_function_module_by_id( + pipe_id + ) + app.state.FUNCTIONS[pipe_id] = function_module + else: + function_module = app.state.FUNCTIONS[pipe_id] + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + + valves = Functions.get_function_valves_by_id(pipe_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + pipe = function_module.pipe + + # Get the signature of the function + sig = inspect.signature(pipe) + params = {"body": form_data} + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if form_data["stream"]: + + async def stream_content(): + try: + if inspect.iscoroutinefunction(pipe): + res = await pipe(**params) + else: + res = pipe(**params) + except Exception as e: + print(f"Error: {e}") + yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + return + + if isinstance(res, str): + message = stream_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" + + if isinstance(res, Iterator): + for line in res: + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + try: + line = line.decode("utf-8") + except: + pass + + if line.startswith("data:"): + yield f"{line}\n\n" + else: + line = stream_message_template(form_data["model"], line) + yield f"data: {json.dumps(line)}\n\n" + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": form_data["model"], + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + yield f"data: {json.dumps(finish_message)}\n\n" + yield f"data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + + try: + if inspect.iscoroutinefunction(pipe): + res = await pipe(**params) + else: + res = pipe(**params) + except Exception as e: + print(f"Error: {e}") + return {"error": {"detail": str(e)}} + + if inspect.iscoroutinefunction(pipe): + res = await pipe(**params) + else: + res = pipe(**params) + + if isinstance(res, dict): + return res + elif isinstance(res, BaseModel): + return res.model_dump() + else: + message = "" + if isinstance(res, str): + message = res + if isinstance(res, Generator): + for stream in res: + message = f"{message}{stream}" + + return { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "object": "chat.completion", + "created": int(time.time()), + "model": form_data["model"], + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + return await job() diff --git a/backend/main.py b/backend/main.py index 12e2d937e..426805276 100644 --- a/backend/main.py +++ b/backend/main.py @@ -43,7 +43,11 @@ from apps.openai.main import ( from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app -from apps.webui.main import app as webui_app, get_pipe_models +from apps.webui.main import ( + app as webui_app, + get_pipe_models, + generate_function_chat_completion, +) from pydantic import BaseModel @@ -228,10 +232,7 @@ async def get_function_call_response( response = None try: - if model["owned_by"] == "ollama": - response = await generate_ollama_chat_completion(payload, user=user) - else: - response = await generate_openai_chat_completion(payload, user=user) + response = await generate_chat_completions(form_data=payload, user=user) content = None @@ -900,159 +901,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u pipe = model.get("pipe") if pipe: - - async def job(): - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, sub_pipe_id = pipe_id.split(".", 1) - print(pipe_id) - - # Check if function is already loaded - if pipe_id not in webui_app.state.FUNCTIONS: - function_module, function_type, frontmatter = ( - load_function_module_by_id(pipe_id) - ) - webui_app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = webui_app.state.FUNCTIONS[pipe_id] - - if hasattr(function_module, "valves") and hasattr( - function_module, "Valves" - ): - - valves = Functions.get_function_valves_by_id(pipe_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - pipe = function_module.pipe - - # Get the signature of the function - sig = inspect.signature(pipe) - params = {"body": form_data} - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - pipe_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if form_data["stream"]: - - async def stream_content(): - try: - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) - except Exception as e: - print(f"Error: {e}") - yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" - return - - if isinstance(res, str): - message = stream_message_template(form_data["model"], res) - yield f"data: {json.dumps(message)}\n\n" - - if isinstance(res, Iterator): - for line in res: - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - try: - line = line.decode("utf-8") - except: - pass - - if line.startswith("data:"): - yield f"{line}\n\n" - else: - line = stream_message_template(form_data["model"], line) - yield f"data: {json.dumps(line)}\n\n" - - if isinstance(res, str) or isinstance(res, Generator): - finish_message = { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": form_data["model"], - "choices": [ - { - "index": 0, - "delta": {}, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - - yield f"data: {json.dumps(finish_message)}\n\n" - yield f"data: [DONE]" - - return StreamingResponse( - stream_content(), media_type="text/event-stream" - ) - else: - - try: - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) - except Exception as e: - print(f"Error: {e}") - return {"error": {"detail": str(e)}} - - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) - - if isinstance(res, dict): - return res - elif isinstance(res, BaseModel): - return res.model_dump() - else: - message = "" - if isinstance(res, str): - message = res - if isinstance(res, Generator): - for stream in res: - message = f"{message}{stream}" - - return { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "object": "chat.completion", - "created": int(time.time()), - "model": form_data["model"], - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": message, - }, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - - return await job() + return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user) else: @@ -1334,10 +1183,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): content={"detail": e.args[1]}, ) - if model["owned_by"] == "ollama": - return await generate_ollama_chat_completion(payload, user=user) - else: - return await generate_openai_chat_completion(payload, user=user) + return await generate_chat_completions(form_data=payload, user=user) @app.post("/api/task/query/completions") @@ -1397,10 +1243,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) content={"detail": e.args[1]}, ) - if model["owned_by"] == "ollama": - return await generate_ollama_chat_completion(payload, user=user) - else: - return await generate_openai_chat_completion(payload, user=user) + return await generate_chat_completions(form_data=payload, user=user) @app.post("/api/task/emoji/completions") @@ -1464,10 +1307,7 @@ Message: """{{prompt}}""" content={"detail": e.args[1]}, ) - if model["owned_by"] == "ollama": - return await generate_ollama_chat_completion(payload, user=user) - else: - return await generate_openai_chat_completion(payload, user=user) + return await generate_chat_completions(form_data=payload, user=user) @app.post("/api/task/tools/completions")