From 3a96e1f109a1317ff19aa3e5bc98c9ca39aad07e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 20:39:55 -0700 Subject: [PATCH] feat: tools backend --- .../internal/migrations/012_add_tools.py | 61 +++++++ backend/apps/webui/main.py | 4 +- backend/apps/webui/models/tools.py | 131 ++++++++++++++ backend/apps/webui/routers/tools.py | 162 +++++++++++++++++ backend/utils/tools.py | 73 ++++++++ src/lib/components/workspace/Tools.svelte | 170 +++++++++++++++++- .../workspace/Tools/ToolkitEditor.svelte | 4 +- src/lib/stores/index.ts | 26 ++- 8 files changed, 611 insertions(+), 20 deletions(-) create mode 100644 backend/apps/webui/internal/migrations/012_add_tools.py create mode 100644 backend/apps/webui/models/tools.py create mode 100644 backend/apps/webui/routers/tools.py create mode 100644 backend/utils/tools.py diff --git a/backend/apps/webui/internal/migrations/012_add_tools.py b/backend/apps/webui/internal/migrations/012_add_tools.py new file mode 100644 index 000000000..4a68eea55 --- /dev/null +++ b/backend/apps/webui/internal/migrations/012_add_tools.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Tool(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + + name = pw.TextField() + content = pw.TextField() + specs = pw.TextField() + + meta = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "tool" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("tool") diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 6ec9bbace..339fe8a83 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -6,6 +6,7 @@ from apps.webui.routers import ( users, chats, documents, + tools, models, prompts, configs, @@ -26,8 +27,8 @@ from config import ( WEBUI_AUTH_TRUSTED_EMAIL_HEADER, JWT_EXPIRES_IN, WEBUI_BANNERS, - AppConfig, ENABLE_COMMUNITY_SHARING, + AppConfig, ) app = FastAPI() @@ -70,6 +71,7 @@ app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) +app.include_router(tools.router, prefix="/tools", tags=["tools"]) app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py new file mode 100644 index 000000000..9e5e63dc3 --- /dev/null +++ b/backend/apps/webui/models/tools.py @@ -0,0 +1,131 @@ +from pydantic import BaseModel +from peewee import * +from playhouse.shortcuts import model_to_dict +from typing import List, Union, Optional +import time +import logging +from apps.webui.internal.db import DB, JSONField + +import json + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# Tools DB Schema +#################### + + +class Tool(Model): + id = CharField(unique=True) + user_id = CharField() + name = TextField() + content = TextField() + specs = JSONField() + meta = JSONField() + updated_at = BigIntegerField() + created_at = BigIntegerField() + + class Meta: + database = DB + + +class ToolMeta(BaseModel): + description: Optional[str] = None + + +class ToolModel(BaseModel): + id: str + user_id: str + name: str + content: str + specs: dict + meta: ToolMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class ToolResponse(BaseModel): + id: str + user_id: str + name: str + meta: ToolMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +class ToolForm(BaseModel): + id: str + name: str + content: str + meta: ToolMeta + + +class ToolsTable: + def __init__(self, db): + self.db = db + self.db.create_tables([Tool]) + + def insert_new_tool( + self, user_id: str, form_data: ToolForm, specs: dict + ) -> Optional[ToolModel]: + tool = ToolModel( + **{ + **form_data.model_dump(), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + + try: + result = Tool.create(**tool.model_dump()) + if result: + return tool + else: + return None + except: + return None + + def get_tool_by_id(self, id: str) -> Optional[ToolModel]: + try: + tool = Tool.get(Tool.id == id) + return ToolModel(**model_to_dict(tool)) + except: + return None + + def get_tools(self) -> List[ToolModel]: + return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] + + def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: + try: + query = Tool.update( + **updated, + updated_at=int(time.time()), + ).where(Tool.id == id) + query.execute() + + tool = Tool.get(Tool.id == id) + return ToolModel(**model_to_dict(tool)) + except: + return None + + def delete_tool_by_id(self, id: str) -> bool: + try: + query = Tool.delete().where((Tool.id == id)) + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + + +Tools = ToolsTable(DB) diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py new file mode 100644 index 000000000..b4964345e --- /dev/null +++ b/backend/apps/webui/routers/tools.py @@ -0,0 +1,162 @@ +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + +from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse + +from utils.utils import get_current_user, get_admin_user +from utils.tools import get_tools_specs +from constants import ERROR_MESSAGES + +from importlib import util +import os + +from config import DATA_DIR + +TOOLS_DIR = f"{DATA_DIR}/tools" +os.makedirs(TOOLS_DIR, exist_ok=True) + +TOOLS = {} + + +router = APIRouter() + + +def load_toolkit_module_from_path(tools_id, tools_path): + spec = util.spec_from_file_location(tools_id, tools_path) + module = util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + print(f"Loaded module: {module.__name__}") + if hasattr(module, "Tools"): + return module.Tools() + else: + raise Exception("No Tools class found") + except Exception as e: + print(f"Error loading module: {tools_id}") + + # Move the file to the error folder + os.rename(tools_path, f"{tools_path}.error") + raise e + + +############################ +# GetToolkits +############################ + + +@router.get("/", response_model=List[ToolResponse]) +async def get_toolkits(user=Depends(get_current_user)): + toolkits = [ToolResponse(**toolkit) for toolkit in Tools.get_tools()] + return toolkits + + +############################ +# CreateNewToolKit +############################ + + +@router.post("/create", response_model=Optional[ToolResponse]) +async def create_new_toolkit(form_data: ToolForm, user=Depends(get_admin_user)): + toolkit = Tools.get_tool_by_id(form_data.id) + if toolkit == None: + toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py") + try: + with open(toolkit_path, "w") as tool_file: + tool_file.write(form_data.content) + + toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path) + TOOLS[form_data.id] = toolkit_module + + specs = get_tools_specs(TOOLS[form_data.id]) + toolkit = Tools.insert_new_tool(user.id, form_data, specs) + + if toolkit: + return ToolResponse(**toolkit) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.FILE_EXISTS, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NAME_TAG_TAKEN, + ) + + +############################ +# GetToolkitById +############################ + + +@router.get("/id/{id}", response_model=Optional[ToolResponse]) +async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): + toolkit = Tools.get_tool_by_id(id) + + if toolkit: + return ToolResponse(**toolkit) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateToolkitById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[ToolResponse]) +async def update_toolkit_by_id( + id: str, form_data: ToolForm, user=Depends(get_admin_user) +): + toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") + + try: + with open(toolkit_path, "w") as tool_file: + tool_file.write(form_data.content) + + toolkit_module = load_toolkit_module_from_path(id, toolkit_path) + TOOLS[id] = toolkit_module + + specs = get_tools_specs(TOOLS[id]) + toolkit = Tools.update_tool_by_id( + id, {**form_data.model_dump(), "specs": specs} + ) + + if toolkit: + return ToolResponse(**toolkit) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"), + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteToolkitById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_toolkit_by_id(id: str, user=Depends(get_admin_user)): + result = Tools.delete_tool_by_id(id) + return result diff --git a/backend/utils/tools.py b/backend/utils/tools.py new file mode 100644 index 000000000..91359bad9 --- /dev/null +++ b/backend/utils/tools.py @@ -0,0 +1,73 @@ +import inspect +from typing import get_type_hints, List, Dict, Any + + +def doc_to_dict(docstring): + lines = docstring.split("\n") + description = lines[1].strip() + param_dict = {} + + for line in lines: + if ":param" in line: + line = line.replace(":param", "").strip() + param, desc = line.split(":", 1) + param_dict[param.strip()] = desc.strip() + ret_dict = {"description": description, "params": param_dict} + return ret_dict + + +def get_tools_specs(tools) -> List[dict]: + function_list = [ + {"name": func, "function": getattr(tools, func)} + for func in dir(tools) + if callable(getattr(tools, func)) and not func.startswith("__") + ] + + specs = [] + for function_item in function_list: + function_name = function_item["name"] + function = function_item["function"] + + function_doc = doc_to_dict(function.__doc__ or function_name) + specs.append( + { + "name": function_name, + # TODO: multi-line desc? + "description": function_doc.get("description", function_name), + "parameters": { + "type": "object", + "properties": { + param_name: { + "type": param_annotation.__name__.lower(), + **( + { + "enum": ( + param_annotation.__args__ + if hasattr(param_annotation, "__args__") + else None + ) + } + if hasattr(param_annotation, "__args__") + else {} + ), + "description": function_doc.get("params", {}).get( + param_name, param_name + ), + } + for param_name, param_annotation in get_type_hints( + function + ).items() + if param_name != "return" + }, + "required": [ + name + for name, param in inspect.signature( + function + ).parameters.items() + if param.default is param.empty + ], + }, + } + ) + + return specs diff --git a/src/lib/components/workspace/Tools.svelte b/src/lib/components/workspace/Tools.svelte index b5db60a08..ec76f1933 100644 --- a/src/lib/components/workspace/Tools.svelte +++ b/src/lib/components/workspace/Tools.svelte @@ -4,12 +4,16 @@ const { saveAs } = fileSaver; import { onMount, getContext } from 'svelte'; - import { WEBUI_NAME, prompts } from '$lib/stores'; + import { WEBUI_NAME, prompts, tools } from '$lib/stores'; import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { goto } from '$app/navigation'; const i18n = getContext('i18n'); + + let toolsImportInputElement: HTMLInputElement; + let importFiles; + let query = ''; @@ -65,3 +69,167 @@
+ +
+ {#each $tools.filter((t) => query === '' || t.name.includes(query)) as tool} +
+ +
+ + + + + + + + + +
+
+ {/each} +
+ +
+
+ { + console.log(importFiles); + + const reader = new FileReader(); + reader.onload = async (event) => { + const tools = JSON.parse(event.target.result); + console.log(tools); + }; + + reader.readAsText(importFiles[0]); + }} + /> + + + + +
+
diff --git a/src/lib/components/workspace/Tools/ToolkitEditor.svelte b/src/lib/components/workspace/Tools/ToolkitEditor.svelte index bee31da7e..da59bbeed 100644 --- a/src/lib/components/workspace/Tools/ToolkitEditor.svelte +++ b/src/lib/components/workspace/Tools/ToolkitEditor.svelte @@ -14,7 +14,7 @@ description: '' }; - let code = ''; + let content = ''; $: if (name) { id = name.replace(/\s+/g, '_').toLowerCase(); @@ -97,7 +97,7 @@
diff --git a/src/lib/stores/index.ts b/src/lib/stores/index.ts index 781529f02..08e351238 100644 --- a/src/lib/stores/index.ts +++ b/src/lib/stores/index.ts @@ -23,24 +23,11 @@ export const chatId = writable(''); export const chats = writable([]); export const tags = writable([]); -export const models: Writable = writable([]); -export const modelfiles = writable([]); +export const models: Writable = writable([]); export const prompts: Writable = writable([]); -export const documents = writable([ - { - collection_name: 'collection_name', - filename: 'filename', - name: 'name', - title: 'title' - }, - { - collection_name: 'collection_name1', - filename: 'filename1', - name: 'name1', - title: 'title1' - } -]); +export const documents: Writable = writable([]); +export const tools = writable([]); export const banners: Writable = writable([]); @@ -135,6 +122,13 @@ type Prompt = { timestamp: number; }; +type Document = { + collection_name: string; + filename: string; + name: string; + title: string; +}; + type Config = { status: boolean; name: string;