This commit is contained in:
Timothy J. Baek 2024-07-11 13:43:44 -07:00
parent 9ab97b834a
commit f462744fc8
4 changed files with 33 additions and 12 deletions

View File

@ -728,8 +728,10 @@ async def generate_chat_completion(
)
payload = {
**form_data.model_dump(exclude_none=True),
**form_data.model_dump(exclude_none=True, exclude=["metadata"]),
}
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.model
model_info = Models.get_model_by_id(model_id)
@ -894,9 +896,9 @@ async def generate_openai_chat_completion(
):
form_data = OpenAIChatCompletionForm(**form_data)
payload = {
**form_data.model_dump(exclude_none=True),
}
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.model
model_info = Models.get_model_by_id(model_id)

View File

@ -357,6 +357,8 @@ async def generate_chat_completion(
):
idx = 0
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)

View File

@ -20,7 +20,6 @@ from apps.webui.routers import (
)
from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template
@ -53,7 +52,7 @@ import uuid
import time
import json
from typing import Iterator, Generator
from typing import Iterator, Generator, Optional
from pydantic import BaseModel
app = FastAPI()
@ -193,6 +192,14 @@ async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = None
if "metadata" in form_data:
metadata = form_data["metadata"]
del form_data["metadata"]
if metadata:
print(metadata)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id

View File

@ -618,6 +618,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)},
)
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
task = None
if "task" in body:
task = body["task"]
del body["task"]
# Extract session_id, chat_id and message_id from the request body
session_id = None
if "session_id" in body:
@ -632,6 +638,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id = body["id"]
del body["id"]
__event_emitter__ = await get_event_emitter(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
)
@ -691,6 +699,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(citations) > 0:
data_items.append({"citations": citations})
body["metadata"] = {
"session_id": session_id,
"chat_id": chat_id,
"message_id": message_id,
"task": task,
}
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
@ -811,9 +826,6 @@ def filter_pipeline(payload, user):
if "detail" in res:
raise Exception(r.status_code, res["detail"])
if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
del payload["task"]
return payload
@ -1024,11 +1036,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
pipe = model.get("pipe")
if pipe:
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)