diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 6f56f3cf6..8bb01ed46 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,23 +1,39 @@ -from fastapi import FastAPI, Request, Response, HTTPException, Depends, status +from fastapi import ( + FastAPI, + Request, + Response, + HTTPException, + Depends, + status, + UploadFile, + File, + BackgroundTasks, +) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel, ConfigDict +import os import random import requests import json import uuid import aiohttp import asyncio +import aiofiles +from urllib.parse import urlparse +from typing import Optional, List, Union + from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import decode_token, get_current_user, get_admin_user -from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST +from utils.misc import calculate_sha256 -from typing import Optional, List, Union + +from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR app = FastAPI() @@ -897,6 +913,170 @@ async def generate_openai_chat_completion( ) +class UrlForm(BaseModel): + url: str + + +class UploadBlobForm(BaseModel): + filename: str + + +def parse_huggingface_url(hf_url): + try: + # Parse the URL + parsed_url = urlparse(hf_url) + + # Get the path and split it into components + path_components = parsed_url.path.split("/") + + # Extract the desired output + user_repo = "/".join(path_components[1:3]) + model_file = path_components[-1] + + return model_file + except ValueError: + return None + + +async def download_file_stream( + ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 +): + done = False + + if os.path.exists(file_path): + current_size = os.path.getsize(file_path) + else: + current_size = 0 + + headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} + + timeout = aiohttp.ClientTimeout(total=600) # Set the timeout + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(file_url, headers=headers) as response: + total_size = int(response.headers.get("content-length", 0)) + current_size + + with open(file_path, "ab+") as file: + async for data in response.content.iter_chunked(chunk_size): + current_size += len(data) + file.write(data) + + done = current_size == total_size + progress = round((current_size / total_size) * 100, 2) + yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' + + if done: + file.seek(0) + hashed = calculate_sha256(file) + file.seek(0) + + url = f"{ollama_url}/api/blobs/sha256:{hashed}" + response = requests.post(url, data=file) + + if response.ok: + res = { + "done": done, + "blob": f"sha256:{hashed}", + "name": file_name, + } + os.remove(file_path) + + yield f"data: {json.dumps(res)}\n\n" + else: + raise "Ollama: Could not create blob, Please try again." + + +# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" +@app.post("/models/download") +@app.post("/models/download/{url_idx}") +async def download_model( + form_data: UrlForm, + url_idx: Optional[int] = None, +): + + if url_idx == None: + url_idx = 0 + url = app.state.OLLAMA_BASE_URLS[url_idx] + + file_name = parse_huggingface_url(form_data.url) + + if file_name: + file_path = f"{UPLOAD_DIR}/{file_name}" + + return StreamingResponse( + download_file_stream(url, form_data.url, file_path, file_name) + ) + else: + return None + + +@app.post("/models/upload") +@app.post("/models/upload/{url_idx}") +def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): + if url_idx == None: + url_idx = 0 + ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] + + file_path = f"{UPLOAD_DIR}/{file.filename}" + + # Save file in chunks + with open(file_path, "wb+") as f: + for chunk in file.file: + f.write(chunk) + + def file_process_stream(): + nonlocal ollama_url + total_size = os.path.getsize(file_path) + chunk_size = 1024 * 1024 + try: + with open(file_path, "rb") as f: + total = 0 + done = False + + while not done: + chunk = f.read(chunk_size) + if not chunk: + done = True + continue + + total += len(chunk) + progress = round((total / total_size) * 100, 2) + + res = { + "progress": progress, + "total": total_size, + "completed": total, + } + yield f"data: {json.dumps(res)}\n\n" + + if done: + f.seek(0) + hashed = calculate_sha256(f) + f.seek(0) + + url = f"{ollama_url}/blobs/sha256:{hashed}" + response = requests.post(url, data=f) + + if response.ok: + res = { + "done": done, + "blob": f"sha256:{hashed}", + "name": file.filename, + } + os.remove(file_path) + yield f"data: {json.dumps(res)}\n\n" + else: + raise Exception( + "Ollama: Could not create blob, Please try again." + ) + + except Exception as e: + res = {"error": str(e)} + yield f"data: {json.dumps(res)}\n\n" + + return StreamingResponse(file_process_stream(), media_type="text/event-stream") + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): url = app.state.OLLAMA_BASE_URLS[0] diff --git a/backend/apps/web/routers/utils.py b/backend/apps/web/routers/utils.py index 0d34b0405..4b5ac8cfa 100644 --- a/backend/apps/web/routers/utils.py +++ b/backend/apps/web/routers/utils.py @@ -21,155 +21,6 @@ from constants import ERROR_MESSAGES router = APIRouter() -class UploadBlobForm(BaseModel): - filename: str - - -from urllib.parse import urlparse - - -def parse_huggingface_url(hf_url): - try: - # Parse the URL - parsed_url = urlparse(hf_url) - - # Get the path and split it into components - path_components = parsed_url.path.split("/") - - # Extract the desired output - user_repo = "/".join(path_components[1:3]) - model_file = path_components[-1] - - return model_file - except ValueError: - return None - - -async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024): - done = False - - if os.path.exists(file_path): - current_size = os.path.getsize(file_path) - else: - current_size = 0 - - headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} - - timeout = aiohttp.ClientTimeout(total=600) # Set the timeout - - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(url, headers=headers) as response: - total_size = int(response.headers.get("content-length", 0)) + current_size - - with open(file_path, "ab+") as file: - async for data in response.content.iter_chunked(chunk_size): - current_size += len(data) - file.write(data) - - done = current_size == total_size - progress = round((current_size / total_size) * 100, 2) - yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' - - if done: - file.seek(0) - hashed = calculate_sha256(file) - file.seek(0) - - url = f"{OLLAMA_BASE_URLS[0]}/api/blobs/sha256:{hashed}" - response = requests.post(url, data=file) - - if response.ok: - res = { - "done": done, - "blob": f"sha256:{hashed}", - "name": file_name, - } - os.remove(file_path) - - yield f"data: {json.dumps(res)}\n\n" - else: - raise "Ollama: Could not create blob, Please try again." - - -@router.get("/download") -async def download( - url: str, -): - # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" - file_name = parse_huggingface_url(url) - - if file_name: - file_path = f"{UPLOAD_DIR}/{file_name}" - - return StreamingResponse( - download_file_stream(url, file_path, file_name), - media_type="text/event-stream", - ) - else: - return None - - -@router.post("/upload") -def upload(file: UploadFile = File(...)): - file_path = f"{UPLOAD_DIR}/{file.filename}" - - # Save file in chunks - with open(file_path, "wb+") as f: - for chunk in file.file: - f.write(chunk) - - def file_process_stream(): - total_size = os.path.getsize(file_path) - chunk_size = 1024 * 1024 - try: - with open(file_path, "rb") as f: - total = 0 - done = False - - while not done: - chunk = f.read(chunk_size) - if not chunk: - done = True - continue - - total += len(chunk) - progress = round((total / total_size) * 100, 2) - - res = { - "progress": progress, - "total": total_size, - "completed": total, - } - yield f"data: {json.dumps(res)}\n\n" - - if done: - f.seek(0) - hashed = calculate_sha256(f) - f.seek(0) - - url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}" - response = requests.post(url, data=f) - - if response.ok: - res = { - "done": done, - "blob": f"sha256:{hashed}", - "name": file.filename, - } - os.remove(file_path) - yield f"data: {json.dumps(res)}\n\n" - else: - raise Exception( - "Ollama: Could not create blob, Please try again." - ) - - except Exception as e: - res = {"error": str(e)} - yield f"data: {json.dumps(res)}\n\n" - - return StreamingResponse(file_process_stream(), media_type="text/event-stream") - - @router.get("/gravatar") async def get_gravatar( email: str, diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 2047fedef..a461d71bd 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -390,6 +390,71 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string | return res; }; +export const downloadModel = async ( + token: string, + download_url: string, + urlIdx: string | null = null +) => { + let error = null; + + const res = await fetch( + `${OLLAMA_API_BASE_URL}/models/download${urlIdx !== null ? `/${urlIdx}` : ''}`, + { + method: 'POST', + headers: { + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + url: download_url + }) + } + ).catch((err) => { + console.log(err); + error = err; + + if ('detail' in err) { + error = err.detail; + } + + return null; + }); + if (error) { + throw error; + } + return res; +}; + +export const uploadModel = async (token: string, file: File, urlIdx: string | null = null) => { + let error = null; + + const formData = new FormData(); + formData.append('file', file); + + const res = await fetch( + `${OLLAMA_API_BASE_URL}/models/upload${urlIdx !== null ? `/${urlIdx}` : ''}`, + { + method: 'POST', + headers: { + Authorization: `Bearer ${token}` + }, + body: formData + } + ).catch((err) => { + console.log(err); + error = err; + + if ('detail' in err) { + error = err.detail; + } + + return null; + }); + if (error) { + throw error; + } + return res; +}; + // export const pullModel = async (token: string, tagName: string) => { // return await fetch(`${OLLAMA_API_BASE_URL}/pull`, { // method: 'POST', diff --git a/src/lib/components/chat/Settings/Models.svelte b/src/lib/components/chat/Settings/Models.svelte index 7b75e3729..f74e85ff4 100644 --- a/src/lib/components/chat/Settings/Models.svelte +++ b/src/lib/components/chat/Settings/Models.svelte @@ -5,9 +5,11 @@ import { createModel, deleteModel, + downloadModel, getOllamaUrls, getOllamaVersion, - pullModel + pullModel, + uploadModel } from '$lib/apis/ollama'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_NAME, models, user } from '$lib/stores'; @@ -60,7 +62,7 @@ let pullProgress = null; let modelUploadMode = 'file'; - let modelInputFile = ''; + let modelInputFile: File[] | null = null; let modelFileUrl = ''; let modelFileContent = `TEMPLATE """{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: """\nPARAMETER num_ctx 4096\nPARAMETER stop ""\nPARAMETER stop "USER:"\nPARAMETER stop "ASSISTANT:"`; let modelFileDigest = ''; @@ -191,30 +193,23 @@ let name = ''; if (modelUploadMode === 'file') { - const file = modelInputFile[0]; - const formData = new FormData(); - formData.append('file', file); + const file = modelInputFile ? modelInputFile[0] : null; - fileResponse = await fetch(`${WEBUI_API_BASE_URL}/utils/upload`, { - method: 'POST', - headers: { - ...($user && { Authorization: `Bearer ${localStorage.token}` }) - }, - body: formData - }).catch((error) => { - console.log(error); - return null; - }); + if (file) { + fileResponse = uploadModel(localStorage.token, file, selectedOllamaUrlIdx).catch( + (error) => { + toast.error(error); + return null; + } + ); + } } else { - fileResponse = await fetch(`${WEBUI_API_BASE_URL}/utils/download?url=${modelFileUrl}`, { - method: 'GET', - headers: { - ...($user && { Authorization: `Bearer ${localStorage.token}` }) + fileResponse = downloadModel(localStorage.token, modelFileUrl, selectedOllamaUrlIdx).catch( + (error) => { + toast.error(error); + return null; } - }).catch((error) => { - console.log(error); - return null; - }); + ); } if (fileResponse && fileResponse.ok) { @@ -318,7 +313,8 @@ } modelFileUrl = ''; - modelInputFile = ''; + modelUploadInputElement.value = ''; + modelInputFile = null; modelTransferring = false; uploadProgress = null;