This commit is contained in:
Timothy J. Baek 2024-06-23 19:18:13 -07:00
parent 26e735618e
commit 0cf936f9e8
5 changed files with 36 additions and 8 deletions

View File

@ -143,10 +143,10 @@ class FunctionsTable:
for function in Function.select().where(Function.type == type) for function in Function.select().where(Function.type == type)
] ]
def get_function_valves_by_id(self, id: str) -> Optional[FunctionValves]: def get_function_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
function = Function.get(Function.id == id) function = Function.get(Function.id == id)
return FunctionValves(**model_to_dict(function)) return function.valves if "valves" in function and function.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None

View File

@ -114,10 +114,10 @@ class ToolsTable:
def get_tools(self) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def get_tool_valves_by_id(self, id: str) -> Optional[ToolValves]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
tool = Tool.get(Tool.id == id) tool = Tool.get(Tool.id == id)
return ToolValves(**model_to_dict(tool)) return tool.valves if "valves" in tool and tool.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None

View File

@ -127,8 +127,8 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id)
if function: if function:
try: try:
function_valves = Functions.get_function_valves_by_id(id) valves = Functions.get_function_valves_by_id(id)
return function_valves.valves return valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -133,8 +133,8 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id) toolkit = Tools.get_tool_by_id(id)
if toolkit: if toolkit:
try: try:
tool_valves = Tools.get_tool_valves_by_id(id) valves = Tools.get_tool_valves_by_id(id)
return tool_valves.valves return valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -262,6 +262,13 @@ async def get_function_call_response(
file_handler = True file_handler = True
print("file_handler: ", file_handler) print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(
toolkit_module, "Valves"
):
toolkit_module.valves = toolkit_module.Valves(
**Tools.get_tool_valves_by_id(tool_id)
)
function = getattr(toolkit_module, result["name"]) function = getattr(toolkit_module, result["name"])
function_result = None function_result = None
try: try:
@ -402,6 +409,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if hasattr(function_module, "file_handler"): if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
function_module.valves = function_module.Valves(
**Functions.get_function_valves_by_id(filter_id)
)
try: try:
if hasattr(function_module, "inlet"): if hasattr(function_module, "inlet"):
inlet = function_module.inlet inlet = function_module.inlet
@ -884,6 +898,13 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
else: else:
function_module = webui_app.state.FUNCTIONS[pipe_id] function_module = webui_app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
function_module.valves = function_module.Valves(
**Functions.get_function_valves_by_id(pipe_id)
)
pipe = function_module.pipe pipe = function_module.pipe
# Get the signature of the function # Get the signature of the function
@ -1105,6 +1126,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
) )
webui_app.state.FUNCTIONS[filter_id] = function_module webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
function_module.valves = function_module.Valves(
**Functions.get_function_valves_by_id(filter_id)
)
try: try:
if hasattr(function_module, "outlet"): if hasattr(function_module, "outlet"):
outlet = function_module.outlet outlet = function_module.outlet