factor out get_content_from_response

This commit is contained in:
Michael Poluektov 2024-08-10 13:40:04 +01:00
parent 0c9119d619
commit 9fb70969d7

View File

@ -282,6 +282,21 @@ def get_filter_function_ids(model):
return filter_ids return filter_ids
async def get_content_from_response(response) -> Optional[str]:
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
return content
async def get_function_call_response( async def get_function_call_response(
messages, messages,
files, files,
@ -293,6 +308,9 @@ async def get_function_call_response(
__event_call__=None, __event_call__=None,
): ):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
if tool is None:
return None, None, False
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs) content = tools_function_calling_generation_template(template, tools_specs)
@ -327,21 +345,9 @@ async def get_function_call_response(
model = app.state.MODELS[task_model_id] model = app.state.MODELS[task_model_id]
response = None
try: try:
response = await generate_chat_completions(form_data=payload, user=user) response = await generate_chat_completions(form_data=payload, user=user)
content = None content = await get_content_from_response(response)
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
if content is None: if content is None:
return None, None, False return None, None, False
@ -351,8 +357,6 @@ async def get_function_call_response(
result = json.loads(content) result = json.loads(content)
print(result) print(result)
citation = None
if "name" not in result: if "name" not in result:
return None, None, False return None, None, False
@ -375,6 +379,7 @@ async def get_function_call_response(
function = getattr(toolkit_module, result["name"]) function = getattr(toolkit_module, result["name"])
function_result = None function_result = None
citation = None
try: try:
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(function) sig = inspect.signature(function)
@ -1091,7 +1096,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
if model.get("pipe"): if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
print("generate_ollama_chat_completion")
return await generate_ollama_chat_completion(form_data, user=user) return await generate_ollama_chat_completion(form_data, user=user)
else: else:
return await generate_openai_chat_completion(form_data, user=user) return await generate_openai_chat_completion(form_data, user=user)