From 4dd51badfe2be1a0a85b817a1325227e92938949 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Sun, 2 Jun 2024 18:06:12 +0100 Subject: [PATCH] fix: ollama streaming cancellation using aiohttp --- backend/apps/ollama/main.py | 399 +++--------------- src/lib/apis/ollama/index.ts | 25 +- src/lib/components/chat/Chat.svelte | 109 +++-- .../chat/ModelSelector/Selector.svelte | 42 +- .../components/chat/Settings/Models.svelte | 55 +-- .../components/workspace/Playground.svelte | 14 +- 6 files changed, 154 insertions(+), 490 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 3ad3d8808..76709b0ee 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -29,6 +29,8 @@ import time from urllib.parse import urlparse from typing import Optional, List, Union +from starlette.background import BackgroundTask + from apps.webui.models.models import Models from apps.webui.models.users import Users from constants import ERROR_MESSAGES @@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} -REQUEST_POOL = [] - - # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # least connections, or least response time for better resource utilization and performance optimization. @@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} -@app.get("/cancel/{request_id}") -async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)): - if user: - if request_id in REQUEST_POOL: - REQUEST_POOL.remove(request_id) - return True - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - async def fetch_url(url): timeout = aiohttp.ClientTimeout(total=5) try: @@ -154,6 +143,45 @@ async def fetch_url(url): return None +async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], +): + if response: + response.close() + if session: + await session.close() + + +async def post_streaming_url(url, payload): + r = None + try: + session = aiohttp.ClientSession() + r = await session.post(url, data=payload) + r.raise_for_status() + + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask(cleanup_response, response=r, session=session), + ) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = await r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status if r else 500, + detail=error_detail, + ) + + def merge_models_lists(model_lists): merged_models = {} @@ -313,65 +341,7 @@ async def pull_model( # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} - def get_request(): - nonlocal url - nonlocal r - - request_id = str(uuid.uuid4()) - try: - REQUEST_POOL.append(request_id) - - def stream_content(): - try: - yield json.dumps({"id": request_id, "done": False}) + "\n" - - for chunk in r.iter_content(chunk_size=8192): - if request_id in REQUEST_POOL: - yield chunk - else: - log.warning("User: canceled request") - break - finally: - if hasattr(r, "close"): - r.close() - if request_id in REQUEST_POOL: - REQUEST_POOL.remove(request_id) - - r = requests.request( - method="POST", - url=f"{url}/api/pull", - data=json.dumps(payload), - stream=True, - ) - - r.raise_for_status() - - return StreamingResponse( - stream_content(), - status_code=r.status_code, - headers=dict(r.headers), - ) - except Exception as e: - raise e - - try: - return await run_in_threadpool(get_request) - - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return await post_streaming_url(f"{url}/api/pull", json.dumps(payload)) class PushModelForm(BaseModel): @@ -399,50 +369,9 @@ async def push_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") - r = None - - def get_request(): - nonlocal url - nonlocal r - try: - - def stream_content(): - for chunk in r.iter_content(chunk_size=8192): - yield chunk - - r = requests.request( - method="POST", - url=f"{url}/api/push", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - - r.raise_for_status() - - return StreamingResponse( - stream_content(), - status_code=r.status_code, - headers=dict(r.headers), - ) - except Exception as e: - raise e - - try: - return await run_in_threadpool(get_request) - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return await post_streaming_url( + f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode() + ) class CreateModelForm(BaseModel): @@ -461,53 +390,9 @@ async def create_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = None - - def get_request(): - nonlocal url - nonlocal r - try: - - def stream_content(): - for chunk in r.iter_content(chunk_size=8192): - yield chunk - - r = requests.request( - method="POST", - url=f"{url}/api/create", - data=form_data.model_dump_json(exclude_none=True).encode(), - stream=True, - ) - - r.raise_for_status() - - log.debug(f"r: {r}") - - return StreamingResponse( - stream_content(), - status_code=r.status_code, - headers=dict(r.headers), - ) - except Exception as e: - raise e - - try: - return await run_in_threadpool(get_request) - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return await post_streaming_url( + f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode() + ) class CopyModelForm(BaseModel): @@ -797,66 +682,9 @@ async def generate_completion( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = None - - def get_request(): - nonlocal form_data - nonlocal r - - request_id = str(uuid.uuid4()) - try: - REQUEST_POOL.append(request_id) - - def stream_content(): - try: - if form_data.stream: - yield json.dumps({"id": request_id, "done": False}) + "\n" - - for chunk in r.iter_content(chunk_size=8192): - if request_id in REQUEST_POOL: - yield chunk - else: - log.warning("User: canceled request") - break - finally: - if hasattr(r, "close"): - r.close() - if request_id in REQUEST_POOL: - REQUEST_POOL.remove(request_id) - - r = requests.request( - method="POST", - url=f"{url}/api/generate", - data=form_data.model_dump_json(exclude_none=True).encode(), - stream=True, - ) - - r.raise_for_status() - - return StreamingResponse( - stream_content(), - status_code=r.status_code, - headers=dict(r.headers), - ) - except Exception as e: - raise e - - try: - return await run_in_threadpool(get_request) - except Exception as e: - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return await post_streaming_url( + f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode() + ) class ChatMessage(BaseModel): @@ -981,67 +809,7 @@ async def generate_chat_completion( print(payload) - r = None - - def get_request(): - nonlocal payload - nonlocal r - - request_id = str(uuid.uuid4()) - try: - REQUEST_POOL.append(request_id) - - def stream_content(): - try: - if payload.get("stream", True): - yield json.dumps({"id": request_id, "done": False}) + "\n" - - for chunk in r.iter_content(chunk_size=8192): - if request_id in REQUEST_POOL: - yield chunk - else: - log.warning("User: canceled request") - break - finally: - if hasattr(r, "close"): - r.close() - if request_id in REQUEST_POOL: - REQUEST_POOL.remove(request_id) - - r = requests.request( - method="POST", - url=f"{url}/api/chat", - data=json.dumps(payload), - stream=True, - ) - - r.raise_for_status() - - return StreamingResponse( - stream_content(), - status_code=r.status_code, - headers=dict(r.headers), - ) - except Exception as e: - log.exception(e) - raise e - - try: - return await run_in_threadpool(get_request) - except Exception as e: - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return await post_streaming_url(f"{url}/api/chat", json.dumps(payload)) # TODO: we should update this part once Ollama supports other types @@ -1132,68 +900,7 @@ async def generate_openai_chat_completion( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = None - - def get_request(): - nonlocal payload - nonlocal r - - request_id = str(uuid.uuid4()) - try: - REQUEST_POOL.append(request_id) - - def stream_content(): - try: - if payload.get("stream"): - yield json.dumps( - {"request_id": request_id, "done": False} - ) + "\n" - - for chunk in r.iter_content(chunk_size=8192): - if request_id in REQUEST_POOL: - yield chunk - else: - log.warning("User: canceled request") - break - finally: - if hasattr(r, "close"): - r.close() - if request_id in REQUEST_POOL: - REQUEST_POOL.remove(request_id) - - r = requests.request( - method="POST", - url=f"{url}/v1/chat/completions", - data=json.dumps(payload), - stream=True, - ) - - r.raise_for_status() - - return StreamingResponse( - stream_content(), - status_code=r.status_code, - headers=dict(r.headers), - ) - except Exception as e: - raise e - - try: - return await run_in_threadpool(get_request) - except Exception as e: - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload)) @app.get("/v1/models") diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 5ab2363cb..aa1ac182b 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -369,27 +369,6 @@ export const generateChatCompletion = async (token: string = '', body: object) = return [res, controller]; }; -export const cancelOllamaRequest = async (token: string = '', requestId: string) => { - let error = null; - - const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, { - method: 'GET', - headers: { - 'Content-Type': 'text/event-stream', - Authorization: `Bearer ${token}` - } - }).catch((err) => { - error = err; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - export const createModel = async (token: string, tagName: string, content: string) => { let error = null; @@ -461,8 +440,10 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => { let error = null; + const controller = new AbortController(); const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, { + signal: controller.signal, method: 'POST', headers: { Accept: 'application/json', @@ -485,7 +466,7 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string | if (error) { throw error; } - return res; + return [res, controller]; }; export const downloadModel = async ( diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 7e389f458..d1bea2e89 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -26,7 +26,7 @@ splitStream } from '$lib/utils'; - import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama'; + import { generateChatCompletion } from '$lib/apis/ollama'; import { addTagById, createNewChat, @@ -65,7 +65,6 @@ let autoScroll = true; let processing = ''; let messagesContainerElement: HTMLDivElement; - let currentRequestId = null; let showModelSelector = true; @@ -130,10 +129,6 @@ ////////////////////////// const initNewChat = async () => { - if (currentRequestId !== null) { - await cancelOllamaRequest(localStorage.token, currentRequestId); - currentRequestId = null; - } window.history.replaceState(history.state, '', `/`); await chatId.set(''); @@ -616,7 +611,6 @@ if (stopResponseFlag) { controller.abort('User: Stop Response'); - await cancelOllamaRequest(localStorage.token, currentRequestId); } else { const messages = createMessagesList(responseMessageId); const res = await chatCompleted(localStorage.token, { @@ -647,8 +641,6 @@ } } - currentRequestId = null; - break; } @@ -669,63 +661,58 @@ throw data; } - if ('id' in data) { - console.log(data); - currentRequestId = data.id; - } else { - if (data.done == false) { - if (responseMessage.content == '' && data.message.content == '\n') { - continue; - } else { - responseMessage.content += data.message.content; - messages = messages; - } + if (data.done == false) { + if (responseMessage.content == '' && data.message.content == '\n') { + continue; } else { - responseMessage.done = true; - - if (responseMessage.content == '') { - responseMessage.error = { - code: 400, - content: `Oops! No text generated from Ollama, Please try again.` - }; - } - - responseMessage.context = data.context ?? null; - responseMessage.info = { - total_duration: data.total_duration, - load_duration: data.load_duration, - sample_count: data.sample_count, - sample_duration: data.sample_duration, - prompt_eval_count: data.prompt_eval_count, - prompt_eval_duration: data.prompt_eval_duration, - eval_count: data.eval_count, - eval_duration: data.eval_duration - }; + responseMessage.content += data.message.content; messages = messages; + } + } else { + responseMessage.done = true; - if ($settings.notificationEnabled && !document.hasFocus()) { - const notification = new Notification( - selectedModelfile - ? `${ - selectedModelfile.title.charAt(0).toUpperCase() + - selectedModelfile.title.slice(1) - }` - : `${model}`, - { - body: responseMessage.content, - icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png` - } - ); - } + if (responseMessage.content == '') { + responseMessage.error = { + code: 400, + content: `Oops! No text generated from Ollama, Please try again.` + }; + } - if ($settings.responseAutoCopy) { - copyToClipboard(responseMessage.content); - } + responseMessage.context = data.context ?? null; + responseMessage.info = { + total_duration: data.total_duration, + load_duration: data.load_duration, + sample_count: data.sample_count, + sample_duration: data.sample_duration, + prompt_eval_count: data.prompt_eval_count, + prompt_eval_duration: data.prompt_eval_duration, + eval_count: data.eval_count, + eval_duration: data.eval_duration + }; + messages = messages; - if ($settings.responseAutoPlayback) { - await tick(); - document.getElementById(`speak-button-${responseMessage.id}`)?.click(); - } + if ($settings.notificationEnabled && !document.hasFocus()) { + const notification = new Notification( + selectedModelfile + ? `${ + selectedModelfile.title.charAt(0).toUpperCase() + + selectedModelfile.title.slice(1) + }` + : `${model}`, + { + body: responseMessage.content, + icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png` + } + ); + } + + if ($settings.responseAutoCopy) { + copyToClipboard(responseMessage.content); + } + + if ($settings.responseAutoPlayback) { + await tick(); + document.getElementById(`speak-button-${responseMessage.id}`)?.click(); } } } diff --git a/src/lib/components/chat/ModelSelector/Selector.svelte b/src/lib/components/chat/ModelSelector/Selector.svelte index 868c75da7..dad1fa512 100644 --- a/src/lib/components/chat/ModelSelector/Selector.svelte +++ b/src/lib/components/chat/ModelSelector/Selector.svelte @@ -8,7 +8,7 @@ import Check from '$lib/components/icons/Check.svelte'; import Search from '$lib/components/icons/Search.svelte'; - import { cancelOllamaRequest, deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama'; + import { deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama'; import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores'; import { toast } from 'svelte-sonner'; @@ -72,10 +72,12 @@ return; } - const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => { - toast.error(error); - return null; - }); + const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch( + (error) => { + toast.error(error); + return null; + } + ); if (res) { const reader = res.body @@ -83,6 +85,16 @@ .pipeThrough(splitStream('\n')) .getReader(); + MODEL_DOWNLOAD_POOL.set({ + ...$MODEL_DOWNLOAD_POOL, + [sanitizedModelTag]: { + ...$MODEL_DOWNLOAD_POOL[sanitizedModelTag], + abortController: controller, + reader, + done: false + } + }); + while (true) { try { const { value, done } = await reader.read(); @@ -101,19 +113,6 @@ throw data.detail; } - if (data.id) { - MODEL_DOWNLOAD_POOL.set({ - ...$MODEL_DOWNLOAD_POOL, - [sanitizedModelTag]: { - ...$MODEL_DOWNLOAD_POOL[sanitizedModelTag], - requestId: data.id, - reader, - done: false - } - }); - console.log(data); - } - if (data.status) { if (data.digest) { let downloadProgress = 0; @@ -181,11 +180,12 @@ }); const cancelModelPullHandler = async (model: string) => { - const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model]; + const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model]; + if (abortController) { + abortController.abort(); + } if (reader) { await reader.cancel(); - - await cancelOllamaRequest(localStorage.token, requestId); delete $MODEL_DOWNLOAD_POOL[model]; MODEL_DOWNLOAD_POOL.set({ ...$MODEL_DOWNLOAD_POOL diff --git a/src/lib/components/chat/Settings/Models.svelte b/src/lib/components/chat/Settings/Models.svelte index 7254f2d27..eb4b7f7b5 100644 --- a/src/lib/components/chat/Settings/Models.svelte +++ b/src/lib/components/chat/Settings/Models.svelte @@ -8,7 +8,6 @@ getOllamaUrls, getOllamaVersion, pullModel, - cancelOllamaRequest, uploadModel } from '$lib/apis/ollama'; @@ -67,12 +66,14 @@ console.log(model); updateModelId = model.id; - const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch( - (error) => { - toast.error(error); - return null; - } - ); + const [res, controller] = await pullModel( + localStorage.token, + model.id, + selectedOllamaUrlIdx + ).catch((error) => { + toast.error(error); + return null; + }); if (res) { const reader = res.body @@ -141,10 +142,12 @@ return; } - const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => { - toast.error(error); - return null; - }); + const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch( + (error) => { + toast.error(error); + return null; + } + ); if (res) { const reader = res.body @@ -152,6 +155,16 @@ .pipeThrough(splitStream('\n')) .getReader(); + MODEL_DOWNLOAD_POOL.set({ + ...$MODEL_DOWNLOAD_POOL, + [sanitizedModelTag]: { + ...$MODEL_DOWNLOAD_POOL[sanitizedModelTag], + abortController: controller, + reader, + done: false + } + }); + while (true) { try { const { value, done } = await reader.read(); @@ -170,19 +183,6 @@ throw data.detail; } - if (data.id) { - MODEL_DOWNLOAD_POOL.set({ - ...$MODEL_DOWNLOAD_POOL, - [sanitizedModelTag]: { - ...$MODEL_DOWNLOAD_POOL[sanitizedModelTag], - requestId: data.id, - reader, - done: false - } - }); - console.log(data); - } - if (data.status) { if (data.digest) { let downloadProgress = 0; @@ -416,11 +416,12 @@ }; const cancelModelPullHandler = async (model: string) => { - const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model]; + const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model]; + if (abortController) { + abortController.abort(); + } if (reader) { await reader.cancel(); - - await cancelOllamaRequest(localStorage.token, requestId); delete $MODEL_DOWNLOAD_POOL[model]; MODEL_DOWNLOAD_POOL.set({ ...$MODEL_DOWNLOAD_POOL diff --git a/src/lib/components/workspace/Playground.svelte b/src/lib/components/workspace/Playground.svelte index 476ce774d..b7453e3f3 100644 --- a/src/lib/components/workspace/Playground.svelte +++ b/src/lib/components/workspace/Playground.svelte @@ -8,7 +8,7 @@ import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants'; import { WEBUI_NAME, config, user, models, settings } from '$lib/stores'; - import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama'; + import { generateChatCompletion } from '$lib/apis/ollama'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { splitStream } from '$lib/utils'; @@ -24,7 +24,6 @@ let selectedModelId = ''; let loading = false; - let currentRequestId = null; let stopResponseFlag = false; let messagesContainerElement: HTMLDivElement; @@ -46,14 +45,6 @@ } }; - // const cancelHandler = async () => { - // if (currentRequestId) { - // const res = await cancelOllamaRequest(localStorage.token, currentRequestId); - // currentRequestId = null; - // loading = false; - // } - // }; - const stopResponse = () => { stopResponseFlag = true; console.log('stopResponse'); @@ -171,8 +162,6 @@ if (stopResponseFlag) { controller.abort('User: Stop Response'); } - - currentRequestId = null; break; } @@ -229,7 +218,6 @@ loading = false; stopResponseFlag = false; - currentRequestId = null; } };