diff --git a/.gitignore b/.gitignore index 528e1f830..2ccac4d50 100644 --- a/.gitignore +++ b/.gitignore @@ -166,7 +166,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ # Logs logs diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 2a67b02d2..370e54baf 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -250,11 +250,26 @@ async def pull_model( def get_request(): nonlocal url nonlocal r + + request_id = str(uuid.uuid4()) try: + REQUEST_POOL.append(request_id) def stream_content(): - for chunk in r.iter_content(chunk_size=8192): - yield chunk + try: + yield json.dumps({"id": request_id, "done": False}) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) r = requests.request( method="POST", @@ -275,6 +290,7 @@ async def pull_model( try: return await run_in_threadpool(get_request) + except Exception as e: print(e) error_detail = "Open WebUI: Server Connection Error" diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index c3b37e007..7c4e809e5 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -271,7 +271,7 @@ export const generateChatCompletion = async (token: string = '', body: object) = return [res, controller]; }; -export const cancelChatCompletion = async (token: string = '', requestId: string) => { +export const cancelOllamaRequest = async (token: string = '', requestId: string) => { let error = null; const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, { diff --git a/src/lib/components/chat/Settings/Models.svelte b/src/lib/components/chat/Settings/Models.svelte index b19984b45..6de76a483 100644 --- a/src/lib/components/chat/Settings/Models.svelte +++ b/src/lib/components/chat/Settings/Models.svelte @@ -9,6 +9,7 @@ getOllamaUrls, getOllamaVersion, pullModel, + cancelOllamaRequest, uploadModel } from '$lib/apis/ollama'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; @@ -163,7 +164,7 @@ // Remove the downloaded model delete modelDownloadStatus[modelName]; - console.log(data); + modelDownloadStatus = { ...modelDownloadStatus }; if (!data.success) { toast.error(data.error); @@ -372,12 +373,24 @@ for (const line of lines) { if (line !== '') { let data = JSON.parse(line); + console.log(data); if (data.error) { throw data.error; } if (data.detail) { throw data.detail; } + + if (data.id) { + modelDownloadStatus[opts.modelName] = { + ...modelDownloadStatus[opts.modelName], + requestId: data.id, + reader, + done: false + }; + console.log(data); + } + if (data.status) { if (data.digest) { let downloadProgress = 0; @@ -387,11 +400,17 @@ downloadProgress = 100; } modelDownloadStatus[opts.modelName] = { + ...modelDownloadStatus[opts.modelName], pullProgress: downloadProgress, digest: data.digest }; } else { toast.success(data.status); + + modelDownloadStatus[opts.modelName] = { + ...modelDownloadStatus[opts.modelName], + done: data.status === 'success' + }; } } } @@ -404,7 +423,14 @@ opts.callback({ success: false, error, modelName: opts.modelName }); } } - opts.callback({ success: true, modelName: opts.modelName }); + + console.log(modelDownloadStatus[opts.modelName]); + + if (modelDownloadStatus[opts.modelName].done) { + opts.callback({ success: true, modelName: opts.modelName }); + } else { + opts.callback({ success: false, error: 'Download canceled', modelName: opts.modelName }); + } } }; @@ -474,6 +500,18 @@ ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false); liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token); }); + + const cancelModelPullHandler = async (model: string) => { + const { reader, requestId } = modelDownloadStatus[model]; + if (reader) { + await reader.cancel(); + + await cancelOllamaRequest(localStorage.token, requestId); + delete modelDownloadStatus[model]; + await deleteModel(localStorage.token, model); + toast.success(`${model} download has been canceled`); + } + };
@@ -604,20 +642,58 @@ {#if Object.keys(modelDownloadStatus).length > 0} {#each Object.keys(modelDownloadStatus) as model} -
-
{model}
-
-
- {modelDownloadStatus[model].pullProgress ?? 0}% -
-
- {modelDownloadStatus[model].digest} + {#if 'pullProgress' in modelDownloadStatus[model]} +
+
{model}
+
+
+
+
+ {modelDownloadStatus[model].pullProgress ?? 0}% +
+
+ + + + +
+ {#if 'digest' in modelDownloadStatus[model]} +
+ {modelDownloadStatus[model].digest} +
+ {/if}
-
+ {/if} {/each} {/if}
diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 0fca312ae..417ddccda 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -19,7 +19,7 @@ } from '$lib/stores'; import { copyToClipboard, splitStream } from '$lib/utils'; - import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama'; + import { generateChatCompletion, cancelOllamaRequest, generateTitle } from '$lib/apis/ollama'; import { addTagById, createNewChat, @@ -104,7 +104,7 @@ const initNewChat = async () => { if (currentRequestId !== null) { - await cancelChatCompletion(localStorage.token, currentRequestId); + await cancelOllamaRequest(localStorage.token, currentRequestId); currentRequestId = null; } window.history.replaceState(history.state, '', `/`); @@ -372,7 +372,7 @@ if (stopResponseFlag) { controller.abort('User: Stop Response'); - await cancelChatCompletion(localStorage.token, currentRequestId); + await cancelOllamaRequest(localStorage.token, currentRequestId); } currentRequestId = null; diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index faa15b4b6..836fc90a4 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -19,7 +19,7 @@ } from '$lib/stores'; import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils'; - import { generateChatCompletion, generateTitle, cancelChatCompletion } from '$lib/apis/ollama'; + import { generateChatCompletion, generateTitle, cancelOllamaRequest } from '$lib/apis/ollama'; import { addTagById, createNewChat, @@ -382,7 +382,7 @@ if (stopResponseFlag) { controller.abort('User: Stop Response'); - await cancelChatCompletion(localStorage.token, currentRequestId); + await cancelOllamaRequest(localStorage.token, currentRequestId); } currentRequestId = null; @@ -843,7 +843,7 @@ shareEnabled={messages.length > 0} initNewChat={async () => { if (currentRequestId !== null) { - await cancelChatCompletion(localStorage.token, currentRequestId); + await cancelOllamaRequest(localStorage.token, currentRequestId); currentRequestId = null; } diff --git a/src/routes/(app)/playground/+page.svelte b/src/routes/(app)/playground/+page.svelte index 737eff224..d8e9320dc 100644 --- a/src/routes/(app)/playground/+page.svelte +++ b/src/routes/(app)/playground/+page.svelte @@ -13,7 +13,7 @@ } from '$lib/constants'; import { WEBUI_NAME, config, user, models, settings } from '$lib/stores'; - import { cancelChatCompletion, generateChatCompletion } from '$lib/apis/ollama'; + import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { splitStream } from '$lib/utils'; @@ -52,7 +52,7 @@ // const cancelHandler = async () => { // if (currentRequestId) { - // const res = await cancelChatCompletion(localStorage.token, currentRequestId); + // const res = await cancelOllamaRequest(localStorage.token, currentRequestId); // currentRequestId = null; // loading = false; // } @@ -95,7 +95,7 @@ const { value, done } = await reader.read(); if (done || stopResponseFlag) { if (stopResponseFlag) { - await cancelChatCompletion(localStorage.token, currentRequestId); + await cancelOllamaRequest(localStorage.token, currentRequestId); } currentRequestId = null; @@ -181,7 +181,7 @@ const { value, done } = await reader.read(); if (done || stopResponseFlag) { if (stopResponseFlag) { - await cancelChatCompletion(localStorage.token, currentRequestId); + await cancelOllamaRequest(localStorage.token, currentRequestId); } currentRequestId = null;