mirror of
https://github.com/open-webui/open-webui
synced 2025-05-12 09:31:34 +00:00
enh: __messages__ support for tools
This commit is contained in:
parent
4a3362f889
commit
55dfc2013a
@ -244,23 +244,28 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
|||||||
try:
|
try:
|
||||||
# Get the signature of the function
|
# Get the signature of the function
|
||||||
sig = inspect.signature(function)
|
sig = inspect.signature(function)
|
||||||
# Check if '__user__' is a parameter of the function
|
params = result["parameters"]
|
||||||
|
|
||||||
if "__user__" in sig.parameters:
|
if "__user__" in sig.parameters:
|
||||||
# Call the function with the '__user__' parameter included
|
# Call the function with the '__user__' parameter included
|
||||||
function_result = function(
|
params = {
|
||||||
**{
|
**params,
|
||||||
**result["parameters"],
|
"__user__": {
|
||||||
"__user__": {
|
"id": user.id,
|
||||||
"id": user.id,
|
"email": user.email,
|
||||||
"email": user.email,
|
"name": user.name,
|
||||||
"name": user.name,
|
"role": user.role,
|
||||||
"role": user.role,
|
},
|
||||||
},
|
}
|
||||||
}
|
|
||||||
)
|
if "__messages__" in sig.parameters:
|
||||||
else:
|
# Call the function with the '__messages__' parameter included
|
||||||
# Call the function without modifying the parameters
|
params = {
|
||||||
function_result = function(**result["parameters"])
|
**params,
|
||||||
|
"__messages__": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
function_result = function(**params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
@ -339,8 +344,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response:
|
if isinstance(response, str):
|
||||||
context += ("\n" if context != "" else "") + response
|
context += ("\n" if context != "" else "") + response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
del data["tool_ids"]
|
del data["tool_ids"]
|
||||||
|
@ -3,7 +3,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
def get_last_user_message(messages: List[dict]) -> str:
|
def get_last_user_message(messages: List[dict]) -> str:
|
||||||
@ -28,6 +28,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_message(messages: List[dict]) -> dict:
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
return message
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def remove_system_message(messages: List[dict]) -> List[dict]:
|
||||||
|
return [message for message in messages if message["role"] != "system"]
|
||||||
|
|
||||||
|
|
||||||
|
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
|
||||||
|
return get_system_message(messages), remove_system_message(messages)
|
||||||
|
|
||||||
|
|
||||||
def add_or_update_system_message(content: str, messages: List[dict]):
|
def add_or_update_system_message(content: str, messages: List[dict]):
|
||||||
"""
|
"""
|
||||||
Adds a new system message at the beginning of the messages list
|
Adds a new system message at the beginning of the messages list
|
||||||
|
Loading…
Reference in New Issue
Block a user