diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index e3210ae5f..c40509fde 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -633,13 +633,7 @@ async def verify_connection( raise HTTPException(status_code=500, detail=error_detail) -def convert_to_azure_payload( - url, - payload: dict, -): - model = payload.get("model", "") - - # Filter allowed parameters based on Azure OpenAI API +def get_azure_allowed_params(api_version: str) -> set[str]: allowed_params = { "messages", "temperature", @@ -668,6 +662,20 @@ def convert_to_azure_payload( "seed", "max_completion_tokens", } + if api_version >= "2024-09-01-preview": + allowed_params.add("stream_options") + return allowed_params + + +def convert_to_azure_payload( + url, + payload: dict, + api_version: str +): + model = payload.get("model", "") + + # Filter allowed parameters based on Azure OpenAI API + allowed_params = get_azure_allowed_params(api_version) # Special handling for o-series models if model.startswith("o") and model.endswith("-mini"): @@ -817,8 +825,8 @@ async def generate_chat_completion( } if api_config.get("azure", False): - request_url, payload = convert_to_azure_payload(url, payload) - api_version = api_config.get("api_version", "") or "2023-03-15-preview" + api_version = api_config.get("api_version", "2023-03-15-preview") + request_url, payload = convert_to_azure_payload(url, payload, api_version) headers["api-key"] = key headers["api-version"] = api_version request_url = f"{request_url}/chat/completions?api-version={api_version}" @@ -1007,16 +1015,15 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): } if api_config.get("azure", False): + api_version = api_config.get("api_version", "2023-03-15-preview") headers["api-key"] = key - headers["api-version"] = ( - api_config.get("api_version", "") or "2023-03-15-preview" - ) + headers["api-version"] = api_version payload = json.loads(body) - url, payload = convert_to_azure_payload(url, payload) + url, payload = convert_to_azure_payload(url, payload, api_version) body = json.dumps(payload).encode() - request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}" + request_url = f"{url}/{path}?api-version={api_version}" else: headers["Authorization"] = f"Bearer {key}" request_url = f"{url}/{path}"