Merge pull request #15122 from tcx4c70/feat/add_stream_options_to_azure

feat(azure): Add stream_options to payload if api_version supports
This commit is contained in:
Tim Jaeryang Baek 2025-06-20 09:57:27 +04:00 committed by GitHub
commit 4e50dd4df6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -633,13 +633,7 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail) raise HTTPException(status_code=500, detail=error_detail)
def convert_to_azure_payload( def get_azure_allowed_params(api_version: str) -> set[str]:
url,
payload: dict,
):
model = payload.get("model", "")
# Filter allowed parameters based on Azure OpenAI API
allowed_params = { allowed_params = {
"messages", "messages",
"temperature", "temperature",
@ -668,6 +662,20 @@ def convert_to_azure_payload(
"seed", "seed",
"max_completion_tokens", "max_completion_tokens",
} }
if api_version >= "2024-09-01-preview":
allowed_params.add("stream_options")
return allowed_params
def convert_to_azure_payload(
url,
payload: dict,
api_version: str
):
model = payload.get("model", "")
# Filter allowed parameters based on Azure OpenAI API
allowed_params = get_azure_allowed_params(api_version)
# Special handling for o-series models # Special handling for o-series models
if model.startswith("o") and model.endswith("-mini"): if model.startswith("o") and model.endswith("-mini"):
@ -817,8 +825,8 @@ async def generate_chat_completion(
} }
if api_config.get("azure", False): if api_config.get("azure", False):
request_url, payload = convert_to_azure_payload(url, payload) api_version = api_config.get("api_version", "2023-03-15-preview")
api_version = api_config.get("api_version", "") or "2023-03-15-preview" request_url, payload = convert_to_azure_payload(url, payload, api_version)
headers["api-key"] = key headers["api-key"] = key
headers["api-version"] = api_version headers["api-version"] = api_version
request_url = f"{request_url}/chat/completions?api-version={api_version}" request_url = f"{request_url}/chat/completions?api-version={api_version}"
@ -1007,16 +1015,15 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
} }
if api_config.get("azure", False): if api_config.get("azure", False):
api_version = api_config.get("api_version", "2023-03-15-preview")
headers["api-key"] = key headers["api-key"] = key
headers["api-version"] = ( headers["api-version"] = api_version
api_config.get("api_version", "") or "2023-03-15-preview"
)
payload = json.loads(body) payload = json.loads(body)
url, payload = convert_to_azure_payload(url, payload) url, payload = convert_to_azure_payload(url, payload, api_version)
body = json.dumps(payload).encode() body = json.dumps(payload).encode()
request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}" request_url = f"{url}/{path}?api-version={api_version}"
else: else:
headers["Authorization"] = f"Bearer {key}" headers["Authorization"] = f"Bearer {key}"
request_url = f"{url}/{path}" request_url = f"{url}/{path}"