diff --git a/backend/apps/web/routers/utils.py b/backend/apps/web/routers/utils.py index ee1259751..4a16b5bb5 100644 --- a/backend/apps/web/routers/utils.py +++ b/backend/apps/web/routers/utils.py @@ -4,13 +4,14 @@ from starlette.responses import StreamingResponse from pydantic import BaseModel -from utils.misc import calculate_sha256 import requests - - import os -import asyncio +import aiohttp import json + + +from utils.misc import calculate_sha256 + from config import OLLAMA_API_BASE_URL @@ -38,7 +39,7 @@ def parse_huggingface_url(hf_url): return [user_repo, model_file] -def download_file_stream(url, file_path, chunk_size=1024 * 1024): +async def download_file_stream(url, file_path, chunk_size=1024 * 1024): done = False if os.path.exists(file_path): @@ -48,18 +49,39 @@ def download_file_stream(url, file_path, chunk_size=1024 * 1024): headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} - with requests.get(url, headers=headers, stream=True) as response: - total_size = int(response.headers.get("content-length", 0)) + current_size + timeout = aiohttp.ClientTimeout(total=60) # Set the timeout - with open(file_path, "ab") as file: - for data in response.iter_content(chunk_size=chunk_size): - current_size += len(data) - file.write(data) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url, headers=headers) as response: + total_size = int(response.headers.get("content-length", 0)) + current_size - done = current_size == total_size + with open(file_path, "ab+") as file: + async for data in response.content.iter_chunked(chunk_size): + current_size += len(data) + file.write(data) - progress = round((current_size / total_size) * 100, 2) - yield f'data: {{"progress": {progress}, "current": {current_size}, "total": {total_size}}}\n\n' + done = current_size == total_size + progress = round((current_size / total_size) * 100, 2) + yield f'data: {{"progress": {progress}, "current": {current_size}, "total": {total_size}}}\n\n' + + if done: + file.seek(0) + hashed = calculate_sha256(file) + file.seek(0) + + url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}" + response = requests.post(url, data=file) + + if response.ok: + res = { + "done": done, + "blob": f"sha256:{hashed}", + } + os.remove(file_path) + + yield f"data: {json.dumps(res)}\n\n" + else: + raise "Ollama: Could not create blob, Please try again." @router.get("/download")