feat: direct connections integration

This commit is contained in:
Timothy Jaeryang Baek 2025-02-12 22:56:33 -08:00
parent 304ce2a14d
commit c83e68282d
6 changed files with 387 additions and 94 deletions

View File

@ -900,11 +900,15 @@ async def chat_completion(
if not request.app.state.MODELS: if not request.app.state.MODELS:
await get_all_models(request) await get_all_models(request)
model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None) tasks = form_data.pop("background_tasks", None)
try: try:
if not model_item.get("direct", False):
model_id = form_data.get("model", None) model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS: if model_id not in request.app.state.MODELS:
raise Exception("Model not found") raise Exception("Model not found")
model = request.app.state.MODELS[model_id] model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
@ -914,6 +918,12 @@ async def chat_completion(
check_model_access(user, model) check_model_access(user, model)
except Exception as e: except Exception as e:
raise e raise e
else:
model = model_item
model_info = None
request.state.direct = True
request.state.model = model
metadata = { metadata = {
"user_id": user.id, "user_id": user.id,
@ -925,6 +935,7 @@ async def chat_completion(
"features": form_data.get("features", None), "features": form_data.get("features", None),
"variables": form_data.get("variables", None), "variables": form_data.get("variables", None),
"model": model_info, "model": model_info,
"direct": model_item.get("direct", False),
**( **(
{"function_calling": "native"} {"function_calling": "native"}
if form_data.get("params", {}).get("function_calling") == "native" if form_data.get("params", {}).get("function_calling") == "native"
@ -936,6 +947,7 @@ async def chat_completion(
else {} else {}
), ),
} }
request.state.metadata = metadata
form_data["metadata"] = metadata form_data["metadata"] = metadata
form_data, metadata, events = await process_chat_payload( form_data, metadata, events = await process_chat_payload(
@ -943,6 +955,7 @@ async def chat_completion(
) )
except Exception as e: except Exception as e:
log.debug(f"Error processing chat payload: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e), detail=str(e),
@ -971,6 +984,12 @@ async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
try: try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_completed_handler(request, form_data, user) return await chat_completed_handler(request, form_data, user)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -984,6 +1003,12 @@ async def chat_action(
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
): ):
try: try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_action_handler(request, action_id, form_data, user) return await chat_action_handler(request, action_id, form_data, user)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(

View File

@ -139,6 +139,11 @@ async def update_task_config(
async def generate_title( async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
@ -198,6 +203,7 @@ async def generate_title(
} }
), ),
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.TITLE_GENERATION), "task": str(TASKS.TITLE_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -225,6 +231,11 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"}, content={"detail": "Tags generation is disabled"},
) )
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
@ -261,6 +272,7 @@ async def generate_chat_tags(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.TAGS_GENERATION), "task": str(TASKS.TAGS_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -281,6 +293,11 @@ async def generate_chat_tags(
async def generate_image_prompt( async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
@ -321,6 +338,7 @@ async def generate_image_prompt(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.IMAGE_PROMPT_GENERATION), "task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -356,6 +374,11 @@ async def generate_queries(
detail=f"Query generation is disabled", detail=f"Query generation is disabled",
) )
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
@ -392,6 +415,7 @@ async def generate_queries(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.QUERY_GENERATION), "task": str(TASKS.QUERY_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -431,6 +455,11 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
) )
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
@ -467,6 +496,7 @@ async def generate_autocompletion(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.AUTOCOMPLETE_GENERATION), "task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -488,6 +518,11 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
@ -531,7 +566,11 @@ async def generate_emoji(
} }
), ),
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.EMOJI_GENERATION),
"task_body": form_data,
},
} }
try: try:
@ -548,7 +587,13 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -581,6 +626,7 @@ async def generate_moa_response(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False), "stream": form_data.get("stream", False),
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION), "task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data, "task_body": form_data,

View File

