This commit is contained in:
Timothy J. Baek 2024-07-11 13:43:44 -07:00
parent 9ab97b834a
commit f462744fc8
4 changed files with 33 additions and 12 deletions

View File

@ -728,8 +728,10 @@ async def generate_chat_completion(
) )
payload = { payload = {
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
} }
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
@ -894,9 +896,9 @@ async def generate_openai_chat_completion(
): ):
form_data = OpenAIChatCompletionForm(**form_data) form_data = OpenAIChatCompletionForm(**form_data)
payload = { payload = {**form_data}
**form_data.model_dump(exclude_none=True), if "metadata" in payload:
} del payload["metadata"]
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)

View File

@ -357,6 +357,8 @@ async def generate_chat_completion(
): ):
idx = 0 idx = 0
payload = {**form_data} payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)

View File

@ -20,7 +20,6 @@ from apps.webui.routers import (
) )
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template from utils.misc import stream_message_template
@ -53,7 +52,7 @@ import uuid
import time import time
import json import json
from typing import Iterator, Generator from typing import Iterator, Generator, Optional
from pydantic import BaseModel from pydantic import BaseModel
app = FastAPI() app = FastAPI()
@ -193,6 +192,14 @@ async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
metadata = None
if "metadata" in form_data:
metadata = form_data["metadata"]
del form_data["metadata"]
if metadata:
print(metadata)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
form_data["model"] = model_info.base_model_id form_data["model"] = model_info.base_model_id

View File

@ -618,6 +618,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)}, content={"detail": str(e)},
) )
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
task = None
if "task" in body:
task = body["task"]
del body["task"]
# Extract session_id, chat_id and message_id from the request body # Extract session_id, chat_id and message_id from the request body
session_id = None session_id = None
if "session_id" in body: if "session_id" in body:
@ -632,6 +638,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id = body["id"] message_id = body["id"]
del body["id"] del body["id"]
__event_emitter__ = await get_event_emitter( __event_emitter__ = await get_event_emitter(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id} {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
) )
@ -691,6 +699,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(citations) > 0: if len(citations) > 0:
data_items.append({"citations": citations}) data_items.append({"citations": citations})
body["metadata"] = {
"session_id": session_id,
"chat_id": chat_id,
"message_id": message_id,
"task": task,
}
modified_body_bytes = json.dumps(body).encode("utf-8") modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
@ -811,9 +826,6 @@ def filter_pipeline(payload, user):
if "detail" in res: if "detail" in res:
raise Exception(r.status_code, res["detail"]) raise Exception(r.status_code, res["detail"])
if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
del payload["task"]
return payload return payload
@ -1024,11 +1036,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
pipe = model.get("pipe") if model.get("pipe"):
if pipe:
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user) return await generate_ollama_chat_completion(form_data, user=user)