feat: direct connections integration

This commit is contained in:
Timothy Jaeryang Baek 2025-02-12 22:56:33 -08:00
parent 304ce2a14d
commit c83e68282d
6 changed files with 387 additions and 94 deletions

View File

@ -900,20 +900,30 @@ async def chat_completion(
if not request.app.state.MODELS: if not request.app.state.MODELS:
await get_all_models(request) await get_all_models(request)
model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None) tasks = form_data.pop("background_tasks", None)
try:
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
# Check if user has access to the model try:
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user": if not model_item.get("direct", False):
try: model_id = form_data.get("model", None)
check_model_access(user, model) if model_id not in request.app.state.MODELS:
except Exception as e: raise Exception("Model not found")
raise e
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
# Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
else:
model = model_item
model_info = None
request.state.direct = True
request.state.model = model
metadata = { metadata = {
"user_id": user.id, "user_id": user.id,
@ -925,6 +935,7 @@ async def chat_completion(
"features": form_data.get("features", None), "features": form_data.get("features", None),
"variables": form_data.get("variables", None), "variables": form_data.get("variables", None),
"model": model_info, "model": model_info,
"direct": model_item.get("direct", False),
**( **(
{"function_calling": "native"} {"function_calling": "native"}
if form_data.get("params", {}).get("function_calling") == "native" if form_data.get("params", {}).get("function_calling") == "native"
@ -936,6 +947,7 @@ async def chat_completion(
else {} else {}
), ),
} }
request.state.metadata = metadata
form_data["metadata"] = metadata form_data["metadata"] = metadata
form_data, metadata, events = await process_chat_payload( form_data, metadata, events = await process_chat_payload(
@ -943,6 +955,7 @@ async def chat_completion(
) )
except Exception as e: except Exception as e:
log.debug(f"Error processing chat payload: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e), detail=str(e),
@ -971,6 +984,12 @@ async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
try: try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_completed_handler(request, form_data, user) return await chat_completed_handler(request, form_data, user)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -984,6 +1003,12 @@ async def chat_action(
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
): ):
try: try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_action_handler(request, action_id, form_data, user) return await chat_action_handler(request, action_id, form_data, user)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(

View File

@ -139,7 +139,12 @@ async def update_task_config(
async def generate_title( async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
models = request.app.state.MODELS if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -198,6 +203,7 @@ async def generate_title(
} }
), ),
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.TITLE_GENERATION), "task": str(TASKS.TITLE_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -225,7 +231,12 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"}, content={"detail": "Tags generation is disabled"},
) )
models = request.app.state.MODELS if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -261,6 +272,7 @@ async def generate_chat_tags(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.TAGS_GENERATION), "task": str(TASKS.TAGS_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -281,7 +293,12 @@ async def generate_chat_tags(
async def generate_image_prompt( async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
models = request.app.state.MODELS if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -321,6 +338,7 @@ async def generate_image_prompt(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.IMAGE_PROMPT_GENERATION), "task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -356,7 +374,12 @@ async def generate_queries(
detail=f"Query generation is disabled", detail=f"Query generation is disabled",
) )
models = request.app.state.MODELS if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -392,6 +415,7 @@ async def generate_queries(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.QUERY_GENERATION), "task": str(TASKS.QUERY_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -431,7 +455,12 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
) )
models = request.app.state.MODELS if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -467,6 +496,7 @@ async def generate_autocompletion(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.AUTOCOMPLETE_GENERATION), "task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
@ -488,7 +518,12 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
models = request.app.state.MODELS if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -531,7 +566,11 @@ async def generate_emoji(
} }
), ),
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.EMOJI_GENERATION),
"task_body": form_data,
},
} }
try: try:
@ -548,7 +587,13 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
models = request.app.state.MODELS if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -581,6 +626,7 @@ async def generate_moa_response(
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False), "stream": form_data.get("stream", False),
"metadata": { "metadata": {
**(request.state.metadata if request.state.metadata else {}),
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION), "task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data, "task_body": form_data,

View File

@ -7,6 +7,8 @@ from typing import Any, Optional
import random import random
import json import json
import inspect import inspect
import uuid
import asyncio
from fastapi import Request from fastapi import Request
from starlette.responses import Response, StreamingResponse from starlette.responses import Response, StreamingResponse
@ -15,6 +17,7 @@ from starlette.responses import Response, StreamingResponse
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.socket.main import ( from open_webui.socket.main import (
sio,
get_event_call, get_event_call,
get_event_emitter, get_event_emitter,
) )
@ -57,6 +60,93 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def generate_direct_chat_completion(
request: Request,
form_data: dict,
user: Any,
models: dict,
):
print("generate_direct_chat_completion")
metadata = form_data.pop("metadata", {})
user_id = metadata.get("user_id")
session_id = metadata.get("session_id")
request_id = str(uuid.uuid4()) # Generate a unique request ID
event_emitter = get_event_emitter(metadata)
event_caller = get_event_call(metadata)
channel = f"{user_id}:{session_id}:{request_id}"
if form_data.get("stream"):
q = asyncio.Queue()
# Define a generator to stream responses
async def event_generator():
nonlocal q
async def message_listener(sid, data):
"""
Handle received socket messages and push them into the queue.
"""
await q.put(data)
# Register the listener
sio.on(channel, message_listener)
# Start processing chat completion in background
await event_emitter(
{
"type": "request:chat:completion",
"data": {
"form_data": form_data,
"model": models[form_data["model"]],
"channel": channel,
"session_id": session_id,
},
}
)
try:
while True:
data = await q.get() # Wait for new messages
if isinstance(data, dict):
if "error" in data:
raise Exception(data["error"])
if "done" in data and data["done"]:
break # Stop streaming when 'done' is received
yield f"data: {json.dumps(data)}\n\n"
elif isinstance(data, str):
yield data
finally:
del sio.handlers["/"][channel] # Remove the listener
# Return the streaming response
return StreamingResponse(event_generator(), media_type="text/event-stream")
else:
res = await event_caller(
{
"type": "request:chat:completion",
"data": {
"form_data": form_data,
"model": models[form_data["model"]],
"channel": channel,
"session_id": session_id,
},
}
)
print(res)
if "error" in res:
raise Exception(res["error"])
return res
async def generate_chat_completion( async def generate_chat_completion(
request: Request, request: Request,
form_data: dict, form_data: dict,
@ -66,7 +156,12 @@ async def generate_chat_completion(
if BYPASS_MODEL_ACCESS_CONTROL: if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True bypass_filter = True
models = request.app.state.MODELS if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -87,78 +182,90 @@ async def generate_chat_completion(
except Exception as e: except Exception as e:
raise e raise e
if model["owned_by"] == "arena": if request.state.direct:
model_ids = model.get("info", {}).get("meta", {}).get("model_ids") return await generate_direct_chat_completion(
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
**(
await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
request, form_data, user=user, models=models request, form_data, user=user, models=models
) )
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
)
if form_data.get("stream"):
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
return convert_response_ollama_to_openai(response)
else: else:
return await generate_openai_chat_completion( if model["owned_by"] == "arena":
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
) filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
**(
await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
request, form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
)
if form_data.get("stream"):
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
)
chat_completion = generate_chat_completion chat_completion = generate_chat_completion
@ -167,7 +274,13 @@ chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any): async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS: if not request.app.state.MODELS:
await get_all_models(request) await get_all_models(request)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
data = form_data data = form_data
model_id = data["model"] model_id = data["model"]
@ -227,7 +340,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
if not request.app.state.MODELS: if not request.app.state.MODELS:
await get_all_models(request) await get_all_models(request)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
data = form_data data = form_data
model_id = data["model"] model_id = data["model"]

View File

@ -622,7 +622,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
# Initialize events to store additional event to be sent to the client # Initialize events to store additional event to be sent to the client
# Initialize contexts and citation # Initialize contexts and citation
models = request.app.state.MODELS if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
task_model_id = get_task_model_id( task_model_id = get_task_model_id(
form_data["model"], form_data["model"],
request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL,
@ -1677,6 +1683,9 @@ async def process_chat_response(
"data": { "data": {
"id": str(uuid4()), "id": str(uuid4()),
"code": code, "code": code,
"session_id": metadata.get(
"session_id", None
),
}, },
} }
) )

View File

@ -838,6 +838,7 @@
timestamp: m.timestamp, timestamp: m.timestamp,
...(m.sources ? { sources: m.sources } : {}) ...(m.sources ? { sources: m.sources } : {})
})), })),
model_item: $models.find((m) => m.id === modelId),
chat_id: chatId, chat_id: chatId,
session_id: $socket?.id, session_id: $socket?.id,
id: responseMessageId id: responseMessageId
@ -896,6 +897,7 @@
...(m.sources ? { sources: m.sources } : {}) ...(m.sources ? { sources: m.sources } : {})
})), })),
...(event ? { event: event } : {}), ...(event ? { event: event } : {}),
model_item: $models.find((m) => m.id === modelId),
chat_id: chatId, chat_id: chatId,
session_id: $socket?.id, session_id: $socket?.id,
id: responseMessageId id: responseMessageId
@ -1574,6 +1576,7 @@
$settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined $settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined
) )
}, },
model_item: $models.find((m) => m.id === model.id),
session_id: $socket?.id, session_id: $socket?.id,
chat_id: $chatId, chat_id: $chatId,

