diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 2ab06eb95..1c6365683 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -11,10 +11,8 @@ import re import time from typing import Optional, Union from urllib.parse import urlparse - import aiohttp from aiocache import cached - import requests from fastapi import ( @@ -990,6 +988,8 @@ async def generate_chat_completion( ) payload = {**form_data.model_dump(exclude_none=True)} + if "metadata" in payload: + del payload["metadata"] model_id = payload["model"] model_info = Models.get_model_by_id(model_id) @@ -1408,9 +1408,10 @@ async def download_model( return None +# TODO: Progress bar does not reflect size & duration of upload. @router.post("/models/upload") @router.post("/models/upload/{url_idx}") -def upload_model( +async def upload_model( request: Request, file: UploadFile = File(...), url_idx: Optional[int] = None, @@ -1419,62 +1420,90 @@ def upload_model( if url_idx is None: url_idx = 0 ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + file_path = os.path.join(UPLOAD_DIR, file.filename) + os.makedirs(UPLOAD_DIR, exist_ok=True) - file_path = f"{UPLOAD_DIR}/{file.filename}" + # --- P1: save file locally --- + chunk_size = 1024 * 1024 * 2 # 2 MB chunks + with open(file_path, "wb") as out_f: + while True: + chunk = file.file.read(chunk_size) + #log.info(f"Chunk: {str(chunk)}") # DEBUG + if not chunk: + break + out_f.write(chunk) - # Save file in chunks - with open(file_path, "wb+") as f: - for chunk in file.file: - f.write(chunk) - - def file_process_stream(): + async def file_process_stream(): nonlocal ollama_url total_size = os.path.getsize(file_path) - chunk_size = 1024 * 1024 + log.info(f"Total Model Size: {str(total_size)}") # DEBUG + + # --- P2: SSE progress + calculate sha256 hash --- + file_hash = calculate_sha256(file_path, chunk_size) + log.info(f"Model Hash: {str(file_hash)}") # DEBUG try: with open(file_path, "rb") as f: - total = 0 - done = False - - while not done: - chunk = f.read(chunk_size) - if not chunk: - done = True - continue - - total += len(chunk) - progress = round((total / total_size) * 100, 2) - - res = { + bytes_read = 0 + while chunk := f.read(chunk_size): + bytes_read += len(chunk) + progress = round(bytes_read / total_size * 100, 2) + data_msg = { "progress": progress, "total": total_size, - "completed": total, + "completed": bytes_read, } - yield f"data: {json.dumps(res)}\n\n" + yield f"data: {json.dumps(data_msg)}\n\n" - if done: - f.seek(0) - hashed = calculate_sha256(f) - f.seek(0) + # --- P3: Upload to ollama /api/blobs --- + with open(file_path, "rb") as f: + url = f"{ollama_url}/api/blobs/sha256:{file_hash}" + response = requests.post(url, data=f) - url = f"{ollama_url}/api/blobs/sha256:{hashed}" - response = requests.post(url, data=f) + if response.ok: + log.info(f"Uploaded to /api/blobs") # DEBUG + # Remove local file + os.remove(file_path) - if response.ok: - res = { - "done": done, - "blob": f"sha256:{hashed}", - "name": file.filename, - } - os.remove(file_path) - yield f"data: {json.dumps(res)}\n\n" - else: - raise Exception( - "Ollama: Could not create blob, Please try again." - ) + # Create model in ollama + model_name, ext = os.path.splitext(file.filename) + log.info(f"Created Model: {model_name}") # DEBUG + + create_payload = { + "model": model_name, + # Reference the file by its original name => the uploaded blob's digest + "files": { + file.filename: f"sha256:{file_hash}" + }, + } + log.info(f"Model Payload: {create_payload}") # DEBUG + + # Call ollama /api/create + #https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model + create_resp = requests.post( + url=f"{ollama_url}/api/create", + headers={"Content-Type": "application/json"}, + data=json.dumps(create_payload), + ) + + if create_resp.ok: + log.info(f"API SUCCESS!") # DEBUG + done_msg = { + "done": True, + "blob": f"sha256:{file_hash}", + "name": file.filename, + "model_created": model_name, + } + yield f"data: {json.dumps(done_msg)}\n\n" + else: + raise Exception( + f"Failed to create model in Ollama. {create_resp.text}" + ) + + else: + raise Exception("Ollama: Could not create blob, Please try again.") except Exception as e: res = {"error": str(e)} yield f"data: {json.dumps(res)}\n\n" - return StreamingResponse(file_process_stream(), media_type="text/event-stream") + return StreamingResponse(file_process_stream(), media_type="text/event-stream") \ No newline at end of file diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index b07393921..eb90ea5ea 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -244,11 +244,12 @@ def get_gravatar_url(email): return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp" -def calculate_sha256(file): +def calculate_sha256(file_path, chunk_size): + #Compute SHA-256 hash of a file efficiently in chunks sha256 = hashlib.sha256() - # Read the file in chunks to efficiently handle large files - for chunk in iter(lambda: file.read(8192), b""): - sha256.update(chunk) + with open(file_path, "rb") as f: + while chunk := f.read(chunk_size): + sha256.update(chunk) return sha256.hexdigest()