enh: __messages__ support for tools

This commit is contained in:
Timothy J. Baek 2024-06-17 13:28:29 -07:00
parent 4a3362f889
commit 55dfc2013a
2 changed files with 38 additions and 17 deletions

View File

@ -244,23 +244,28 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
try: try:
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(function) sig = inspect.signature(function)
# Check if '__user__' is a parameter of the function params = result["parameters"]
if "__user__" in sig.parameters: if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included # Call the function with the '__user__' parameter included
function_result = function( params = {
**{ **params,
**result["parameters"], "__user__": {
"__user__": { "id": user.id,
"id": user.id, "email": user.email,
"email": user.email, "name": user.name,
"name": user.name, "role": user.role,
"role": user.role, },
}, }
}
) if "__messages__" in sig.parameters:
else: # Call the function with the '__messages__' parameter included
# Call the function without modifying the parameters params = {
function_result = function(**result["parameters"]) **params,
"__messages__": messages,
}
function_result = function(**params)
except Exception as e: except Exception as e:
print(e) print(e)
@ -339,8 +344,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user=user, user=user,
) )
if response: if isinstance(response, str):
context += ("\n" if context != "" else "") + response context += ("\n" if context != "" else "") + response
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
del data["tool_ids"] del data["tool_ids"]

View File

@ -3,7 +3,7 @@ import hashlib
import json import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional, List from typing import Optional, List, Tuple
def get_last_user_message(messages: List[dict]) -> str: def get_last_user_message(messages: List[dict]) -> str:
@ -28,6 +28,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
return None return None
def get_system_message(messages: List[dict]) -> dict:
for message in messages:
if message["role"] == "system":
return message
return None
def remove_system_message(messages: List[dict]) -> List[dict]:
return [message for message in messages if message["role"] != "system"]
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
return get_system_message(messages), remove_system_message(messages)
def add_or_update_system_message(content: str, messages: List[dict]): def add_or_update_system_message(content: str, messages: List[dict]):
""" """
Adds a new system message at the beginning of the messages list Adds a new system message at the beginning of the messages list