Merge pull request #2742 from open-webui/dev

0.2.2
This commit is contained in:
Timothy Jaeryang Baek 2024-06-02 18:26:09 -07:00 committed by GitHub
commit 0744806523
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1386 additions and 570 deletions

View File

@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.2.2] - 2024-06-02
### Added
- **🌊 Mermaid Rendering Support**: We've included support for Mermaid rendering. This allows you to create beautiful diagrams and flowcharts directly within Open WebUI.
- **🔄 New Environment Variable 'RESET_CONFIG_ON_START'**: Introducing a new environment variable: `RESET_CONFIG_ON_START`. Set this variable to reset your configuration settings upon starting the application, making it easier to revert to default settings.
### Fixed
- **🔧 Pipelines Filter Issue**: We've addressed an issue with the pipelines where filters were not functioning as expected.
## [0.2.1] - 2024-06-02 ## [0.2.1] - 2024-06-02
### Added ### Added

View File

@ -29,6 +29,8 @@ import time
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, List, Union
from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
REQUEST_POOL = []
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization. # least connections, or least response time for better resource utilization and performance optimization.
@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
return True
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
async def fetch_url(url): async def fetch_url(url):
timeout = aiohttp.ClientTimeout(total=5) timeout = aiohttp.ClientTimeout(total=5)
try: try:
@ -154,6 +143,45 @@ async def fetch_url(url):
return None return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
async def post_streaming_url(url: str, payload: str):
r = None
try:
session = aiohttp.ClientSession()
r = await session.post(url, data=payload)
r.raise_for_status()
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(cleanup_response, response=r, session=session),
)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=error_detail,
)
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
merged_models = {} merged_models = {}
@ -313,65 +341,7 @@ async def pull_model(
# Admin should be able to pull models from any source # Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True} payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
def get_request(): return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
nonlocal url
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/pull",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class PushModelForm(BaseModel): class PushModelForm(BaseModel):
@ -399,50 +369,9 @@ async def push_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}") log.debug(f"url: {url}")
r = None return await post_streaming_url(
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
def get_request(): )
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/push",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class CreateModelForm(BaseModel): class CreateModelForm(BaseModel):
@ -461,53 +390,9 @@ async def create_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None return await post_streaming_url(
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
def get_request(): )
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/create",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
log.debug(f"r: {r}")
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class CopyModelForm(BaseModel): class CopyModelForm(BaseModel):
@ -797,66 +682,9 @@ async def generate_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None return await post_streaming_url(
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
def get_request(): )
nonlocal form_data
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if form_data.stream:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/generate",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@ -1014,67 +842,7 @@ async def generate_chat_completion(
print(payload) print(payload)
r = None return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream", None):
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/chat",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
log.exception(e)
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
# TODO: we should update this part once Ollama supports other types # TODO: we should update this part once Ollama supports other types
@ -1165,68 +933,7 @@ async def generate_openai_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream"):
yield json.dumps(
{"request_id": request_id, "done": False}
) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/v1/chat/completions",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
@app.get("/v1/models") @app.get("/v1/models")
@ -1555,7 +1262,7 @@ async def deprecated_proxy(
if path == "generate": if path == "generate":
data = json.loads(body.decode("utf-8")) data = json.loads(body.decode("utf-8"))
if not ("stream" in data and data["stream"] == False): if data.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n" yield json.dumps({"id": request_id, "done": False}) + "\n"
elif path == "chat": elif path == "chat":

View File

@ -9,6 +9,7 @@ import json
import logging import logging
from pydantic import BaseModel from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
@ -194,6 +195,16 @@ async def fetch_url(url, key):
return None return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}") log.debug(f"merge_models_lists {model_lists}")
merged_list = [] merged_list = []
@ -447,40 +458,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None r = None
session = None
streaming = False
try: try:
r = requests.request( session = aiohttp.ClientSession()
r = await session.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
data=payload if payload else body, data=payload if payload else body,
headers=headers, headers=headers,
stream=True,
) )
r.raise_for_status() r.raise_for_status()
# Check if response is SSE # Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""): if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse( return StreamingResponse(
r.iter_content(chunk_size=8192), r.content,
status_code=r.status_code, status_code=r.status,
headers=dict(r.headers), headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
) )
else: else:
response_data = r.json() response_data = await r.json()
return response_data return response_data
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
res = r.json() res = await r.json()
print(res) print(res)
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except: except:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
raise HTTPException( finally:
status_code=r.status_code if r else 500, detail=error_detail if not streaming and session:
) if r:
r.close()
await session.close()