@ -7,6 +7,8 @@ from typing import Any, Optional
import random import random
import json import json
import inspect import inspect
import uuid
import asyncio
from fastapi import Request from fastapi import Request
from starlette.responses import Response, StreamingResponse from starlette.responses import Response, StreamingResponse
@ -15,6 +17,7 @@ from starlette.responses import Response, StreamingResponse
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.socket.main import ( from open_webui.socket.main import (
sio,
get_event_call, get_event_call,
get_event_emitter, get_event_emitter,
) )
@ -57,6 +60,93 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def generate_direct_chat_completion(
request: Request,
form_data: dict,
user: Any,
models: dict,
):
print("generate_direct_chat_completion")
metadata = form_data.pop("metadata", {})
user_id = metadata.get("user_id")
session_id = metadata.get("session_id")
request_id = str(uuid.uuid4()) # Generate a unique request ID
event_emitter = get_event_emitter(metadata)
event_caller = get_event_call(metadata)
channel = f"{user_id}:{session_id}:{request_id}"
if form_data.get("stream"):
q = asyncio.Queue()
# Define a generator to stream responses
async def event_generator():
nonlocal q
async def message_listener(sid, data):
"""
Handle received socket messages and push them into the queue.
"""
await q.put(data)
# Register the listener
sio.on(channel, message_listener)
# Start processing chat completion in background
await event_emitter(
{
"type": "request:chat:completion",
"data": {
"form_data": form_data,
"model": models[form_data["model"]],
"channel": channel,
"session_id": session_id,
},
}
)
try:
while True:
data = await q.get() # Wait for new messages
if isinstance(data, dict):
if "error" in data:
raise Exception(data["error"])
if "done" in data and data["done"]:
break # Stop streaming when 'done' is received
yield f"data: {json.dumps(data)}\n\n"
elif isinstance(data, str):
yield data
finally:
del sio.handlers["/"][channel] # Remove the listener
# Return the streaming response
return StreamingResponse(event_generator(), media_type="text/event-stream")
else:
res = await event_caller(
{
"type": "request:chat:completion",
"data": {
"form_data": form_data,
"model": models[form_data["model"]],
"channel": channel,
"session_id": session_id,
},
}
)
print(res)
if "error" in res:
raise Exception(res["error"])
return res
async def generate_chat_completion( async def generate_chat_completion(
request: Request, request: Request,
form_data: dict, form_data: dict,
@ -66,6 +156,11 @@ async def generate_chat_completion(
if BYPASS_MODEL_ACCESS_CONTROL: if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True bypass_filter = True
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
@ -87,6 +182,12 @@ async def generate_chat_completion(
except Exception as e: except Exception as e:
raise e raise e
if request.state.direct:
return await generate_direct_chat_completion(
request, form_data, user=user, models=models
)
else:
if model["owned_by"] == "arena": if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids") model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
@ -144,7 +245,10 @@ async def generate_chat_completion(
# Using /ollama/api/chat endpoint # Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data) form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion( response = await generate_ollama_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
) )
if form_data.get("stream"): if form_data.get("stream"):
response.headers["content-type"] = "text/event-stream" response.headers["content-type"] = "text/event-stream"
@ -157,7 +261,10 @@ async def generate_chat_completion(
return convert_response_ollama_to_openai(response) return convert_response_ollama_to_openai(response)
else: else:
return await generate_openai_chat_completion( return await generate_openai_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
) )
@ -167,6 +274,12 @@ chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any): async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS: if not request.app.state.MODELS:
await get_all_models(request) await get_all_models(request)
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
data = form_data data = form_data
@ -227,6 +340,12 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
if not request.app.state.MODELS: if not request.app.state.MODELS:
await get_all_models(request) await get_all_models(request)
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
data = form_data data = form_data

View File

@ -622,7 +622,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
# Initialize events to store additional event to be sent to the client # Initialize events to store additional event to be sent to the client
# Initialize contexts and citation # Initialize contexts and citation
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS models = request.app.state.MODELS
task_model_id = get_task_model_id( task_model_id = get_task_model_id(
form_data["model"], form_data["model"],
request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL,
@ -1677,6 +1683,9 @@ async def process_chat_response(
"data": { "data": {
"id": str(uuid4()), "id": str(uuid4()),
"code": code, "code": code,
"session_id": metadata.get(
"session_id", None
),
}, },
} }
) )

View File

@ -838,6 +838,7 @@
timestamp: m.timestamp, timestamp: m.timestamp,
...(m.sources ? { sources: m.sources } : {}) ...(m.sources ? { sources: m.sources } : {})
})), })),
model_item: $models.find((m) => m.id === modelId),
chat_id: chatId, chat_id: chatId,
session_id: $socket?.id, session_id: $socket?.id,
id: responseMessageId id: responseMessageId
@ -896,6 +897,7 @@
...(m.sources ? { sources: m.sources } : {}) ...(m.sources ? { sources: m.sources } : {})
})), })),
...(event ? { event: event } : {}), ...(event ? { event: event } : {}),
model_item: $models.find((m) => m.id === modelId),
chat_id: chatId, chat_id: chatId,
session_id: $socket?.id, session_id: $socket?.id,
id: responseMessageId id: responseMessageId
@ -1574,6 +1576,7 @@
$settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined $settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined
) )
}, },
model_item: $models.find((m) => m.id === model.id),
session_id: $socket?.id, session_id: $socket?.id,
chat_id: $chatId, chat_id: $chatId,

