feat: user hook

This commit is contained in:
Timothy J. Baek 2024-06-11 10:19:59 -07:00
parent f62b15d8da
commit 8a86f32700

View File

@ -11,6 +11,7 @@ import requests
import mimetypes import mimetypes
import shutil import shutil
import os import os
import inspect
import asyncio import asyncio
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
@ -204,6 +205,8 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
# Parse the function response # Parse the function response
if content is not None: if content is not None:
print(content)
result = json.loads(content) result = json.loads(content)
print(result) print(result)
@ -218,6 +221,23 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
function = getattr(toolkit_module, result["name"]) function = getattr(toolkit_module, result["name"])
function_result = None function_result = None
try: try:
# Get the signature of the function
sig = inspect.signature(function)
# Check if 'user' is a parameter of the function
if "user" in sig.parameters:
# Call the function with the 'user' parameter included
function_result = function(
**{
**result["parameters"],
"user": {
"id": user.id,
"name": user.name,
"role": user.role,
},
}
)
else:
# Call the function without modifying the parameters
function_result = function(**result["parameters"]) function_result = function(**result["parameters"])
except Exception as e: except Exception as e:
print(e) print(e)
@ -284,6 +304,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
if "tool_ids" in data: if "tool_ids" in data:
print(data["tool_ids"])
prompt = get_last_user_message(data["messages"]) prompt = get_last_user_message(data["messages"])
for tool_id in data["tool_ids"]: for tool_id in data["tool_ids"]:
print(tool_id) print(tool_id)
@ -299,7 +320,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context += ("\n" if context != "" else "") + response context += ("\n" if context != "" else "") + response
del data["tool_ids"] del data["tool_ids"]
print(context) print(f"tool_context: {context}")
# If docs field is present, generate RAG completions # If docs field is present, generate RAG completions
if "docs" in data: if "docs" in data: