From c83e68282d4f099d478a3f413d598219ff042292 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 12 Feb 2025 22:56:33 -0800 Subject: [PATCH] feat: direct connections integration --- backend/open_webui/main.py | 49 +++-- backend/open_webui/routers/tasks.py | 62 +++++- backend/open_webui/utils/chat.py | 263 ++++++++++++++++++------- backend/open_webui/utils/middleware.py | 11 +- src/lib/components/chat/Chat.svelte | 3 + src/routes/+layout.svelte | 93 ++++++++- 6 files changed, 387 insertions(+), 94 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 1de0348c3..0311e82d8 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -900,20 +900,30 @@ async def chat_completion( if not request.app.state.MODELS: await get_all_models(request) + model_item = form_data.pop("model_item", {}) tasks = form_data.pop("background_tasks", None) - try: - model_id = form_data.get("model", None) - if model_id not in request.app.state.MODELS: - raise Exception("Model not found") - model = request.app.state.MODELS[model_id] - model_info = Models.get_model_by_id(model_id) - # Check if user has access to the model - if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user": - try: - check_model_access(user, model) - except Exception as e: - raise e + try: + if not model_item.get("direct", False): + model_id = form_data.get("model", None) + if model_id not in request.app.state.MODELS: + raise Exception("Model not found") + + model = request.app.state.MODELS[model_id] + model_info = Models.get_model_by_id(model_id) + + # Check if user has access to the model + if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user": + try: + check_model_access(user, model) + except Exception as e: + raise e + else: + model = model_item + model_info = None + + request.state.direct = True + request.state.model = model metadata = { "user_id": user.id, @@ -925,6 +935,7 @@ async def chat_completion( "features": form_data.get("features", None), "variables": form_data.get("variables", None), "model": model_info, + "direct": model_item.get("direct", False), **( {"function_calling": "native"} if form_data.get("params", {}).get("function_calling") == "native" @@ -936,6 +947,7 @@ async def chat_completion( else {} ), } + request.state.metadata = metadata form_data["metadata"] = metadata form_data, metadata, events = await process_chat_payload( @@ -943,6 +955,7 @@ async def chat_completion( ) except Exception as e: + log.debug(f"Error processing chat payload: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), @@ -971,6 +984,12 @@ async def chat_completed( request: Request, form_data: dict, user=Depends(get_verified_user) ): try: + model_item = form_data.pop("model_item", {}) + + if model_item.get("direct", False): + request.state.direct = True + request.state.model = model_item + return await chat_completed_handler(request, form_data, user) except Exception as e: raise HTTPException( @@ -984,6 +1003,12 @@ async def chat_action( request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) ): try: + model_item = form_data.pop("model_item", {}) + + if model_item.get("direct", False): + request.state.direct = True + request.state.model = model_item + return await chat_action_handler(request, action_id, form_data, user) except Exception as e: raise HTTPException( diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index f56a0232d..c885d764b 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -139,7 +139,12 @@ async def update_task_config( async def generate_title( request: Request, form_data: dict, user=Depends(get_verified_user) ): - models = request.app.state.MODELS + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -198,6 +203,7 @@ async def generate_title( } ), "metadata": { + **(request.state.metadata if request.state.metadata else {}), "task": str(TASKS.TITLE_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -225,7 +231,12 @@ async def generate_chat_tags( content={"detail": "Tags generation is disabled"}, ) - models = request.app.state.MODELS + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -261,6 +272,7 @@ async def generate_chat_tags( "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { + **(request.state.metadata if request.state.metadata else {}), "task": str(TASKS.TAGS_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -281,7 +293,12 @@ async def generate_chat_tags( async def generate_image_prompt( request: Request, form_data: dict, user=Depends(get_verified_user) ): - models = request.app.state.MODELS + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -321,6 +338,7 @@ async def generate_image_prompt( "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { + **(request.state.metadata if request.state.metadata else {}), "task": str(TASKS.IMAGE_PROMPT_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -356,7 +374,12 @@ async def generate_queries( detail=f"Query generation is disabled", ) - models = request.app.state.MODELS + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -392,6 +415,7 @@ async def generate_queries( "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { + **(request.state.metadata if request.state.metadata else {}), "task": str(TASKS.QUERY_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -431,7 +455,12 @@ async def generate_autocompletion( detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", ) - models = request.app.state.MODELS + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -467,6 +496,7 @@ async def generate_autocompletion( "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { + **(request.state.metadata if request.state.metadata else {}), "task": str(TASKS.AUTOCOMPLETE_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None), @@ -488,7 +518,12 @@ async def generate_emoji( request: Request, form_data: dict, user=Depends(get_verified_user) ): - models = request.app.state.MODELS + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -531,7 +566,11 @@ async def generate_emoji( } ), "chat_id": form_data.get("chat_id", None), - "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, + "metadata": { + **(request.state.metadata if request.state.metadata else {}), + "task": str(TASKS.EMOJI_GENERATION), + "task_body": form_data, + }, } try: @@ -548,7 +587,13 @@ async def generate_moa_response( request: Request, form_data: dict, user=Depends(get_verified_user) ): - models = request.app.state.MODELS + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + model_id = form_data["model"] if model_id not in models: @@ -581,6 +626,7 @@ async def generate_moa_response( "messages": [{"role": "user", "content": content}], "stream": form_data.get("stream", False), "metadata": { + **(request.state.metadata if request.state.metadata else {}), "chat_id": form_data.get("chat_id", None), "task": str(TASKS.MOA_RESPONSE_GENERATION), "task_body": form_data, diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 3b6d5ea04..97e1aae74 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -7,6 +7,8 @@ from typing import Any, Optional import random import json import inspect +import uuid +import asyncio from fastapi import Request from starlette.responses import Response, StreamingResponse @@ -15,6 +17,7 @@ from starlette.responses import Response, StreamingResponse from open_webui.models.users import UserModel from open_webui.socket.main import ( + sio, get_event_call, get_event_emitter, ) @@ -57,6 +60,93 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) +async def generate_direct_chat_completion( + request: Request, + form_data: dict, + user: Any, + models: dict, +): + print("generate_direct_chat_completion") + + metadata = form_data.pop("metadata", {}) + + user_id = metadata.get("user_id") + session_id = metadata.get("session_id") + request_id = str(uuid.uuid4()) # Generate a unique request ID + + event_emitter = get_event_emitter(metadata) + event_caller = get_event_call(metadata) + + channel = f"{user_id}:{session_id}:{request_id}" + + if form_data.get("stream"): + q = asyncio.Queue() + + # Define a generator to stream responses + async def event_generator(): + nonlocal q + + async def message_listener(sid, data): + """ + Handle received socket messages and push them into the queue. + """ + await q.put(data) + + # Register the listener + sio.on(channel, message_listener) + + # Start processing chat completion in background + await event_emitter( + { + "type": "request:chat:completion", + "data": { + "form_data": form_data, + "model": models[form_data["model"]], + "channel": channel, + "session_id": session_id, + }, + } + ) + + try: + while True: + data = await q.get() # Wait for new messages + if isinstance(data, dict): + if "error" in data: + raise Exception(data["error"]) + + if "done" in data and data["done"]: + break # Stop streaming when 'done' is received + + yield f"data: {json.dumps(data)}\n\n" + elif isinstance(data, str): + yield data + finally: + del sio.handlers["/"][channel] # Remove the listener + + # Return the streaming response + return StreamingResponse(event_generator(), media_type="text/event-stream") + else: + res = await event_caller( + { + "type": "request:chat:completion", + "data": { + "form_data": form_data, + "model": models[form_data["model"]], + "channel": channel, + "session_id": session_id, + }, + } + ) + + print(res) + + if "error" in res: + raise Exception(res["error"]) + + return res + + async def generate_chat_completion( request: Request, form_data: dict, @@ -66,7 +156,12 @@ async def generate_chat_completion( if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - models = request.app.state.MODELS + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS model_id = form_data["model"] if model_id not in models: @@ -87,78 +182,90 @@ async def generate_chat_completion( except Exception as e: raise e - if model["owned_by"] == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in list(request.app.state.MODELS.values()) - if model.get("owned_by") != "arena" and model["id"] not in model_ids - ] - - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in list(request.app.state.MODELS.values()) - if model.get("owned_by") != "arena" - ] - selected_model_id = random.choice(model_ids) - - form_data["model"] = selected_model_id - - if form_data.get("stream") == True: - - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk - - response = await generate_chat_completion( - request, form_data, user, bypass_filter=True - ) - return StreamingResponse( - stream_wrapper(response.body_iterator), - media_type="text/event-stream", - background=response.background, - ) - else: - return { - **( - await generate_chat_completion( - request, form_data, user, bypass_filter=True - ) - ), - "selected_model_id": selected_model_id, - } - - if model.get("pipe"): - # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion( + if request.state.direct: + return await generate_direct_chat_completion( request, form_data, user=user, models=models ) - if model["owned_by"] == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - response = await generate_ollama_chat_completion( - request=request, form_data=form_data, user=user, bypass_filter=bypass_filter - ) - if form_data.get("stream"): - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - background=response.background, - ) - else: - return convert_response_ollama_to_openai(response) + else: - return await generate_openai_chat_completion( - request=request, form_data=form_data, user=user, bypass_filter=bypass_filter - ) + if model["owned_by"] == "arena": + model_ids = model.get("info", {}).get("meta", {}).get("model_ids") + filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") + if model_ids and filter_mode == "exclude": + model_ids = [ + model["id"] + for model in list(request.app.state.MODELS.values()) + if model.get("owned_by") != "arena" and model["id"] not in model_ids + ] + + selected_model_id = None + if isinstance(model_ids, list) and model_ids: + selected_model_id = random.choice(model_ids) + else: + model_ids = [ + model["id"] + for model in list(request.app.state.MODELS.values()) + if model.get("owned_by") != "arena" + ] + selected_model_id = random.choice(model_ids) + + form_data["model"] = selected_model_id + + if form_data.get("stream") == True: + + async def stream_wrapper(stream): + yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" + async for chunk in stream: + yield chunk + + response = await generate_chat_completion( + request, form_data, user, bypass_filter=True + ) + return StreamingResponse( + stream_wrapper(response.body_iterator), + media_type="text/event-stream", + background=response.background, + ) + else: + return { + **( + await generate_chat_completion( + request, form_data, user, bypass_filter=True + ) + ), + "selected_model_id": selected_model_id, + } + + if model.get("pipe"): + # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter + return await generate_function_chat_completion( + request, form_data, user=user, models=models + ) + if model["owned_by"] == "ollama": + # Using /ollama/api/chat endpoint + form_data = convert_payload_openai_to_ollama(form_data) + response = await generate_ollama_chat_completion( + request=request, + form_data=form_data, + user=user, + bypass_filter=bypass_filter, + ) + if form_data.get("stream"): + response.headers["content-type"] = "text/event-stream" + return StreamingResponse( + convert_streaming_response_ollama_to_openai(response), + headers=dict(response.headers), + background=response.background, + ) + else: + return convert_response_ollama_to_openai(response) + else: + return await generate_openai_chat_completion( + request=request, + form_data=form_data, + user=user, + bypass_filter=bypass_filter, + ) chat_completion = generate_chat_completion @@ -167,7 +274,13 @@ chat_completion = generate_chat_completion async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: await get_all_models(request) - models = request.app.state.MODELS + + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS data = form_data model_id = data["model"] @@ -227,7 +340,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A if not request.app.state.MODELS: await get_all_models(request) - models = request.app.state.MODELS + + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS data = form_data model_id = data["model"] diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 5de8f8193..e630af687 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -622,7 +622,13 @@ async def process_chat_payload(request, form_data, metadata, user, model): # Initialize events to store additional event to be sent to the client # Initialize contexts and citation - models = request.app.state.MODELS + if request.state.direct and request.state.model: + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + task_model_id = get_task_model_id( form_data["model"], request.app.state.config.TASK_MODEL, @@ -1677,6 +1683,9 @@ async def process_chat_response( "data": { "id": str(uuid4()), "code": code, + "session_id": metadata.get( + "session_id", None + ), }, } ) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index da966d90d..d3986ca04 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -838,6 +838,7 @@ timestamp: m.timestamp, ...(m.sources ? { sources: m.sources } : {}) })), + model_item: $models.find((m) => m.id === modelId), chat_id: chatId, session_id: $socket?.id, id: responseMessageId @@ -896,6 +897,7 @@ ...(m.sources ? { sources: m.sources } : {}) })), ...(event ? { event: event } : {}), + model_item: $models.find((m) => m.id === modelId), chat_id: chatId, session_id: $socket?.id, id: responseMessageId @@ -1574,6 +1576,7 @@ $settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined ) }, + model_item: $models.find((m) => m.id === model.id), session_id: $socket?.id, chat_id: $chatId, diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index 76323f772..4d08b3e3d 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -45,6 +45,7 @@ import { getAllTags, getChatList } from '$lib/apis/chats'; import NotificationToast from '$lib/components/NotificationToast.svelte'; import AppSidebar from '$lib/components/app/AppSidebar.svelte'; + import { chatCompletion } from '$lib/apis/openai'; setContext('i18n', i18n); @@ -251,10 +252,100 @@ } else if (type === 'chat:tags') { tags.set(await getAllTags(localStorage.token)); } - } else { + } else if (data?.session_id === $socket.id) { if (type === 'execute:python') { console.log('execute:python', data); executePythonAsWorker(data.id, data.code, cb); + } else if (type === 'request:chat:completion') { + console.log(data, $socket.id); + const { session_id, channel, form_data, model } = data; + + try { + const directConnections = $settings?.directConnections ?? {}; + + if (directConnections) { + const urlIdx = model?.urlIdx; + + console.log(model, directConnections); + + const OPENAI_API_URL = directConnections.OPENAI_API_BASE_URLS[urlIdx]; + const OPENAI_API_KEY = directConnections.OPENAI_API_KEYS[urlIdx]; + const API_CONFIG = directConnections.OPENAI_API_CONFIGS[urlIdx]; + + try { + const [res, controller] = await chatCompletion( + OPENAI_API_KEY, + form_data, + OPENAI_API_URL + ); + + if (res && res.ok) { + if (form_data?.stream ?? false) { + // res will either be SSE or JSON + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + + const processStream = async () => { + while (true) { + // Read data chunks from the response stream + const { done, value } = await reader.read(); + if (done) { + break; + } + + // Decode the received chunk + const chunk = decoder.decode(value, { stream: true }); + + // Process lines within the chunk + const lines = chunk.split('\n').filter((line) => line.trim() !== ''); + + for (const line of lines) { + $socket?.emit(channel, line); + } + } + }; + + // Process the stream in the background + await processStream(); + } else { + const data = await res.json(); + cb(data); + } + } else { + throw new Error('An error occurred while fetching the completion'); + } + } catch (error) { + console.error('chatCompletion', error); + + if (form_data?.stream ?? false) { + $socket.emit(channel, { + error: error + }); + } else { + cb({ + error: error + }); + } + } + } + } catch (error) { + console.error('chatCompletion', error); + if (form_data?.stream ?? false) { + $socket.emit(channel, { + error: error + }); + } else { + cb({ + error: error + }); + } + } finally { + $socket.emit(channel, { + done: true + }); + } + } else { + console.log('chatEventHandler', event); } } };