View File

@ -45,6 +45,7 @@
import { getAllTags, getChatList } from '$lib/apis/chats'; import { getAllTags, getChatList } from '$lib/apis/chats';
import NotificationToast from '$lib/components/NotificationToast.svelte'; import NotificationToast from '$lib/components/NotificationToast.svelte';
import AppSidebar from '$lib/components/app/AppSidebar.svelte'; import AppSidebar from '$lib/components/app/AppSidebar.svelte';
import { chatCompletion } from '$lib/apis/openai';
setContext('i18n', i18n); setContext('i18n', i18n);
@ -251,10 +252,100 @@
} else if (type === 'chat:tags') { } else if (type === 'chat:tags') {
tags.set(await getAllTags(localStorage.token)); tags.set(await getAllTags(localStorage.token));
} }
} else { } else if (data?.session_id === $socket.id) {
if (type === 'execute:python') { if (type === 'execute:python') {
console.log('execute:python', data); console.log('execute:python', data);
executePythonAsWorker(data.id, data.code, cb); executePythonAsWorker(data.id, data.code, cb);
} else if (type === 'request:chat:completion') {
console.log(data, $socket.id);
const { session_id, channel, form_data, model } = data;
try {
const directConnections = $settings?.directConnections ?? {};
if (directConnections) {
const urlIdx = model?.urlIdx;
console.log(model, directConnections);
const OPENAI_API_URL = directConnections.OPENAI_API_BASE_URLS[urlIdx];
const OPENAI_API_KEY = directConnections.OPENAI_API_KEYS[urlIdx];
const API_CONFIG = directConnections.OPENAI_API_CONFIGS[urlIdx];
try {
const [res, controller] = await chatCompletion(
OPENAI_API_KEY,
form_data,
OPENAI_API_URL
);
if (res && res.ok) {
if (form_data?.stream ?? false) {
// res will either be SSE or JSON
const reader = res.body.getReader();
const decoder = new TextDecoder();
const processStream = async () => {
while (true) {
// Read data chunks from the response stream
const { done, value } = await reader.read();
if (done) {
break;
}
// Decode the received chunk
const chunk = decoder.decode(value, { stream: true });
// Process lines within the chunk
const lines = chunk.split('\n').filter((line) => line.trim() !== '');
for (const line of lines) {
$socket?.emit(channel, line);
}
}
};
// Process the stream in the background
await processStream();
} else {
const data = await res.json();
cb(data);
}
} else {
throw new Error('An error occurred while fetching the completion');
}
} catch (error) {
console.error('chatCompletion', error);
if (form_data?.stream ?? false) {
$socket.emit(channel, {
error: error
});
} else {
cb({
error: error
});
}
}
}
} catch (error) {
console.error('chatCompletion', error);
if (form_data?.stream ?? false) {
$socket.emit(channel, {
error: error
});
} else {
cb({
error: error
});
}
} finally {
$socket.emit(channel, {
done: true
});
}
} else {
console.log('chatEventHandler', event);
} }
} }
}; };