diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index be8ec6489..a0d8f3750 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -30,7 +30,7 @@ from config import ( MODEL_FILTER_LIST, AppConfig, ) -from typing import List, Optional +from typing import List, Optional, Literal, overload import hashlib @@ -262,12 +262,22 @@ async def get_all_models_raw() -> list: return responses -async def get_all_models() -> dict[str, list]: +@overload +async def get_all_models(raw: Literal[True]) -> list: ... + + +@overload +async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ... + + +async def get_all_models(raw=False) -> dict[str, list] | list: log.info("get_all_models()") if is_openai_api_disabled(): - return {"data": []} + return [] if raw else {"data": []} responses = await get_all_models_raw() + if raw: + return responses def extract_data(response): if response and "data" in response: @@ -370,13 +380,6 @@ async def generate_chat_completion( "role": user.role, } - # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 - # This is a workaround until OpenAI fixes the issue with this model - if payload.get("model") == "gpt-4-vision-preview": - if "max_tokens" not in payload: - payload["max_tokens"] = 4000 - log.debug("Modified payload:", payload) - # Convert the modified body back to JSON payload = json.dumps(payload) diff --git a/backend/main.py b/backend/main.py index 3e1a58d90..181944606 100644 --- a/backend/main.py +++ b/backend/main.py @@ -36,7 +36,6 @@ from apps.ollama.main import ( from apps.openai.main import ( app as openai_app, get_all_models as get_openai_models, - get_all_models_raw as get_openai_models_raw, generate_chat_completion as generate_openai_chat_completion, ) @@ -1657,7 +1656,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ @app.get("/api/pipelines/list") async def get_pipelines_list(user=Depends(get_admin_user)): - responses = await get_openai_models_raw() + responses = await get_openai_models(raw = True) print(responses) urlIdxs = [