View File

@ -180,6 +180,17 @@ WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve() DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve() FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except:
pass
try: try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text()) CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except: except:

View File

@ -498,10 +498,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
] ]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
model = app.state.MODELS[model_id] print(model_id)
if "pipeline" in model: if model_id in app.state.MODELS:
sorted_filters = [model] + sorted_filters model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters: for filter in sorted_filters:
r = None r = None
@ -550,7 +552,11 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
responses = await get_openai_models(raw=True) responses = await get_openai_models(raw=True)
print(responses) print(responses)
urlIdxs = [idx for idx, response in enumerate(responses) if "pipelines" in response] urlIdxs = [
idx
for idx, response in enumerate(responses)
if response != None and "pipelines" in response
]
return { return {
"data": [ "data": [

1122
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.2.1", "version": "0.2.2",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "npm run pyodide:fetch && vite dev --host", "dev": "npm run pyodide:fetch && vite dev --host",
@ -63,6 +63,7 @@
"js-sha256": "^0.10.1", "js-sha256": "^0.10.1",
"katex": "^0.16.9", "katex": "^0.16.9",
"marked": "^9.1.0", "marked": "^9.1.0",
"mermaid": "^10.9.1",
"pyodide": "^0.26.0-alpha.4", "pyodide": "^0.26.0-alpha.4",
"sortablejs": "^1.15.2", "sortablejs": "^1.15.2",
"svelte-sonner": "^0.3.19", "svelte-sonner": "^0.3.19",

View File

@ -369,27 +369,6 @@ export const generateChatCompletion = async (token: string = '', body: object) =
return [res, controller]; return [res, controller];
}; };
export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {
method: 'GET',
headers: {
'Content-Type': 'text/event-stream',
Authorization: `Bearer ${token}`
}
}).catch((err) => {
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
export const createModel = async (token: string, tagName: string, content: string) => { export const createModel = async (token: string, tagName: string, content: string) => {
let error = null; let error = null;
@ -461,8 +440,10 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => { export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => {
let error = null; let error = null;
const controller = new AbortController();
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, {
signal: controller.signal,
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -485,7 +466,7 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
if (error) { if (error) {
throw error; throw error;
} }
return res; return [res, controller];
}; };
export const downloadModel = async ( export const downloadModel = async (

View File

@ -1,6 +1,7 @@
<script lang="ts"> <script lang="ts">
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import mermaid from 'mermaid';
import { getContext, onMount, tick } from 'svelte'; import { getContext, onMount, tick } from 'svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
@ -26,7 +27,7 @@
splitStream splitStream
} from '$lib/utils'; } from '$lib/utils';
import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama'; import { generateChatCompletion } from '$lib/apis/ollama';
import { import {
addTagById, addTagById,
createNewChat, createNewChat,
@ -65,7 +66,6 @@
let autoScroll = true; let autoScroll = true;
let processing = ''; let processing = '';
let messagesContainerElement: HTMLDivElement; let messagesContainerElement: HTMLDivElement;
let currentRequestId = null;
let showModelSelector = true; let showModelSelector = true;
@ -130,10 +130,6 @@
////////////////////////// //////////////////////////
const initNewChat = async () => { const initNewChat = async () => {
if (currentRequestId !== null) {
await cancelOllamaRequest(localStorage.token, currentRequestId);
currentRequestId = null;
}
window.history.replaceState(history.state, '', `/`); window.history.replaceState(history.state, '', `/`);
await chatId.set(''); await chatId.set('');
@ -251,6 +247,39 @@
} }
}; };
const chatCompletedHandler = async (model, messages) => {
await mermaid.run({
querySelector: '.mermaid'
});
const res = await chatCompleted(localStorage.token, {
model: model.id,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
timestamp: m.timestamp
})),
chat_id: $chatId
}).catch((error) => {
console.error(error);
return null;
});
if (res !== null) {
// Update chat history with the new messages
for (const message of res.messages) {
history.messages[message.id] = {
...history.messages[message.id],
...(history.messages[message.id].content !== message.content
? { originalContent: history.messages[message.id].content }
: {}),
...message
};
}
}
};
////////////////////////// //////////////////////////
// Ollama functions // Ollama functions
////////////////////////// //////////////////////////
@ -616,39 +645,11 @@
if (stopResponseFlag) { if (stopResponseFlag) {
controller.abort('User: Stop Response'); controller.abort('User: Stop Response');
await cancelOllamaRequest(localStorage.token, currentRequestId);
} else { } else {
const messages = createMessagesList(responseMessageId); const messages = createMessagesList(responseMessageId);
const res = await chatCompleted(localStorage.token, { await chatCompletedHandler(model, messages);
model: model,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
timestamp: m.timestamp
})),
chat_id: $chatId
}).catch((error) => {
console.error(error);
return null;
});
if (res !== null) {
// Update chat history with the new messages
for (const message of res.messages) {
history.messages[message.id] = {
...history.messages[message.id],
...(history.messages[message.id].content !== message.content
? { originalContent: history.messages[message.id].content }
: {}),
...message
};
}
}
} }
currentRequestId = null;
break; break;
} }
@ -669,63 +670,58 @@
throw data; throw data;
} }
if ('id' in data) { if (data.done == false) {
console.log(data); if (responseMessage.content == '' && data.message.content == '\n') {
currentRequestId = data.id; continue;
} else {
if (data.done == false) {
if (responseMessage.content == '' && data.message.content == '\n') {
continue;
} else {
responseMessage.content += data.message.content;
messages = messages;
}
} else { } else {
responseMessage.done = true; responseMessage.content += data.message.content;
if (responseMessage.content == '') {
responseMessage.error = {
code: 400,
content: `Oops! No text generated from Ollama, Please try again.`
};
}
responseMessage.context = data.context ?? null;
responseMessage.info = {
total_duration: data.total_duration,
load_duration: data.load_duration,
sample_count: data.sample_count,
sample_duration: data.sample_duration,
prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration,
eval_count: data.eval_count,
eval_duration: data.eval_duration
};
messages = messages; messages = messages;
}
} else {
responseMessage.done = true;
if ($settings.notificationEnabled && !document.hasFocus()) { if (responseMessage.content == '') {
const notification = new Notification( responseMessage.error = {
selectedModelfile code: 400,
? `${ content: `Oops! No text generated from Ollama, Please try again.`
selectedModelfile.title.charAt(0).toUpperCase() + };
selectedModelfile.title.slice(1) }
}`
: `${model}`,
{
body: responseMessage.content,
icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
}
);
}
if ($settings.responseAutoCopy) { responseMessage.context = data.context ?? null;
copyToClipboard(responseMessage.content); responseMessage.info = {
} total_duration: data.total_duration,
load_duration: data.load_duration,
sample_count: data.sample_count,
sample_duration: data.sample_duration,
prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration,
eval_count: data.eval_count,
eval_duration: data.eval_duration
};
messages = messages;
if ($settings.responseAutoPlayback) { if ($settings.notificationEnabled && !document.hasFocus()) {
await tick(); const notification = new Notification(
document.getElementById(`speak-button-${responseMessage.id}`)?.click(); selectedModelfile
} ? `${
selectedModelfile.title.charAt(0).toUpperCase() +
selectedModelfile.title.slice(1)
}`
: `${model}`,
{
body: responseMessage.content,
icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
}
);
}
if ($settings.responseAutoCopy) {
copyToClipboard(responseMessage.content);
}
if ($settings.responseAutoPlayback) {
await tick();
document.getElementById(`speak-button-${responseMessage.id}`)?.click();
} }
} }
} }
@ -906,32 +902,7 @@
} else { } else {
const messages = createMessagesList(responseMessageId); const messages = createMessagesList(responseMessageId);
const res = await chatCompleted(localStorage.token, { await chatCompletedHandler(model, messages);
model: model.id,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
timestamp: m.timestamp
})),
chat_id: $chatId
}).catch((error) => {
console.error(error);
return null;
});
if (res !== null) {
// Update chat history with the new messages
for (const message of res.messages) {
history.messages[message.id] = {
...history.messages[message.id],
...(history.messages[message.id].content !== message.content
? { originalContent: history.messages[message.id].content }
: {}),
...message
};
}
}
} }
break; break;

