From 3978efd7104cbaf21ad054080e694f8919f1b306 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 13:35:02 +0100 Subject: [PATCH 01/12] refac: Refactor functions --- backend/apps/socket/main.py | 6 +- backend/apps/webui/main.py | 468 ++++++++++++---------------- backend/apps/webui/models/models.py | 12 +- backend/main.py | 21 +- 4 files changed, 215 insertions(+), 292 deletions(-) diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index 18ce7a607..1d98d37ff 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -52,7 +52,6 @@ async def user_join(sid, data): user = Users.get_user_by_id(data["id"]) if user: - SESSION_POOL[sid] = user.id if user.id in USER_POOL: USER_POOL[user.id].append(sid) @@ -80,7 +79,6 @@ def get_models_in_use(): @sio.on("usage") async def usage(sid, data): - model_id = data["model"] # Cancel previous callback if there is one @@ -139,7 +137,7 @@ async def disconnect(sid): print(f"Unknown session ID {sid} disconnected") -async def get_event_emitter(request_info): +def get_event_emitter(request_info): async def __event_emitter__(event_data): await sio.emit( "chat-events", @@ -154,7 +152,7 @@ async def get_event_emitter(request_info): return __event_emitter__ -async def get_event_call(request_info): +def get_event_call(request_info): async def __event_call__(event_data): response = await sio.call( "chat-events", diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 97165a11b..331713b07 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -1,9 +1,6 @@ -from fastapi import FastAPI, Depends -from fastapi.routing import APIRoute +from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.sessions import SessionMiddleware -from sqlalchemy.orm import Session from apps.webui.routers import ( auths, users, @@ -27,7 +24,6 @@ from utils.task import prompt_template from config import ( - WEBUI_BUILD_HASH, SHOW_ADMIN_DETAILS, ADMIN_EMAIL, WEBUI_AUTH, @@ -55,7 +51,7 @@ import uuid import time import json -from typing import Iterator, Generator, AsyncGenerator, Optional +from typing import Iterator, Generator, AsyncGenerator from pydantic import BaseModel app = FastAPI() @@ -127,60 +123,60 @@ async def get_status(): } +def get_function_module(pipe_id: str): + # Check if function is already loaded + if pipe_id not in app.state.FUNCTIONS: + function_module, _, _ = load_function_module_by_id(pipe_id) + app.state.FUNCTIONS[pipe_id] = function_module + else: + function_module = app.state.FUNCTIONS[pipe_id] + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(pipe_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + return function_module + + async def get_pipe_models(): pipes = Functions.get_functions_by_type("pipe", active_only=True) pipe_models = [] for pipe in pipes: - # Check if function is already loaded - if pipe.id not in app.state.FUNCTIONS: - function_module, function_type, frontmatter = load_function_module_by_id( - pipe.id - ) - app.state.FUNCTIONS[pipe.id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe.id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(pipe.id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + function_module = get_function_module(pipe.id) # Check if function is a manifold if hasattr(function_module, "type"): - if function_module.type == "manifold": - manifold_pipes = [] + if not function_module.type == "manifold": + continue + manifold_pipes = [] - # Check if pipes is a function or a list - if callable(function_module.pipes): - manifold_pipes = function_module.pipes() - else: - manifold_pipes = function_module.pipes + # Check if pipes is a function or a list + if callable(function_module.pipes): + manifold_pipes = function_module.pipes() + else: + manifold_pipes = function_module.pipes - for p in manifold_pipes: - manifold_pipe_id = f'{pipe.id}.{p["id"]}' - manifold_pipe_name = p["name"] + for p in manifold_pipes: + manifold_pipe_id = f'{pipe.id}.{p["id"]}' + manifold_pipe_name = p["name"] - if hasattr(function_module, "name"): - manifold_pipe_name = ( - f"{function_module.name}{manifold_pipe_name}" - ) + if hasattr(function_module, "name"): + manifold_pipe_name = f"{function_module.name}{manifold_pipe_name}" - pipe_flag = {"type": pipe.type} - if hasattr(function_module, "ChatValves"): - pipe_flag["valves_spec"] = function_module.ChatValves.schema() + pipe_flag = {"type": pipe.type} + if hasattr(function_module, "ChatValves"): + pipe_flag["valves_spec"] = function_module.ChatValves.schema() - pipe_models.append( - { - "id": manifold_pipe_id, - "name": manifold_pipe_name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) + pipe_models.append( + { + "id": manifold_pipe_id, + "name": manifold_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) else: pipe_flag = {"type": "pipe"} if hasattr(function_module, "ChatValves"): @@ -200,162 +196,179 @@ async def get_pipe_models(): return pipe_models +async def execute_pipe(pipe, params): + if inspect.iscoroutinefunction(pipe): + return await pipe(**params) + else: + return pipe(**params) + + +async def get_message(res: str | Generator | AsyncGenerator) -> str: + if isinstance(res, str): + return res + if isinstance(res, Generator): + return "".join(map(str, res)) + if isinstance(res, AsyncGenerator): + return "".join([str(stream) async for stream in res]) + + +def get_final_message(form_data: dict, message: str | None = None) -> dict: + choice = { + "index": 0, + "logprobs": None, + "finish_reason": "stop", + } + + # If message is None, we're dealing with a chunk + if not message: + choice["delta"] = {} + else: + choice["message"] = {"role": "assistant", "content": message} + + return { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "created": int(time.time()), + "model": form_data["model"], + "object": "chat.completion" if message is not None else "chat.completion.chunk", + "choices": [choice], + } + + +def process_line(form_data: dict, line): + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + if isinstance(line, dict): + line = f"data: {json.dumps(line)}" + + try: + line = line.decode("utf-8") + except Exception: + pass + + if line.startswith("data:"): + return f"{line}\n\n" + else: + line = stream_message_template(form_data["model"], line) + return f"data: {json.dumps(line)}\n\n" + + +def get_pipe_id(form_data: dict) -> str: + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, _ = pipe_id.split(".", 1) + print(pipe_id) + return pipe_id + + +def get_params_dict(pipe, form_data, user, extra_params, function_module): + pipe_id = get_pipe_id(form_data) + # Get the signature of the function + sig = inspect.signature(pipe) + params = {"body": form_data} + + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) + ) + except Exception as e: + print(e) + + params["__user__"] = __user__ + return params + + async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) - metadata = None - if "metadata" in form_data: - metadata = form_data["metadata"] - del form_data["metadata"] + metadata = form_data.pop("metadata", None) - __event_emitter__ = None - __event_call__ = None - __task__ = None + __event_emitter__ = __event_call__ = __task__ = None if metadata: - if ( - metadata.get("session_id") - and metadata.get("chat_id") - and metadata.get("message_id") - ): - __event_emitter__ = await get_event_emitter(metadata) - __event_call__ = await get_event_call(metadata) + if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + __task__ = metadata.get("task", None) - if metadata.get("task"): - __task__ = metadata.get("task") + if not model_info: + return - if model_info: - if model_info.base_model_id: - form_data["model"] = model_info.base_model_id + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() - if model_info.params: - if model_info.params.get("temperature", None) is not None: - form_data["temperature"] = float(model_info.params.get("temperature")) + if params: + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + } - if model_info.params.get("top_p", None): - form_data["top_p"] = int(model_info.params.get("top_p", None)) + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) - if model_info.params.get("max_tokens", None): - form_data["max_tokens"] = int(model_info.params.get("max_tokens", None)) - - if model_info.params.get("frequency_penalty", None): - form_data["frequency_penalty"] = int( - model_info.params.get("frequency_penalty", None) - ) - - if model_info.params.get("seed", None): - form_data["seed"] = model_info.params.get("seed", None) - - if model_info.params.get("stop", None): - form_data["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) - - system = model_info.params.get("system", None) - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), - ) - # Check if the payload already has a system message - # If not, add a system message to the payload - if form_data.get("messages"): - for message in form_data["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - form_data["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) + system = params.get("system", None) + if not system: + return + if user: + template_params = { + "user_name": user.name, + "user_location": user.info.get("location") if user.info else None, + } else: - pass + template_params = {} + + system = prompt_template(system, **template_params) + + # Check if the payload already has a system message + # If not, add a system message to the payload + for message in form_data.get("messages", []): + if message.get("role") == "system": + message["content"] = system + message["content"] + break + else: + if form_data.get("messages"): + form_data["messages"].insert(0, {"role": "system", "content": system}) + + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + } async def job(): - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, sub_pipe_id = pipe_id.split(".", 1) - print(pipe_id) - - # Check if function is already loaded - if pipe_id not in app.state.FUNCTIONS: - function_module, function_type, frontmatter = load_function_module_by_id( - pipe_id - ) - app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe_id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - - valves = Functions.get_function_valves_by_id(pipe_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + pipe_id = get_pipe_id(form_data) + function_module = get_function_module(pipe_id) pipe = function_module.pipe - - # Get the signature of the function - sig = inspect.signature(pipe) - params = {"body": form_data} - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if "__event_emitter__" in sig.parameters: - params = {**params, "__event_emitter__": __event_emitter__} - - if "__event_call__" in sig.parameters: - params = {**params, "__event_call__": __event_call__} - - if "__task__" in sig.parameters: - params = {**params, "__task__": __task__} + params = get_params_dict(pipe, form_data, user, extra_params, function_module) if form_data["stream"]: async def stream_content(): try: - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) + res = await execute_pipe(pipe, params) # Directly return if the response is a StreamingResponse if isinstance(res, StreamingResponse): @@ -377,107 +390,32 @@ async def generate_function_chat_completion(form_data, user): if isinstance(res, Iterator): for line in res: - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - if isinstance(line, dict): - line = f"data: {json.dumps(line)}" - - try: - line = line.decode("utf-8") - except: - pass - - if line.startswith("data:"): - yield f"{line}\n\n" - else: - line = stream_message_template(form_data["model"], line) - yield f"data: {json.dumps(line)}\n\n" - - if isinstance(res, str) or isinstance(res, Generator): - finish_message = { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": form_data["model"], - "choices": [ - { - "index": 0, - "delta": {}, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - - yield f"data: {json.dumps(finish_message)}\n\n" - yield f"data: [DONE]" + yield process_line(form_data, line) if isinstance(res, AsyncGenerator): async for line in res: - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - if isinstance(line, dict): - line = f"data: {json.dumps(line)}" + yield process_line(form_data, line) - try: - line = line.decode("utf-8") - except: - pass - - if line.startswith("data:"): - yield f"{line}\n\n" - else: - line = stream_message_template(form_data["model"], line) - yield f"data: {json.dumps(line)}\n\n" + if isinstance(res, str) or isinstance(res, Generator): + finish_message = get_final_message(form_data) + yield f"data: {json.dumps(finish_message)}\n\n" + yield "data: [DONE]" return StreamingResponse(stream_content(), media_type="text/event-stream") else: - try: - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) + res = await execute_pipe(pipe, params) - if isinstance(res, StreamingResponse): - return res except Exception as e: print(f"Error: {e}") return {"error": {"detail": str(e)}} - if isinstance(res, dict): + if isinstance(res, StreamingResponse) or isinstance(res, dict): return res - elif isinstance(res, BaseModel): + if isinstance(res, BaseModel): return res.model_dump() - else: - message = "" - if isinstance(res, str): - message = res - elif isinstance(res, Generator): - for stream in res: - message = f"{message}{stream}" - elif isinstance(res, AsyncGenerator): - async for stream in res: - message = f"{message}{stream}" - return { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "object": "chat.completion", - "created": int(time.time()), - "model": form_data["model"], - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": message, - }, - "logprobs": None, - "finish_reason": "stop", - } - ], - } + message = await get_message(res) + return get_final_message(form_data, message) return await job() diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 3b128c7d6..8277d1d0b 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -1,13 +1,11 @@ -import json import logging -from typing import Optional +from typing import Optional, List from pydantic import BaseModel, ConfigDict -from sqlalchemy import String, Column, BigInteger, Text +from sqlalchemy import Column, BigInteger, Text from apps.webui.internal.db import Base, JSONField, get_db -from typing import List, Union, Optional from config import SRC_LOG_LEVELS import time @@ -113,7 +111,6 @@ class ModelForm(BaseModel): class ModelsTable: - def insert_new_model( self, form_data: ModelForm, user_id: str ) -> Optional[ModelModel]: @@ -126,9 +123,7 @@ class ModelsTable: } ) try: - with get_db() as db: - result = Model(**model.model_dump()) db.add(result) db.commit() @@ -144,13 +139,11 @@ class ModelsTable: def get_all_models(self) -> List[ModelModel]: with get_db() as db: - return [ModelModel.model_validate(model) for model in db.query(Model).all()] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: with get_db() as db: - model = db.get(Model, id) return ModelModel.model_validate(model) except: @@ -178,7 +171,6 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: with get_db() as db: - db.query(Model).filter_by(id=id).delete() db.commit() diff --git a/backend/main.py b/backend/main.py index 360f5f415..915abe333 100644 --- a/backend/main.py +++ b/backend/main.py @@ -13,8 +13,6 @@ import aiohttp import requests import mimetypes import shutil -import os -import uuid import inspect from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form @@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import StreamingResponse, Response, RedirectResponse -from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call +from apps.socket.main import app as socket_app, get_event_emitter, get_event_call from apps.ollama.main import ( app as ollama_app, get_all_models as get_ollama_models, @@ -639,10 +637,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): message_id = body["id"] del body["id"] - __event_emitter__ = await get_event_emitter( + __event_emitter__ = get_event_emitter( {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} ) - __event_call__ = await get_event_call( + __event_call__ = get_event_call( {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} ) @@ -1191,13 +1189,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): status_code=r.status_code, content=res, ) - except: + except Exception: pass else: pass - __event_emitter__ = await get_event_emitter( + __event_emitter__ = get_event_emitter( { "chat_id": data["chat_id"], "message_id": data["id"], @@ -1205,7 +1203,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): } ) - __event_call__ = await get_event_call( + __event_call__ = get_event_call( { "chat_id": data["chat_id"], "message_id": data["id"], @@ -1334,14 +1332,14 @@ async def chat_completed( ) model = app.state.MODELS[model_id] - __event_emitter__ = await get_event_emitter( + __event_emitter__ = get_event_emitter( { "chat_id": data["chat_id"], "message_id": data["id"], "session_id": data["session_id"], } ) - __event_call__ = await get_event_call( + __event_call__ = get_event_call( { "chat_id": data["chat_id"], "message_id": data["id"], @@ -1770,7 +1768,6 @@ class AddPipelineForm(BaseModel): @app.post("/api/pipelines/add") async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): - r = None try: urlIdx = form_data.urlIdx @@ -1813,7 +1810,6 @@ class DeletePipelineForm(BaseModel): @app.delete("/api/pipelines/delete") async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): - r = None try: urlIdx = form_data.urlIdx @@ -1891,7 +1887,6 @@ async def get_pipeline_valves( models = await get_all_models() r = None try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] From deec41d29a850afa62f1607da231ad25aee88cc2 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 13:51:25 +0100 Subject: [PATCH 02/12] fix: function early returns --- backend/apps/webui/main.py | 110 +++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 53 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 331713b07..13761f8cb 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -291,12 +291,7 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module): return params -async def generate_function_chat_completion(form_data, user): - model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) - - metadata = form_data.pop("metadata", None) - +def get_extra_params(metadata: dict): __event_emitter__ = __event_call__ = __task__ = None if metadata: @@ -305,58 +300,67 @@ async def generate_function_chat_completion(form_data, user): __event_call__ = get_event_call(metadata) __task__ = metadata.get("task", None) - if not model_info: - return - - if model_info.base_model_id: - form_data["model"] = model_info.base_model_id - - params = model_info.params.model_dump() - - if params: - mappings = { - "temperature": float, - "top_p": int, - "max_tokens": int, - "frequency_penalty": int, - "seed": lambda x: x, - "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], - } - - for key, cast_func in mappings.items(): - if (value := params.get(key)) is not None: - form_data[key] = cast_func(value) - - system = params.get("system", None) - if not system: - return - - if user: - template_params = { - "user_name": user.name, - "user_location": user.info.get("location") if user.info else None, - } - else: - template_params = {} - - system = prompt_template(system, **template_params) - - # Check if the payload already has a system message - # If not, add a system message to the payload - for message in form_data.get("messages", []): - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - if form_data.get("messages"): - form_data["messages"].insert(0, {"role": "system", "content": system}) - - extra_params = { + return { "__event_emitter__": __event_emitter__, "__event_call__": __event_call__, "__task__": __task__, } + +async def generate_function_chat_completion(form_data, user): + print("entry point") + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + metadata = form_data.pop("metadata", None) + extra_params = get_extra_params(metadata) + + if model_info: + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + + if params: + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [ + bytes(s, "utf-8").decode("unicode_escape") for s in x + ], + } + + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) + + system = params.get("system", None) + if system: + if user: + template_params = { + "user_name": user.name, + "user_location": user.info.get("location") if user.info else None, + } + else: + template_params = {} + + system = prompt_template(system, **template_params) + + # Check if the payload already has a system message + # If not, add a system message to the payload + for message in form_data.get("messages", []): + if message.get("role") == "system": + message["content"] = system + message["content"] + break + else: + if form_data.get("messages"): + form_data["messages"].insert( + 0, {"role": "system", "content": system} + ) + async def job(): pipe_id = get_pipe_id(form_data) function_module = get_function_module(pipe_id) From 22a5e196c9fd460099c9703a02b8298b5fd4de74 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 14:01:40 +0100 Subject: [PATCH 03/12] simplify main.py --- backend/main.py | 41 +++++++++-------------------------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/backend/main.py b/backend/main.py index 915abe333..474df20b6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -617,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): content={"detail": str(e)}, ) - # Extract valves from the request body - valves = None - if "valves" in body: - valves = body["valves"] - del body["valves"] + metadata = { + "chat_id": body.pop("chat_id", None), + "message_id": body.pop("id", None), + "session_id": body.pop("session_id", None), + "valves": body.pop("valves", None), + } - # Extract session_id, chat_id and message_id from the request body - session_id = None - if "session_id" in body: - session_id = body["session_id"] - del body["session_id"] - chat_id = None - if "chat_id" in body: - chat_id = body["chat_id"] - del body["chat_id"] - message_id = None - if "id" in body: - message_id = body["id"] - del body["id"] - - __event_emitter__ = get_event_emitter( - {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} - ) - __event_call__ = get_event_call( - {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} - ) + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) # Initialize data_items to store additional data to be sent to the client data_items = [] @@ -707,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if len(citations) > 0: data_items.append({"citations": citations}) - body["metadata"] = { - "session_id": session_id, - "chat_id": chat_id, - "message_id": message_id, - "valves": valves, - } - + body["metadata"] = metadata modified_body_bytes = json.dumps(body).encode("utf-8") # Replace the request body with the modified one request._body = modified_body_bytes From 29a3b82336cecf7a602e7569478b964f3003667e Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 15:26:26 +0100 Subject: [PATCH 04/12] refac: reuse stream_message_template --- backend/apps/webui/main.py | 37 +++++++++---------------------------- backend/utils/misc.py | 26 ++++++++++++++++---------- 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 13761f8cb..51fa711ca 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -19,7 +19,7 @@ from apps.webui.models.functions import Functions from apps.webui.models.models import Models from apps.webui.utils import load_function_module_by_id -from utils.misc import stream_message_template +from utils.misc import stream_message_template, whole_message_template from utils.task import prompt_template @@ -203,7 +203,7 @@ async def execute_pipe(pipe, params): return pipe(**params) -async def get_message(res: str | Generator | AsyncGenerator) -> str: +async def get_message_content(res: str | Generator | AsyncGenerator) -> str: if isinstance(res, str): return res if isinstance(res, Generator): @@ -212,28 +212,6 @@ async def get_message(res: str | Generator | AsyncGenerator) -> str: return "".join([str(stream) async for stream in res]) -def get_final_message(form_data: dict, message: str | None = None) -> dict: - choice = { - "index": 0, - "logprobs": None, - "finish_reason": "stop", - } - - # If message is None, we're dealing with a chunk - if not message: - choice["delta"] = {} - else: - choice["message"] = {"role": "assistant", "content": message} - - return { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "created": int(time.time()), - "model": form_data["model"], - "object": "chat.completion" if message is not None else "chat.completion.chunk", - "choices": [choice], - } - - def process_line(form_data: dict, line): if isinstance(line, BaseModel): line = line.model_dump_json() @@ -292,7 +270,9 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module): def get_extra_params(metadata: dict): - __event_emitter__ = __event_call__ = __task__ = None + __event_emitter__ = None + __event_call__ = None + __task__ = None if metadata: if all(k in metadata for k in ("session_id", "chat_id", "message_id")): @@ -401,7 +381,8 @@ async def generate_function_chat_completion(form_data, user): yield process_line(form_data, line) if isinstance(res, str) or isinstance(res, Generator): - finish_message = get_final_message(form_data) + finish_message = stream_message_template(form_data, "") + finish_message["choices"][0]["finish_reason"] = "stop" yield f"data: {json.dumps(finish_message)}\n\n" yield "data: [DONE]" @@ -419,7 +400,7 @@ async def generate_function_chat_completion(form_data, user): if isinstance(res, BaseModel): return res.model_dump() - message = await get_message(res) - return get_final_message(form_data, message) + message = await get_message_content(res) + return whole_message_template(form_data["model"], message) return await job() diff --git a/backend/utils/misc.py b/backend/utils/misc.py index f44a7ce7a..a1e0a8e80 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -87,23 +87,29 @@ def add_or_update_system_message(content: str, messages: List[dict]): return messages -def stream_message_template(model: str, message: str): +def message_template(model: str): return { "id": f"{model}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", "created": int(time.time()), "model": model, - "choices": [ - { - "index": 0, - "delta": {"content": message}, - "logprobs": None, - "finish_reason": None, - } - ], + "choices": [{"index": 0, "logprobs": None, "finish_reason": None}], } +def stream_message_template(model: str, message: str): + template = message_template(model) + template["object"] = "chat.completion.chunk" + template["choices"][0]["delta"] = {"content": message} + return template + + +def whole_message_template(model: str, message: str): + template = message_template(model) + template["object"] = "chat.completion" + template["choices"][0]["message"] = {"content": message, "role": "assistant"} + template["choices"][0]["finish_reason"] = "stop" + + def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters From 006fc3495ec00e9ac018442f74379f2f1f0f6e8e Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 16:45:47 +0100 Subject: [PATCH 05/12] fix: stream_message_template --- backend/apps/webui/main.py | 41 +++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 51fa711ca..e64995af5 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -287,6 +287,26 @@ def get_extra_params(metadata: dict): } +def add_model_params(params: dict, form_data: dict) -> dict: + if not params: + return form_data + + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + } + + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) + + return form_data + + async def generate_function_chat_completion(form_data, user): print("entry point") model_id = form_data.get("model") @@ -300,24 +320,9 @@ async def generate_function_chat_completion(form_data, user): form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - - if params: - mappings = { - "temperature": float, - "top_p": int, - "max_tokens": int, - "frequency_penalty": int, - "seed": lambda x: x, - "stop": lambda x: [ - bytes(s, "utf-8").decode("unicode_escape") for s in x - ], - } - - for key, cast_func in mappings.items(): - if (value := params.get(key)) is not None: - form_data[key] = cast_func(value) - system = params.get("system", None) + form_data = add_model_params(params, form_data) + if system: if user: template_params = { @@ -381,7 +386,7 @@ async def generate_function_chat_completion(form_data, user): yield process_line(form_data, line) if isinstance(res, str) or isinstance(res, Generator): - finish_message = stream_message_template(form_data, "") + finish_message = stream_message_template(form_data["model"], "") finish_message["choices"][0]["finish_reason"] = "stop" yield f"data: {json.dumps(finish_message)}\n\n" yield "data: [DONE]" From baf58ef396c0d8652222d1d9edcd6410ce1390eb Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 17:16:07 +0100 Subject: [PATCH 06/12] refac: use add_or_update_system_message --- backend/apps/webui/main.py | 59 +++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index e64995af5..adfb82f2b 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -19,7 +19,11 @@ from apps.webui.models.functions import Functions from apps.webui.models.models import Models from apps.webui.utils import load_function_module_by_id -from utils.misc import stream_message_template, whole_message_template +from utils.misc import ( + stream_message_template, + whole_message_template, + add_or_update_system_message, +) from utils.task import prompt_template @@ -47,8 +51,6 @@ from config import ( from apps.socket.main import get_event_call, get_event_emitter import inspect -import uuid -import time import json from typing import Iterator, Generator, AsyncGenerator @@ -287,6 +289,7 @@ def get_extra_params(metadata: dict): } +# inplace function: form_data is modified def add_model_params(params: dict, form_data: dict) -> dict: if not params: return form_data @@ -307,44 +310,40 @@ def add_model_params(params: dict, form_data: dict) -> dict: return form_data +# inplace function: form_data is modified +def populate_system_message(params: dict, form_data: dict, user) -> dict: + system = params.get("system", None) + if not system: + return form_data + + if user: + template_params = { + "user_name": user.name, + "user_location": user.info.get("location") if user.info else None, + } + else: + template_params = {} + system = prompt_template(system, **template_params) + form_data["messages"] = add_or_update_system_message( + system, form_data.get("messages", []) + ) + return form_data + + async def generate_function_chat_completion(form_data, user): - print("entry point") model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) - metadata = form_data.pop("metadata", None) - extra_params = get_extra_params(metadata) + # Add extra params such as __event_emitter__ + extra_params = get_extra_params(metadata) if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - system = params.get("system", None) form_data = add_model_params(params, form_data) - - if system: - if user: - template_params = { - "user_name": user.name, - "user_location": user.info.get("location") if user.info else None, - } - else: - template_params = {} - - system = prompt_template(system, **template_params) - - # Check if the payload already has a system message - # If not, add a system message to the payload - for message in form_data.get("messages", []): - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - if form_data.get("messages"): - form_data["messages"].insert( - 0, {"role": "system", "content": system} - ) + form_data = populate_system_message(params, form_data, user) async def job(): pipe_id = get_pipe_id(form_data) From 034411e47eda245678249cd23394403d478b7311 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 17:24:00 +0100 Subject: [PATCH 07/12] fix: type not manifold --- backend/apps/webui/main.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index adfb82f2b..2f4da0384 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -147,9 +147,7 @@ async def get_pipe_models(): function_module = get_function_module(pipe.id) # Check if function is a manifold - if hasattr(function_module, "type"): - if not function_module.type == "manifold": - continue + if hasattr(function_module, "type") and function_module.type == "manifold": manifold_pipes = [] # Check if pipes is a function or a list From f8726719ef93f28f92e5af550e4109ce3da098b9 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 21:58:40 +0100 Subject: [PATCH 08/12] refac: rename whole_message_template, silence lsp --- backend/apps/webui/main.py | 4 ++-- backend/utils/misc.py | 47 +++++++++++++++++++------------------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 2f4da0384..46fc00af3 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -21,7 +21,7 @@ from apps.webui.utils import load_function_module_by_id from utils.misc import ( stream_message_template, - whole_message_template, + openai_chat_completion_message_template, add_or_update_system_message, ) from utils.task import prompt_template @@ -403,6 +403,6 @@ async def generate_function_chat_completion(form_data, user): return res.model_dump() message = await get_message_content(res) - return whole_message_template(form_data["model"], message) + return openai_chat_completion_message_template(form_data["model"], message) return await job() diff --git a/backend/utils/misc.py b/backend/utils/misc.py index a1e0a8e80..5e8a39af8 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -1,6 +1,5 @@ from pathlib import Path import hashlib -import json import re from datetime import timedelta from typing import Optional, List, Tuple @@ -8,37 +7,39 @@ import uuid import time -def get_last_user_message_item(messages: List[dict]) -> str: +def get_last_user_message_item(messages: List[dict]) -> Optional[dict]: for message in reversed(messages): if message["role"] == "user": return message return None -def get_last_user_message(messages: List[dict]) -> str: - message = get_last_user_message_item(messages) - - if message is not None: - if isinstance(message["content"], list): - for item in message["content"]: - if item["type"] == "text": - return item["text"] +def get_content_from_message(message: dict) -> Optional[str]: + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + else: return message["content"] return None -def get_last_assistant_message(messages: List[dict]) -> str: +def get_last_user_message(messages: List[dict]) -> Optional[str]: + message = get_last_user_message_item(messages) + if message is None: + return None + + return get_content_from_message(message) + + +def get_last_assistant_message(messages: List[dict]) -> Optional[str]: for message in reversed(messages): if message["role"] == "assistant": - if isinstance(message["content"], list): - for item in message["content"]: - if item["type"] == "text": - return item["text"] - return message["content"] + return get_content_from_message(message) return None -def get_system_message(messages: List[dict]) -> dict: +def get_system_message(messages: List[dict]) -> Optional[dict]: for message in messages: if message["role"] == "system": return message @@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]: return [message for message in messages if message["role"] != "system"] -def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]: +def pop_system_message(messages: List[dict]) -> Tuple[Optional[dict], List[dict]]: return get_system_message(messages), remove_system_message(messages) @@ -103,7 +104,7 @@ def stream_message_template(model: str, message: str): return template -def whole_message_template(model: str, message: str): +def openai_chat_completion_message_template(model: str, message: str): template = message_template(model) template["object"] = "chat.completion" template["choices"][0]["message"] = {"content": message, "role": "assistant"} @@ -180,7 +181,7 @@ def extract_folders_after_data_docs(path): tags = [] folders = parts[index_docs:-1] - for idx, part in enumerate(folders): + for idx, _ in enumerate(folders): tags.append("/".join(folders[: idx + 1])) return tags @@ -276,11 +277,11 @@ def parse_ollama_modelfile(model_text): value = param_match.group(1) try: - if param_type == int: + if param_type is int: value = int(value) - elif param_type == float: + elif param_type is float: value = float(value) - elif param_type == bool: + elif param_type is bool: value = value.lower() == "true" except Exception as e: print(e) From 2e0fa1c6a09fbf1ba35c5286458838a96f6d8d1a Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 22:00:00 +0100 Subject: [PATCH 09/12] refac: rename stream_message_template --- backend/apps/webui/main.py | 12 ++++++++---- backend/utils/misc.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 46fc00af3..69493f5cc 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -20,7 +20,7 @@ from apps.webui.models.models import Models from apps.webui.utils import load_function_module_by_id from utils.misc import ( - stream_message_template, + openai_chat_chunk_message_template, openai_chat_completion_message_template, add_or_update_system_message, ) @@ -227,7 +227,7 @@ def process_line(form_data: dict, line): if line.startswith("data:"): return f"{line}\n\n" else: - line = stream_message_template(form_data["model"], line) + line = openai_chat_chunk_message_template(form_data["model"], line) return f"data: {json.dumps(line)}\n\n" @@ -371,7 +371,9 @@ async def generate_function_chat_completion(form_data, user): return if isinstance(res, str): - message = stream_message_template(form_data["model"], res) + message = openai_chat_chunk_message_template( + form_data["model"], res + ) yield f"data: {json.dumps(message)}\n\n" if isinstance(res, Iterator): @@ -383,7 +385,9 @@ async def generate_function_chat_completion(form_data, user): yield process_line(form_data, line) if isinstance(res, str) or isinstance(res, Generator): - finish_message = stream_message_template(form_data["model"], "") + finish_message = openai_chat_chunk_message_template( + form_data["model"], "" + ) finish_message["choices"][0]["finish_reason"] = "stop" yield f"data: {json.dumps(finish_message)}\n\n" yield "data: [DONE]" diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 5e8a39af8..adbf4f8b2 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -97,7 +97,7 @@ def message_template(model: str): } -def stream_message_template(model: str, message: str): +def openai_chat_chunk_message_template(model: str, message: str): template = message_template(model) template["object"] = "chat.completion.chunk" template["choices"][0]["delta"] = {"content": message} From b9b1fdd1a1ff2b0d9984cb78623fc14886d7566e Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 22:01:22 +0100 Subject: [PATCH 10/12] refac: rename message_template --- backend/utils/misc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/utils/misc.py b/backend/utils/misc.py index adbf4f8b2..c4e2eda6f 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -88,7 +88,7 @@ def add_or_update_system_message(content: str, messages: List[dict]): return messages -def message_template(model: str): +def openai_chat_message_template(model: str): return { "id": f"{model}-{str(uuid.uuid4())}", "created": int(time.time()), @@ -98,14 +98,14 @@ def message_template(model: str): def openai_chat_chunk_message_template(model: str, message: str): - template = message_template(model) + template = openai_chat_message_template(model) template["object"] = "chat.completion.chunk" template["choices"][0]["delta"] = {"content": message} return template def openai_chat_completion_message_template(model: str, message: str): - template = message_template(model) + template = openai_chat_message_template(model) template["object"] = "chat.completion" template["choices"][0]["message"] = {"content": message, "role": "assistant"} template["choices"][0]["finish_reason"] = "stop" From c89b34fd75ed4a76d8d491061a80a1e8a36d2dd3 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 31 Jul 2024 22:05:37 +0100 Subject: [PATCH 11/12] flatten job() --- backend/apps/webui/main.py | 109 ++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 57 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 69493f5cc..96adb5080 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -147,7 +147,7 @@ async def get_pipe_models(): function_module = get_function_module(pipe.id) # Check if function is a manifold - if hasattr(function_module, "type") and function_module.type == "manifold": + if hasattr(function_module, "pipes"): manifold_pipes = [] # Check if pipes is a function or a list @@ -343,70 +343,65 @@ async def generate_function_chat_completion(form_data, user): form_data = add_model_params(params, form_data) form_data = populate_system_message(params, form_data, user) - async def job(): - pipe_id = get_pipe_id(form_data) - function_module = get_function_module(pipe_id) + pipe_id = get_pipe_id(form_data) + function_module = get_function_module(pipe_id) - pipe = function_module.pipe - params = get_params_dict(pipe, form_data, user, extra_params, function_module) + pipe = function_module.pipe + params = get_params_dict(pipe, form_data, user, extra_params, function_module) - if form_data["stream"]: + if form_data["stream"]: - async def stream_content(): - try: - res = await execute_pipe(pipe, params) - - # Directly return if the response is a StreamingResponse - if isinstance(res, StreamingResponse): - async for data in res.body_iterator: - yield data - return - if isinstance(res, dict): - yield f"data: {json.dumps(res)}\n\n" - return - - except Exception as e: - print(f"Error: {e}") - yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" - return - - if isinstance(res, str): - message = openai_chat_chunk_message_template( - form_data["model"], res - ) - yield f"data: {json.dumps(message)}\n\n" - - if isinstance(res, Iterator): - for line in res: - yield process_line(form_data, line) - - if isinstance(res, AsyncGenerator): - async for line in res: - yield process_line(form_data, line) - - if isinstance(res, str) or isinstance(res, Generator): - finish_message = openai_chat_chunk_message_template( - form_data["model"], "" - ) - finish_message["choices"][0]["finish_reason"] = "stop" - yield f"data: {json.dumps(finish_message)}\n\n" - yield "data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - else: + async def stream_content(): try: res = await execute_pipe(pipe, params) + # Directly return if the response is a StreamingResponse + if isinstance(res, StreamingResponse): + async for data in res.body_iterator: + yield data + return + if isinstance(res, dict): + yield f"data: {json.dumps(res)}\n\n" + return + except Exception as e: print(f"Error: {e}") - return {"error": {"detail": str(e)}} + yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" + return - if isinstance(res, StreamingResponse) or isinstance(res, dict): - return res - if isinstance(res, BaseModel): - return res.model_dump() + if isinstance(res, str): + message = openai_chat_chunk_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" - message = await get_message_content(res) - return openai_chat_completion_message_template(form_data["model"], message) + if isinstance(res, Iterator): + for line in res: + yield process_line(form_data, line) - return await job() + if isinstance(res, AsyncGenerator): + async for line in res: + yield process_line(form_data, line) + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = openai_chat_chunk_message_template( + form_data["model"], "" + ) + finish_message["choices"][0]["finish_reason"] = "stop" + yield f"data: {json.dumps(finish_message)}\n\n" + yield "data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + try: + res = await execute_pipe(pipe, params) + + except Exception as e: + print(f"Error: {e}") + return {"error": {"detail": str(e)}} + + if isinstance(res, StreamingResponse) or isinstance(res, dict): + return res + if isinstance(res, BaseModel): + return res.model_dump() + + message = await get_message_content(res) + return openai_chat_completion_message_template(form_data["model"], message) From e6c64282fc920897c89e98c11f0ada475bed9e61 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 2 Aug 2024 01:45:50 +0200 Subject: [PATCH 12/12] refac --- backend/apps/webui/main.py | 53 +++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 96adb5080..972562a04 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -239,10 +239,10 @@ def get_pipe_id(form_data: dict) -> str: return pipe_id -def get_params_dict(pipe, form_data, user, extra_params, function_module): +def get_function_params(function_module, form_data, user, extra_params={}): pipe_id = get_pipe_id(form_data) # Get the signature of the function - sig = inspect.signature(pipe) + sig = inspect.signature(function_module.pipe) params = {"body": form_data} for key, value in extra_params.items(): @@ -269,26 +269,8 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module): return params -def get_extra_params(metadata: dict): - __event_emitter__ = None - __event_call__ = None - __task__ = None - - if metadata: - if all(k in metadata for k in ("session_id", "chat_id", "message_id")): - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - __task__ = metadata.get("task", None) - - return { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - } - - # inplace function: form_data is modified -def add_model_params(params: dict, form_data: dict) -> dict: +def apply_model_params_to_body(params: dict, form_data: dict) -> dict: if not params: return form_data @@ -309,7 +291,7 @@ def add_model_params(params: dict, form_data: dict) -> dict: # inplace function: form_data is modified -def populate_system_message(params: dict, form_data: dict, user) -> dict: +def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict: system = params.get("system", None) if not system: return form_data @@ -333,21 +315,38 @@ async def generate_function_chat_completion(form_data, user): model_info = Models.get_model_by_id(model_id) metadata = form_data.pop("metadata", None) - # Add extra params such as __event_emitter__ - extra_params = get_extra_params(metadata) + __event_emitter__ = None + __event_call__ = None + __task__ = None + + if metadata: + if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + __task__ = metadata.get("task", None) + if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - form_data = add_model_params(params, form_data) - form_data = populate_system_message(params, form_data, user) + form_data = apply_model_params_to_body(params, form_data) + form_data = apply_model_system_prompt_to_body(params, form_data, user) pipe_id = get_pipe_id(form_data) function_module = get_function_module(pipe_id) pipe = function_module.pipe - params = get_params_dict(pipe, form_data, user, extra_params, function_module) + params = get_function_params( + function_module, + form_data, + user, + { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + }, + ) if form_data["stream"]: