refac: apps/openai/main.py and utils

This commit is contained in:
Michael Poluektov
2024-08-03 14:24:26 +01:00
parent 774defd184
commit 12c21fac22
8 changed files with 148 additions and 230 deletions

View File

@@ -1,6 +1,6 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.responses import StreamingResponse, FileResponse
import requests
import aiohttp
@@ -12,16 +12,12 @@ from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
get_verified_user,
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from utils.misc import add_or_update_system_message
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
from config import (
SRC_LOG_LEVELS,
@@ -69,8 +65,6 @@ app.state.MODELS = {}
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
response = await call_next(request)
return response
@@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
@@ -234,64 +228,58 @@ def merge_models_lists(model_lists):
return merged_list
async def get_all_models(raw: bool = False):
def is_openai_api_disabled():
api_keys = app.state.config.OPENAI_API_KEYS
no_keys = len(api_keys) == 1 and api_keys[0] == ""
return no_keys or not app.state.config.ENABLE_OPENAI_API
async def get_all_models_raw() -> list:
if is_openai_api_disabled():
return []
# Check if API KEYS length is same than API URLS length
num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
num_keys = len(app.state.config.OPENAI_API_KEYS)
if num_keys != num_urls:
# if there are more keys than urls, remove the extra keys
if num_keys > num_urls:
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
app.state.config.OPENAI_API_KEYS = new_keys
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
log.debug(f"get_all_models:responses() {responses}")
return responses
async def get_all_models() -> dict[str, list]:
log.info("get_all_models()")
if is_openai_api_disabled():
return {"data": []}
if (
len(app.state.config.OPENAI_API_KEYS) == 1
and app.state.config.OPENAI_API_KEYS[0] == ""
) or not app.state.config.ENABLE_OPENAI_API:
models = {"data": []}
else:
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
):
# if there are more keys than urls, remove the extra keys
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [
""
for _ in range(
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
)
]
responses = await get_all_models_raw()
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
def extract_data(response):
if response and "data" in response:
return response["data"]
if isinstance(response, list):
return response
return None
responses = await asyncio.gather(*tasks)
log.debug(f"get_all_models:responses() {responses}")
models = {"data": merge_models_lists(map(extract_data, responses))}
if raw:
return responses
models = {
"data": merge_models_lists(
list(
map(
lambda response: (
response["data"]
if (response and "data" in response)
else (response if isinstance(response, list) else None)
),
responses,
)
)
)
}
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
@@ -299,7 +287,7 @@ async def get_all_models(raw: bool = False):
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
if url_idx == None:
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
@@ -340,7 +328,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
@@ -358,8 +346,7 @@ async def generate_chat_completion(
):
idx = 0
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
payload.pop("metadata")
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
@@ -368,70 +355,9 @@ async def generate_chat_completion(
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if (
model_info.params.get("temperature", None) is not None
and payload.get("temperature") is None
):
payload["temperature"] = float(model_info.params.get("temperature"))
if model_info.params.get("top_p", None) and payload.get("top_p") is None:
payload["top_p"] = int(model_info.params.get("top_p", None))
if (
model_info.params.get("max_tokens", None)
and payload.get("max_tokens") is None
):
payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
if (
model_info.params.get("frequency_penalty", None)
and payload.get("frequency_penalty") is None
):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if (
model_info.params.get("seed", None) is not None
and payload.get("seed") is None
):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None) and payload.get("stop") is None:
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"):
payload["messages"] = add_or_update_system_message(
system, payload["messages"]
)
else:
pass
params = model_info.params.model_dump()
payload = apply_model_params_to_body(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
@@ -506,7 +432,7 @@ async def generate_chat_completion(
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
@@ -569,7 +495,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally: