From 55dfc2013a2be993e33615f7be82f45b3801667e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 17 Jun 2024 13:28:29 -0700 Subject: [PATCH] enh: __messages__ support for tools --- backend/main.py | 38 ++++++++++++++++++++++---------------- backend/utils/misc.py | 17 ++++++++++++++++- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/backend/main.py b/backend/main.py index 04f886162..40793867d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -244,23 +244,28 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, try: # Get the signature of the function sig = inspect.signature(function) - # Check if '__user__' is a parameter of the function + params = result["parameters"] + if "__user__" in sig.parameters: # Call the function with the '__user__' parameter included - function_result = function( - **{ - **result["parameters"], - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - } - ) - else: - # Call the function without modifying the parameters - function_result = function(**result["parameters"]) + params = { + **params, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + } + + if "__messages__" in sig.parameters: + # Call the function with the '__messages__' parameter included + params = { + **params, + "__messages__": messages, + } + + function_result = function(**params) except Exception as e: print(e) @@ -339,8 +344,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): user=user, ) - if response: + if isinstance(response, str): context += ("\n" if context != "" else "") + response + except Exception as e: print(f"Error: {e}") del data["tool_ids"] diff --git a/backend/utils/misc.py b/backend/utils/misc.py index c3c65d3f5..41fbdcc75 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -3,7 +3,7 @@ import hashlib import json import re from datetime import timedelta -from typing import Optional, List +from typing import Optional, List, Tuple def get_last_user_message(messages: List[dict]) -> str: @@ -28,6 +28,21 @@ def get_last_assistant_message(messages: List[dict]) -> str: return None +def get_system_message(messages: List[dict]) -> dict: + for message in messages: + if message["role"] == "system": + return message + return None + + +def remove_system_message(messages: List[dict]) -> List[dict]: + return [message for message in messages if message["role"] != "system"] + + +def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]: + return get_system_message(messages), remove_system_message(messages) + + def add_or_update_system_message(content: str, messages: List[dict]): """ Adds a new system message at the beginning of the messages list