From f462744fc8d9a164fde286966084b8156058e4c3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 11 Jul 2024 13:43:44 -0700 Subject: [PATCH] refac --- backend/apps/ollama/main.py | 10 ++++++---- backend/apps/openai/main.py | 2 ++ backend/apps/webui/main.py | 11 +++++++++-- backend/main.py | 22 ++++++++++++++++------ 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index fd4ba7b06..09df6cc33 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -728,8 +728,10 @@ async def generate_chat_completion( ) payload = { - **form_data.model_dump(exclude_none=True), + **form_data.model_dump(exclude_none=True, exclude=["metadata"]), } + if "metadata" in payload: + del payload["metadata"] model_id = form_data.model model_info = Models.get_model_by_id(model_id) @@ -894,9 +896,9 @@ async def generate_openai_chat_completion( ): form_data = OpenAIChatCompletionForm(**form_data) - payload = { - **form_data.model_dump(exclude_none=True), - } + payload = {**form_data} + if "metadata" in payload: + del payload["metadata"] model_id = form_data.model model_info = Models.get_model_by_id(model_id) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 7c67c40ae..8cd321802 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -357,6 +357,8 @@ async def generate_chat_completion( ): idx = 0 payload = {**form_data} + if "metadata" in payload: + del payload["metadata"] model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index ab28868ae..7a0be2d22 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -20,7 +20,6 @@ from apps.webui.routers import ( ) from apps.webui.models.functions import Functions from apps.webui.models.models import Models - from apps.webui.utils import load_function_module_by_id from utils.misc import stream_message_template @@ -53,7 +52,7 @@ import uuid import time import json -from typing import Iterator, Generator +from typing import Iterator, Generator, Optional from pydantic import BaseModel app = FastAPI() @@ -193,6 +192,14 @@ async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) + metadata = None + if "metadata" in form_data: + metadata = form_data["metadata"] + del form_data["metadata"] + + if metadata: + print(metadata) + if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id diff --git a/backend/main.py b/backend/main.py index 869f88908..1ce3f699a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -618,6 +618,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): content={"detail": str(e)}, ) + # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. + task = None + if "task" in body: + task = body["task"] + del body["task"] + # Extract session_id, chat_id and message_id from the request body session_id = None if "session_id" in body: @@ -632,6 +638,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): message_id = body["id"] del body["id"] + + __event_emitter__ = await get_event_emitter( {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} ) @@ -691,6 +699,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if len(citations) > 0: data_items.append({"citations": citations}) + body["metadata"] = { + "session_id": session_id, + "chat_id": chat_id, + "message_id": message_id, + "task": task, + } + modified_body_bytes = json.dumps(body).encode("utf-8") # Replace the request body with the modified one request._body = modified_body_bytes @@ -811,9 +826,6 @@ def filter_pipeline(payload, user): if "detail" in res: raise Exception(r.status_code, res["detail"]) - if "pipeline" not in app.state.MODELS[model_id] and "task" in payload: - del payload["task"] - return payload @@ -1024,11 +1036,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) - model = app.state.MODELS[model_id] - pipe = model.get("pipe") - if pipe: + if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user)