From 1bfcd801b798c21b65623f353168fe83b0ec6260 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 18 Mar 2024 01:11:48 -0700 Subject: [PATCH] fix: multiple openai issue --- backend/apps/ollama/main.py | 18 +++++++++--------- backend/apps/openai/main.py | 30 ++++++++++++++++++------------ backend/config.py | 6 ++++-- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 2e236f343..154be97c9 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -98,13 +98,14 @@ def merge_models_lists(model_lists): merged_models = {} for idx, model_list in enumerate(model_lists): - for model in model_list: - digest = model["digest"] - if digest not in merged_models: - model["urls"] = [idx] - merged_models[digest] = model - else: - merged_models[digest]["urls"].append(idx) + if model_list is not None: + for model in model_list: + digest = model["digest"] + if digest not in merged_models: + model["urls"] = [idx] + merged_models[digest] = model + else: + merged_models[digest]["urls"].append(idx) return list(merged_models.values()) @@ -116,11 +117,10 @@ async def get_all_models(): print("get_all_models") tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] responses = await asyncio.gather(*tasks) - responses = list(filter(lambda x: x is not None, responses)) models = { "models": merge_models_lists( - map(lambda response: response["models"], responses) + map(lambda response: response["models"] if response else None, responses) ) } diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 40bcdc0c3..b012f237f 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -168,14 +168,15 @@ def merge_models_lists(model_lists): merged_list = [] for idx, models in enumerate(model_lists): - merged_list.extend( - [ - {**model, "urlIdx": idx} - for model in models - if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] - or "gpt" in model["id"] - ] - ) + if models is not None and "error" not in models: + merged_list.extend( + [ + {**model, "urlIdx": idx} + for model in models + if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] + or "gpt" in model["id"] + ] + ) return merged_list @@ -190,15 +191,20 @@ async def get_all_models(): fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) ] + responses = await asyncio.gather(*tasks) - responses = list( - filter(lambda x: x is not None and "error" not in x, responses) - ) models = { "data": merge_models_lists( - list(map(lambda response: response["data"], responses)) + list( + map( + lambda response: response["data"] if response else None, + responses, + ) + ) ) } + + print(models) app.state.MODELS = {model["id"]: model for model in models["data"]} return models diff --git a/backend/config.py b/backend/config.py index 831371bb7..099f726ca 100644 --- a/backend/config.py +++ b/backend/config.py @@ -250,8 +250,10 @@ OPENAI_API_BASE_URLS = ( OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL ) -OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URLS.split(";")] - +OPENAI_API_BASE_URLS = [ + url.strip() if url != "" else "https://api.openai.com/v1" + for url in OPENAI_API_BASE_URLS.split(";") +] #################################### # WEBUI