mirror of
https://github.com/open-webui/open-webui
synced 2025-02-06 13:10:16 +00:00
feat: user hook
This commit is contained in:
parent
f62b15d8da
commit
8a86f32700
@ -11,6 +11,7 @@ import requests
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import shutil
|
import shutil
|
||||||
import os
|
import os
|
||||||
|
import inspect
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
||||||
@ -204,6 +205,8 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
|
|||||||
|
|
||||||
# Parse the function response
|
# Parse the function response
|
||||||
if content is not None:
|
if content is not None:
|
||||||
|
|
||||||
|
print(content)
|
||||||
result = json.loads(content)
|
result = json.loads(content)
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
@ -218,6 +221,23 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
|
|||||||
function = getattr(toolkit_module, result["name"])
|
function = getattr(toolkit_module, result["name"])
|
||||||
function_result = None
|
function_result = None
|
||||||
try:
|
try:
|
||||||
|
# Get the signature of the function
|
||||||
|
sig = inspect.signature(function)
|
||||||
|
# Check if 'user' is a parameter of the function
|
||||||
|
if "user" in sig.parameters:
|
||||||
|
# Call the function with the 'user' parameter included
|
||||||
|
function_result = function(
|
||||||
|
**{
|
||||||
|
**result["parameters"],
|
||||||
|
"user": {
|
||||||
|
"id": user.id,
|
||||||
|
"name": user.name,
|
||||||
|
"role": user.role,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Call the function without modifying the parameters
|
||||||
function_result = function(**result["parameters"])
|
function_result = function(**result["parameters"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
@ -284,6 +304,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# If tool_ids field is present, call the functions
|
# If tool_ids field is present, call the functions
|
||||||
if "tool_ids" in data:
|
if "tool_ids" in data:
|
||||||
|
print(data["tool_ids"])
|
||||||
prompt = get_last_user_message(data["messages"])
|
prompt = get_last_user_message(data["messages"])
|
||||||
for tool_id in data["tool_ids"]:
|
for tool_id in data["tool_ids"]:
|
||||||
print(tool_id)
|
print(tool_id)
|
||||||
@ -299,7 +320,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
context += ("\n" if context != "" else "") + response
|
context += ("\n" if context != "" else "") + response
|
||||||
del data["tool_ids"]
|
del data["tool_ids"]
|
||||||
|
|
||||||
print(context)
|
print(f"tool_context: {context}")
|
||||||
|
|
||||||
# If docs field is present, generate RAG completions
|
# If docs field is present, generate RAG completions
|
||||||
if "docs" in data:
|
if "docs" in data:
|
||||||
|
Loading…
Reference in New Issue
Block a user