import asyncio import json import logging import os import random import re import time from typing import Optional, Union from urllib.parse import urlparse import aiohttp from aiocache import cached import requests from open_webui.apps.webui.models.models import Models from open_webui.config import ( CORS_ALLOW_ORIGIN, ENABLE_OLLAMA_API, OLLAMA_BASE_URLS, OLLAMA_API_CONFIGS, UPLOAD_DIR, AppConfig, ) from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, ) from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict from starlette.background import BackgroundTask from open_webui.utils.misc import ( calculate_sha256, ) from open_webui.utils.payload import ( apply_model_params_to_body_ollama, apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) app = FastAPI( docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None, ) app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.state.config = AppConfig() app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # least connections, or least response time for better resource utilization and performance optimization. @app.head("/") @app.get("/") async def get_status(): return {"status": True} class ConnectionVerificationForm(BaseModel): url: str key: Optional[str] = None @app.post("/verify") async def verify_connection( form_data: ConnectionVerificationForm, user=Depends(get_admin_user) ): url = form_data.url key = form_data.key headers = {} if key: headers["Authorization"] = f"Bearer {key}" timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) async with aiohttp.ClientSession(timeout=timeout) as session: try: async with session.get(f"{url}/api/version", headers=headers) as r: if r.status != 200: # Extract response error details if available error_detail = f"HTTP Error: {r.status}" res = await r.json() if "error" in res: error_detail = f"External Error: {res['error']}" raise Exception(error_detail) response_data = await r.json() return response_data except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( status_code=500, detail="Open WebUI: Server Connection Error" ) except Exception as e: log.exception(f"Unexpected error: {e}") # Generic error handler in case parsing JSON or other steps fail error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) @app.get("/config") async def get_config(user=Depends(get_admin_user)): return { "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, } class OllamaConfigForm(BaseModel): ENABLE_OLLAMA_API: Optional[bool] = None OLLAMA_BASE_URLS: list[str] OLLAMA_API_CONFIGS: dict @app.post("/config/update") async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS # Remove any extra configs config_urls = app.state.config.OLLAMA_API_CONFIGS.keys() for url in list(app.state.config.OLLAMA_BASE_URLS): if url not in config_urls: app.state.config.OLLAMA_API_CONFIGS.pop(url, None) return { "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, } async def aiohttp_get(url, key=None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: headers = {"Authorization": f"Bearer {key}"} if key else {} async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get(url, headers=headers) as response: return await response.json() except Exception as e: # Handle connection error here log.error(f"Connection error: {e}") return None async def cleanup_response( response: Optional[aiohttp.ClientResponse], session: Optional[aiohttp.ClientSession], ): if response: response.close() if session: await session.close() async def post_streaming_url( url: str, payload: Union[str, bytes], stream: bool = True, content_type=None ): r = None try: session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} if key: headers["Authorization"] = f"Bearer {key}" r = await session.post( url, data=payload, headers=headers, ) r.raise_for_status() if stream: headers = dict(r.headers) if content_type: headers["Content-Type"] = content_type return StreamingResponse( r.content, status_code=r.status, headers=headers, background=BackgroundTask( cleanup_response, response=r, session=session ), ) else: res = await r.json() await cleanup_response(r, session) return res except Exception as e: error_detail = "Open WebUI: Server Connection Error" if r is not None: try: res = await r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except Exception: error_detail = f"Ollama: {e}" raise HTTPException( status_code=r.status if r else 500, detail=error_detail, ) def merge_models_lists(model_lists): merged_models = {} for idx, model_list in enumerate(model_lists): if model_list is not None: for model in model_list: id = model["model"] if id not in merged_models: model["urls"] = [idx] merged_models[id] = model else: merged_models[id]["urls"].append(idx) return list(merged_models.values()) @cached(ttl=3) async def get_all_models(): log.info("get_all_models()") if app.state.config.ENABLE_OLLAMA_API: tasks = [] for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS): if url not in app.state.config.OLLAMA_API_CONFIGS: tasks.append(aiohttp_get(f"{url}/api/tags")) else: api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) key = api_config.get("key", None) if enable: tasks.append(aiohttp_get(f"{url}/api/tags", key)) else: tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) responses = await asyncio.gather(*tasks) for idx, response in enumerate(responses): if response: url = app.state.config.OLLAMA_BASE_URLS[idx] api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) model_ids = api_config.get("model_ids", []) if len(model_ids) != 0 and "models" in response: response["models"] = list( filter( lambda model: model["model"] in model_ids, response["models"], ) ) if prefix_id: for model in response.get("models", []): model["model"] = f"{prefix_id}.{model['model']}" models = { "models": merge_models_lists( map( lambda response: response.get("models", []) if response else None, responses, ) ) } else: models = {"models": []} return models @app.get("/api/tags") @app.get("/api/tags/{url_idx}") async def get_ollama_tags( url_idx: Optional[int] = None, user=Depends(get_verified_user) ): models = [] if url_idx is None: models = await get_all_models() else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) key = api_config.get("key", None) headers = {} if key: headers["Authorization"] = f"Bearer {key}" r = None try: r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers) r.raise_for_status() models = r.json() except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" if r is not None: try: res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except Exception: error_detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail, ) if user.role == "user": # Filter models based on user access control filtered_models = [] for model in models.get("models", []): model_info = Models.get_model_by_id(model["model"]) if model_info: if user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) models["models"] = filtered_models return models @app.get("/api/version") @app.get("/api/version/{url_idx}") async def get_ollama_versions(url_idx: Optional[int] = None): if app.state.config.ENABLE_OLLAMA_API: if url_idx is None: # returns lowest version tasks = [ aiohttp_get( f"{url}/api/version", app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), ) for url in app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) responses = list(filter(lambda x: x is not None, responses)) if len(responses) > 0: lowest_version = min( responses, key=lambda x: tuple( map(int, re.sub(r"^v|-.*", "", x["version"]).split(".")) ), ) return {"version": lowest_version["version"]} else: raise HTTPException( status_code=500, detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, ) else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] r = None try: r = requests.request(method="GET", url=f"{url}/api/version") r.raise_for_status() return r.json() except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" if r is not None: try: res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except Exception: error_detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail, ) else: return {"version": False} class ModelNameForm(BaseModel): name: str @app.post("/api/pull") @app.post("/api/pull/{url_idx}") async def pull_model( form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) ): url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} return await post_streaming_url(f"{url}/api/pull", json.dumps(payload)) class PushModelForm(BaseModel): name: str insecure: Optional[bool] = None stream: Optional[bool] = None @app.delete("/api/push") @app.delete("/api/push/{url_idx}") async def push_model( form_data: PushModelForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} if form_data.name in models: url_idx = models[form_data.name]["urls"][0] else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") return await post_streaming_url( f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode() ) class CreateModelForm(BaseModel): name: str modelfile: Optional[str] = None stream: Optional[bool] = None path: Optional[str] = None @app.post("/api/create") @app.post("/api/create/{url_idx}") async def create_model( form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) ): log.debug(f"form_data: {form_data}") url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") return await post_streaming_url( f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode() ) class CopyModelForm(BaseModel): source: str destination: str @app.post("/api/copy") @app.post("/api/copy/{url_idx}") async def copy_model( form_data: CopyModelForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} if form_data.source in models: url_idx = models[form_data.source]["urls"][0] else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} if key: headers["Authorization"] = f"Bearer {key}" r = requests.request( method="POST", url=f"{url}/api/copy", headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: r.raise_for_status() log.debug(f"r.text: {r.text}") return True except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" if r is not None: try: res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except Exception: error_detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail, ) @app.delete("/api/delete") @app.delete("/api/delete/{url_idx}") async def delete_model( form_data: ModelNameForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} if form_data.name in models: url_idx = models[form_data.name]["urls"][0] else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} if key: headers["Authorization"] = f"Bearer {key}" r = requests.request( method="DELETE", url=f"{url}/api/delete", data=form_data.model_dump_json(exclude_none=True).encode(), headers=headers, ) try: r.raise_for_status() log.debug(f"r.text: {r.text}") return True except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" if r is not None: try: res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except Exception: error_detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail, ) @app.post("/api/show") async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} if form_data.name not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) url_idx = random.choice(models[form_data.name]["urls"]) url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} if key: headers["Authorization"] = f"Bearer {key}" r = requests.request( method="POST", url=f"{url}/api/show", headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: r.raise_for_status() return r.json() except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" if r is not None: try: res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except Exception: error_detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail, ) class GenerateEmbeddingsForm(BaseModel): model: str prompt: str options: Optional[dict] = None keep_alive: Optional[Union[int, str]] = None class GenerateEmbedForm(BaseModel): model: str input: list[str] | str truncate: Optional[bool] = None options: Optional[dict] = None keep_alive: Optional[Union[int, str]] = None @app.post("/api/embed") @app.post("/api/embed/{url_idx}") async def generate_embeddings( form_data: GenerateEmbedForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), ): return generate_ollama_batch_embeddings(form_data, url_idx) @app.post("/api/embeddings") @app.post("/api/embeddings/{url_idx}") async def generate_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), ): return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) async def generate_ollama_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, ): log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} model = form_data.model if ":" not in model: model = f"{model}:latest" if model in models: url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} if key: headers["Authorization"] = f"Bearer {key}" r = requests.request( method="POST", url=f"{url}/api/embeddings", headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: r.raise_for_status() data = r.json() log.info(f"generate_ollama_embeddings {data}") if "embedding" in data: return data else: raise Exception("Something went wrong :/") except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" if r is not None: try: res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except Exception: error_detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail, ) async def generate_ollama_batch_embeddings( form_data: GenerateEmbedForm, url_idx: Optional[int] = None, ): log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} model = form_data.model if ":" not in model: model = f"{model}:latest" if model in models: url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} if key: headers["Authorization"] = f"Bearer {key}" r = requests.request( method="POST", url=f"{url}/api/embed", headers=headers, data=form_data.model_dump_json(exclude_none=True).encode(), ) try: r.raise_for_status() data = r.json() log.info(f"generate_ollama_batch_embeddings {data}") if "embeddings" in data: return data else: raise Exception("Something went wrong :/") except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" if r is not None: try: res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except Exception: error_detail = f"Ollama: {e}" raise Exception(error_detail) class GenerateCompletionForm(BaseModel): model: str prompt: str suffix: Optional[str] = None images: Optional[list[str]] = None format: Optional[str] = None options: Optional[dict] = None system: Optional[str] = None template: Optional[str] = None context: Optional[list[int]] = None stream: Optional[bool] = True raw: Optional[bool] = None keep_alive: Optional[Union[int, str]] = None @app.post("/api/generate") @app.post("/api/generate/{url_idx}") async def generate_completion( form_data: GenerateCompletionForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), ): if url_idx is None: model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} model = form_data.model if ":" not in model: model = f"{model}:latest" if model in models: url_idx = random.choice(models[model]["urls"]) else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) url = app.state.config.OLLAMA_BASE_URLS[url_idx] api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: form_data.model = form_data.model.replace(f"{prefix_id}.", "") log.info(f"url: {url}") return await post_streaming_url( f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode() ) class ChatMessage(BaseModel): role: str content: str images: Optional[list[str]] = None class GenerateChatCompletionForm(BaseModel): model: str messages: list[ChatMessage] format: Optional[str] = None options: Optional[dict] = None template: Optional[str] = None stream: Optional[bool] = True keep_alive: Optional[Union[int, str]] = None async def get_ollama_url(url_idx: Optional[int], model: str): if url_idx is None: model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} if model not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) url_idx = random.choice(models[model]["urls"]) url = app.state.config.OLLAMA_BASE_URLS[url_idx] return url @app.post("/api/chat") @app.post("/api/chat/{url_idx}") async def generate_chat_completion( form_data: GenerateChatCompletionForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, ): payload = {**form_data.model_dump(exclude_none=True)} log.debug(f"generate_chat_completion() - 1.payload = {payload}") if "metadata" in payload: del payload["metadata"] model_id = payload["model"] model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id params = model_info.params.model_dump() if params: if payload.get("options") is None: payload["options"] = {} payload["options"] = apply_model_params_to_body_ollama( params, payload["options"] ) payload = apply_model_system_prompt_to_body(params, payload, user) # Check if user has access to the model if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control ) ): raise HTTPException( status_code=403, detail="Model not found", ) elif not bypass_filter: if user.role != "admin": raise HTTPException( status_code=403, detail="Model not found", ) if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") return await post_streaming_url( f"{url}/api/chat", json.dumps(payload), stream=form_data.stream, content_type="application/x-ndjson", ) # TODO: we should update this part once Ollama supports other types class OpenAIChatMessageContent(BaseModel): type: str model_config = ConfigDict(extra="allow") class OpenAIChatMessage(BaseModel): role: str content: Union[str, list[OpenAIChatMessageContent]] model_config = ConfigDict(extra="allow") class OpenAIChatCompletionForm(BaseModel): model: str messages: list[OpenAIChatMessage] model_config = ConfigDict(extra="allow") @app.post("/v1/chat/completions") @app.post("/v1/chat/completions/{url_idx}") async def generate_openai_chat_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), ): try: completion_form = OpenAIChatCompletionForm(**form_data) except Exception as e: log.exception(e) raise HTTPException( status_code=400, detail=str(e), ) payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} if "metadata" in payload: del payload["metadata"] model_id = completion_form.model if ":" not in model_id: model_id = f"{model_id}:latest" model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id params = model_info.params.model_dump() if params: payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) # Check if user has access to the model if user.role == "user": if not ( user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control ) ): raise HTTPException( status_code=403, detail="Model not found", ) else: if user.role != "admin": raise HTTPException( status_code=403, detail="Model not found", ) if ":" not in payload["model"]: payload["model"] = f"{payload['model']}:latest" url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") return await post_streaming_url( f"{url}/v1/chat/completions", json.dumps(payload), stream=payload.get("stream", False), ) @app.get("/v1/models") @app.get("/v1/models/{url_idx}") async def get_openai_models( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): models = [] if url_idx is None: model_list = await get_all_models() models = [ { "id": model["model"], "object": "model", "created": int(time.time()), "owned_by": "openai", } for model in model_list["models"] ] else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() model_list = r.json() models = [ { "id": model["model"], "object": "model", "created": int(time.time()), "owned_by": "openai", } for model in models["models"] ] except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" if r is not None: try: res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except Exception: error_detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, detail=error_detail, ) if user.role == "user": # Filter models based on user access control filtered_models = [] for model in models: model_info = Models.get_model_by_id(model["id"]) if model_info: if user.id == model_info.user_id or has_access( user.id, type="read", access_control=model_info.access_control ): filtered_models.append(model) models = filtered_models return { "data": models, "object": "list", } class UrlForm(BaseModel): url: str class UploadBlobForm(BaseModel): filename: str def parse_huggingface_url(hf_url): try: # Parse the URL parsed_url = urlparse(hf_url) # Get the path and split it into components path_components = parsed_url.path.split("/") # Extract the desired output model_file = path_components[-1] return model_file except ValueError: return None async def download_file_stream( ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 ): done = False if os.path.exists(file_path): current_size = os.path.getsize(file_path) else: current_size = 0 headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} timeout = aiohttp.ClientTimeout(total=600) # Set the timeout async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get(file_url, headers=headers) as response: total_size = int(response.headers.get("content-length", 0)) + current_size with open(file_path, "ab+") as file: async for data in response.content.iter_chunked(chunk_size): current_size += len(data) file.write(data) done = current_size == total_size progress = round((current_size / total_size) * 100, 2) yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' if done: file.seek(0) hashed = calculate_sha256(file) file.seek(0) url = f"{ollama_url}/api/blobs/sha256:{hashed}" response = requests.post(url, data=file) if response.ok: res = { "done": done, "blob": f"sha256:{hashed}", "name": file_name, } os.remove(file_path) yield f"data: {json.dumps(res)}\n\n" else: raise "Ollama: Could not create blob, Please try again." # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" @app.post("/models/download") @app.post("/models/download/{url_idx}") async def download_model( form_data: UrlForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), ): allowed_hosts = ["https://huggingface.co/", "https://github.com/"] if not any(form_data.url.startswith(host) for host in allowed_hosts): raise HTTPException( status_code=400, detail="Invalid file_url. Only URLs from allowed hosts are permitted.", ) if url_idx is None: url_idx = 0 url = app.state.config.OLLAMA_BASE_URLS[url_idx] file_name = parse_huggingface_url(form_data.url) if file_name: file_path = f"{UPLOAD_DIR}/{file_name}" return StreamingResponse( download_file_stream(url, form_data.url, file_path, file_name), ) else: return None @app.post("/models/upload") @app.post("/models/upload/{url_idx}") def upload_model( file: UploadFile = File(...), url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: url_idx = 0 ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] file_path = f"{UPLOAD_DIR}/{file.filename}" # Save file in chunks with open(file_path, "wb+") as f: for chunk in file.file: f.write(chunk) def file_process_stream(): nonlocal ollama_url total_size = os.path.getsize(file_path) chunk_size = 1024 * 1024 try: with open(file_path, "rb") as f: total = 0 done = False while not done: chunk = f.read(chunk_size) if not chunk: done = True continue total += len(chunk) progress = round((total / total_size) * 100, 2) res = { "progress": progress, "total": total_size, "completed": total, } yield f"data: {json.dumps(res)}\n\n" if done: f.seek(0) hashed = calculate_sha256(f) f.seek(0) url = f"{ollama_url}/api/blobs/sha256:{hashed}" response = requests.post(url, data=f) if response.ok: res = { "done": done, "blob": f"sha256:{hashed}", "name": file.filename, } os.remove(file_path) yield f"data: {json.dumps(res)}\n\n" else: raise Exception( "Ollama: Could not create blob, Please try again." ) except Exception as e: res = {"error": str(e)} yield f"data: {json.dumps(res)}\n\n" return StreamingResponse(file_process_stream(), media_type="text/event-stream")