View File

@ -45,6 +45,7 @@
import { getAllTags, getChatList } from '$lib/apis/chats'; import { getAllTags, getChatList } from '$lib/apis/chats';
import NotificationToast from '$lib/components/NotificationToast.svelte'; import NotificationToast from '$lib/components/NotificationToast.svelte';
import AppSidebar from '$lib/components/app/AppSidebar.svelte'; import AppSidebar from '$lib/components/app/AppSidebar.svelte';
import { chatCompletion } from '$lib/apis/openai';
setContext('i18n', i18n); setContext('i18n', i18n);
@ -251,10 +252,100 @@
} else if (type === 'chat:tags') { } else if (type === 'chat:tags') {
tags.set(await getAllTags(localStorage.token)); tags.set(await getAllTags(localStorage.token));
} }
} else { } else if (data?.session_id === $socket.id) {
if (type === 'execute:python') { if (type === 'execute:python') {
console.log('execute:python', data); console.log('execute:python', data);
executePythonAsWorker(data.id, data.code, cb); executePythonAsWorker(data.id, data.code, cb);
} else if (type === 'request:chat:completion') {
console.log(data, $socket.id);
const { session_id, channel, form_data, model } = data;
try {
const directConnections = $settings?.directConnections ?? {};
if (directConnections) {
const urlIdx = model?.urlIdx;
console.log(model, directConnections);
const OPENAI_API_URL = directConnections.OPENAI_API_BASE_URLS[urlIdx];
const OPENAI_API_KEY = directConnections.OPENAI_API_KEYS[urlIdx];
const API_CONFIG = directConnections.OPENAI_API_CONFIGS[urlIdx];
try {
const [res, controller] = await chatCompletion(
OPENAI_API_KEY,
form_data,
OPENAI_API_URL
);
if (res && res.ok) {
if (form_data?.stream ?? false) {
// res will either be SSE or JSON
const reader = res.body.getReader();
const decoder = new TextDecoder();
const processStream = async () => {
while (true) {
// Read data chunks from the response stream
const { done, value } = await reader.read();
if (done) {
break;
}
// Decode the received chunk
const chunk = decoder.decode(value, { stream: true });
// Process lines within the chunk
const lines = chunk.split('\n').filter((line) => line.trim() !== '');
for (const line of lines) {
$socket?.emit(channel, line);
}
}
};
// Process the stream in the background
await processStream();
} else {
const data = await res.json();
cb(data);
}
} else {
throw new Error('An error occurred while fetching the completion');
}
} catch (error) {
console.error('chatCompletion', error);
if (form_data?.stream ?? false) {
$socket.emit(channel, {
error: error
});
} else {
cb({
error: error
});
}
}
}
} catch (error) {
console.error('chatCompletion', error);
if (form_data?.stream ?? false) {
$socket.emit(channel, {
error: error
});
} else {
cb({
error: error
});
}
} finally {
$socket.emit(channel, {
done: true
});
}
} else {
console.log('chatEventHandler', event);
} }
} }
}; };