mirror of
https://github.com/open-webui/open-webui
synced 2025-06-09 07:56:42 +00:00
This feature allows you to generate and iteratively refine visual designs (like posters, logos, etc.) directly through the chat interface. Key changes include: 1. **Backend - Intent Detection & Processing:** * I've introduced an `ImageGenerationIntentDetector` in `backend/open_webui/utils/intent_processors.py`. This function: * Uses keyword-based intent detection for new designs and refinements. * Extracts prompts from your messages. * Modifies previous prompts for iterative refinements. * Calls the existing `/api/v1/images/generations` endpoint using `httpx`. * Formats the response (image URL or error) as a chat message, including metadata like `is_generated_design`, `original_prompt`, and `engine_used`. * I've integrated this detector into the main chat processing logic in `backend/open_webui/utils/chat.py`. Design generation requests are now handled by the detector, bypassing the LLM if intent is recognized. 2. **Frontend - Image Display:** * My analysis confirmed that existing Svelte components (`MarkdownInlineTokens.svelte` using `Image.svelte`) are capable of rendering Markdown-formatted image URLs (``) sent by the backend. * The `Image.svelte` component also provides an image preview feature. 3. **Design Management (MVP Approach):** * For the MVP, generated images are saved via the existing file upload mechanism. * The chat history, with messages containing image URLs and generation metadata, serves as the primary way for you to access and track their designs and refinements. No new database models for explicit design management were added. 4. **Dependencies:** * I've added `httpx>=0.25.0` to `backend/requirements.txt` to ensure the HTTP client for the intent detector is explicitly listed. 5. **Documentation:** * I've drafted updates for `README.md` to highlight the new "AI-Powered Design Generation" feature, replacing the previous, more basic "Image Generation Integration" description. **Testing Plan:** * I've prepared detailed manual end-to-end test cases, unit test cases for the `ImageGenerationIntentDetector`, and a frontend visual review checklist to guide developer testing. This set of changes provides the core functionality for you to conversationally create and refine designs within Open WebUI.
559 lines
20 KiB
Python
559 lines
20 KiB
Python
import time
|
|
import logging
|
|
import sys
|
|
|
|
from aiocache import cached
|
|
from typing import Any, Optional
|
|
import random
|
|
import json
|
|
import inspect
|
|
import uuid
|
|
import asyncio
|
|
|
|
from fastapi import Request, status
|
|
from starlette.responses import Response, StreamingResponse, JSONResponse
|
|
|
|
|
|
from open_webui.models.users import UserModel
|
|
|
|
from open_webui.socket.main import (
|
|
sio,
|
|
get_event_call,
|
|
get_event_emitter,
|
|
)
|
|
from open_webui.functions import generate_function_chat_completion
|
|
|
|
from open_webui.routers.openai import (
|
|
generate_chat_completion as generate_openai_chat_completion,
|
|
)
|
|
|
|
from open_webui.routers.ollama import (
|
|
generate_chat_completion as generate_ollama_chat_completion,
|
|
)
|
|
|
|
from open_webui.routers.pipelines import (
|
|
process_pipeline_inlet_filter,
|
|
process_pipeline_outlet_filter,
|
|
)
|
|
|
|
from open_webui.models.functions import Functions
|
|
from open_webui.models.models import Models
|
|
|
|
|
|
from open_webui.utils.plugin import load_function_module_by_id
|
|
from open_webui.utils.models import get_all_models, check_model_access
|
|
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
|
from open_webui.utils.response import (
|
|
convert_response_ollama_to_openai,
|
|
convert_streaming_response_ollama_to_openai,
|
|
)
|
|
from open_webui.utils.filter import (
|
|
get_sorted_filter_ids,
|
|
process_filter_functions,
|
|
)
|
|
from open_webui.utils.intent_processors import image_generation_intent_detector
|
|
from open_webui.models.chats import Chats # Required for saving the message
|
|
|
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
|
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
|
|
|
|
async def generate_direct_chat_completion(
|
|
request: Request,
|
|
form_data: dict,
|
|
user: Any,
|
|
models: dict,
|
|
):
|
|
log.info("generate_direct_chat_completion")
|
|
|
|
metadata = form_data.pop("metadata", {})
|
|
|
|
user_id = metadata.get("user_id")
|
|
session_id = metadata.get("session_id")
|
|
request_id = str(uuid.uuid4()) # Generate a unique request ID
|
|
|
|
event_caller = get_event_call(metadata)
|
|
|
|
channel = f"{user_id}:{session_id}:{request_id}"
|
|
|
|
if form_data.get("stream"):
|
|
q = asyncio.Queue()
|
|
|
|
async def message_listener(sid, data):
|
|
"""
|
|
Handle received socket messages and push them into the queue.
|
|
"""
|
|
await q.put(data)
|
|
|
|
# Register the listener
|
|
sio.on(channel, message_listener)
|
|
|
|
# Start processing chat completion in background
|
|
res = await event_caller(
|
|
{
|
|
"type": "request:chat:completion",
|
|
"data": {
|
|
"form_data": form_data,
|
|
"model": models[form_data["model"]],
|
|
"channel": channel,
|
|
"session_id": session_id,
|
|
},
|
|
}
|
|
)
|
|
|
|
log.info(f"res: {res}")
|
|
|
|
if res.get("status", False):
|
|
# Define a generator to stream responses
|
|
async def event_generator():
|
|
nonlocal q
|
|
try:
|
|
while True:
|
|
data = await q.get() # Wait for new messages
|
|
if isinstance(data, dict):
|
|
if "done" in data and data["done"]:
|
|
break # Stop streaming when 'done' is received
|
|
|
|
yield f"data: {json.dumps(data)}\n\n"
|
|
elif isinstance(data, str):
|
|
yield data
|
|
except Exception as e:
|
|
log.debug(f"Error in event generator: {e}")
|
|
pass
|
|
|
|
# Define a background task to run the event generator
|
|
async def background():
|
|
try:
|
|
del sio.handlers["/"][channel]
|
|
except Exception as e:
|
|
pass
|
|
|
|
# Return the streaming response
|
|
return StreamingResponse(
|
|
event_generator(), media_type="text/event-stream", background=background
|
|
)
|
|
else:
|
|
raise Exception(str(res))
|
|
else:
|
|
res = await event_caller(
|
|
{
|
|
"type": "request:chat:completion",
|
|
"data": {
|
|
"form_data": form_data,
|
|
"model": models[form_data["model"]],
|
|
"channel": channel,
|
|
"session_id": session_id,
|
|
},
|
|
}
|
|
)
|
|
|
|
if "error" in res and res["error"]:
|
|
raise Exception(res["error"])
|
|
|
|
return res
|
|
|
|
|
|
async def generate_chat_completion(
|
|
request: Request,
|
|
form_data: dict,
|
|
user: Any,
|
|
bypass_filter: bool = False,
|
|
):
|
|
log.debug(f"generate_chat_completion: {form_data}")
|
|
|
|
# Extract chat_id, user_message, and history for the intent detector
|
|
# The form_data["messages"] is expected to be the full history including the latest user message
|
|
messages = form_data.get("messages", [])
|
|
chat_id = form_data.get("chat_id") # Assuming chat_id is available in form_data
|
|
|
|
if not chat_id and hasattr(request.state, "chat_id"): # Try to get from request state if available
|
|
chat_id = request.state.chat_id
|
|
|
|
log.debug(f"Messages for detector: {messages}")
|
|
|
|
user_message_content = ""
|
|
chat_history_for_detector = []
|
|
|
|
if messages:
|
|
# The last message is the current user message
|
|
user_message_obj = messages[-1]
|
|
if user_message_obj.get("role") == "user":
|
|
user_message_content = user_message_obj.get("content", "")
|
|
# History is all messages before the last one
|
|
chat_history_for_detector = messages[:-1]
|
|
else:
|
|
# This case should ideally not happen if form_data is well-formed for a new turn
|
|
log.warning("Last message is not from user, cannot process with intent detector.")
|
|
# Fall through to normal processing without detector call
|
|
|
|
if user_message_content and not getattr(request.state, "direct", False): # Don't run detector for direct model connections
|
|
log.debug(f"Calling ImageGenerationIntentDetector for chat_id '{chat_id}' with message: '{user_message_content}'")
|
|
try:
|
|
detector_response = await image_generation_intent_detector(
|
|
user_message=user_message_content,
|
|
chat_history=chat_history_for_detector,
|
|
request=request,
|
|
current_chat_id=chat_id,
|
|
)
|
|
except Exception as e:
|
|
log.exception("ImageGenerationIntentDetector raised an exception.")
|
|
detector_response = None # Proceed to LLM on detector error
|
|
|
|
if detector_response:
|
|
log.info(f"ImageGenerationIntentDetector returned a response for chat_id '{chat_id}': {detector_response}")
|
|
|
|
# The detector's response is a complete assistant message dict.
|
|
# This message needs to be added to the chat history.
|
|
# The `generate_chat_completion` function itself typically returns the content
|
|
# that the calling router then uses.
|
|
# For now, we'll assume the calling router is responsible for saving the full chat context
|
|
# including this new message.
|
|
# We need to return it in a format that mimics a non-streaming LLM response if possible,
|
|
# or handle it specially if streaming.
|
|
|
|
# If the original request was for streaming, this is tricky.
|
|
# For simplicity in this integration, if detector responds, we won't stream.
|
|
# We will return the single message. The frontend will need to handle it.
|
|
if form_data.get("stream"):
|
|
log.warn("ImageGenerationIntentDetector responded, but original request was for stream. Returning as single event.")
|
|
async def single_event_stream():
|
|
yield f"data: {json.dumps({'id': str(uuid.uuid4()), 'choices': [{'delta': detector_response}]})}\n\n"
|
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
|
|
|
# The response needs to be structured somewhat like an OpenAI streaming chunk
|
|
# to be minimally disruptive to existing stream handling.
|
|
# A more robust solution would be custom event types for this.
|
|
# For now, we send the whole message as a 'delta' in the first chunk.
|
|
# This is a HACK for streaming.
|
|
# A better way would be to have the socket emit this message directly.
|
|
return StreamingResponse(single_event_stream(), media_type="text/event-stream")
|
|
|
|
# For non-streaming, the detector_response is already a dict like:
|
|
# {"role": "assistant", "content": "...", "metadata": {...}}
|
|
# The typical non-streaming response from generate_openai_chat_completion is a dict
|
|
# that includes a "choices" list, e.g., {"choices": [{"message": detector_response}] }
|
|
return {
|
|
"id": str(uuid.uuid4()), # Generate a unique ID for this "response"
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": detector_response.get("metadata", {}).get("engine_used") or form_data.get("model"), # Use engine from metadata or original model
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": detector_response, # The full message from the detector
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": { # Dummy usage
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"total_tokens": 0,
|
|
}
|
|
}
|
|
# IMPORTANT: The above response will be handled by the router.
|
|
# The router (e.g., in chats.py or a socket handler) is responsible for taking this response
|
|
# and saving the `detector_response` message to the database.
|
|
# This function `generate_chat_completion` is primarily for *generating* the content.
|
|
|
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
|
bypass_filter = True
|
|
|
|
if hasattr(request.state, "metadata"):
|
|
if "metadata" not in form_data:
|
|
form_data["metadata"] = request.state.metadata
|
|
else:
|
|
form_data["metadata"] = {
|
|
**form_data["metadata"],
|
|
**request.state.metadata,
|
|
}
|
|
|
|
# If detector did not handle it, proceed with normal LLM flow
|
|
log.debug("ImageGenerationIntentDetector did not handle the message, proceeding to LLM.")
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
log.debug(f"direct connection to model: {models}")
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data["model"]
|
|
if model_id not in models:
|
|
# This check might be redundant if the detector already ran and returned,
|
|
# but good for the path where detector doesn't run or returns None.
|
|
raise Exception("Model not found")
|
|
|
|
model = models[model_id]
|
|
|
|
if getattr(request.state, "direct", False):
|
|
# Detector is currently skipped for direct connections, so this path remains unchanged.
|
|
return await generate_direct_chat_completion(
|
|
request, form_data, user=user, models=models
|
|
)
|
|
else:
|
|
# Check if user has access to the model
|
|
if not bypass_filter and user.role == "user":
|
|
try:
|
|
check_model_access(user, model)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
if model.get("owned_by") == "arena":
|
|
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
|
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
|
if model_ids and filter_mode == "exclude":
|
|
model_ids = [
|
|
model["id"]
|
|
for model in list(request.app.state.MODELS.values())
|
|
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
|
]
|
|
|
|
selected_model_id = None
|
|
if isinstance(model_ids, list) and model_ids:
|
|
selected_model_id = random.choice(model_ids)
|
|
else:
|
|
model_ids = [
|
|
model["id"]
|
|
for model in list(request.app.state.MODELS.values())
|
|
if model.get("owned_by") != "arena"
|
|
]
|
|
selected_model_id = random.choice(model_ids)
|
|
|
|
form_data["model"] = selected_model_id
|
|
|
|
if form_data.get("stream") == True:
|
|
|
|
async def stream_wrapper(stream):
|
|
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
|
async for chunk in stream:
|
|
yield chunk
|
|
|
|
response = await generate_chat_completion(
|
|
request, form_data, user, bypass_filter=True
|
|
)
|
|
return StreamingResponse(
|
|
stream_wrapper(response.body_iterator),
|
|
media_type="text/event-stream",
|
|
background=response.background,
|
|
)
|
|
else:
|
|
return {
|
|
**(
|
|
await generate_chat_completion(
|
|
request, form_data, user, bypass_filter=True
|
|
)
|
|
),
|
|
"selected_model_id": selected_model_id,
|
|
}
|
|
|
|
if model.get("pipe"):
|
|
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
|
return await generate_function_chat_completion(
|
|
request, form_data, user=user, models=models
|
|
)
|
|
if model.get("owned_by") == "ollama":
|
|
# Using /ollama/api/chat endpoint
|
|
form_data = convert_payload_openai_to_ollama(form_data)
|
|
response = await generate_ollama_chat_completion(
|
|
request=request,
|
|
form_data=form_data,
|
|
user=user,
|
|
bypass_filter=bypass_filter,
|
|
)
|
|
if form_data.get("stream"):
|
|
response.headers["content-type"] = "text/event-stream"
|
|
return StreamingResponse(
|
|
convert_streaming_response_ollama_to_openai(response),
|
|
headers=dict(response.headers),
|
|
background=response.background,
|
|
)
|
|
else:
|
|
return convert_response_ollama_to_openai(response)
|
|
else:
|
|
return await generate_openai_chat_completion(
|
|
request=request,
|
|
form_data=form_data,
|
|
user=user,
|
|
bypass_filter=bypass_filter,
|
|
)
|
|
|
|
|
|
chat_completion = generate_chat_completion
|
|
|
|
|
|
async def chat_completed(request: Request, form_data: dict, user: Any):
|
|
if not request.app.state.MODELS:
|
|
await get_all_models(request, user=user)
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
data = form_data
|
|
model_id = data["model"]
|
|
if model_id not in models:
|
|
raise Exception("Model not found")
|
|
|
|
model = models[model_id]
|
|
|
|
try:
|
|
data = await process_pipeline_outlet_filter(request, data, user, models)
|
|
except Exception as e:
|
|
return Exception(f"Error: {e}")
|
|
|
|
metadata = {
|
|
"chat_id": data["chat_id"],
|
|
"message_id": data["id"],
|
|
"filter_ids": data.get("filter_ids", []),
|
|
"session_id": data["session_id"],
|
|
"user_id": user.id,
|
|
}
|
|
|
|
extra_params = {
|
|
"__event_emitter__": get_event_emitter(metadata),
|
|
"__event_call__": get_event_call(metadata),
|
|
"__user__": {
|
|
"id": user.id,
|
|
"email": user.email,
|
|
"name": user.name,
|
|
"role": user.role,
|
|
},
|
|
"__metadata__": metadata,
|
|
"__request__": request,
|
|
"__model__": model,
|
|
}
|
|
|
|
try:
|
|
filter_functions = [
|
|
Functions.get_function_by_id(filter_id)
|
|
for filter_id in get_sorted_filter_ids(
|
|
request, model, metadata.get("filter_ids", [])
|
|
)
|
|
]
|
|
|
|
result, _ = await process_filter_functions(
|
|
request=request,
|
|
filter_functions=filter_functions,
|
|
filter_type="outlet",
|
|
form_data=data,
|
|
extra_params=extra_params,
|
|
)
|
|
return result
|
|
except Exception as e:
|
|
return Exception(f"Error: {e}")
|
|
|
|
|
|
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
|
if "." in action_id:
|
|
action_id, sub_action_id = action_id.split(".")
|
|
else:
|
|
sub_action_id = None
|
|
|
|
action = Functions.get_function_by_id(action_id)
|
|
if not action:
|
|
raise Exception(f"Action not found: {action_id}")
|
|
|
|
if not request.app.state.MODELS:
|
|
await get_all_models(request, user=user)
|
|
|
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
models = {
|
|
request.state.model["id"]: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
data = form_data
|
|
model_id = data["model"]
|
|
|
|
if model_id not in models:
|
|
raise Exception("Model not found")
|
|
model = models[model_id]
|
|
|
|
__event_emitter__ = get_event_emitter(
|
|
{
|
|
"chat_id": data["chat_id"],
|
|
"message_id": data["id"],
|
|
"session_id": data["session_id"],
|
|
"user_id": user.id,
|
|
}
|
|
)
|
|
__event_call__ = get_event_call(
|
|
{
|
|
"chat_id": data["chat_id"],
|
|
"message_id": data["id"],
|
|
"session_id": data["session_id"],
|
|
"user_id": user.id,
|
|
}
|
|
)
|
|
|
|
if action_id in request.app.state.FUNCTIONS:
|
|
function_module = request.app.state.FUNCTIONS[action_id]
|
|
else:
|
|
function_module, _, _ = load_function_module_by_id(action_id)
|
|
request.app.state.FUNCTIONS[action_id] = function_module
|
|
|
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
|
valves = Functions.get_function_valves_by_id(action_id)
|
|
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
|
|
|
if hasattr(function_module, "action"):
|
|
try:
|
|
action = function_module.action
|
|
|
|
# Get the signature of the function
|
|
sig = inspect.signature(action)
|
|
params = {"body": data}
|
|
|
|
# Extra parameters to be passed to the function
|
|
extra_params = {
|
|
"__model__": model,
|
|
"__id__": sub_action_id if sub_action_id is not None else action_id,
|
|
"__event_emitter__": __event_emitter__,
|
|
"__event_call__": __event_call__,
|
|
"__request__": request,
|
|
}
|
|
|
|
# Add extra params in contained in function signature
|
|
for key, value in extra_params.items():
|
|
if key in sig.parameters:
|
|
params[key] = value
|
|
|
|
if "__user__" in sig.parameters:
|
|
__user__ = {
|
|
"id": user.id,
|
|
"email": user.email,
|
|
"name": user.name,
|
|
"role": user.role,
|
|
}
|
|
|
|
try:
|
|
if hasattr(function_module, "UserValves"):
|
|
__user__["valves"] = function_module.UserValves(
|
|
**Functions.get_user_valves_by_id_and_user_id(
|
|
action_id, user.id
|
|
)
|
|
)
|
|
except Exception as e:
|
|
log.exception(f"Failed to get user values: {e}")
|
|
|
|
params = {**params, "__user__": __user__}
|
|
|
|
if inspect.iscoroutinefunction(action):
|
|
data = await action(**params)
|
|
else:
|
|
data = action(**params)
|
|
|
|
except Exception as e:
|
|
return Exception(f"Error: {e}")
|
|
|
|
return data
|