mirror of
https://github.com/open-webui/open-webui
synced 2025-02-23 13:51:07 +00:00
1966 lines
75 KiB
Python
1966 lines
75 KiB
Python
import time
|
|
import logging
|
|
import sys
|
|
import os
|
|
import base64
|
|
|
|
import asyncio
|
|
from aiocache import cached
|
|
from typing import Any, Optional
|
|
import random
|
|
import json
|
|
import html
|
|
import inspect
|
|
import re
|
|
import ast
|
|
|
|
from uuid import uuid4
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
from fastapi import Request
|
|
from fastapi import BackgroundTasks
|
|
|
|
from starlette.responses import Response, StreamingResponse
|
|
|
|
|
|
from open_webui.models.chats import Chats
|
|
from open_webui.models.users import Users
|
|
from open_webui.socket.main import (
|
|
get_event_call,
|
|
get_event_emitter,
|
|
get_active_status_by_user_id,
|
|
)
|
|
from open_webui.routers.tasks import (
|
|
generate_queries,
|
|
generate_title,
|
|
generate_image_prompt,
|
|
generate_chat_tags,
|
|
)
|
|
from open_webui.routers.retrieval import process_web_search, SearchForm
|
|
from open_webui.routers.images import image_generations, GenerateImageForm
|
|
from open_webui.routers.pipelines import (
|
|
process_pipeline_inlet_filter,
|
|
process_pipeline_outlet_filter,
|
|
)
|
|
|
|
from open_webui.utils.webhook import post_webhook
|
|
|
|
|
|
from open_webui.models.users import UserModel
|
|
from open_webui.models.functions import Functions
|
|
from open_webui.models.models import Models
|
|
|
|
from open_webui.retrieval.utils import get_sources_from_files
|
|
|
|
|
|
from open_webui.utils.chat import generate_chat_completion
|
|
from open_webui.utils.task import (
|
|
get_task_model_id,
|
|
rag_template,
|
|
tools_function_calling_generation_template,
|
|
)
|
|
from open_webui.utils.misc import (
|
|
deep_update,
|
|
get_message_list,
|
|
add_or_update_system_message,
|
|
add_or_update_user_message,
|
|
get_last_user_message,
|
|
get_last_assistant_message,
|
|
prepend_to_first_user_message_content,
|
|
)
|
|
from open_webui.utils.tools import get_tools
|
|
from open_webui.utils.plugin import load_function_module_by_id
|
|
from open_webui.utils.filter import (
|
|
get_sorted_filter_ids,
|
|
process_filter_functions,
|
|
)
|
|
from open_webui.utils.code_interpreter import execute_code_jupyter
|
|
|
|
from open_webui.tasks import create_task
|
|
|
|
from open_webui.config import (
|
|
CACHE_DIR,
|
|
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
DEFAULT_CODE_INTERPRETER_PROMPT,
|
|
)
|
|
from open_webui.env import (
|
|
SRC_LOG_LEVELS,
|
|
GLOBAL_LOG_LEVEL,
|
|
BYPASS_MODEL_ACCESS_CONTROL,
|
|
ENABLE_REALTIME_CHAT_SAVE,
|
|
)
|
|
from open_webui.constants import TASKS
|
|
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
|
|
|
|
async def chat_completion_tools_handler(
|
|
request: Request, body: dict, user: UserModel, models, tools
|
|
) -> tuple[dict, dict]:
|
|
async def get_content_from_response(response) -> Optional[str]:
|
|
content = None
|
|
if hasattr(response, "body_iterator"):
|
|
async for chunk in response.body_iterator:
|
|
data = json.loads(chunk.decode("utf-8"))
|
|
content = data["choices"][0]["message"]["content"]
|
|
|
|
# Cleanup any remaining background tasks if necessary
|
|
if response.background is not None:
|
|
await response.background()
|
|
else:
|
|
content = response["choices"][0]["message"]["content"]
|
|
return content
|
|
|
|
def get_tools_function_calling_payload(messages, task_model_id, content):
|
|
user_message = get_last_user_message(messages)
|
|
history = "\n".join(
|
|
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
|
for message in messages[::-1][:4]
|
|
)
|
|
|
|
prompt = f"History:\n{history}\nQuery: {user_message}"
|
|
|
|
return {
|
|
"model": task_model_id,
|
|
"messages": [
|
|
{"role": "system", "content": content},
|
|
{"role": "user", "content": f"Query: {prompt}"},
|
|
],
|
|
"stream": False,
|
|
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
|
}
|
|
|
|
task_model_id = get_task_model_id(
|
|
body["model"],
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
skip_files = False
|
|
sources = []
|
|
|
|
specs = [tool["spec"] for tool in tools.values()]
|
|
tools_specs = json.dumps(specs)
|
|
|
|
if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
|
|
template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
|
tools_function_calling_prompt = tools_function_calling_generation_template(
|
|
template, tools_specs
|
|
)
|
|
log.info(f"{tools_function_calling_prompt=}")
|
|
payload = get_tools_function_calling_payload(
|
|
body["messages"], task_model_id, tools_function_calling_prompt
|
|
)
|
|
|
|
try:
|
|
response = await generate_chat_completion(request, form_data=payload, user=user)
|
|
log.debug(f"{response=}")
|
|
content = await get_content_from_response(response)
|
|
log.debug(f"{content=}")
|
|
|
|
if not content:
|
|
return body, {}
|
|
|
|
try:
|
|
content = content[content.find("{") : content.rfind("}") + 1]
|
|
if not content:
|
|
raise Exception("No JSON object found in the response")
|
|
|
|
result = json.loads(content)
|
|
|
|
async def tool_call_handler(tool_call):
|
|
nonlocal skip_files
|
|
|
|
log.debug(f"{tool_call=}")
|
|
|
|
tool_function_name = tool_call.get("name", None)
|
|
if tool_function_name not in tools:
|
|
return body, {}
|
|
|
|
tool_function_params = tool_call.get("parameters", {})
|
|
|
|
try:
|
|
required_params = (
|
|
tools[tool_function_name]
|
|
.get("spec", {})
|
|
.get("parameters", {})
|
|
.get("required", [])
|
|
)
|
|
tool_function = tools[tool_function_name]["callable"]
|
|
tool_function_params = {
|
|
k: v
|
|
for k, v in tool_function_params.items()
|
|
if k in required_params
|
|
}
|
|
tool_output = await tool_function(**tool_function_params)
|
|
|
|
except Exception as e:
|
|
tool_output = str(e)
|
|
|
|
if isinstance(tool_output, str):
|
|
if tools[tool_function_name]["citation"]:
|
|
sources.append(
|
|
{
|
|
"source": {
|
|
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
|
},
|
|
"document": [tool_output],
|
|
"metadata": [
|
|
{
|
|
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
|
}
|
|
],
|
|
}
|
|
)
|
|
else:
|
|
sources.append(
|
|
{
|
|
"source": {},
|
|
"document": [tool_output],
|
|
"metadata": [
|
|
{
|
|
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
|
|
}
|
|
],
|
|
}
|
|
)
|
|
|
|
if tools[tool_function_name]["file_handler"]:
|
|
skip_files = True
|
|
|
|
# check if "tool_calls" in result
|
|
if result.get("tool_calls"):
|
|
for tool_call in result.get("tool_calls"):
|
|
await tool_call_handler(tool_call)
|
|
else:
|
|
await tool_call_handler(result)
|
|
|
|
except Exception as e:
|
|
log.exception(f"Error: {e}")
|
|
content = None
|
|
except Exception as e:
|
|
log.exception(f"Error: {e}")
|
|
content = None
|
|
|
|
log.debug(f"tool_contexts: {sources}")
|
|
|
|
if skip_files and "files" in body.get("metadata", {}):
|
|
del body["metadata"]["files"]
|
|
|
|
return body, {"sources": sources}
|
|
|
|
|
|
async def chat_web_search_handler(
|
|
request: Request, form_data: dict, extra_params: dict, user
|
|
):
|
|
event_emitter = extra_params["__event_emitter__"]
|
|
await event_emitter(
|
|
{
|
|
"type": "status",
|
|
"data": {
|
|
"action": "web_search",
|
|
"description": "Generating search query",
|
|
"done": False,
|
|
},
|
|
}
|
|
)
|
|
|
|
messages = form_data["messages"]
|
|
user_message = get_last_user_message(messages)
|
|
|
|
queries = []
|
|
try:
|
|
res = await generate_queries(
|
|
request,
|
|
{
|
|
"model": form_data["model"],
|
|
"messages": messages,
|
|
"prompt": user_message,
|
|
"type": "web_search",
|
|
},
|
|
user,
|
|
)
|
|
|
|
response = res["choices"][0]["message"]["content"]
|
|
|
|
try:
|
|
bracket_start = response.find("{")
|
|
bracket_end = response.rfind("}") + 1
|
|
|
|
if bracket_start == -1 or bracket_end == -1:
|
|
raise Exception("No JSON object found in the response")
|
|
|
|
response = response[bracket_start:bracket_end]
|
|
queries = json.loads(response)
|
|
queries = queries.get("queries", [])
|
|
except Exception as e:
|
|
queries = [response]
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
queries = [user_message]
|
|
|
|
if len(queries) == 0:
|
|
await event_emitter(
|
|
{
|
|
"type": "status",
|
|
"data": {
|
|
"action": "web_search",
|
|
"description": "No search query generated",
|
|
"done": True,
|
|
},
|
|
}
|
|
)
|
|
return form_data
|
|
|
|
all_results = []
|
|
|
|
for searchQuery in queries:
|
|
await event_emitter(
|
|
{
|
|
"type": "status",
|
|
"data": {
|
|
"action": "web_search",
|
|
"description": 'Searching "{{searchQuery}}"',
|
|
"query": searchQuery,
|
|
"done": False,
|
|
},
|
|
}
|
|
)
|
|
|
|
try:
|
|
results = await process_web_search(
|
|
request,
|
|
SearchForm(
|
|
**{
|
|
"query": searchQuery,
|
|
}
|
|
),
|
|
user=user,
|
|
)
|
|
|
|
if results:
|
|
all_results.append(results)
|
|
files = form_data.get("files", [])
|
|
|
|
if request.app.state.config.RAG_WEB_SEARCH_FULL_CONTEXT:
|
|
files.append(
|
|
{
|
|
"docs": results.get("docs", []),
|
|
"name": searchQuery,
|
|
"type": "web_search_docs",
|
|
"urls": results["filenames"],
|
|
}
|
|
)
|
|
else:
|
|
files.append(
|
|
{
|
|
"collection_name": results["collection_name"],
|
|
"name": searchQuery,
|
|
"type": "web_search_results",
|
|
"urls": results["filenames"],
|
|
}
|
|
)
|
|
form_data["files"] = files
|
|
except Exception as e:
|
|
log.exception(e)
|
|
await event_emitter(
|
|
{
|
|
"type": "status",
|
|
"data": {
|
|
"action": "web_search",
|
|
"description": 'Error searching "{{searchQuery}}"',
|
|
"query": searchQuery,
|
|
"done": True,
|
|
"error": True,
|
|
},
|
|
}
|
|
)
|
|
|
|
if all_results:
|
|
urls = []
|
|
for results in all_results:
|
|
if "filenames" in results:
|
|
urls.extend(results["filenames"])
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "status",
|
|
"data": {
|
|
"action": "web_search",
|
|
"description": "Searched {{count}} sites",
|
|
"urls": urls,
|
|
"done": True,
|
|
},
|
|
}
|
|
)
|
|
else:
|
|
await event_emitter(
|
|
{
|
|
"type": "status",
|
|
"data": {
|
|
"action": "web_search",
|
|
"description": "No search results found",
|
|
"done": True,
|
|
"error": True,
|
|
},
|
|
}
|
|
)
|
|
|
|
return form_data
|
|
|
|
|
|
async def chat_image_generation_handler(
|
|
request: Request, form_data: dict, extra_params: dict, user
|
|
):
|
|
__event_emitter__ = extra_params["__event_emitter__"]
|
|
await __event_emitter__(
|
|
{
|
|
"type": "status",
|
|
"data": {"description": "Generating an image", "done": False},
|
|
}
|
|
)
|
|
|
|
messages = form_data["messages"]
|
|
user_message = get_last_user_message(messages)
|
|
|
|
prompt = user_message
|
|
negative_prompt = ""
|
|
|
|
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
|
|
try:
|
|
res = await generate_image_prompt(
|
|
request,
|
|
{
|
|
"model": form_data["model"],
|
|
"messages": messages,
|
|
},
|
|
user,
|
|
)
|
|
|
|
response = res["choices"][0]["message"]["content"]
|
|
|
|
try:
|
|
bracket_start = response.find("{")
|
|
bracket_end = response.rfind("}") + 1
|
|
|
|
if bracket_start == -1 or bracket_end == -1:
|
|
raise Exception("No JSON object found in the response")
|
|
|
|
response = response[bracket_start:bracket_end]
|
|
response = json.loads(response)
|
|
prompt = response.get("prompt", [])
|
|
except Exception as e:
|
|
prompt = user_message
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
prompt = user_message
|
|
|
|
system_message_content = ""
|
|
|
|
try:
|
|
images = await image_generations(
|
|
request=request,
|
|
form_data=GenerateImageForm(**{"prompt": prompt}),
|
|
user=user,
|
|
)
|
|
|
|
await __event_emitter__(
|
|
{
|
|
"type": "status",
|
|
"data": {"description": "Generated an image", "done": True},
|
|
}
|
|
)
|
|
|
|
for image in images:
|
|
await __event_emitter__(
|
|
{
|
|
"type": "message",
|
|
"data": {"content": f"\n"},
|
|
}
|
|
)
|
|
|
|
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
|
|
except Exception as e:
|
|
log.exception(e)
|
|
await __event_emitter__(
|
|
{
|
|
"type": "status",
|
|
"data": {
|
|
"description": f"An error occurred while generating an image",
|
|
"done": True,
|
|
},
|
|
}
|
|
)
|
|
|
|
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>"
|
|
|
|
if system_message_content:
|
|
form_data["messages"] = add_or_update_system_message(
|
|
system_message_content, form_data["messages"]
|
|
)
|
|
|
|
return form_data
|
|
|
|
|
|
async def chat_completion_files_handler(
|
|
request: Request, body: dict, user: UserModel
|
|
) -> tuple[dict, dict[str, list]]:
|
|
sources = []
|
|
|
|
if files := body.get("metadata", {}).get("files", None):
|
|
try:
|
|
queries_response = await generate_queries(
|
|
request,
|
|
{
|
|
"model": body["model"],
|
|
"messages": body["messages"],
|
|
"type": "retrieval",
|
|
},
|
|
user,
|
|
)
|
|
queries_response = queries_response["choices"][0]["message"]["content"]
|
|
|
|
try:
|
|
bracket_start = queries_response.find("{")
|
|
bracket_end = queries_response.rfind("}") + 1
|
|
|
|
if bracket_start == -1 or bracket_end == -1:
|
|
raise Exception("No JSON object found in the response")
|
|
|
|
queries_response = queries_response[bracket_start:bracket_end]
|
|
queries_response = json.loads(queries_response)
|
|
except Exception as e:
|
|
queries_response = {"queries": [queries_response]}
|
|
|
|
queries = queries_response.get("queries", [])
|
|
except Exception as e:
|
|
queries = []
|
|
|
|
if len(queries) == 0:
|
|
queries = [get_last_user_message(body["messages"])]
|
|
|
|
try:
|
|
# Offload get_sources_from_files to a separate thread
|
|
loop = asyncio.get_running_loop()
|
|
with ThreadPoolExecutor() as executor:
|
|
sources = await loop.run_in_executor(
|
|
executor,
|
|
lambda: get_sources_from_files(
|
|
files=files,
|
|
queries=queries,
|
|
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
|
query, user=user
|
|
),
|
|
k=request.app.state.config.TOP_K,
|
|
reranking_function=request.app.state.rf,
|
|
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
|
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
full_context=request.app.state.config.RAG_FULL_CONTEXT,
|
|
),
|
|
)
|
|
except Exception as e:
|
|
log.exception(e)
|
|
|
|
log.debug(f"rag_contexts:sources: {sources}")
|
|
|
|
return body, {"sources": sources}
|
|
|
|
|
|
def apply_params_to_form_data(form_data, model):
|
|
params = form_data.pop("params", {})
|
|
if model.get("ollama"):
|
|
form_data["options"] = params
|
|
|
|
if "format" in params:
|
|
form_data["format"] = params["format"]
|
|
|
|
if "keep_alive" in params:
|
|
form_data["keep_alive"] = params["keep_alive"]
|
|
else:
|
|
if "seed" in params:
|
|
form_data["seed"] = params["seed"]
|
|
|
|
if "stop" in params:
|
|
form_data["stop"] = params["stop"]
|
|
|
|
if "temperature" in params:
|
|
form_data["temperature"] = params["temperature"]
|
|
|
|
if "max_tokens" in params:
|
|
form_data["max_tokens"] = params["max_tokens"]
|
|
|
|
if "top_p" in params:
|
|
form_data["top_p"] = params["top_p"]
|
|
|
|
if "frequency_penalty" in params:
|
|
form_data["frequency_penalty"] = params["frequency_penalty"]
|
|
|
|
if "reasoning_effort" in params:
|
|
form_data["reasoning_effort"] = params["reasoning_effort"]
|
|
|
|
return form_data
|
|
|
|
|
|
async def process_chat_payload(request, form_data, metadata, user, model):
|
|
|
|
form_data = apply_params_to_form_data(form_data, model)
|
|
log.debug(f"form_data: {form_data}")
|
|
|
|
event_emitter = get_event_emitter(metadata)
|
|
event_call = get_event_call(metadata)
|
|
|
|
extra_params = {
|
|
"__event_emitter__": event_emitter,
|
|
"__event_call__": event_call,
|
|
"__user__": {
|
|
"id": user.id,
|
|
"email": user.email,
|
|
"name": user.name,
|
|
"role": user.role,
|
|
},
|
|
"__metadata__": metadata,
|
|
"__request__": request,
|
|
"__model__": model,
|
|
}
|
|
|
|
# Initialize events to store additional event to be sent to the client
|
|
# Initialize contexts and citation
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
task_model_id = get_task_model_id(
|
|
form_data["model"],
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
events = []
|
|
sources = []
|
|
|
|
user_message = get_last_user_message(form_data["messages"])
|
|
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
|
|
|
|
if model_knowledge:
|
|
await event_emitter(
|
|
{
|
|
"type": "status",
|
|
"data": {
|
|
"action": "knowledge_search",
|
|
"query": user_message,
|
|
"done": False,
|
|
},
|
|
}
|
|
)
|
|
|
|
knowledge_files = []
|
|
for item in model_knowledge:
|
|
if item.get("collection_name"):
|
|
knowledge_files.append(
|
|
{
|
|
"id": item.get("collection_name"),
|
|
"name": item.get("name"),
|
|
"legacy": True,
|
|
}
|
|
)
|
|
elif item.get("collection_names"):
|
|
knowledge_files.append(
|
|
{
|
|
"name": item.get("name"),
|
|
"type": "collection",
|
|
"collection_names": item.get("collection_names"),
|
|
"legacy": True,
|
|
}
|
|
)
|
|
else:
|
|
knowledge_files.append(item)
|
|
|
|
files = form_data.get("files", [])
|
|
files.extend(knowledge_files)
|
|
form_data["files"] = files
|
|
|
|
variables = form_data.pop("variables", None)
|
|
|
|
# Process the form_data through the pipeline
|
|
try:
|
|
form_data = await process_pipeline_inlet_filter(
|
|
request, form_data, user, models
|
|
)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
form_data, flags = await process_filter_functions(
|
|
request=request,
|
|
filter_ids=get_sorted_filter_ids(model),
|
|
filter_type="inlet",
|
|
form_data=form_data,
|
|
extra_params=extra_params,
|
|
)
|
|
except Exception as e:
|
|
raise Exception(f"Error: {e}")
|
|
|
|
features = form_data.pop("features", None)
|
|
if features:
|
|
if "web_search" in features and features["web_search"]:
|
|
form_data = await chat_web_search_handler(
|
|
request, form_data, extra_params, user
|
|
)
|
|
|
|
if "image_generation" in features and features["image_generation"]:
|
|
form_data = await chat_image_generation_handler(
|
|
request, form_data, extra_params, user
|
|
)
|
|
|
|
if "code_interpreter" in features and features["code_interpreter"]:
|
|
form_data["messages"] = add_or_update_user_message(
|
|
(
|
|
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE
|
|
if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != ""
|
|
else DEFAULT_CODE_INTERPRETER_PROMPT
|
|
),
|
|
form_data["messages"],
|
|
)
|
|
|
|
tool_ids = form_data.pop("tool_ids", None)
|
|
files = form_data.pop("files", None)
|
|
# Remove files duplicates
|
|
if files:
|
|
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
|
|
|
metadata = {
|
|
**metadata,
|
|
"tool_ids": tool_ids,
|
|
"files": files,
|
|
}
|
|
form_data["metadata"] = metadata
|
|
|
|
tool_ids = metadata.get("tool_ids", None)
|
|
log.debug(f"{tool_ids=}")
|
|
|
|
if tool_ids:
|
|
# If tool_ids field is present, then get the tools
|
|
tools = get_tools(
|
|
request,
|
|
tool_ids,
|
|
user,
|
|
{
|
|
**extra_params,
|
|
"__model__": models[task_model_id],
|
|
"__messages__": form_data["messages"],
|
|
"__files__": metadata.get("files", []),
|
|
},
|
|
)
|
|
log.info(f"{tools=}")
|
|
|
|
if metadata.get("function_calling") == "native":
|
|
# If the function calling is native, then call the tools function calling handler
|
|
metadata["tools"] = tools
|
|
form_data["tools"] = [
|
|
{"type": "function", "function": tool.get("spec", {})}
|
|
for tool in tools.values()
|
|
]
|
|
else:
|
|
# If the function calling is not native, then call the tools function calling handler
|
|
try:
|
|
form_data, flags = await chat_completion_tools_handler(
|
|
request, form_data, user, models, tools
|
|
)
|
|
sources.extend(flags.get("sources", []))
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
|
|
try:
|
|
form_data, flags = await chat_completion_files_handler(request, form_data, user)
|
|
sources.extend(flags.get("sources", []))
|
|
except Exception as e:
|
|
log.exception(e)
|
|
|
|
# If context is not empty, insert it into the messages
|
|
if len(sources) > 0:
|
|
context_string = ""
|
|
for source_idx, source in enumerate(sources):
|
|
source_id = source.get("source", {}).get("name", "")
|
|
|
|
if "document" in source:
|
|
for doc_idx, doc_context in enumerate(source["document"]):
|
|
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
|
|
|
|
context_string = context_string.strip()
|
|
prompt = get_last_user_message(form_data["messages"])
|
|
|
|
if prompt is None:
|
|
raise Exception("No user message found")
|
|
if (
|
|
request.app.state.config.RELEVANCE_THRESHOLD == 0
|
|
and context_string.strip() == ""
|
|
):
|
|
log.debug(
|
|
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
|
|
)
|
|
|
|
# Workaround for Ollama 2.0+ system prompt issue
|
|
# TODO: replace with add_or_update_system_message
|
|
if model.get("owned_by") == "ollama":
|
|
form_data["messages"] = prepend_to_first_user_message_content(
|
|
rag_template(
|
|
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
),
|
|
form_data["messages"],
|
|
)
|
|
else:
|
|
form_data["messages"] = add_or_update_system_message(
|
|
rag_template(
|
|
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
),
|
|
form_data["messages"],
|
|
)
|
|
|
|
# If there are citations, add them to the data_items
|
|
sources = [source for source in sources if source.get("source", {}).get("name", "")]
|
|
|
|
if len(sources) > 0:
|
|
events.append({"sources": sources})
|
|
|
|
if model_knowledge:
|
|
await event_emitter(
|
|
{
|
|
"type": "status",
|
|
"data": {
|
|
"action": "knowledge_search",
|
|
"query": user_message,
|
|
"done": True,
|
|
"hidden": True,
|
|
},
|
|
}
|
|
)
|
|
|
|
return form_data, metadata, events
|
|
|
|
|
|
async def process_chat_response(
|
|
request, response, form_data, user, events, metadata, tasks
|
|
):
|
|
async def background_tasks_handler():
|
|
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
|
|
message = message_map.get(metadata["message_id"]) if message_map else None
|
|
|
|
if message:
|
|
messages = get_message_list(message_map, message.get("id"))
|
|
|
|
if tasks and messages:
|
|
if TASKS.TITLE_GENERATION in tasks:
|
|
if tasks[TASKS.TITLE_GENERATION]:
|
|
res = await generate_title(
|
|
request,
|
|
{
|
|
"model": message["model"],
|
|
"messages": messages,
|
|
"chat_id": metadata["chat_id"],
|
|
},
|
|
user,
|
|
)
|
|
|
|
if res and isinstance(res, dict):
|
|
if len(res.get("choices", [])) == 1:
|
|
title_string = (
|
|
res.get("choices", [])[0]
|
|
.get("message", {})
|
|
.get("content", message.get("content", "New Chat"))
|
|
)
|
|
else:
|
|
title_string = ""
|
|
|
|
title_string = title_string[
|
|
title_string.find("{") : title_string.rfind("}") + 1
|
|
]
|
|
|
|
try:
|
|
title = json.loads(title_string).get(
|
|
"title", "New Chat"
|
|
)
|
|
except Exception as e:
|
|
title = ""
|
|
|
|
if not title:
|
|
title = messages[0].get("content", "New Chat")
|
|
|
|
Chats.update_chat_title_by_id(metadata["chat_id"], title)
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:title",
|
|
"data": title,
|
|
}
|
|
)
|
|
elif len(messages) == 2:
|
|
title = messages[0].get("content", "New Chat")
|
|
|
|
Chats.update_chat_title_by_id(metadata["chat_id"], title)
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:title",
|
|
"data": message.get("content", "New Chat"),
|
|
}
|
|
)
|
|
|
|
if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
|
|
res = await generate_chat_tags(
|
|
request,
|
|
{
|
|
"model": message["model"],
|
|
"messages": messages,
|
|
"chat_id": metadata["chat_id"],
|
|
},
|
|
user,
|
|
)
|
|
|
|
if res and isinstance(res, dict):
|
|
if len(res.get("choices", [])) == 1:
|
|
tags_string = (
|
|
res.get("choices", [])[0]
|
|
.get("message", {})
|
|
.get("content", "")
|
|
)
|
|
else:
|
|
tags_string = ""
|
|
|
|
tags_string = tags_string[
|
|
tags_string.find("{") : tags_string.rfind("}") + 1
|
|
]
|
|
|
|
try:
|
|
tags = json.loads(tags_string).get("tags", [])
|
|
Chats.update_chat_tags_by_id(
|
|
metadata["chat_id"], tags, user
|
|
)
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:tags",
|
|
"data": tags,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
pass
|
|
|
|
event_emitter = None
|
|
event_caller = None
|
|
if (
|
|
"session_id" in metadata
|
|
and metadata["session_id"]
|
|
and "chat_id" in metadata
|
|
and metadata["chat_id"]
|
|
and "message_id" in metadata
|
|
and metadata["message_id"]
|
|
):
|
|
event_emitter = get_event_emitter(metadata)
|
|
event_caller = get_event_call(metadata)
|
|
|
|
# Non-streaming response
|
|
if not isinstance(response, StreamingResponse):
|
|
if event_emitter:
|
|
if "selected_model_id" in response:
|
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
metadata["chat_id"],
|
|
metadata["message_id"],
|
|
{
|
|
"selectedModelId": response["selected_model_id"],
|
|
},
|
|
)
|
|
|
|
if response.get("choices", [])[0].get("message", {}).get("content"):
|
|
content = response["choices"][0]["message"]["content"]
|
|
|
|
if content:
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:completion",
|
|
"data": response,
|
|
}
|
|
)
|
|
|
|
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:completion",
|
|
"data": {
|
|
"done": True,
|
|
"content": content,
|
|
"title": title,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Save message in the database
|
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
metadata["chat_id"],
|
|
metadata["message_id"],
|
|
{
|
|
"content": content,
|
|
},
|
|
)
|
|
|
|
# Send a webhook notification if the user is not active
|
|
if get_active_status_by_user_id(user.id) is None:
|
|
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
|
if webhook_url:
|
|
post_webhook(
|
|
request.app.state.WEBUI_NAME,
|
|
webhook_url,
|
|
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
|
{
|
|
"action": "chat",
|
|
"message": content,
|
|
"title": title,
|
|
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
|
},
|
|
)
|
|
|
|
await background_tasks_handler()
|
|
|
|
return response
|
|
else:
|
|
return response
|
|
|
|
# Non standard response
|
|
if not any(
|
|
content_type in response.headers["Content-Type"]
|
|
for content_type in ["text/event-stream", "application/x-ndjson"]
|
|
):
|
|
return response
|
|
|
|
# Streaming response
|
|
if event_emitter and event_caller:
|
|
task_id = str(uuid4()) # Create a unique task ID.
|
|
model_id = form_data.get("model", "")
|
|
|
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
metadata["chat_id"],
|
|
metadata["message_id"],
|
|
{
|
|
"model": model_id,
|
|
},
|
|
)
|
|
|
|
def split_content_and_whitespace(content):
|
|
content_stripped = content.rstrip()
|
|
original_whitespace = (
|
|
content[len(content_stripped) :]
|
|
if len(content) > len(content_stripped)
|
|
else ""
|
|
)
|
|
return content_stripped, original_whitespace
|
|
|
|
def is_opening_code_block(content):
|
|
backtick_segments = content.split("```")
|
|
# Even number of segments means the last backticks are opening a new block
|
|
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
|
|
|
# Handle as a background task
|
|
async def post_response_handler(response, events):
|
|
def serialize_content_blocks(content_blocks, raw=False):
|
|
content = ""
|
|
|
|
for block in content_blocks:
|
|
if block["type"] == "text":
|
|
content = f"{content}{block['content'].strip()}\n"
|
|
elif block["type"] == "tool_calls":
|
|
attributes = block.get("attributes", {})
|
|
|
|
block_content = block.get("content", [])
|
|
results = block.get("results", [])
|
|
|
|
if results:
|
|
|
|
result_display_content = ""
|
|
|
|
for result in results:
|
|
tool_call_id = result.get("tool_call_id", "")
|
|
tool_name = ""
|
|
|
|
for tool_call in block_content:
|
|
if tool_call.get("id", "") == tool_call_id:
|
|
tool_name = tool_call.get("function", {}).get(
|
|
"name", ""
|
|
)
|
|
break
|
|
|
|
result_display_content = f"{result_display_content}\n> {tool_name}: {result.get('content', '')}"
|
|
|
|
if not raw:
|
|
content = f'{content}\n<details type="tool_calls" done="true" content="{html.escape(json.dumps(block_content))}" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n{result_display_content}\n</details>\n'
|
|
else:
|
|
tool_calls_display_content = ""
|
|
|
|
for tool_call in block_content:
|
|
tool_calls_display_content = f"{tool_calls_display_content}\n> Executing {tool_call.get('function', {}).get('name', '')}"
|
|
|
|
if not raw:
|
|
content = f'{content}\n<details type="tool_calls" done="false" content="{html.escape(json.dumps(block_content))}">\n<summary>Tool Executing...</summary>\n{tool_calls_display_content}\n</details>\n'
|
|
|
|
elif block["type"] == "reasoning":
|
|
reasoning_display_content = "\n".join(
|
|
(f"> {line}" if not line.startswith(">") else line)
|
|
for line in block["content"].splitlines()
|
|
)
|
|
|
|
reasoning_duration = block.get("duration", None)
|
|
|
|
if reasoning_duration is not None:
|
|
if raw:
|
|
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
|
else:
|
|
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
|
else:
|
|
if raw:
|
|
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
|
else:
|
|
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
|
|
|
elif block["type"] == "code_interpreter":
|
|
attributes = block.get("attributes", {})
|
|
output = block.get("output", None)
|
|
lang = attributes.get("lang", "")
|
|
|
|
content_stripped, original_whitespace = (
|
|
split_content_and_whitespace(content)
|
|
)
|
|
if is_opening_code_block(content_stripped):
|
|
# Remove trailing backticks that would open a new block
|
|
content = (
|
|
content_stripped.rstrip("`").rstrip()
|
|
+ original_whitespace
|
|
)
|
|
else:
|
|
# Keep content as is - either closing backticks or no backticks
|
|
content = content_stripped + original_whitespace
|
|
|
|
if output:
|
|
output = html.escape(json.dumps(output))
|
|
|
|
if raw:
|
|
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n'
|
|
else:
|
|
content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
|
else:
|
|
if raw:
|
|
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n'
|
|
else:
|
|
content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n'
|
|
|
|
else:
|
|
block_content = str(block["content"]).strip()
|
|
content = f"{content}{block['type']}: {block_content}\n"
|
|
|
|
return content.strip()
|
|
|
|
def convert_content_blocks_to_messages(content_blocks):
|
|
messages = []
|
|
|
|
temp_blocks = []
|
|
for idx, block in enumerate(content_blocks):
|
|
if block["type"] == "tool_calls":
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": serialize_content_blocks(temp_blocks),
|
|
"tool_calls": block.get("content"),
|
|
}
|
|
)
|
|
|
|
results = block.get("results", [])
|
|
|
|
for result in results:
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": result["tool_call_id"],
|
|
"content": result["content"],
|
|
}
|
|
)
|
|
temp_blocks = []
|
|
else:
|
|
temp_blocks.append(block)
|
|
|
|
if temp_blocks:
|
|
content = serialize_content_blocks(temp_blocks)
|
|
if content:
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": content,
|
|
}
|
|
)
|
|
|
|
return messages
|
|
|
|
def tag_content_handler(content_type, tags, content, content_blocks):
|
|
end_flag = False
|
|
|
|
def extract_attributes(tag_content):
|
|
"""Extract attributes from a tag if they exist."""
|
|
attributes = {}
|
|
if not tag_content: # Ensure tag_content is not None
|
|
return attributes
|
|
# Match attributes in the format: key="value" (ignores single quotes for simplicity)
|
|
matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content)
|
|
for key, value in matches:
|
|
attributes[key] = value
|
|
return attributes
|
|
|
|
if content_blocks[-1]["type"] == "text":
|
|
for tag in tags:
|
|
# Match start tag e.g., <tag> or <tag attr="value">
|
|
start_tag_pattern = rf"<{tag}(\s.*?)?>"
|
|
match = re.search(start_tag_pattern, content)
|
|
if match:
|
|
attr_content = (
|
|
match.group(1) if match.group(1) else ""
|
|
) # Ensure it's not None
|
|
attributes = extract_attributes(
|
|
attr_content
|
|
) # Extract attributes safely
|
|
|
|
# Capture everything before and after the matched tag
|
|
before_tag = content[
|
|
: match.start()
|
|
] # Content before opening tag
|
|
after_tag = content[
|
|
match.end() :
|
|
] # Content after opening tag
|
|
|
|
# Remove the start tag and after from the currently handling text block
|
|
content_blocks[-1]["content"] = content_blocks[-1][
|
|
"content"
|
|
].replace(match.group(0) + after_tag, "")
|
|
|
|
if before_tag:
|
|
content_blocks[-1]["content"] = before_tag
|
|
|
|
if not content_blocks[-1]["content"]:
|
|
content_blocks.pop()
|
|
|
|
# Append the new block
|
|
content_blocks.append(
|
|
{
|
|
"type": content_type,
|
|
"tag": tag,
|
|
"attributes": attributes,
|
|
"content": "",
|
|
"started_at": time.time(),
|
|
}
|
|
)
|
|
|
|
if after_tag:
|
|
content_blocks[-1]["content"] = after_tag
|
|
|
|
break
|
|
elif content_blocks[-1]["type"] == content_type:
|
|
tag = content_blocks[-1]["tag"]
|
|
# Match end tag e.g., </tag>
|
|
end_tag_pattern = rf"</{tag}>"
|
|
|
|
# Check if the content has the end tag
|
|
if re.search(end_tag_pattern, content):
|
|
end_flag = True
|
|
|
|
block_content = content_blocks[-1]["content"]
|
|
# Strip start and end tags from the content
|
|
start_tag_pattern = rf"<{tag}(.*?)>"
|
|
block_content = re.sub(
|
|
start_tag_pattern, "", block_content
|
|
).strip()
|
|
|
|
end_tag_regex = re.compile(end_tag_pattern, re.DOTALL)
|
|
split_content = end_tag_regex.split(block_content, maxsplit=1)
|
|
|
|
# Content inside the tag
|
|
block_content = (
|
|
split_content[0].strip() if split_content else ""
|
|
)
|
|
|
|
# Leftover content (everything after `</tag>`)
|
|
leftover_content = (
|
|
split_content[1].strip() if len(split_content) > 1 else ""
|
|
)
|
|
|
|
if block_content:
|
|
content_blocks[-1]["content"] = block_content
|
|
content_blocks[-1]["ended_at"] = time.time()
|
|
content_blocks[-1]["duration"] = int(
|
|
content_blocks[-1]["ended_at"]
|
|
- content_blocks[-1]["started_at"]
|
|
)
|
|
|
|
# Reset the content_blocks by appending a new text block
|
|
if content_type != "code_interpreter":
|
|
if leftover_content:
|
|
|
|
content_blocks.append(
|
|
{
|
|
"type": "text",
|
|
"content": leftover_content,
|
|
}
|
|
)
|
|
else:
|
|
content_blocks.append(
|
|
{
|
|
"type": "text",
|
|
"content": "",
|
|
}
|
|
)
|
|
|
|
else:
|
|
# Remove the block if content is empty
|
|
content_blocks.pop()
|
|
|
|
if leftover_content:
|
|
content_blocks.append(
|
|
{
|
|
"type": "text",
|
|
"content": leftover_content,
|
|
}
|
|
)
|
|
else:
|
|
content_blocks.append(
|
|
{
|
|
"type": "text",
|
|
"content": "",
|
|
}
|
|
)
|
|
|
|
# Clean processed content
|
|
content = re.sub(
|
|
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
|
|
"",
|
|
content,
|
|
flags=re.DOTALL,
|
|
)
|
|
|
|
return content, content_blocks, end_flag
|
|
|
|
message = Chats.get_message_by_id_and_message_id(
|
|
metadata["chat_id"], metadata["message_id"]
|
|
)
|
|
|
|
tool_calls = []
|
|
|
|
last_assistant_message = None
|
|
try:
|
|
if form_data["messages"][-1]["role"] == "assistant":
|
|
last_assistant_message = get_last_assistant_message(
|
|
form_data["messages"]
|
|
)
|
|
except Exception as e:
|
|
pass
|
|
|
|
content = (
|
|
message.get("content", "")
|
|
if message
|
|
else last_assistant_message if last_assistant_message else ""
|
|
)
|
|
|
|
content_blocks = [
|
|
{
|
|
"type": "text",
|
|
"content": content,
|
|
}
|
|
]
|
|
|
|
# We might want to disable this by default
|
|
DETECT_REASONING = True
|
|
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
|
|
"code_interpreter", False
|
|
)
|
|
|
|
reasoning_tags = [
|
|
"think",
|
|
"thinking",
|
|
"reason",
|
|
"reasoning",
|
|
"thought",
|
|
"Thought",
|
|
]
|
|
code_interpreter_tags = ["code_interpreter"]
|
|
|
|
try:
|
|
for event in events:
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:completion",
|
|
"data": event,
|
|
}
|
|
)
|
|
|
|
# Save message in the database
|
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
metadata["chat_id"],
|
|
metadata["message_id"],
|
|
{
|
|
**event,
|
|
},
|
|
)
|
|
|
|
async def stream_body_handler(response):
|
|
nonlocal content
|
|
nonlocal content_blocks
|
|
|
|
response_tool_calls = []
|
|
|
|
async for line in response.body_iterator:
|
|
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
|
data = line
|
|
|
|
# Skip empty lines
|
|
if not data.strip():
|
|
continue
|
|
|
|
# "data:" is the prefix for each event
|
|
if not data.startswith("data:"):
|
|
continue
|
|
|
|
# Remove the prefix
|
|
data = data[len("data:") :].strip()
|
|
|
|
try:
|
|
data = json.loads(data)
|
|
|
|
if "selected_model_id" in data:
|
|
model_id = data["selected_model_id"]
|
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
metadata["chat_id"],
|
|
metadata["message_id"],
|
|
{
|
|
"selectedModelId": model_id,
|
|
},
|
|
)
|
|
else:
|
|
choices = data.get("choices", [])
|
|
if not choices:
|
|
continue
|
|
|
|
delta = choices[0].get("delta", {})
|
|
delta_tool_calls = delta.get("tool_calls", None)
|
|
|
|
if delta_tool_calls:
|
|
for delta_tool_call in delta_tool_calls:
|
|
tool_call_index = delta_tool_call.get("index")
|
|
|
|
if tool_call_index is not None:
|
|
if (
|
|
len(response_tool_calls)
|
|
<= tool_call_index
|
|
):
|
|
response_tool_calls.append(
|
|
delta_tool_call
|
|
)
|
|
else:
|
|
delta_name = delta_tool_call.get(
|
|
"function", {}
|
|
).get("name")
|
|
delta_arguments = delta_tool_call.get(
|
|
"function", {}
|
|
).get("arguments")
|
|
|
|
if delta_name:
|
|
response_tool_calls[
|
|
tool_call_index
|
|
]["function"]["name"] += delta_name
|
|
|
|
if delta_arguments:
|
|
response_tool_calls[
|
|
tool_call_index
|
|
]["function"][
|
|
"arguments"
|
|
] += delta_arguments
|
|
|
|
value = delta.get("content")
|
|
|
|
if value:
|
|
content = f"{content}{value}"
|
|
|
|
if not content_blocks:
|
|
content_blocks.append(
|
|
{
|
|
"type": "text",
|
|
"content": "",
|
|
}
|
|
)
|
|
|
|
content_blocks[-1]["content"] = (
|
|
content_blocks[-1]["content"] + value
|
|
)
|
|
|
|
if DETECT_REASONING:
|
|
content, content_blocks, _ = (
|
|
tag_content_handler(
|
|
"reasoning",
|
|
reasoning_tags,
|
|
content,
|
|
content_blocks,
|
|
)
|
|
)
|
|
|
|
if DETECT_CODE_INTERPRETER:
|
|
content, content_blocks, end = (
|
|
tag_content_handler(
|
|
"code_interpreter",
|
|
code_interpreter_tags,
|
|
content,
|
|
content_blocks,
|
|
)
|
|
)
|
|
|
|
if end:
|
|
break
|
|
|
|
if ENABLE_REALTIME_CHAT_SAVE:
|
|
# Save message in the database
|
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
metadata["chat_id"],
|
|
metadata["message_id"],
|
|
{
|
|
"content": serialize_content_blocks(
|
|
content_blocks
|
|
),
|
|
},
|
|
)
|
|
else:
|
|
data = {
|
|
"content": serialize_content_blocks(
|
|
content_blocks
|
|
),
|
|
}
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:completion",
|
|
"data": data,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
done = "data: [DONE]" in line
|
|
if done:
|
|
pass
|
|
else:
|
|
log.debug("Error: ", e)
|
|
continue
|
|
|
|
if content_blocks:
|
|
# Clean up the last text block
|
|
if content_blocks[-1]["type"] == "text":
|
|
content_blocks[-1]["content"] = content_blocks[-1][
|
|
"content"
|
|
].strip()
|
|
|
|
if not content_blocks[-1]["content"]:
|
|
content_blocks.pop()
|
|
|
|
if not content_blocks:
|
|
content_blocks.append(
|
|
{
|
|
"type": "text",
|
|
"content": "",
|
|
}
|
|
)
|
|
|
|
if response_tool_calls:
|
|
tool_calls.append(response_tool_calls)
|
|
|
|
if response.background:
|
|
await response.background()
|
|
|
|
await stream_body_handler(response)
|
|
|
|
MAX_TOOL_CALL_RETRIES = 5
|
|
tool_call_retries = 0
|
|
|
|
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
|
|
tool_call_retries += 1
|
|
|
|
response_tool_calls = tool_calls.pop(0)
|
|
|
|
content_blocks.append(
|
|
{
|
|
"type": "tool_calls",
|
|
"content": response_tool_calls,
|
|
}
|
|
)
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:completion",
|
|
"data": {
|
|
"content": serialize_content_blocks(content_blocks),
|
|
},
|
|
}
|
|
)
|
|
|
|
tools = metadata.get("tools", {})
|
|
|
|
results = []
|
|
for tool_call in response_tool_calls:
|
|
tool_call_id = tool_call.get("id", "")
|
|
tool_name = tool_call.get("function", {}).get("name", "")
|
|
|
|
tool_function_params = {}
|
|
try:
|
|
# json.loads cannot be used because some models do not produce valid JSON
|
|
tool_function_params = ast.literal_eval(
|
|
tool_call.get("function", {}).get("arguments", "{}")
|
|
)
|
|
except Exception as e:
|
|
log.debug(e)
|
|
|
|
tool_result = None
|
|
|
|
if tool_name in tools:
|
|
tool = tools[tool_name]
|
|
spec = tool.get("spec", {})
|
|
|
|
try:
|
|
required_params = spec.get("parameters", {}).get(
|
|
"required", []
|
|
)
|
|
tool_function = tool["callable"]
|
|
tool_function_params = {
|
|
k: v
|
|
for k, v in tool_function_params.items()
|
|
if k in required_params
|
|
}
|
|
tool_result = await tool_function(
|
|
**tool_function_params
|
|
)
|
|
except Exception as e:
|
|
tool_result = str(e)
|
|
|
|
results.append(
|
|
{
|
|
"tool_call_id": tool_call_id,
|
|
"content": tool_result,
|
|
}
|
|
)
|
|
|
|
content_blocks[-1]["results"] = results
|
|
|
|
content_blocks.append(
|
|
{
|
|
"type": "text",
|
|
"content": "",
|
|
}
|
|
)
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:completion",
|
|
"data": {
|
|
"content": serialize_content_blocks(content_blocks),
|
|
},
|
|
}
|
|
)
|
|
|
|
try:
|
|
res = await generate_chat_completion(
|
|
request,
|
|
{
|
|
"model": model_id,
|
|
"stream": True,
|
|
"tools": form_data["tools"],
|
|
"messages": [
|
|
*form_data["messages"],
|
|
*convert_content_blocks_to_messages(content_blocks),
|
|
],
|
|
},
|
|
user,
|
|
)
|
|
|
|
if isinstance(res, StreamingResponse):
|
|
await stream_body_handler(res)
|
|
else:
|
|
break
|
|
except Exception as e:
|
|
log.debug(e)
|
|
break
|
|
|
|
if DETECT_CODE_INTERPRETER:
|
|
MAX_RETRIES = 5
|
|
retries = 0
|
|
|
|
while (
|
|
content_blocks[-1]["type"] == "code_interpreter"
|
|
and retries < MAX_RETRIES
|
|
):
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:completion",
|
|
"data": {
|
|
"content": serialize_content_blocks(content_blocks),
|
|
},
|
|
}
|
|
)
|
|
|
|
retries += 1
|
|
log.debug(f"Attempt count: {retries}")
|
|
|
|
output = ""
|
|
try:
|
|
if content_blocks[-1]["attributes"].get("type") == "code":
|
|
code = content_blocks[-1]["content"]
|
|
|
|
if (
|
|
request.app.state.config.CODE_INTERPRETER_ENGINE
|
|
== "pyodide"
|
|
):
|
|
output = await event_caller(
|
|
{
|
|
"type": "execute:python",
|
|
"data": {
|
|
"id": str(uuid4()),
|
|
"code": code,
|
|
"session_id": metadata.get(
|
|
"session_id", None
|
|
),
|
|
},
|
|
}
|
|
)
|
|
elif (
|
|
request.app.state.config.CODE_INTERPRETER_ENGINE
|
|
== "jupyter"
|
|
):
|
|
output = await execute_code_jupyter(
|
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
|
code,
|
|
(
|
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
|
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
|
== "token"
|
|
else None
|
|
),
|
|
(
|
|
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
|
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
|
== "password"
|
|
else None
|
|
),
|
|
)
|
|
else:
|
|
output = {
|
|
"stdout": "Code interpreter engine not configured."
|
|
}
|
|
|
|
log.debug(f"Code interpreter output: {output}")
|
|
|
|
if isinstance(output, dict):
|
|
stdout = output.get("stdout", "")
|
|
|
|
if isinstance(stdout, str):
|
|
stdoutLines = stdout.split("\n")
|
|
for idx, line in enumerate(stdoutLines):
|
|
if "data:image/png;base64" in line:
|
|
id = str(uuid4())
|
|
|
|
# ensure the path exists
|
|
os.makedirs(
|
|
os.path.join(CACHE_DIR, "images"),
|
|
exist_ok=True,
|
|
)
|
|
|
|
image_path = os.path.join(
|
|
CACHE_DIR,
|
|
f"images/{id}.png",
|
|
)
|
|
|
|
with open(image_path, "wb") as f:
|
|
f.write(
|
|
base64.b64decode(
|
|
line.split(",")[1]
|
|
)
|
|
)
|
|
|
|
stdoutLines[idx] = (
|
|
f""
|
|
)
|
|
|
|
output["stdout"] = "\n".join(stdoutLines)
|
|
|
|
result = output.get("result", "")
|
|
|
|
if isinstance(result, str):
|
|
resultLines = result.split("\n")
|
|
for idx, line in enumerate(resultLines):
|
|
if "data:image/png;base64" in line:
|
|
id = str(uuid4())
|
|
|
|
# ensure the path exists
|
|
os.makedirs(
|
|
os.path.join(CACHE_DIR, "images"),
|
|
exist_ok=True,
|
|
)
|
|
|
|
image_path = os.path.join(
|
|
CACHE_DIR,
|
|
f"images/{id}.png",
|
|
)
|
|
|
|
with open(image_path, "wb") as f:
|
|
f.write(
|
|
base64.b64decode(
|
|
line.split(",")[1]
|
|
)
|
|
)
|
|
|
|
resultLines[idx] = (
|
|
f""
|
|
)
|
|
|
|
output["result"] = "\n".join(resultLines)
|
|
except Exception as e:
|
|
output = str(e)
|
|
|
|
content_blocks[-1]["output"] = output
|
|
|
|
content_blocks.append(
|
|
{
|
|
"type": "text",
|
|
"content": "",
|
|
}
|
|
)
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:completion",
|
|
"data": {
|
|
"content": serialize_content_blocks(content_blocks),
|
|
},
|
|
}
|
|
)
|
|
|
|
print(content_blocks, serialize_content_blocks(content_blocks))
|
|
|
|
try:
|
|
res = await generate_chat_completion(
|
|
request,
|
|
{
|
|
"model": model_id,
|
|
"stream": True,
|
|
"messages": [
|
|
*form_data["messages"],
|
|
{
|
|
"role": "assistant",
|
|
"content": serialize_content_blocks(
|
|
content_blocks, raw=True
|
|
),
|
|
},
|
|
],
|
|
},
|
|
user,
|
|
)
|
|
|
|
if isinstance(res, StreamingResponse):
|
|
await stream_body_handler(res)
|
|
else:
|
|
break
|
|
except Exception as e:
|
|
log.debug(e)
|
|
break
|
|
|
|
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
|
data = {
|
|
"done": True,
|
|
"content": serialize_content_blocks(content_blocks),
|
|
"title": title,
|
|
}
|
|
|
|
if not ENABLE_REALTIME_CHAT_SAVE:
|
|
# Save message in the database
|
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
metadata["chat_id"],
|
|
metadata["message_id"],
|
|
{
|
|
"content": serialize_content_blocks(content_blocks),
|
|
},
|
|
)
|
|
|
|
# Send a webhook notification if the user is not active
|
|
if get_active_status_by_user_id(user.id) is None:
|
|
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
|
if webhook_url:
|
|
post_webhook(
|
|
request.app.state.WEBUI_NAME,
|
|
webhook_url,
|
|
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
|
|
{
|
|
"action": "chat",
|
|
"message": content,
|
|
"title": title,
|
|
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
|
|
},
|
|
)
|
|
|
|
await event_emitter(
|
|
{
|
|
"type": "chat:completion",
|
|
"data": data,
|
|
}
|
|
)
|
|
|
|
await background_tasks_handler()
|
|
except asyncio.CancelledError:
|
|
print("Task was cancelled!")
|
|
await event_emitter({"type": "task-cancelled"})
|
|
|
|
if not ENABLE_REALTIME_CHAT_SAVE:
|
|
# Save message in the database
|
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
metadata["chat_id"],
|
|
metadata["message_id"],
|
|
{
|
|
"content": serialize_content_blocks(content_blocks),
|
|
},
|
|
)
|
|
|
|
if response.background is not None:
|
|
await response.background()
|
|
|
|
# background_tasks.add_task(post_response_handler, response, events)
|
|
task_id, _ = create_task(post_response_handler(response, events))
|
|
return {"status": True, "task_id": task_id}
|
|
|
|
else:
|
|
|
|
# Fallback to the original response
|
|
async def stream_wrapper(original_generator, events):
|
|
def wrap_item(item):
|
|
return f"data: {item}\n\n"
|
|
|
|
for event in events:
|
|
yield wrap_item(json.dumps(event))
|
|
|
|
async for data in original_generator:
|
|
yield data
|
|
|
|
return StreamingResponse(
|
|
stream_wrapper(response.body_iterator, events),
|
|
headers=dict(response.headers),
|
|
background=response.background,
|
|
)
|