refac: connections handling

This commit is contained in:
Timothy Jaeryang Baek
2025-01-18 17:10:15 -08:00
parent 430854e223
commit ca0285fc91
3 changed files with 113 additions and 56 deletions

View File

@@ -145,11 +145,13 @@ async def update_config(
request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
# Remove any extra configs
config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys()
for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
if url not in config_urls:
request.app.state.config.OPENAI_API_CONFIGS.pop(url, None)
# Remove the API configs that are not in the API URLS
keys = list(map(str, range(len(request.app.state.config.OPENAI_API_BASE_URLS))))
request.app.state.config.OPENAI_API_CONFIGS = {
key: value
for key, value in request.app.state.config.OPENAI_API_CONFIGS.items()
if key in keys
}
return {
"ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
@@ -264,14 +266,21 @@ async def get_all_models_responses(request: Request) -> list:
request_tasks = []
for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
if url not in request.app.state.config.OPENAI_API_CONFIGS:
if (idx not in request.app.state.config.OPENAI_API_CONFIGS) or (
url not in request.app.state.config.OPENAI_API_CONFIGS # Legacy support
):
request_tasks.append(
send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
)
)
else:
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
idx,
request.app.state.config.OPENAI_API_CONFIGS.get(
url, {}
), # Legacy support
)
enable = api_config.get("enable", True)
model_ids = api_config.get("model_ids", [])
@@ -310,7 +319,12 @@ async def get_all_models_responses(request: Request) -> list:
for idx, response in enumerate(responses):
if response:
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
idx,
request.app.state.config.OPENAI_API_CONFIGS.get(
url, {}
), # Legacy support
)
prefix_id = api_config.get("prefix_id", None)
@@ -585,7 +599,10 @@ async def generate_chat_completion(
# Get the API config for the model
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
idx,
request.app.state.config.OPENAI_API_CONFIGS.get(
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
), # Legacy support
)
prefix_id = api_config.get("prefix_id", None)