diff --git a/backend/main.py b/backend/main.py index befa7828a..3ffb5bdd7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -141,7 +141,8 @@ class RAGMiddleware(BaseHTTPMiddleware): return_citations = False if request.method == "POST" and ( - "/api/chat" in request.url.path or "/chat/completions" in request.url.path + "/ollama/api/chat" in request.url.path + or "/chat/completions" in request.url.path ): log.debug(f"request.url.path: {request.url.path}") @@ -229,7 +230,8 @@ app.add_middleware(RAGMiddleware) class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if request.method == "POST" and ( - "/api/chat" in request.url.path or "/chat/completions" in request.url.path + "/ollama/api/chat" in request.url.path + or "/chat/completions" in request.url.path ): log.debug(f"request.url.path: {request.url.path}") @@ -308,6 +310,9 @@ class PipelineMiddleware(BaseHTTPMiddleware): else: pass + if "chat_id" in data: + del data["chat_id"] + modified_body_bytes = json.dumps(data).encode("utf-8") # Replace the request body with the modified one request._body = modified_body_bytes @@ -464,6 +469,69 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.post("/api/chat/completed") +async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + data = form_data + model_id = data["model"] + + filters = [ + model + for model in app.state.MODELS.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers=headers, + json={ + "user": {"id": user.id, "name": user.name, "role": user.role}, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except: + pass + + else: + pass + + return data + + @app.get("/api/pipelines/list") async def get_pipelines_list(user=Depends(get_admin_user)): responses = await get_openai_models(raw=True) diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 6cc3d5405..63a0c21b9 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -49,6 +49,45 @@ export const getModels = async (token: string = '') => { return models; }; +type ChatCompletedForm = { + model: string; + messages: string[]; + chat_id: string; +}; + +export const chatCompleted = async (token: string, body: ChatCompletedForm) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/chat/completed`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify(body) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getPipelinesList = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index ffbed16a7..bb0975445 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -48,6 +48,7 @@ import { runWebSearch } from '$lib/apis/rag'; import Banner from '../common/Banner.svelte'; import { getUserSettings } from '$lib/apis/users'; + import { chatCompleted } from '$lib/apis'; const i18n: Writable = getContext('i18n'); @@ -576,7 +577,8 @@ format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, docs: docs.length > 0 ? docs : undefined, - citations: docs.length > 0 + citations: docs.length > 0, + chat_id: $chatId }); if (res && res.ok) { @@ -596,6 +598,27 @@ if (stopResponseFlag) { controller.abort('User: Stop Response'); await cancelOllamaRequest(localStorage.token, currentRequestId); + } else { + const res = await chatCompleted(localStorage.token, { + model: model, + messages: messages.map((m) => ({ + id: m.id, + role: m.role, + content: m.content, + timestamp: m.timestamp + })), + chat_id: $chatId + }).catch((error) => { + console.error(error); + return null; + }); + + if (res !== null) { + // Update chat history with the new messages + for (const message of res.messages) { + history.messages[message.id] = { ...history.messages[message.id], ...message }; + } + } } currentRequestId = null; @@ -829,7 +852,8 @@ frequency_penalty: $settings?.params?.frequency_penalty ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined, docs: docs.length > 0 ? docs : undefined, - citations: docs.length > 0 + citations: docs.length > 0, + chat_id: $chatId }, `${OPENAI_API_BASE_URL}` ); @@ -855,6 +879,27 @@ if (stopResponseFlag) { controller.abort('User: Stop Response'); + } else { + const res = await chatCompleted(localStorage.token, { + model: model, + messages: messages.map((m) => ({ + id: m.id, + role: m.role, + content: m.content, + timestamp: m.timestamp + })), + chat_id: $chatId + }).catch((error) => { + console.error(error); + return null; + }); + + if (res !== null) { + // Update chat history with the new messages + for (const message of res.messages) { + history.messages[message.id] = { ...history.messages[message.id], ...message }; + } + } } break;