mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac: apps/openai/main.py and utils
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from fastapi import FastAPI, Request, Response, HTTPException, Depends
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
|
||||
import requests
|
||||
import aiohttp
|
||||
@@ -12,16 +12,12 @@ from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.users import Users
|
||||
from constants import ERROR_MESSAGES
|
||||
from utils.utils import (
|
||||
decode_token,
|
||||
get_verified_user,
|
||||
get_verified_user,
|
||||
get_admin_user,
|
||||
)
|
||||
from utils.task import prompt_template
|
||||
from utils.misc import add_or_update_system_message
|
||||
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
|
||||
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
@@ -69,8 +65,6 @@ app.state.MODELS = {}
|
||||
async def check_url(request: Request, call_next):
|
||||
if len(app.state.MODELS) == 0:
|
||||
await get_all_models()
|
||||
else:
|
||||
pass
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
@@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']}"
|
||||
except:
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
@@ -234,64 +228,58 @@ def merge_models_lists(model_lists):
|
||||
return merged_list
|
||||
|
||||
|
||||
async def get_all_models(raw: bool = False):
|
||||
def is_openai_api_disabled():
|
||||
api_keys = app.state.config.OPENAI_API_KEYS
|
||||
no_keys = len(api_keys) == 1 and api_keys[0] == ""
|
||||
return no_keys or not app.state.config.ENABLE_OPENAI_API
|
||||
|
||||
|
||||
async def get_all_models_raw() -> list:
|
||||
if is_openai_api_disabled():
|
||||
return []
|
||||
|
||||
# Check if API KEYS length is same than API URLS length
|
||||
num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
num_keys = len(app.state.config.OPENAI_API_KEYS)
|
||||
|
||||
if num_keys != num_urls:
|
||||
# if there are more keys than urls, remove the extra keys
|
||||
if num_keys > num_urls:
|
||||
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
|
||||
app.state.config.OPENAI_API_KEYS = new_keys
|
||||
# if there are more urls than keys, add empty keys
|
||||
else:
|
||||
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
|
||||
|
||||
tasks = [
|
||||
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
||||
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
|
||||
]
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
log.debug(f"get_all_models:responses() {responses}")
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
async def get_all_models() -> dict[str, list]:
|
||||
log.info("get_all_models()")
|
||||
if is_openai_api_disabled():
|
||||
return {"data": []}
|
||||
|
||||
if (
|
||||
len(app.state.config.OPENAI_API_KEYS) == 1
|
||||
and app.state.config.OPENAI_API_KEYS[0] == ""
|
||||
) or not app.state.config.ENABLE_OPENAI_API:
|
||||
models = {"data": []}
|
||||
else:
|
||||
# Check if API KEYS length is same than API URLS length
|
||||
if len(app.state.config.OPENAI_API_KEYS) != len(
|
||||
app.state.config.OPENAI_API_BASE_URLS
|
||||
):
|
||||
# if there are more keys than urls, remove the extra keys
|
||||
if len(app.state.config.OPENAI_API_KEYS) > len(
|
||||
app.state.config.OPENAI_API_BASE_URLS
|
||||
):
|
||||
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
|
||||
: len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
]
|
||||
# if there are more urls than keys, add empty keys
|
||||
else:
|
||||
app.state.config.OPENAI_API_KEYS += [
|
||||
""
|
||||
for _ in range(
|
||||
len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
- len(app.state.config.OPENAI_API_KEYS)
|
||||
)
|
||||
]
|
||||
responses = await get_all_models_raw()
|
||||
|
||||
tasks = [
|
||||
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
||||
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
|
||||
]
|
||||
def extract_data(response):
|
||||
if response and "data" in response:
|
||||
return response["data"]
|
||||
if isinstance(response, list):
|
||||
return response
|
||||
return None
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
log.debug(f"get_all_models:responses() {responses}")
|
||||
models = {"data": merge_models_lists(map(extract_data, responses))}
|
||||
|
||||
if raw:
|
||||
return responses
|
||||
|
||||
models = {
|
||||
"data": merge_models_lists(
|
||||
list(
|
||||
map(
|
||||
lambda response: (
|
||||
response["data"]
|
||||
if (response and "data" in response)
|
||||
else (response if isinstance(response, list) else None)
|
||||
),
|
||||
responses,
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
log.debug(f"models: {models}")
|
||||
app.state.MODELS = {model["id"]: model for model in models["data"]}
|
||||
log.debug(f"models: {models}")
|
||||
app.state.MODELS = {model["id"]: model for model in models["data"]}
|
||||
|
||||
return models
|
||||
|
||||
@@ -299,7 +287,7 @@ async def get_all_models(raw: bool = False):
|
||||
@app.get("/models")
|
||||
@app.get("/models/{url_idx}")
|
||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
||||
if url_idx == None:
|
||||
if url_idx is None:
|
||||
models = await get_all_models()
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
@@ -340,7 +328,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']}"
|
||||
except:
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
@@ -358,8 +346,7 @@ async def generate_chat_completion(
|
||||
):
|
||||
idx = 0
|
||||
payload = {**form_data}
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
payload.pop("metadata")
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
@@ -368,70 +355,9 @@ async def generate_chat_completion(
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
if model_info.params:
|
||||
if (
|
||||
model_info.params.get("temperature", None) is not None
|
||||
and payload.get("temperature") is None
|
||||
):
|
||||
payload["temperature"] = float(model_info.params.get("temperature"))
|
||||
|
||||
if model_info.params.get("top_p", None) and payload.get("top_p") is None:
|
||||
payload["top_p"] = int(model_info.params.get("top_p", None))
|
||||
|
||||
if (
|
||||
model_info.params.get("max_tokens", None)
|
||||
and payload.get("max_tokens") is None
|
||||
):
|
||||
payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
|
||||
|
||||
if (
|
||||
model_info.params.get("frequency_penalty", None)
|
||||
and payload.get("frequency_penalty") is None
|
||||
):
|
||||
payload["frequency_penalty"] = int(
|
||||
model_info.params.get("frequency_penalty", None)
|
||||
)
|
||||
|
||||
if (
|
||||
model_info.params.get("seed", None) is not None
|
||||
and payload.get("seed") is None
|
||||
):
|
||||
payload["seed"] = model_info.params.get("seed", None)
|
||||
|
||||
if model_info.params.get("stop", None) and payload.get("stop") is None:
|
||||
payload["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
system = model_info.params.get("system", None)
|
||||
if system:
|
||||
system = prompt_template(
|
||||
system,
|
||||
**(
|
||||
{
|
||||
"user_name": user.name,
|
||||
"user_location": (
|
||||
user.info.get("location") if user.info else None
|
||||
),
|
||||
}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
if payload.get("messages"):
|
||||
payload["messages"] = add_or_update_system_message(
|
||||
system, payload["messages"]
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
params = model_info.params.model_dump()
|
||||
payload = apply_model_params_to_body(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
|
||||
model = app.state.MODELS[payload.get("model")]
|
||||
idx = model["urlIdx"]
|
||||
@@ -506,7 +432,7 @@ async def generate_chat_completion(
|
||||
print(res)
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
||||
finally:
|
||||
@@ -569,7 +495,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
print(res)
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user