From df48eac22bbd5665d4e3a2d4f8ca635555449900 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 11 Dec 2024 03:38:45 -0800 Subject: [PATCH] wip --- backend/open_webui/main.py | 12 +- backend/open_webui/routers/ollama.py | 378 ++++++++++++++------------- 2 files changed, 204 insertions(+), 186 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 308489ee6..2e1929bb3 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -372,6 +372,7 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS +app.state.OLLAMA_MODELS = {} ######################################## # @@ -384,6 +385,7 @@ app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS +app.state.OPENAI_MODELS = {} ######################################## # @@ -607,6 +609,14 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( ) +######################################## +# +# WEBUI +# +######################################## + +app.state.MODELS = {} + ################################## # # ChatCompletion Middleware @@ -1437,7 +1447,7 @@ async def get_all_base_models(): openai_models = openai_models["data"] if app.state.config.ENABLE_OLLAMA_API: - ollama_models = await get_ollama_models() + ollama_models = await ollama.get_all_models() ollama_models = [ { "id": model["model"], diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 8a43d5c52..19bc12e21 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -1,3 +1,7 @@ +# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. +# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, +# least connections, or least response time for better resource utilization and performance optimization. + import asyncio import json import logging @@ -29,6 +33,16 @@ from starlette.background import BackgroundTask from open_webui.models.models import Models +from open_webui.utils.misc import ( + calculate_sha256, +) +from open_webui.utils.payload import ( + apply_model_params_to_body_ollama, + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access from open_webui.config import ( @@ -41,29 +55,114 @@ from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, BYPASS_MODEL_ACCESS_CONTROL, ) - from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.misc import ( - calculate_sha256, -) -from open_webui.utils.payload import ( - apply_model_params_to_body_ollama, - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access - log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) -router = APIRouter() +########################################## +# +# Utility functions +# +########################################## -# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. -# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, -# least connections, or least response time for better resource utilization and performance optimization. + +async def send_get_request(url, key=None): + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + try: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + ) as response: + return await response.json() + except Exception as e: + # Handle connection error here + log.error(f"Connection error: {e}") + return None + + +async def send_post_request( + url: str, + payload: Union[str, bytes], + stream: bool = True, + key: Optional[str] = None, + content_type: Optional[str] = None, +): + async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], + ): + if response: + response.close() + if session: + await session.close() + + r = None + try: + session = aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) + + r = await session.post( + url, + data=payload, + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + }, + ) + r.raise_for_status() + + if stream: + response_headers = dict(r.headers) + + if content_type: + response_headers["Content-Type"] = content_type + + return StreamingResponse( + r.content, + status_code=r.status, + headers=response_headers, + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + res = await r.json() + await cleanup_response(r, session) + return res + + except Exception as e: + detail = None + + if r is not None: + try: + res = await r.json() + if "error" in res: + detail = f"Ollama: {res.get('error', 'Unknown error')}" + except Exception: + detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + + +def get_api_key(url, configs): + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + return configs.get(base_url, {}).get("key", None) + + +########################################## +# +# API routes +# +########################################## + +router = APIRouter() @router.head("/") @@ -84,35 +183,31 @@ async def verify_connection( url = form_data.url key = form_data.key - headers = {} - if key: - headers["Authorization"] = f"Bearer {key}" - - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + ) as session: try: - async with session.get(f"{url}/api/version", headers=headers) as r: + async with session.get( + f"{url}/api/version", + headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + ) as r: if r.status != 200: - # Extract response error details if available - error_detail = f"HTTP Error: {r.status}" + detail = f"HTTP Error: {r.status}" res = await r.json() + if "error" in res: - error_detail = f"External Error: {res['error']}" - raise Exception(error_detail) - - response_data = await r.json() - return response_data + detail = f"External Error: {res['error']}" + raise Exception(detail) + data = await r.json() + return data except aiohttp.ClientError as e: - # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") - # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( status_code=500, detail="Open WebUI: Server Connection Error" ) except Exception as e: log.exception(f"Unexpected error: {e}") - # Generic error handler in case parsing JSON or other steps fail error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) @@ -137,8 +232,8 @@ async def update_config( request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user) ): request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API - request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS + request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS # Remove any extra configs @@ -154,127 +249,26 @@ async def update_config( } -async def aiohttp_get(url, key=None): - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - try: - headers = {"Authorization": f"Bearer {key}"} if key else {} - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url, headers=headers) as response: - return await response.json() - except Exception as e: - # Handle connection error here - log.error(f"Connection error: {e}") - return None - - -def get_api_key(url, configs): - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - return configs.get(base_url, {}).get("key", None) - - -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], -): - if response: - response.close() - if session: - await session.close() - - -async def post_streaming_url( - url: str, - payload: Union[str, bytes], - stream: bool = True, - key: Optional[str] = None, - content_type=None, -): - r = None - try: - session = aiohttp.ClientSession( - trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) - ) - - headers = {"Content-Type": "application/json"} - if key: - headers["Authorization"] = f"Bearer {key}" - - r = await session.post( - url, - data=payload, - headers=headers, - ) - r.raise_for_status() - - if stream: - response_headers = dict(r.headers) - if content_type: - response_headers["Content-Type"] = content_type - return StreamingResponse( - r.content, - status_code=r.status, - headers=response_headers, - background=BackgroundTask( - cleanup_response, response=r, session=session - ), - ) - else: - res = await r.json() - await cleanup_response(r, session) - return res - - except Exception as e: - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = await r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status if r else 500, - detail=error_detail, - ) - - -def merge_models_lists(model_lists): - merged_models = {} - - for idx, model_list in enumerate(model_lists): - if model_list is not None: - for model in model_list: - id = model["model"] - if id not in merged_models: - model["urls"] = [idx] - merged_models[id] = model - else: - merged_models[id]["urls"].append(idx) - - return list(merged_models.values()) - - @cached(ttl=3) -async def get_all_models(): +async def get_all_models(request: Request): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: - tasks = [] + request_tasks = [] + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): if url not in request.app.state.config.OLLAMA_API_CONFIGS: - tasks.append(aiohttp_get(f"{url}/api/tags")) + request_tasks.append(send_get_request(f"{url}/api/tags")) else: api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) key = api_config.get("key", None) if enable: - tasks.append(aiohttp_get(f"{url}/api/tags", key)) + request_tasks.append(send_get_request(f"{url}/api/tags", key)) else: - tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) for idx, response in enumerate(responses): if response: @@ -296,6 +290,21 @@ async def get_all_models(): for model in response.get("models", []): model["model"] = f"{prefix_id}.{model['model']}" + def merge_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + if model_list is not None: + for model in model_list: + id = model["model"] + if id not in merged_models: + model["urls"] = [idx] + merged_models[id] = model + else: + merged_models[id]["urls"].append(idx) + + return list(merged_models.values()) + models = { "models": merge_models_lists( map( @@ -311,60 +320,61 @@ async def get_all_models(): return models +async def get_filtered_models(models, user): + # Filter models based on user access control + filtered_models = [] + for model in models.get("models", []): + model_info = Models.get_model_by_id(model["model"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + return filtered_models + + @router.get("/api/tags") @router.get("/api/tags/{url_idx}") async def get_ollama_tags( request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) ): models = [] + if url_idx is None: models = await get_all_models() else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - - headers = {} - if key: - headers["Authorization"] = f"Bearer {key}" + key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) r = None try: - r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers) + r = requests.request( + method="GET", + url=f"{url}/api/tags", + headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + ) r.raise_for_status() models = r.json() except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - # Filter models based on user access control - filtered_models = [] - for model in models.get("models", []): - model_info = Models.get_model_by_id(model["model"]) - if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control - ): - filtered_models.append(model) - models["models"] = filtered_models + models["models"] = get_filtered_models(models, user) return models @@ -376,7 +386,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): if url_idx is None: # returns lowest version tasks = [ - aiohttp_get( + send_get_request( f"{url}/api/version", request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( "key", None @@ -412,18 +422,19 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): return r.json() except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"Ollama: {res['error']}" + detail = f"Ollama: {res['error']}" except Exception: - error_detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status_code if r else 500, - detail=error_detail, + detail=detail if detail else "Open WebUI: Server Connection Error", ) else: return {"version": False} @@ -436,7 +447,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u """ if request.app.state.config.ENABLE_OLLAMA_API: tasks = [ - aiohttp_get( + send_get_request( f"{url}/api/ps", request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( "key", None @@ -469,7 +480,7 @@ async def pull_model( # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/pull", payload=json.dumps(payload), key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), @@ -505,7 +516,7 @@ async def push_model( url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/push", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), @@ -531,7 +542,7 @@ async def create_model( url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/create", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), @@ -941,7 +952,7 @@ async def generate_completion( form_data.model = form_data.model.replace(f"{prefix_id}.", "") log.info(f"url: {url}") - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/generate", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), @@ -966,15 +977,13 @@ class GenerateChatCompletionForm(BaseModel): async def get_ollama_url(url_idx: Optional[int], model: str): if url_idx is None: - model_list = await get_all_models() - models = {model["model"]: model for model in model_list["models"]} - + models = request.app.state.OLLAMA_MODELS if model not in models: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) - url_idx = random.choice(models[model]["urls"]) + url_idx = random.choice(models[model].get("urls", [])) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] return url @@ -1037,7 +1046,6 @@ async def generate_chat_completion( payload["model"] = f"{payload['model']}:latest" url = await get_ollama_url(url_idx, payload["model"]) - log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") parsed_url = urlparse(url) @@ -1048,7 +1056,7 @@ async def generate_chat_completion( if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - return await post_streaming_url( + return await send_post_request( url=f"{url}/api/chat", payload=json.dumps(payload), stream=form_data.stream, @@ -1149,7 +1157,7 @@ async def generate_openai_completion( if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - return await post_streaming_url( + return await send_post_request( url=f"{url}/v1/completions", payload=json.dumps(payload), stream=payload.get("stream", False), @@ -1223,7 +1231,7 @@ async def generate_openai_chat_completion( if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - return await post_streaming_url( + return await send_post_request( url=f"{url}/v1/chat/completions", payload=json.dumps(payload), stream=payload.get("stream", False),