From f68aba687e14cc0b539a01ee7665746def64bd01 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 00:37:02 -0700 Subject: [PATCH] feat: functions router --- backend/apps/webui/main.py | 2 +- backend/apps/webui/models/functions.py | 4 +- backend/apps/webui/routers/functions.py | 180 ++++++++++++++++++++++++ backend/apps/webui/utils.py | 24 +++- backend/config.py | 8 ++ 5 files changed, 214 insertions(+), 4 deletions(-) create mode 100644 backend/apps/webui/routers/functions.py diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index bdc6ec4f4..ee5957224 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -60,7 +60,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.MODELS = {} app.state.TOOLS = {} - +app.state.FUNCTIONS = {} app.add_middleware( CORSMiddleware, diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index cd877434d..ac12ab9e3 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -69,7 +69,7 @@ class FunctionForm(BaseModel): meta: FunctionMeta -class ToolsTable: +class FunctionsTable: def __init__(self, db): self.db = db self.db.create_tables([Function]) @@ -137,4 +137,4 @@ class ToolsTable: return False -Tools = ToolsTable(DB) +Functions = FunctionsTable(DB) diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py new file mode 100644 index 000000000..1021cc10a --- /dev/null +++ b/backend/apps/webui/routers/functions.py @@ -0,0 +1,180 @@ +from fastapi import Depends, FastAPI, HTTPException, status, Request +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + +from apps.webui.models.functions import ( + Functions, + FunctionForm, + FunctionModel, + FunctionResponse, +) +from apps.webui.utils import load_function_module_by_id +from utils.utils import get_verified_user, get_admin_user +from constants import ERROR_MESSAGES + +from importlib import util +import os +from pathlib import Path + +from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR + + +router = APIRouter() + +############################ +# GetFunctions +############################ + + +@router.get("/", response_model=List[FunctionResponse]) +async def get_functions(user=Depends(get_verified_user)): + return Functions.get_functions() + + +############################ +# ExportFunctions +############################ + + +@router.get("/export", response_model=List[FunctionModel]) +async def get_functions(user=Depends(get_admin_user)): + return Functions.get_functions() + + +############################ +# CreateNewFunction +############################ + + +@router.post("/create", response_model=Optional[FunctionResponse]) +async def create_new_function( + request: Request, form_data: FunctionForm, user=Depends(get_admin_user) +): + if not form_data.id.isidentifier(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only alphanumeric characters and underscores are allowed in the id", + ) + + form_data.id = form_data.id.lower() + + function = Functions.get_function_by_id(form_data.id) + if function == None: + function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py") + try: + with open(function_path, "w") as function_file: + function_file.write(form_data.content) + + function_module = load_function_module_by_id(form_data.id) + + FUNCTIONS = request.app.state.FUNCTIONS + FUNCTIONS[form_data.id] = function_module + + function = Functions.insert_new_function(user.id, form_data) + + function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id + function_cache_dir.mkdir(parents=True, exist_ok=True) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating function"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ID_TAKEN, + ) + + +############################ +# GetFunctionById +############################ + + +@router.get("/id/{id}", response_model=Optional[FunctionModel]) +async def get_function_by_id(id: str, user=Depends(get_admin_user)): + function = Functions.get_function_by_id(id) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateFunctionById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[FunctionModel]) +async def update_toolkit_by_id( + request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) +): + function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") + + try: + with open(function_path, "w") as function_file: + function_file.write(form_data.content) + + function_module = load_function_module_by_id(id) + + FUNCTIONS = request.app.state.FUNCTIONS + FUNCTIONS[id] = function_module + + updated = {**form_data.model_dump(exclude={"id"})} + print(updated) + + function = Functions.update_function_by_id(id, updated) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteFunctionById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_function_by_id( + request: Request, id: str, user=Depends(get_admin_user) +): + result = Functions.delete_function_by_id(id) + + if result: + FUNCTIONS = request.app.state.FUNCTIONS + if id in FUNCTIONS: + del FUNCTIONS[id] + + # delete the function file + function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") + os.remove(function_path) + + return result diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py index 19a8615bc..64d116f11 100644 --- a/backend/apps/webui/utils.py +++ b/backend/apps/webui/utils.py @@ -1,7 +1,7 @@ from importlib import util import os -from config import TOOLS_DIR +from config import TOOLS_DIR, FUNCTIONS_DIR def load_toolkit_module_by_id(toolkit_id): @@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id): # Move the file to the error folder os.rename(toolkit_path, f"{toolkit_path}.error") raise e + + +def load_function_module_by_id(function_id): + function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py") + + spec = util.spec_from_file_location(function_id, function_path) + module = util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + print(f"Loaded module: {module.__name__}") + if hasattr(module, "Pipe"): + return module.Pipe() + elif hasattr(module, "Filter"): + return module.Filter() + else: + raise Exception("No Function class found") + except Exception as e: + print(f"Error loading module: {function_id}") + # Move the file to the error folder + os.rename(function_path, f"{function_path}.error") + raise e diff --git a/backend/config.py b/backend/config.py index 01ce060a3..842cea1ba 100644 --- a/backend/config.py +++ b/backend/config.py @@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) +#################################### +# Functions DIR +#################################### + +FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions") +Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # LITELLM_CONFIG ####################################