diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 1eeae85fc..a94dc37f1 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -5,6 +5,7 @@ from fastapi.concurrency import run_in_threadpool import requests import json +import uuid from pydantic import BaseModel from apps.web.models.users import Users @@ -26,6 +27,9 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL # TARGET_SERVER_URL = OLLAMA_API_BASE_URL +REQUEST_POOL = [] + + @app.get("/url") async def get_ollama_api_url(user=Depends(get_current_user)): if user and user.role == "admin": @@ -49,6 +53,16 @@ async def update_ollama_api_url( raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +@app.get("/cancel/{request_id}") +async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)): + if user: + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + return True + else: + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_current_user)): target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" @@ -74,7 +88,27 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): def get_request(): nonlocal r + + request_id = str(uuid.uuid4()) try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if path in ["chat"]: + yield json.dumps({"id": request_id, "done": False}) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + REQUEST_POOL.remove(request_id) + r = requests.request( method=request.method, url=target_url, @@ -85,8 +119,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): r.raise_for_status() + # r.close() + return StreamingResponse( - r.iter_content(chunk_size=8192), + stream_content(), status_code=r.status_code, headers=dict(r.headers), ) diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index e863e51ec..625019660 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -206,9 +206,11 @@ export const generatePrompt = async (token: string = '', model: string, conversa }; export const generateChatCompletion = async (token: string = '', body: object) => { + let controller = new AbortController(); let error = null; const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, { + signal: controller.signal, method: 'POST', headers: { 'Content-Type': 'text/event-stream', @@ -224,6 +226,27 @@ export const generateChatCompletion = async (token: string = '', body: object) = throw error; } + return [res, controller]; +}; + +export const cancelChatCompletion = async (token: string = '', requestId: string) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, { + method: 'GET', + headers: { + 'Content-Type': 'text/event-stream', + Authorization: `Bearer ${token}` + } + }).catch((err) => { + error = err; + return null; + }); + + if (error) { + throw error; + } + return res; }; diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index e0be0757f..2f3983a3b 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -9,7 +9,7 @@ import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; import { copyToClipboard, splitStream } from '$lib/utils'; - import { generateChatCompletion, generateTitle } from '$lib/apis/ollama'; + import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama'; import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats'; import { queryVectorDB } from '$lib/apis/rag'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; @@ -24,6 +24,8 @@ let autoScroll = true; let processing = ''; + let currentRequestId = null; + let selectedModels = ['']; let selectedModelfile = null; @@ -279,7 +281,7 @@ // Scroll down window.scrollTo({ top: document.body.scrollHeight }); - const res = await generateChatCompletion(localStorage.token, { + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model, messages: [ $settings.system @@ -307,6 +309,8 @@ }); if (res && res.ok) { + console.log('controller', controller); + const reader = res.body .pipeThrough(new TextDecoderStream()) .pipeThrough(splitStream('\n')) @@ -317,6 +321,14 @@ if (done || stopResponseFlag || _chatId !== $chatId) { responseMessage.done = true; messages = messages; + + if (stopResponseFlag) { + controller.abort('User: Stop Response'); + await cancelChatCompletion(localStorage.token, currentRequestId); + } + + currentRequestId = null; + break; } @@ -332,52 +344,57 @@ throw data; } - if (data.done == false) { - if (responseMessage.content == '' && data.message.content == '\n') { - continue; - } else { - responseMessage.content += data.message.content; - messages = messages; - } + if ('id' in data) { + console.log(data); + currentRequestId = data.id; } else { - responseMessage.done = true; + if (data.done == false) { + if (responseMessage.content == '' && data.message.content == '\n') { + continue; + } else { + responseMessage.content += data.message.content; + messages = messages; + } + } else { + responseMessage.done = true; - if (responseMessage.content == '') { - responseMessage.error = true; - responseMessage.content = - 'Oops! No text generated from Ollama, Please try again.'; - } + if (responseMessage.content == '') { + responseMessage.error = true; + responseMessage.content = + 'Oops! No text generated from Ollama, Please try again.'; + } - responseMessage.context = data.context ?? null; - responseMessage.info = { - total_duration: data.total_duration, - load_duration: data.load_duration, - sample_count: data.sample_count, - sample_duration: data.sample_duration, - prompt_eval_count: data.prompt_eval_count, - prompt_eval_duration: data.prompt_eval_duration, - eval_count: data.eval_count, - eval_duration: data.eval_duration - }; - messages = messages; + responseMessage.context = data.context ?? null; + responseMessage.info = { + total_duration: data.total_duration, + load_duration: data.load_duration, + sample_count: data.sample_count, + sample_duration: data.sample_duration, + prompt_eval_count: data.prompt_eval_count, + prompt_eval_duration: data.prompt_eval_duration, + eval_count: data.eval_count, + eval_duration: data.eval_duration + }; + messages = messages; - if ($settings.notificationEnabled && !document.hasFocus()) { - const notification = new Notification( - selectedModelfile - ? `${ - selectedModelfile.title.charAt(0).toUpperCase() + - selectedModelfile.title.slice(1) - }` - : `Ollama - ${model}`, - { - body: responseMessage.content, - icon: selectedModelfile?.imageUrl ?? '/favicon.png' - } - ); - } + if ($settings.notificationEnabled && !document.hasFocus()) { + const notification = new Notification( + selectedModelfile + ? `${ + selectedModelfile.title.charAt(0).toUpperCase() + + selectedModelfile.title.slice(1) + }` + : `Ollama - ${model}`, + { + body: responseMessage.content, + icon: selectedModelfile?.imageUrl ?? '/favicon.png' + } + ); + } - if ($settings.responseAutoCopy) { - copyToClipboard(responseMessage.content); + if ($settings.responseAutoCopy) { + copyToClipboard(responseMessage.content); + } } } } diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index e0a79e35c..bb33aa808 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -297,7 +297,7 @@ // Scroll down window.scrollTo({ top: document.body.scrollHeight }); - const res = await generateChatCompletion(localStorage.token, { + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model, messages: [ $settings.system @@ -335,6 +335,10 @@ if (done || stopResponseFlag || _chatId !== $chatId) { responseMessage.done = true; messages = messages; + + if (stopResponseFlag) { + controller.abort('User: Stop Response'); + } break; } @@ -350,52 +354,56 @@ throw data; } - if (data.done == false) { - if (responseMessage.content == '' && data.message.content == '\n') { - continue; - } else { - responseMessage.content += data.message.content; - messages = messages; - } + if ('id' in data) { + console.log(data); } else { - responseMessage.done = true; + if (data.done == false) { + if (responseMessage.content == '' && data.message.content == '\n') { + continue; + } else { + responseMessage.content += data.message.content; + messages = messages; + } + } else { + responseMessage.done = true; - if (responseMessage.content == '') { - responseMessage.error = true; - responseMessage.content = - 'Oops! No text generated from Ollama, Please try again.'; - } + if (responseMessage.content == '') { + responseMessage.error = true; + responseMessage.content = + 'Oops! No text generated from Ollama, Please try again.'; + } - responseMessage.context = data.context ?? null; - responseMessage.info = { - total_duration: data.total_duration, - load_duration: data.load_duration, - sample_count: data.sample_count, - sample_duration: data.sample_duration, - prompt_eval_count: data.prompt_eval_count, - prompt_eval_duration: data.prompt_eval_duration, - eval_count: data.eval_count, - eval_duration: data.eval_duration - }; - messages = messages; + responseMessage.context = data.context ?? null; + responseMessage.info = { + total_duration: data.total_duration, + load_duration: data.load_duration, + sample_count: data.sample_count, + sample_duration: data.sample_duration, + prompt_eval_count: data.prompt_eval_count, + prompt_eval_duration: data.prompt_eval_duration, + eval_count: data.eval_count, + eval_duration: data.eval_duration + }; + messages = messages; - if ($settings.notificationEnabled && !document.hasFocus()) { - const notification = new Notification( - selectedModelfile - ? `${ - selectedModelfile.title.charAt(0).toUpperCase() + - selectedModelfile.title.slice(1) - }` - : `Ollama - ${model}`, - { - body: responseMessage.content, - icon: selectedModelfile?.imageUrl ?? '/favicon.png' - } - ); - } + if ($settings.notificationEnabled && !document.hasFocus()) { + const notification = new Notification( + selectedModelfile + ? `${ + selectedModelfile.title.charAt(0).toUpperCase() + + selectedModelfile.title.slice(1) + }` + : `Ollama - ${model}`, + { + body: responseMessage.content, + icon: selectedModelfile?.imageUrl ?? '/favicon.png' + } + ); + } - if ($settings.responseAutoCopy) { - copyToClipboard(responseMessage.content); + if ($settings.responseAutoCopy) { + copyToClipboard(responseMessage.content); + } } } }