factor out get_content_from_response

This commit is contained in:
Michael Poluektov 2024-08-10 13:40:04 +01:00
parent 0c9119d619
commit 9fb70969d7

View File

@ -282,6 +282,21 @@ def get_filter_function_ids(model):
return filter_ids
async def get_content_from_response(response) -> Optional[str]:
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
return content
async def get_function_call_response(
messages,
files,
@ -293,6 +308,9 @@ async def get_function_call_response(
__event_call__=None,
):
tool = Tools.get_tool_by_id(tool_id)
if tool is None:
return None, None, False
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
@ -327,21 +345,9 @@ async def get_function_call_response(
model = app.state.MODELS[task_model_id]
response = None
try:
response = await generate_chat_completions(form_data=payload, user=user)
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
content = await get_content_from_response(response)
if content is None:
return None, None, False
@ -351,8 +357,6 @@ async def get_function_call_response(
result = json.loads(content)
print(result)
citation = None
if "name" not in result:
return None, None, False
@ -375,6 +379,7 @@ async def get_function_call_response(
function = getattr(toolkit_module, result["name"])
function_result = None
citation = None
try:
# Get the signature of the function
sig = inspect.signature(function)
@ -1091,7 +1096,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":
print("generate_ollama_chat_completion")
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)