From 9fb70969d729af13960ddf9b4be5df753299d57e Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 13:40:04 +0100 Subject: [PATCH] factor out get_content_from_response --- backend/main.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/backend/main.py b/backend/main.py index 0099aabb8..44fdc6298 100644 --- a/backend/main.py +++ b/backend/main.py @@ -282,6 +282,21 @@ def get_filter_function_ids(model): return filter_ids +async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + async def get_function_call_response( messages, files, @@ -293,6 +308,9 @@ async def get_function_call_response( __event_call__=None, ): tool = Tools.get_tool_by_id(tool_id) + if tool is None: + return None, None, False + tools_specs = json.dumps(tool.specs, indent=2) content = tools_function_calling_generation_template(template, tools_specs) @@ -327,21 +345,9 @@ async def get_function_call_response( model = app.state.MODELS[task_model_id] - response = None try: response = await generate_chat_completions(form_data=payload, user=user) - content = None - - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] + content = await get_content_from_response(response) if content is None: return None, None, False @@ -351,8 +357,6 @@ async def get_function_call_response( result = json.loads(content) print(result) - citation = None - if "name" not in result: return None, None, False @@ -375,6 +379,7 @@ async def get_function_call_response( function = getattr(toolkit_module, result["name"]) function_result = None + citation = None try: # Get the signature of the function sig = inspect.signature(function) @@ -1091,7 +1096,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": - print("generate_ollama_chat_completion") return await generate_ollama_chat_completion(form_data, user=user) else: return await generate_openai_chat_completion(form_data, user=user)