mirror of
https://github.com/open-webui/open-webui
synced 2025-02-16 18:22:29 +00:00
enh: __messages__ support for tools
This commit is contained in:
parent
4a3362f889
commit
55dfc2013a
@ -244,23 +244,28 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
||||
try:
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function)
|
||||
# Check if '__user__' is a parameter of the function
|
||||
params = result["parameters"]
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
# Call the function with the '__user__' parameter included
|
||||
function_result = function(
|
||||
**{
|
||||
**result["parameters"],
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Call the function without modifying the parameters
|
||||
function_result = function(**result["parameters"])
|
||||
params = {
|
||||
**params,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
}
|
||||
|
||||
if "__messages__" in sig.parameters:
|
||||
# Call the function with the '__messages__' parameter included
|
||||
params = {
|
||||
**params,
|
||||
"__messages__": messages,
|
||||
}
|
||||
|
||||
function_result = function(**params)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
@ -339,8 +344,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
user=user,
|
||||
)
|
||||
|
||||
if response:
|
||||
if isinstance(response, str):
|
||||
context += ("\n" if context != "" else "") + response
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
del data["tool_ids"]
|
||||
|
@ -3,7 +3,7 @@ import hashlib
|
||||
import json
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
|
||||
def get_last_user_message(messages: List[dict]) -> str:
|
||||
@ -28,6 +28,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
|
||||
return None
|
||||
|
||||
|
||||
def get_system_message(messages: List[dict]) -> dict:
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def remove_system_message(messages: List[dict]) -> List[dict]:
|
||||
return [message for message in messages if message["role"] != "system"]
|
||||
|
||||
|
||||
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
|
||||
return get_system_message(messages), remove_system_message(messages)
|
||||
|
||||
|
||||
def add_or_update_system_message(content: str, messages: List[dict]):
|
||||
"""
|
||||
Adds a new system message at the beginning of the messages list
|
||||
|
Loading…
Reference in New Issue
Block a user