feat: direct connections integration

This commit is contained in:
Timothy Jaeryang Baek 2025-02-12 22:56:33 -08:00
parent 304ce2a14d
commit c83e68282d
6 changed files with 387 additions and 94 deletions

View File

@ -900,20 +900,30 @@ async def chat_completion(
if not request.app.state.MODELS:
await get_all_models(request)
model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None)
try:
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
# Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
try:
if not model_item.get("direct", False):
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
# Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
else:
model = model_item
model_info = None
request.state.direct = True
request.state.model = model
metadata = {
"user_id": user.id,
@ -925,6 +935,7 @@ async def chat_completion(
"features": form_data.get("features", None),
"variables": form_data.get("variables", None),
"model": model_info,
"direct": model_item.get("direct", False),
**(
{"function_calling": "native"}
if form_data.get("params", {}).get("function_calling") == "native"
@ -936,6 +947,7 @@ async def chat_completion(
else {}
),
}
request.state.metadata = metadata
form_data["metadata"] = metadata
form_data, metadata, events = await process_chat_payload(
@ -943,6 +955,7 @@ async def chat_completion(
)
except Exception as e:
log.debug(f"Error processing chat payload: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
@ -971,6 +984,12 @@ async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_completed_handler(request, form_data, user)
except Exception as e:
raise HTTPException(
@ -984,6 +1003,12 @@ async def chat_action(
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
):
try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_action_handler(request, action_id, form_data, user)
except Exception as e:
raise HTTPException(

View File

@ -139,7 +139,12 @@ async def update_task_config(
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -198,6 +203,7 @@ async def generate_title(
}
),
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.TITLE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@ -225,7 +231,12 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"},
)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -261,6 +272,7 @@ async def generate_chat_tags(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.TAGS_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@ -281,7 +293,12 @@ async def generate_chat_tags(
async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -321,6 +338,7 @@ async def generate_image_prompt(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@ -356,7 +374,12 @@ async def generate_queries(
detail=f"Query generation is disabled",
)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -392,6 +415,7 @@ async def generate_queries(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.QUERY_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@ -431,7 +455,12 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -467,6 +496,7 @@ async def generate_autocompletion(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@ -488,7 +518,12 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -531,7 +566,11 @@ async def generate_emoji(
}
),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.EMOJI_GENERATION),
"task_body": form_data,
},
}
try:
@ -548,7 +587,13 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -581,6 +626,7 @@ async def generate_moa_response(
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,

View File

@ -7,6 +7,8 @@ from typing import Any, Optional
import random
import json
import inspect
import uuid
import asyncio
from fastapi import Request
from starlette.responses import Response, StreamingResponse
@ -15,6 +17,7 @@ from starlette.responses import Response, StreamingResponse
from open_webui.models.users import UserModel
from open_webui.socket.main import (
sio,
get_event_call,
get_event_emitter,
)
@ -57,6 +60,93 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def generate_direct_chat_completion(
request: Request,
form_data: dict,
user: Any,
models: dict,
):
print("generate_direct_chat_completion")
metadata = form_data.pop("metadata", {})
user_id = metadata.get("user_id")
session_id = metadata.get("session_id")
request_id = str(uuid.uuid4()) # Generate a unique request ID
event_emitter = get_event_emitter(metadata)
event_caller = get_event_call(metadata)
channel = f"{user_id}:{session_id}:{request_id}"
if form_data.get("stream"):
q = asyncio.Queue()
# Define a generator to stream responses
async def event_generator():
nonlocal q
async def message_listener(sid, data):
"""
Handle received socket messages and push them into the queue.
"""
await q.put(data)
# Register the listener
sio.on(channel, message_listener)
# Start processing chat completion in background
await event_emitter(
{
"type": "request:chat:completion",
"data": {
"form_data": form_data,
"model": models[form_data["model"]],
"channel": channel,
"session_id": session_id,
},
}
)
try:
while True:
data = await q.get() # Wait for new messages
if isinstance(data, dict):
if "error" in data:
raise Exception(data["error"])
if "done" in data and data["done"]:
break # Stop streaming when 'done' is received
yield f"data: {json.dumps(data)}\n\n"
elif isinstance(data, str):
yield data
finally:
del sio.handlers["/"][channel] # Remove the listener
# Return the streaming response
return StreamingResponse(event_generator(), media_type="text/event-stream")
else:
res = await event_caller(
{
"type": "request:chat:completion",
"data": {
"form_data": form_data,
"model": models[form_data["model"]],
"channel": channel,
"session_id": session_id,
},
}
)
print(res)
if "error" in res:
raise Exception(res["error"])
return res
async def generate_chat_completion(
request: Request,
form_data: dict,
@ -66,7 +156,12 @@ async def generate_chat_completion(
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -87,78 +182,90 @@ async def generate_chat_completion(
except Exception as e:
raise e
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
**(
await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
if request.state.direct:
return await generate_direct_chat_completion(
request, form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
)
if form_data.get("stream"):
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
)
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
**(
await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
request, form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
)
if form_data.get("stream"):
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
)
chat_completion = generate_chat_completion
@ -167,7 +274,13 @@ chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS:
await get_all_models(request)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
data = form_data
model_id = data["model"]
@ -227,7 +340,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
if not request.app.state.MODELS:
await get_all_models(request)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
data = form_data
model_id = data["model"]

View File

@ -622,7 +622,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
# Initialize events to store additional event to be sent to the client
# Initialize contexts and citation
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
task_model_id = get_task_model_id(
form_data["model"],
request.app.state.config.TASK_MODEL,
@ -1677,6 +1683,9 @@ async def process_chat_response(
"data": {
"id": str(uuid4()),
"code": code,
"session_id": metadata.get(
"session_id", None
),
},
}
)

View File

@ -838,6 +838,7 @@
timestamp: m.timestamp,
...(m.sources ? { sources: m.sources } : {})
})),
model_item: $models.find((m) => m.id === modelId),
chat_id: chatId,
session_id: $socket?.id,
id: responseMessageId
@ -896,6 +897,7 @@
...(m.sources ? { sources: m.sources } : {})
})),
...(event ? { event: event } : {}),
model_item: $models.find((m) => m.id === modelId),
chat_id: chatId,
session_id: $socket?.id,
id: responseMessageId
@ -1574,6 +1576,7 @@
$settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined
)
},
model_item: $models.find((m) => m.id === model.id),
session_id: $socket?.id,
chat_id: $chatId,

View File

@ -45,6 +45,7 @@
import { getAllTags, getChatList } from '$lib/apis/chats';
import NotificationToast from '$lib/components/NotificationToast.svelte';
import AppSidebar from '$lib/components/app/AppSidebar.svelte';
import { chatCompletion } from '$lib/apis/openai';
setContext('i18n', i18n);
@ -251,10 +252,100 @@
} else if (type === 'chat:tags') {
tags.set(await getAllTags(localStorage.token));
}
} else {
} else if (data?.session_id === $socket.id) {
if (type === 'execute:python') {
console.log('execute:python', data);
executePythonAsWorker(data.id, data.code, cb);
} else if (type === 'request:chat:completion') {
console.log(data, $socket.id);
const { session_id, channel, form_data, model } = data;
try {
const directConnections = $settings?.directConnections ?? {};
if (directConnections) {
const urlIdx = model?.urlIdx;
console.log(model, directConnections);
const OPENAI_API_URL = directConnections.OPENAI_API_BASE_URLS[urlIdx];
const OPENAI_API_KEY = directConnections.OPENAI_API_KEYS[urlIdx];
const API_CONFIG = directConnections.OPENAI_API_CONFIGS[urlIdx];
try {
const [res, controller] = await chatCompletion(
OPENAI_API_KEY,
form_data,
OPENAI_API_URL
);
if (res && res.ok) {
if (form_data?.stream ?? false) {
// res will either be SSE or JSON
const reader = res.body.getReader();
const decoder = new TextDecoder();
const processStream = async () => {
while (true) {
// Read data chunks from the response stream
const { done, value } = await reader.read();
if (done) {
break;
}
// Decode the received chunk
const chunk = decoder.decode(value, { stream: true });
// Process lines within the chunk
const lines = chunk.split('\n').filter((line) => line.trim() !== '');
for (const line of lines) {
$socket?.emit(channel, line);
}
}
};
// Process the stream in the background
await processStream();
} else {
const data = await res.json();
cb(data);
}
} else {
throw new Error('An error occurred while fetching the completion');
}
} catch (error) {
console.error('chatCompletion', error);
if (form_data?.stream ?? false) {
$socket.emit(channel, {
error: error
});
} else {
cb({
error: error
});
}
}
}
} catch (error) {
console.error('chatCompletion', error);
if (form_data?.stream ?? false) {
$socket.emit(channel, {
error: error
});
} else {
cb({
error: error
});
}
} finally {
$socket.emit(channel, {
done: true
});
}
} else {
console.log('chatEventHandler', event);
}
}
};