fix:required params

This commit is contained in:
Samuel 2024-11-17 18:00:11 +00:00
parent b8ca03fafd
commit 3c00adae38

View File

@ -534,7 +534,8 @@ async def handle_streaming_response(request: Request, response: Response,
)
async def handle_nonstreaming_response(request: Request, response: Response,
tools: dict, user: UserModel, data_items: list) -> JSONResponse:
tools: dict, user: UserModel,
data_items: list, models) -> JSONResponse:
# It only should be one response since we are in the non streaming scenario
content = ''
async for data in response.body_iterator:
@ -544,7 +545,7 @@ async def handle_nonstreaming_response(request: Request, response: Response,
body = json.loads(request._body)
is_ollama = False
if app.state.MODELS[body["model"]]["owned_by"] == "ollama":
if models[body["model"]]["owned_by"] == "ollama":
is_ollama = True
is_openai = not is_ollama
@ -620,7 +621,7 @@ def get_task_model_id(
async def chat_completion_tools_handler(
body: dict, user: UserModel, models, extra_params: dict
body: dict, user: UserModel, extra_params: dict, models:dict
) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions
metadata = body.get("metadata", {})
@ -879,7 +880,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
del body["tools"] # we won't use those
try:
body, flags = await chat_completion_tools_handler(
body, user, extra_params
body, user, extra_params, models
)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
@ -958,7 +959,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
else:
if not body.get("stream", False):
return await handle_nonstreaming_response(request, response, tools, user, data_items)
return await handle_nonstreaming_response(request, response, tools, user, data_items, models)
return await handle_streaming_response(request, response, tools, data_items, call_next, user)
async def _receive(self, body: bytes):