fix: ollama streaming cancellation using aiohttp

This commit is contained in:
Jun Siang Cheah 2024-06-02 18:06:12 +01:00
parent 24c35c308d
commit 4dd51badfe
6 changed files with 154 additions and 490 deletions

View File

@ -29,6 +29,8 @@ import time
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, List, Union
from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
REQUEST_POOL = []
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization. # least connections, or least response time for better resource utilization and performance optimization.
@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
return True
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
async def fetch_url(url): async def fetch_url(url):
timeout = aiohttp.ClientTimeout(total=5) timeout = aiohttp.ClientTimeout(total=5)
try: try:
@ -154,6 +143,45 @@ async def fetch_url(url):
return None return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
async def post_streaming_url(url, payload):
r = None
try:
session = aiohttp.ClientSession()
r = await session.post(url, data=payload)
r.raise_for_status()
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(cleanup_response, response=r, session=session),
)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=error_detail,
)
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
merged_models = {} merged_models = {}
@ -313,65 +341,7 @@ async def pull_model(
# Admin should be able to pull models from any source # Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True} payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
def get_request(): return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
nonlocal url
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/pull",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class PushModelForm(BaseModel): class PushModelForm(BaseModel):
@ -399,49 +369,8 @@ async def push_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}") log.debug(f"url: {url}")
r = None return await post_streaming_url(
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
def get_request():
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/push",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
) )
@ -461,52 +390,8 @@ async def create_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None return await post_streaming_url(
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
def get_request():
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/create",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
log.debug(f"r: {r}")
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
) )
@ -797,65 +682,8 @@ async def generate_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None return await post_streaming_url(
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
def get_request():
nonlocal form_data
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if form_data.stream:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/generate",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
) )
@ -981,67 +809,7 @@ async def generate_chat_completion(
print(payload) print(payload)
r = None return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/chat",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
log.exception(e)
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
# TODO: we should update this part once Ollama supports other types # TODO: we should update this part once Ollama supports other types
@ -1132,68 +900,7 @@ async def generate_openai_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream"):
yield json.dumps(
{"request_id": request_id, "done": False}
) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/v1/chat/completions",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
@app.get("/v1/models") @app.get("/v1/models")

View File

@ -369,27 +369,6 @@ export const generateChatCompletion = async (token: string = '', body: object) =
return [res, controller]; return [res, controller];
}; };
export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {
method: 'GET',
headers: {
'Content-Type': 'text/event-stream',
Authorization: `Bearer ${token}`
}
}).catch((err) => {
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
export const createModel = async (token: string, tagName: string, content: string) => { export const createModel = async (token: string, tagName: string, content: string) => {
let error = null; let error = null;
@ -461,8 +440,10 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => { export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => {
let error = null; let error = null;
const controller = new AbortController();
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, {
signal: controller.signal,
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -485,7 +466,7 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
if (error) { if (error) {
throw error; throw error;
} }
return res; return [res, controller];
}; };
export const downloadModel = async ( export const downloadModel = async (

View File

@ -26,7 +26,7 @@
splitStream splitStream
} from '$lib/utils'; } from '$lib/utils';
import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama'; import { generateChatCompletion } from '$lib/apis/ollama';
import { import {
addTagById, addTagById,
createNewChat, createNewChat,
@ -65,7 +65,6 @@
let autoScroll = true; let autoScroll = true;
let processing = ''; let processing = '';
let messagesContainerElement: HTMLDivElement; let messagesContainerElement: HTMLDivElement;
let currentRequestId = null;
let showModelSelector = true; let showModelSelector = true;
@ -130,10 +129,6 @@
////////////////////////// //////////////////////////
const initNewChat = async () => { const initNewChat = async () => {
if (currentRequestId !== null) {
await cancelOllamaRequest(localStorage.token, currentRequestId);
currentRequestId = null;
}
window.history.replaceState(history.state, '', `/`); window.history.replaceState(history.state, '', `/`);
await chatId.set(''); await chatId.set('');
@ -616,7 +611,6 @@
if (stopResponseFlag) { if (stopResponseFlag) {
controller.abort('User: Stop Response'); controller.abort('User: Stop Response');
await cancelOllamaRequest(localStorage.token, currentRequestId);
} else { } else {
const messages = createMessagesList(responseMessageId); const messages = createMessagesList(responseMessageId);
const res = await chatCompleted(localStorage.token, { const res = await chatCompleted(localStorage.token, {
@ -647,8 +641,6 @@
} }
} }
currentRequestId = null;
break; break;
} }
@ -669,10 +661,6 @@
throw data; throw data;
} }
if ('id' in data) {
console.log(data);
currentRequestId = data.id;
} else {
if (data.done == false) { if (data.done == false) {
if (responseMessage.content == '' && data.message.content == '\n') { if (responseMessage.content == '' && data.message.content == '\n') {
continue; continue;
@ -729,7 +717,6 @@
} }
} }
} }
}
} catch (error) { } catch (error) {
console.log(error); console.log(error);
if ('detail' in error) { if ('detail' in error) {

View File

@ -8,7 +8,7 @@
import Check from '$lib/components/icons/Check.svelte'; import Check from '$lib/components/icons/Check.svelte';
import Search from '$lib/components/icons/Search.svelte'; import Search from '$lib/components/icons/Search.svelte';
import { cancelOllamaRequest, deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama'; import { deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores'; import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
@ -72,10 +72,12 @@
return; return;
} }
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => { const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
(error) => {
toast.error(error); toast.error(error);
return null; return null;
}); }
);
if (res) { if (res) {
const reader = res.body const reader = res.body
@ -83,6 +85,16 @@
.pipeThrough(splitStream('\n')) .pipeThrough(splitStream('\n'))
.getReader(); .getReader();
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
abortController: controller,
reader,
done: false
}
});
while (true) { while (true) {
try { try {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
@ -101,19 +113,6 @@
throw data.detail; throw data.detail;
} }
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) { if (data.status) {
if (data.digest) { if (data.digest) {
let downloadProgress = 0; let downloadProgress = 0;
@ -181,11 +180,12 @@
}); });
const cancelModelPullHandler = async (model: string) => { const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model]; const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
if (abortController) {
abortController.abort();
}
if (reader) { if (reader) {
await reader.cancel(); await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model]; delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({ MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL ...$MODEL_DOWNLOAD_POOL

View File

@ -8,7 +8,6 @@
getOllamaUrls, getOllamaUrls,
getOllamaVersion, getOllamaVersion,
pullModel, pullModel,
cancelOllamaRequest,
uploadModel uploadModel
} from '$lib/apis/ollama'; } from '$lib/apis/ollama';
@ -67,12 +66,14 @@
console.log(model); console.log(model);
updateModelId = model.id; updateModelId = model.id;
const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch( const [res, controller] = await pullModel(
(error) => { localStorage.token,
model.id,
selectedOllamaUrlIdx
).catch((error) => {
toast.error(error); toast.error(error);
return null; return null;
} });
);
if (res) { if (res) {
const reader = res.body const reader = res.body
@ -141,10 +142,12 @@
return; return;
} }
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => { const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
(error) => {
toast.error(error); toast.error(error);
return null; return null;
}); }
);
if (res) { if (res) {
const reader = res.body const reader = res.body
@ -152,6 +155,16 @@
.pipeThrough(splitStream('\n')) .pipeThrough(splitStream('\n'))
.getReader(); .getReader();
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
abortController: controller,
reader,
done: false
}
});
while (true) { while (true) {
try { try {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
@ -170,19 +183,6 @@
throw data.detail; throw data.detail;
} }
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) { if (data.status) {
if (data.digest) { if (data.digest) {
let downloadProgress = 0; let downloadProgress = 0;
@ -416,11 +416,12 @@
}; };
const cancelModelPullHandler = async (model: string) => { const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model]; const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
if (abortController) {
abortController.abort();
}
if (reader) { if (reader) {
await reader.cancel(); await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model]; delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({ MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL ...$MODEL_DOWNLOAD_POOL

View File

@ -8,7 +8,7 @@
import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants'; import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants';
import { WEBUI_NAME, config, user, models, settings } from '$lib/stores'; import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama'; import { generateChatCompletion } from '$lib/apis/ollama';
import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
@ -24,7 +24,6 @@
let selectedModelId = ''; let selectedModelId = '';
let loading = false; let loading = false;
let currentRequestId = null;
let stopResponseFlag = false; let stopResponseFlag = false;
let messagesContainerElement: HTMLDivElement; let messagesContainerElement: HTMLDivElement;
@ -46,14 +45,6 @@
} }
}; };
// const cancelHandler = async () => {
// if (currentRequestId) {
// const res = await cancelOllamaRequest(localStorage.token, currentRequestId);
// currentRequestId = null;
// loading = false;
// }
// };
const stopResponse = () => { const stopResponse = () => {
stopResponseFlag = true; stopResponseFlag = true;
console.log('stopResponse'); console.log('stopResponse');
@ -171,8 +162,6 @@
if (stopResponseFlag) { if (stopResponseFlag) {
controller.abort('User: Stop Response'); controller.abort('User: Stop Response');
} }
currentRequestId = null;
break; break;
} }
@ -229,7 +218,6 @@
loading = false; loading = false;
stopResponseFlag = false; stopResponseFlag = false;
currentRequestId = null;
} }
}; };