enh: __messages__ support for tools

This commit is contained in:
Timothy J. Baek 2024-06-17 13:28:29 -07:00
parent 4a3362f889
commit 55dfc2013a
2 changed files with 38 additions and 17 deletions

View File

@ -244,23 +244,28 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
try:
# Get the signature of the function
sig = inspect.signature(function)
# Check if '__user__' is a parameter of the function
params = result["parameters"]
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
function_result = function(
**{
**result["parameters"],
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
}
)
else:
# Call the function without modifying the parameters
function_result = function(**result["parameters"])
params = {
**params,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
}
if "__messages__" in sig.parameters:
# Call the function with the '__messages__' parameter included
params = {
**params,
"__messages__": messages,
}
function_result = function(**params)
except Exception as e:
print(e)
@ -339,8 +344,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user=user,
)
if response:
if isinstance(response, str):
context += ("\n" if context != "" else "") + response
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]

View File

@ -3,7 +3,7 @@ import hashlib
import json
import re
from datetime import timedelta
from typing import Optional, List
from typing import Optional, List, Tuple
def get_last_user_message(messages: List[dict]) -> str:
@ -28,6 +28,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
return None
def get_system_message(messages: List[dict]) -> dict:
for message in messages:
if message["role"] == "system":
return message
return None
def remove_system_message(messages: List[dict]) -> List[dict]:
return [message for message in messages if message["role"] != "system"]
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
return get_system_message(messages), remove_system_message(messages)
def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list