View File

@ -1,8 +1,7 @@
<script lang="ts"> <script lang="ts">
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { chats, config, settings, user as _user, mobile } from '$lib/stores'; import { chats, config, settings, user as _user, mobile } from '$lib/stores';
import { tick, getContext } from 'svelte'; import { tick, getContext, onMount } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { getChatList, updateChatById } from '$lib/apis/chats'; import { getChatList, updateChatById } from '$lib/apis/chats';

View File

@ -5,6 +5,7 @@
import tippy from 'tippy.js'; import tippy from 'tippy.js';
import auto_render from 'katex/dist/contrib/auto-render.mjs'; import auto_render from 'katex/dist/contrib/auto-render.mjs';
import 'katex/dist/katex.min.css'; import 'katex/dist/katex.min.css';
import mermaid from 'mermaid';
import { fade } from 'svelte/transition'; import { fade } from 'svelte/transition';
import { createEventDispatcher } from 'svelte'; import { createEventDispatcher } from 'svelte';
@ -343,6 +344,10 @@
onMount(async () => { onMount(async () => {
await tick(); await tick();
renderStyling(); renderStyling();
await mermaid.run({
querySelector: '.mermaid'
});
}); });
</script> </script>
@ -458,11 +463,15 @@
<!-- unless message.error === true which is legacy error handling, where the error message is stored in message.content --> <!-- unless message.error === true which is legacy error handling, where the error message is stored in message.content -->
{#each tokens as token, tokenIdx} {#each tokens as token, tokenIdx}
{#if token.type === 'code'} {#if token.type === 'code'}
<CodeBlock {#if token.lang === 'mermaid'}
id={`${message.id}-${tokenIdx}`} <pre class="mermaid">{revertSanitizedResponseContent(token.text)}</pre>
lang={token?.lang ?? ''} {:else}
code={revertSanitizedResponseContent(token?.text ?? '')} <CodeBlock
/> id={`${message.id}-${tokenIdx}`}
lang={token?.lang ?? ''}
code={revertSanitizedResponseContent(token?.text ?? '')}
/>
{/if}
{:else} {:else}
{@html marked.parse(token.raw, { {@html marked.parse(token.raw, {
...defaults, ...defaults,

View File

@ -8,7 +8,7 @@
import Check from '$lib/components/icons/Check.svelte'; import Check from '$lib/components/icons/Check.svelte';
import Search from '$lib/components/icons/Search.svelte'; import Search from '$lib/components/icons/Search.svelte';
import { cancelOllamaRequest, deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama'; import { deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores'; import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
@ -72,10 +72,12 @@
return; return;
} }
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => { const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
toast.error(error); (error) => {
return null; toast.error(error);
}); return null;
}
);
if (res) { if (res) {
const reader = res.body const reader = res.body
@ -83,6 +85,16 @@
.pipeThrough(splitStream('\n')) .pipeThrough(splitStream('\n'))
.getReader(); .getReader();
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
abortController: controller,
reader,
done: false
}
});
while (true) { while (true) {
try { try {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
@ -101,19 +113,6 @@
throw data.detail; throw data.detail;
} }
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) { if (data.status) {
if (data.digest) { if (data.digest) {
let downloadProgress = 0; let downloadProgress = 0;
@ -181,11 +180,12 @@
}); });
const cancelModelPullHandler = async (model: string) => { const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model]; const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
if (abortController) {
abortController.abort();
}
if (reader) { if (reader) {
await reader.cancel(); await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model]; delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({ MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL ...$MODEL_DOWNLOAD_POOL

View File

@ -8,7 +8,6 @@
getOllamaUrls, getOllamaUrls,
getOllamaVersion, getOllamaVersion,
pullModel, pullModel,
cancelOllamaRequest,
uploadModel, uploadModel,
getOllamaConfig getOllamaConfig
} from '$lib/apis/ollama'; } from '$lib/apis/ollama';
@ -70,12 +69,14 @@
console.log(model); console.log(model);
updateModelId = model.id; updateModelId = model.id;
const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch( const [res, controller] = await pullModel(
(error) => { localStorage.token,
toast.error(error); model.id,
return null; selectedOllamaUrlIdx
} ).catch((error) => {
); toast.error(error);
return null;
});
if (res) { if (res) {
const reader = res.body const reader = res.body
@ -144,10 +145,12 @@
return; return;
} }
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => { const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
toast.error(error); (error) => {
return null; toast.error(error);
}); return null;
}
);
if (res) { if (res) {
const reader = res.body const reader = res.body
@ -155,6 +158,16 @@
.pipeThrough(splitStream('\n')) .pipeThrough(splitStream('\n'))
.getReader(); .getReader();
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
abortController: controller,
reader,
done: false
}
});
while (true) { while (true) {
try { try {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
@ -173,19 +186,6 @@
throw data.detail; throw data.detail;
} }
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) { if (data.status) {
if (data.digest) { if (data.digest) {
let downloadProgress = 0; let downloadProgress = 0;
@ -419,11 +419,12 @@
}; };
const cancelModelPullHandler = async (model: string) => { const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model]; const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
if (abortController) {
abortController.abort();
}
if (reader) { if (reader) {
await reader.cancel(); await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model]; delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({ MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL ...$MODEL_DOWNLOAD_POOL

View File

@ -8,7 +8,7 @@
import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants'; import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants';
import { WEBUI_NAME, config, user, models, settings } from '$lib/stores'; import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama'; import { generateChatCompletion } from '$lib/apis/ollama';
import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
@ -24,7 +24,6 @@
let selectedModelId = ''; let selectedModelId = '';
let loading = false; let loading = false;
let currentRequestId = null;
let stopResponseFlag = false; let stopResponseFlag = false;
let messagesContainerElement: HTMLDivElement; let messagesContainerElement: HTMLDivElement;
@ -46,14 +45,6 @@
} }
}; };
// const cancelHandler = async () => {
// if (currentRequestId) {
// const res = await cancelOllamaRequest(localStorage.token, currentRequestId);
// currentRequestId = null;
// loading = false;
// }
// };
const stopResponse = () => { const stopResponse = () => {
stopResponseFlag = true; stopResponseFlag = true;
console.log('stopResponse'); console.log('stopResponse');
@ -171,8 +162,6 @@
if (stopResponseFlag) { if (stopResponseFlag) {
controller.abort('User: Stop Response'); controller.abort('User: Stop Response');
} }
currentRequestId = null;
break; break;
} }
@ -229,7 +218,6 @@
loading = false; loading = false;
stopResponseFlag = false; stopResponseFlag = false;
currentRequestId = null;
} }
}; };