From 3a96e1f109a1317ff19aa3e5bc98c9ca39aad07e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 20:39:55 -0700 Subject: [PATCH 01/21] feat: tools backend --- .../internal/migrations/012_add_tools.py | 61 +++++++ backend/apps/webui/main.py | 4 +- backend/apps/webui/models/tools.py | 131 ++++++++++++++ backend/apps/webui/routers/tools.py | 162 +++++++++++++++++ backend/utils/tools.py | 73 ++++++++ src/lib/components/workspace/Tools.svelte | 170 +++++++++++++++++- .../workspace/Tools/ToolkitEditor.svelte | 4 +- src/lib/stores/index.ts | 26 ++- 8 files changed, 611 insertions(+), 20 deletions(-) create mode 100644 backend/apps/webui/internal/migrations/012_add_tools.py create mode 100644 backend/apps/webui/models/tools.py create mode 100644 backend/apps/webui/routers/tools.py create mode 100644 backend/utils/tools.py diff --git a/backend/apps/webui/internal/migrations/012_add_tools.py b/backend/apps/webui/internal/migrations/012_add_tools.py new file mode 100644 index 000000000..4a68eea55 --- /dev/null +++ b/backend/apps/webui/internal/migrations/012_add_tools.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Tool(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + + name = pw.TextField() + content = pw.TextField() + specs = pw.TextField() + + meta = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "tool" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("tool") diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 6ec9bbace..339fe8a83 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -6,6 +6,7 @@ from apps.webui.routers import ( users, chats, documents, + tools, models, prompts, configs, @@ -26,8 +27,8 @@ from config import ( WEBUI_AUTH_TRUSTED_EMAIL_HEADER, JWT_EXPIRES_IN, WEBUI_BANNERS, - AppConfig, ENABLE_COMMUNITY_SHARING, + AppConfig, ) app = FastAPI() @@ -70,6 +71,7 @@ app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) +app.include_router(tools.router, prefix="/tools", tags=["tools"]) app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py new file mode 100644 index 000000000..9e5e63dc3 --- /dev/null +++ b/backend/apps/webui/models/tools.py @@ -0,0 +1,131 @@ +from pydantic import BaseModel +from peewee import * +from playhouse.shortcuts import model_to_dict +from typing import List, Union, Optional +import time +import logging +from apps.webui.internal.db import DB, JSONField + +import json + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# Tools DB Schema +#################### + + +class Tool(Model): + id = CharField(unique=True) + user_id = CharField() + name = TextField() + content = TextField() + specs = JSONField() + meta = JSONField() + updated_at = BigIntegerField() + created_at = BigIntegerField() + + class Meta: + database = DB + + +class ToolMeta(BaseModel): + description: Optional[str] = None + + +class ToolModel(BaseModel): + id: str + user_id: str + name: str + content: str + specs: dict + meta: ToolMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class ToolResponse(BaseModel): + id: str + user_id: str + name: str + meta: ToolMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +class ToolForm(BaseModel): + id: str + name: str + content: str + meta: ToolMeta + + +class ToolsTable: + def __init__(self, db): + self.db = db + self.db.create_tables([Tool]) + + def insert_new_tool( + self, user_id: str, form_data: ToolForm, specs: dict + ) -> Optional[ToolModel]: + tool = ToolModel( + **{ + **form_data.model_dump(), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + + try: + result = Tool.create(**tool.model_dump()) + if result: + return tool + else: + return None + except: + return None + + def get_tool_by_id(self, id: str) -> Optional[ToolModel]: + try: + tool = Tool.get(Tool.id == id) + return ToolModel(**model_to_dict(tool)) + except: + return None + + def get_tools(self) -> List[ToolModel]: + return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] + + def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: + try: + query = Tool.update( + **updated, + updated_at=int(time.time()), + ).where(Tool.id == id) + query.execute() + + tool = Tool.get(Tool.id == id) + return ToolModel(**model_to_dict(tool)) + except: + return None + + def delete_tool_by_id(self, id: str) -> bool: + try: + query = Tool.delete().where((Tool.id == id)) + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + + +Tools = ToolsTable(DB) diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py new file mode 100644 index 000000000..b4964345e --- /dev/null +++ b/backend/apps/webui/routers/tools.py @@ -0,0 +1,162 @@ +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + +from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse + +from utils.utils import get_current_user, get_admin_user +from utils.tools import get_tools_specs +from constants import ERROR_MESSAGES + +from importlib import util +import os + +from config import DATA_DIR + +TOOLS_DIR = f"{DATA_DIR}/tools" +os.makedirs(TOOLS_DIR, exist_ok=True) + +TOOLS = {} + + +router = APIRouter() + + +def load_toolkit_module_from_path(tools_id, tools_path): + spec = util.spec_from_file_location(tools_id, tools_path) + module = util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + print(f"Loaded module: {module.__name__}") + if hasattr(module, "Tools"): + return module.Tools() + else: + raise Exception("No Tools class found") + except Exception as e: + print(f"Error loading module: {tools_id}") + + # Move the file to the error folder + os.rename(tools_path, f"{tools_path}.error") + raise e + + +############################ +# GetToolkits +############################ + + +@router.get("/", response_model=List[ToolResponse]) +async def get_toolkits(user=Depends(get_current_user)): + toolkits = [ToolResponse(**toolkit) for toolkit in Tools.get_tools()] + return toolkits + + +############################ +# CreateNewToolKit +############################ + + +@router.post("/create", response_model=Optional[ToolResponse]) +async def create_new_toolkit(form_data: ToolForm, user=Depends(get_admin_user)): + toolkit = Tools.get_tool_by_id(form_data.id) + if toolkit == None: + toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py") + try: + with open(toolkit_path, "w") as tool_file: + tool_file.write(form_data.content) + + toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path) + TOOLS[form_data.id] = toolkit_module + + specs = get_tools_specs(TOOLS[form_data.id]) + toolkit = Tools.insert_new_tool(user.id, form_data, specs) + + if toolkit: + return ToolResponse(**toolkit) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.FILE_EXISTS, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NAME_TAG_TAKEN, + ) + + +############################ +# GetToolkitById +############################ + + +@router.get("/id/{id}", response_model=Optional[ToolResponse]) +async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): + toolkit = Tools.get_tool_by_id(id) + + if toolkit: + return ToolResponse(**toolkit) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateToolkitById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[ToolResponse]) +async def update_toolkit_by_id( + id: str, form_data: ToolForm, user=Depends(get_admin_user) +): + toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") + + try: + with open(toolkit_path, "w") as tool_file: + tool_file.write(form_data.content) + + toolkit_module = load_toolkit_module_from_path(id, toolkit_path) + TOOLS[id] = toolkit_module + + specs = get_tools_specs(TOOLS[id]) + toolkit = Tools.update_tool_by_id( + id, {**form_data.model_dump(), "specs": specs} + ) + + if toolkit: + return ToolResponse(**toolkit) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"), + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteToolkitById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_toolkit_by_id(id: str, user=Depends(get_admin_user)): + result = Tools.delete_tool_by_id(id) + return result diff --git a/backend/utils/tools.py b/backend/utils/tools.py new file mode 100644 index 000000000..91359bad9 --- /dev/null +++ b/backend/utils/tools.py @@ -0,0 +1,73 @@ +import inspect +from typing import get_type_hints, List, Dict, Any + + +def doc_to_dict(docstring): + lines = docstring.split("\n") + description = lines[1].strip() + param_dict = {} + + for line in lines: + if ":param" in line: + line = line.replace(":param", "").strip() + param, desc = line.split(":", 1) + param_dict[param.strip()] = desc.strip() + ret_dict = {"description": description, "params": param_dict} + return ret_dict + + +def get_tools_specs(tools) -> List[dict]: + function_list = [ + {"name": func, "function": getattr(tools, func)} + for func in dir(tools) + if callable(getattr(tools, func)) and not func.startswith("__") + ] + + specs = [] + for function_item in function_list: + function_name = function_item["name"] + function = function_item["function"] + + function_doc = doc_to_dict(function.__doc__ or function_name) + specs.append( + { + "name": function_name, + # TODO: multi-line desc? + "description": function_doc.get("description", function_name), + "parameters": { + "type": "object", + "properties": { + param_name: { + "type": param_annotation.__name__.lower(), + **( + { + "enum": ( + param_annotation.__args__ + if hasattr(param_annotation, "__args__") + else None + ) + } + if hasattr(param_annotation, "__args__") + else {} + ), + "description": function_doc.get("params", {}).get( + param_name, param_name + ), + } + for param_name, param_annotation in get_type_hints( + function + ).items() + if param_name != "return" + }, + "required": [ + name + for name, param in inspect.signature( + function + ).parameters.items() + if param.default is param.empty + ], + }, + } + ) + + return specs diff --git a/src/lib/components/workspace/Tools.svelte b/src/lib/components/workspace/Tools.svelte index b5db60a08..ec76f1933 100644 --- a/src/lib/components/workspace/Tools.svelte +++ b/src/lib/components/workspace/Tools.svelte @@ -4,12 +4,16 @@ const { saveAs } = fileSaver; import { onMount, getContext } from 'svelte'; - import { WEBUI_NAME, prompts } from '$lib/stores'; + import { WEBUI_NAME, prompts, tools } from '$lib/stores'; import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { goto } from '$app/navigation'; const i18n = getContext('i18n'); + + let toolsImportInputElement: HTMLInputElement; + let importFiles; + let query = ''; @@ -65,3 +69,167 @@
+ +
+ {#each $tools.filter((t) => query === '' || t.name.includes(query)) as tool} +
+ +
+ + + + + + + + + +
+
+ {/each} +
+ +
+
+ { + console.log(importFiles); + + const reader = new FileReader(); + reader.onload = async (event) => { + const tools = JSON.parse(event.target.result); + console.log(tools); + }; + + reader.readAsText(importFiles[0]); + }} + /> + + + + +
+
diff --git a/src/lib/components/workspace/Tools/ToolkitEditor.svelte b/src/lib/components/workspace/Tools/ToolkitEditor.svelte index bee31da7e..da59bbeed 100644 --- a/src/lib/components/workspace/Tools/ToolkitEditor.svelte +++ b/src/lib/components/workspace/Tools/ToolkitEditor.svelte @@ -14,7 +14,7 @@ description: '' }; - let code = ''; + let content = ''; $: if (name) { id = name.replace(/\s+/g, '_').toLowerCase(); @@ -97,7 +97,7 @@
diff --git a/src/lib/stores/index.ts b/src/lib/stores/index.ts index 781529f02..08e351238 100644 --- a/src/lib/stores/index.ts +++ b/src/lib/stores/index.ts @@ -23,24 +23,11 @@ export const chatId = writable(''); export const chats = writable([]); export const tags = writable([]); -export const models: Writable = writable([]); -export const modelfiles = writable([]); +export const models: Writable = writable([]); export const prompts: Writable = writable([]); -export const documents = writable([ - { - collection_name: 'collection_name', - filename: 'filename', - name: 'name', - title: 'title' - }, - { - collection_name: 'collection_name1', - filename: 'filename1', - name: 'name1', - title: 'title1' - } -]); +export const documents: Writable = writable([]); +export const tools = writable([]); export const banners: Writable = writable([]); @@ -135,6 +122,13 @@ type Prompt = { timestamp: number; }; +type Document = { + collection_name: string; + filename: string; + name: string; + title: string; +}; + type Config = { status: boolean; name: string; From e27c264081791d6d2c04fee123e8f4e4caa35cec Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 20:43:11 -0700 Subject: [PATCH 02/21] feat: tools apis --- src/lib/apis/tools/index.ts | 162 ++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 src/lib/apis/tools/index.ts diff --git a/src/lib/apis/tools/index.ts b/src/lib/apis/tools/index.ts new file mode 100644 index 000000000..7fc0ed03d --- /dev/null +++ b/src/lib/apis/tools/index.ts @@ -0,0 +1,162 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewTool = async (token: string, tool: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...tool + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getTools = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getToolById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateToolById = async (token: string, id: string, tool: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...tool + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteToolById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; From 6589464ddf02bf3f8c8f4756227603200426970f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 20:58:47 -0700 Subject: [PATCH 03/21] refac --- src/lib/components/common/CodeEditor.svelte | 8 +---- .../workspace/Tools/CodeEditor.svelte | 25 ++++++++++---- .../workspace/Tools/ToolkitEditor.svelte | 33 ++++++++++++++----- src/routes/(app)/+layout.svelte | 30 +++++++---------- .../(app)/workspace/tools/create/+page.svelte | 10 +++++- 5 files changed, 64 insertions(+), 42 deletions(-) diff --git a/src/lib/components/common/CodeEditor.svelte b/src/lib/components/common/CodeEditor.svelte index 04a7f35ab..4f3e2baef 100644 --- a/src/lib/components/common/CodeEditor.svelte +++ b/src/lib/components/common/CodeEditor.svelte @@ -113,13 +113,7 @@ const handleSave = async (e) => { if ((e.ctrlKey || e.metaKey) && e.key === 's') { e.preventDefault(); - const res = await formatPythonCodeHandler().catch((error) => { - return null; - }); - - if (res) { - dispatch('save'); - } + dispatch('save'); } }; diff --git a/src/lib/components/workspace/Tools/CodeEditor.svelte b/src/lib/components/workspace/Tools/CodeEditor.svelte index 11f074092..f02230822 100644 --- a/src/lib/components/workspace/Tools/CodeEditor.svelte +++ b/src/lib/components/workspace/Tools/CodeEditor.svelte @@ -1,16 +1,15 @@ - + { + saveHandler(e.detail); + }} +/> From c5683dd24cfded161558bf68e41690d6d915761f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 21:05:06 -0700 Subject: [PATCH 04/21] refac --- .../workspace/Tools/CodeEditor.svelte | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/lib/components/workspace/Tools/CodeEditor.svelte b/src/lib/components/workspace/Tools/CodeEditor.svelte index f02230822..a89293dcf 100644 --- a/src/lib/components/workspace/Tools/CodeEditor.svelte +++ b/src/lib/components/workspace/Tools/CodeEditor.svelte @@ -56,6 +56,44 @@ class Tools: print(e) return "Invalid equation" + def get_current_weather(self, city: str) -> str: + """ + Get the current weather for a given city. + :param city: The name of the city to get the weather for. + :return: The current weather information or an error message. + """ + api_key = os.getenv('OPENWEATHER_API_KEY') + if not api_key: + return "API key is not set in the environment variable 'OPENWEATHER_API_KEY'." + + base_url = "http://api.openweathermap.org/data/2.5/weather" + params = { + 'q': city, + 'appid': api_key, + 'units': 'metric' # Optional: Use 'imperial' for Fahrenheit + } + + try: + response = requests.get(base_url, params=params) + response.raise_for_status() # Raise HTTPError for bad responses (4xx and 5xx) + data = response.json() + + if data.get('cod') != 200: + return f"Error fetching weather data: {data.get('message')}" + + weather_description = data['weather'][0]['description'] + temperature = data['main']['temp'] + humidity = data['main']['humidity'] + wind_speed = data['wind']['speed'] + + return (f"Weather in {city}:\n" + f"Description: {weather_description}\n" + f"Temperature: {temperature}°C\n" + f"Humidity: {humidity}%\n" + f"Wind Speed: {wind_speed} m/s") + except requests.RequestException as e: + return f"Error fetching weather data: {str(e)}" + `; export const formatHandler = async () => { From b434ebf3ad99c45122b8ce24ce4fd3a810357ae4 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 21:33:46 -0700 Subject: [PATCH 05/21] feat: tools integration --- backend/apps/webui/models/tools.py | 4 +- backend/apps/webui/routers/tools.py | 21 ++++++-- backend/constants.py | 1 + src/lib/components/workspace/Tools.svelte | 23 ++++++-- .../workspace/Tools/CodeEditor.svelte | 44 +++++++-------- .../workspace/Tools/ToolkitEditor.svelte | 21 +++++--- .../(app)/workspace/tools/create/+page.svelte | 20 +++++++ .../(app)/workspace/tools/edit/+page.svelte | 54 ++++++++++++++++++- 8 files changed, 146 insertions(+), 42 deletions(-) diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 9e5e63dc3..99463878b 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -41,7 +41,7 @@ class ToolModel(BaseModel): user_id: str name: str content: str - specs: dict + specs: List[dict] meta: ToolMeta updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -74,7 +74,7 @@ class ToolsTable: self.db.create_tables([Tool]) def insert_new_tool( - self, user_id: str, form_data: ToolForm, specs: dict + self, user_id: str, form_data: ToolForm, specs: List[dict] ) -> Optional[ToolModel]: tool = ToolModel( **{ diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index b4964345e..dc76bb312 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -52,7 +52,18 @@ def load_toolkit_module_from_path(tools_id, tools_path): @router.get("/", response_model=List[ToolResponse]) async def get_toolkits(user=Depends(get_current_user)): - toolkits = [ToolResponse(**toolkit) for toolkit in Tools.get_tools()] + toolkits = [toolkit for toolkit in Tools.get_tools()] + return toolkits + + +############################ +# ExportToolKits +############################ + + +@router.get("/export", response_model=List[ToolModel]) +async def get_toolkits(user=Depends(get_current_user)): + toolkits = [toolkit for toolkit in Tools.get_tools()] return toolkits @@ -77,7 +88,7 @@ async def create_new_toolkit(form_data: ToolForm, user=Depends(get_admin_user)): toolkit = Tools.insert_new_tool(user.id, form_data, specs) if toolkit: - return ToolResponse(**toolkit) + return toolkit else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -91,7 +102,7 @@ async def create_new_toolkit(form_data: ToolForm, user=Depends(get_admin_user)): else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.NAME_TAG_TAKEN, + detail=ERROR_MESSAGES.ID_TAKEN, ) @@ -105,7 +116,7 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): toolkit = Tools.get_tool_by_id(id) if toolkit: - return ToolResponse(**toolkit) + return toolkit else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -137,7 +148,7 @@ async def update_toolkit_by_id( ) if toolkit: - return ToolResponse(**toolkit) + return toolkit else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/backend/constants.py b/backend/constants.py index 0740fa49d..f1eed43d3 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum): COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." + ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string." MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." diff --git a/src/lib/components/workspace/Tools.svelte b/src/lib/components/workspace/Tools.svelte index ec76f1933..b7a5461e6 100644 --- a/src/lib/components/workspace/Tools.svelte +++ b/src/lib/components/workspace/Tools.svelte @@ -8,6 +8,7 @@ import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { goto } from '$app/navigation'; + import { deleteToolById, getTools } from '$lib/apis/tools'; const i18n = getContext('i18n'); @@ -78,7 +79,12 @@
-
{tool.name}
+
+
+ {tool.name} +
+
{tool.id}
+
{tool.meta.description}
@@ -89,7 +95,7 @@
{ - // deletePrompt(prompt.command); - // deleteTool + on:click={async () => { + const res = await deleteToolById(localStorage.token, tool.id).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + toast.success('Tool deleted successfully'); + tools.set(await getTools(localStorage.token)); + } }} > { diff --git a/src/lib/components/workspace/Tools/ToolkitEditor.svelte b/src/lib/components/workspace/Tools/ToolkitEditor.svelte index f046c6e85..ef181d776 100644 --- a/src/lib/components/workspace/Tools/ToolkitEditor.svelte +++ b/src/lib/components/workspace/Tools/ToolkitEditor.svelte @@ -8,15 +8,18 @@ const dispatch = createEventDispatcher(); + let formElement = null; + let loading = false; - let id = ''; - let name = ''; - let meta = { + export let edit = false; + + export let id = ''; + export let name = ''; + export let meta = { description: '' }; - - let content = ''; + export let content = ''; $: if (name) { id = name.replace(/\s+/g, '_').toLowerCase(); @@ -49,6 +52,7 @@
{ submitHandler(); @@ -60,6 +64,7 @@ on:click={() => { goto('/workspace/tools'); }} + type="button" >
{ - // submit form - submitHandler(); + if (formElement) { + formElement.requestSubmit(); + } }} />
diff --git a/src/routes/(app)/workspace/tools/create/+page.svelte b/src/routes/(app)/workspace/tools/create/+page.svelte index 7d201c774..1a9d88021 100644 --- a/src/routes/(app)/workspace/tools/create/+page.svelte +++ b/src/routes/(app)/workspace/tools/create/+page.svelte @@ -1,8 +1,28 @@ diff --git a/src/routes/(app)/workspace/tools/edit/+page.svelte b/src/routes/(app)/workspace/tools/edit/+page.svelte index 91a461df6..2db915719 100644 --- a/src/routes/(app)/workspace/tools/edit/+page.svelte +++ b/src/routes/(app)/workspace/tools/edit/+page.svelte @@ -1,5 +1,57 @@ - +{#if tool} + { + saveHandler(e.detail); + }} + /> +{/if} From 1611a3aa70d6d9827ec5fb3820de68b4ddd039a0 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 21:36:13 -0700 Subject: [PATCH 06/21] feat: export tools --- backend/apps/webui/routers/tools.py | 2 +- src/lib/apis/tools/index.ts | 31 +++++++++++++++++++++++ src/lib/components/workspace/Tools.svelte | 15 ++++++++--- 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index dc76bb312..048813c72 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -62,7 +62,7 @@ async def get_toolkits(user=Depends(get_current_user)): @router.get("/export", response_model=List[ToolModel]) -async def get_toolkits(user=Depends(get_current_user)): +async def get_toolkits(user=Depends(get_admin_user)): toolkits = [toolkit for toolkit in Tools.get_tools()] return toolkits diff --git a/src/lib/apis/tools/index.ts b/src/lib/apis/tools/index.ts index 7fc0ed03d..47a535cdf 100644 --- a/src/lib/apis/tools/index.ts +++ b/src/lib/apis/tools/index.ts @@ -62,6 +62,37 @@ export const getTools = async (token: string = '') => { return res; }; +export const exportTools = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/export`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getToolById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/components/workspace/Tools.svelte b/src/lib/components/workspace/Tools.svelte index b7a5461e6..837ba62dc 100644 --- a/src/lib/components/workspace/Tools.svelte +++ b/src/lib/components/workspace/Tools.svelte @@ -8,7 +8,7 @@ import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { goto } from '$app/navigation'; - import { deleteToolById, getTools } from '$lib/apis/tools'; + import { deleteToolById, exportTools, getTools } from '$lib/apis/tools'; const i18n = getContext('i18n'); @@ -221,10 +221,17 @@
-
+ {/each}
From aa7d25600f2b31d94c8d44b8dcb5fbf3bf2de075 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 22:33:25 -0700 Subject: [PATCH 15/21] refac --- src/lib/components/common/CodeEditor.svelte | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/lib/components/common/CodeEditor.svelte b/src/lib/components/common/CodeEditor.svelte index ef923b09f..0822ac019 100644 --- a/src/lib/components/common/CodeEditor.svelte +++ b/src/lib/components/common/CodeEditor.svelte @@ -110,21 +110,24 @@ attributeFilter: ['class'] }); - // Add a keyboard shortcut to format the code when Ctrl/Cmd + S is pressed - // Override the default browser save functionality - - const handleSave = async (e) => { + const keydownHandler = async (e) => { if ((e.ctrlKey || e.metaKey) && e.key === 's') { e.preventDefault(); dispatch('save'); } + + // Format code when Ctrl + Shift + F is pressed + if ((e.ctrlKey || e.metaKey) && e.shiftKey && e.key === 'f') { + e.preventDefault(); + await formatPythonCodeHandler(); + } }; - document.addEventListener('keydown', handleSave); + document.addEventListener('keydown', keydownHandler); return () => { observer.disconnect(); - document.removeEventListener('keydown', handleSave); + document.removeEventListener('keydown', keydownHandler); }; }); From ff1cd306d8898e8a50f29adff3154df2bcd5ede6 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 22:38:48 -0700 Subject: [PATCH 16/21] refac --- backend/apps/webui/main.py | 3 ++- backend/apps/webui/routers/tools.py | 21 +++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 339fe8a83..62a0a7a7b 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -39,6 +39,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN +app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS @@ -55,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.MODELS = {} -app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER +app.state.TOOLS = {} app.add_middleware( diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 67e391ddd..60e8319e6 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -1,4 +1,4 @@ -from fastapi import Depends, FastAPI, HTTPException, status +from fastapi import Depends, FastAPI, HTTPException, status, Request from datetime import datetime, timedelta from typing import List, Union, Optional @@ -20,8 +20,6 @@ from config import DATA_DIR TOOLS_DIR = f"{DATA_DIR}/tools" os.makedirs(TOOLS_DIR, exist_ok=True) -TOOLS = {} - router = APIRouter() @@ -73,7 +71,9 @@ async def get_toolkits(user=Depends(get_admin_user)): @router.post("/create", response_model=Optional[ToolResponse]) -async def create_new_toolkit(form_data: ToolForm, user=Depends(get_admin_user)): +async def create_new_toolkit( + request: Request, form_data: ToolForm, user=Depends(get_admin_user) +): if not form_data.id.isidentifier(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -90,6 +90,8 @@ async def create_new_toolkit(form_data: ToolForm, user=Depends(get_admin_user)): tool_file.write(form_data.content) toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path) + + TOOLS = request.app.state.TOOLS TOOLS[form_data.id] = toolkit_module specs = get_tools_specs(TOOLS[form_data.id]) @@ -139,7 +141,7 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/update", response_model=Optional[ToolModel]) async def update_toolkit_by_id( - id: str, form_data: ToolForm, user=Depends(get_admin_user) + request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user) ): toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") @@ -148,6 +150,8 @@ async def update_toolkit_by_id( tool_file.write(form_data.content) toolkit_module = load_toolkit_module_from_path(id, toolkit_path) + + TOOLS = request.app.state.TOOLS TOOLS[id] = toolkit_module specs = get_tools_specs(TOOLS[id]) @@ -181,6 +185,11 @@ async def update_toolkit_by_id( @router.delete("/id/{id}/delete", response_model=bool) -async def delete_toolkit_by_id(id: str, user=Depends(get_admin_user)): +async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): result = Tools.delete_tool_by_id(id) + + if result: + TOOLS = request.app.state.TOOLS + del TOOLS[id] + return result From a27175d67252aa01c7be503525c124e8de12abf6 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 23:40:27 -0700 Subject: [PATCH 17/21] feat: fc integration --- backend/apps/webui/routers/tools.py | 26 +--- backend/apps/webui/utils.py | 23 ++++ backend/config.py | 20 ++- backend/main.py | 181 +++++++++++++++++++++++++--- backend/utils/task.py | 5 + 5 files changed, 215 insertions(+), 40 deletions(-) create mode 100644 backend/apps/webui/utils.py diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 60e8319e6..c2a815ee7 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -7,6 +7,7 @@ from pydantic import BaseModel import json from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse +from apps.webui.utils import load_toolkit_module_by_id from utils.utils import get_current_user, get_admin_user from utils.tools import get_tools_specs @@ -17,32 +18,13 @@ import os from config import DATA_DIR + TOOLS_DIR = f"{DATA_DIR}/tools" os.makedirs(TOOLS_DIR, exist_ok=True) router = APIRouter() - -def load_toolkit_module_from_path(tools_id, tools_path): - spec = util.spec_from_file_location(tools_id, tools_path) - module = util.module_from_spec(spec) - - try: - spec.loader.exec_module(module) - print(f"Loaded module: {module.__name__}") - if hasattr(module, "Tools"): - return module.Tools() - else: - raise Exception("No Tools class found") - except Exception as e: - print(f"Error loading module: {tools_id}") - - # Move the file to the error folder - os.rename(tools_path, f"{tools_path}.error") - raise e - - ############################ # GetToolkits ############################ @@ -89,7 +71,7 @@ async def create_new_toolkit( with open(toolkit_path, "w") as tool_file: tool_file.write(form_data.content) - toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path) + toolkit_module = load_toolkit_module_by_id(form_data.id) TOOLS = request.app.state.TOOLS TOOLS[form_data.id] = toolkit_module @@ -149,7 +131,7 @@ async def update_toolkit_by_id( with open(toolkit_path, "w") as tool_file: tool_file.write(form_data.content) - toolkit_module = load_toolkit_module_from_path(id, toolkit_path) + toolkit_module = load_toolkit_module_by_id(id) TOOLS = request.app.state.TOOLS TOOLS[id] = toolkit_module diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py new file mode 100644 index 000000000..19a8615bc --- /dev/null +++ b/backend/apps/webui/utils.py @@ -0,0 +1,23 @@ +from importlib import util +import os + +from config import TOOLS_DIR + + +def load_toolkit_module_by_id(toolkit_id): + toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py") + spec = util.spec_from_file_location(toolkit_id, toolkit_path) + module = util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + print(f"Loaded module: {module.__name__}") + if hasattr(module, "Tools"): + return module.Tools() + else: + raise Exception("No Tools class found") + except Exception as e: + print(f"Error loading module: {toolkit_id}") + # Move the file to the error folder + os.rename(toolkit_path, f"{toolkit_path}.error") + raise e diff --git a/backend/config.py b/backend/config.py index df52a4b69..32e64347e 100644 --- a/backend/config.py +++ b/backend/config.py @@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs") Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) +#################################### +# Tools DIR +#################################### + +TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") +Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # LITELLM_CONFIG #################################### @@ -669,7 +677,6 @@ Question: ), ) - SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", "task.search.prompt_length_threshold", @@ -679,6 +686,17 @@ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( ), ) +TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", + "task.tools.prompt_template", + os.environ.get( + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", + """Tools: {{TOOLS}} +If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""", + ), +) + + #################################### # WEBUI_SECRET_KEY #################################### diff --git a/backend/main.py b/backend/main.py index 99b409983..3f72c5710 100644 --- a/backend/main.py +++ b/backend/main.py @@ -47,15 +47,24 @@ from pydantic import BaseModel from typing import List, Optional from apps.webui.models.models import Models, ModelModel +from apps.webui.models.tools import Tools +from apps.webui.utils import load_toolkit_module_by_id + + from utils.utils import ( get_admin_user, get_verified_user, get_current_user, get_http_authorization_cred, ) -from utils.task import title_generation_template, search_query_generation_template +from utils.task import ( + title_generation_template, + search_query_generation_template, + tools_function_calling_generation_template, +) +from utils.misc import get_last_user_message, add_or_update_system_message -from apps.rag.utils import rag_messages +from apps.rag.utils import rag_messages, rag_template from config import ( CONFIG_DATA, @@ -82,6 +91,7 @@ from config import ( TITLE_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, AppConfig, ) from constants import ERROR_MESSAGES @@ -148,24 +158,71 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD ) +app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +) app.state.MODELS = {} origins = ["*"] -# Custom middleware to add security headers -# class SecurityHeadersMiddleware(BaseHTTPMiddleware): -# async def dispatch(self, request: Request, call_next): -# response: Response = await call_next(request) -# response.headers["Cross-Origin-Opener-Policy"] = "same-origin" -# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" -# return response + +async def get_function_call_response(prompt, tool_id, template, task_model_id, user): + tool = Tools.get_tool_by_id(tool_id) + tools_specs = json.dumps(tool.specs, indent=2) + content = tools_function_calling_generation_template(template, tools_specs) + + payload = { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + } + + payload = filter_pipeline(payload, user) + model = app.state.MODELS[task_model_id] + + response = None + if model["owned_by"] == "ollama": + response = await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + response = await generate_openai_chat_completion(payload, user=user) + + print(response) + content = response["choices"][0]["message"]["content"] + + # Parse the function response + if content != "": + result = json.loads(content) + print(result) + + # Call the function + if "name" in result: + if tool_id in webui_app.state.TOOLS: + toolkit_module = webui_app.state.TOOLS[tool_id] + else: + toolkit_module = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = toolkit_module + + function = getattr(toolkit_module, result["name"]) + function_result = None + try: + function_result = function(**result["parameters"]) + except Exception as e: + print(e) + + # Add the function result to the system prompt + if function_result: + return function_result + + return None -# app.add_middleware(SecurityHeadersMiddleware) - - -class RAGMiddleware(BaseHTTPMiddleware): +class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): return_citations = False @@ -182,12 +239,65 @@ class RAGMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} + # Remove the citations from the body return_citations = data.get("citations", False) if "citations" in data: del data["citations"] - # Example: Add a new key-value pair or modify existing ones - # data["modified"] = True # Example modification + # Set the task model + task_model_id = data["model"] + if task_model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[task_model_id]["owned_by"] == "ollama": + if ( + app.state.config.TASK_MODEL + and app.state.config.TASK_MODEL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL + else: + if ( + app.state.config.TASK_MODEL_EXTERNAL + and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + + if "tool_ids" in data: + user = get_current_user( + get_http_authorization_cred(request.headers.get("Authorization")) + ) + prompt = get_last_user_message(data["messages"]) + context = "" + + for tool_id in data["tool_ids"]: + response = await get_function_call_response( + prompt=prompt, + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + ) + print(response) + + if response: + context += f"\n{response}" + + system_prompt = rag_template( + rag_app.state.config.RAG_TEMPLATE, context, prompt + ) + + data["messages"] = add_or_update_system_message( + system_prompt, data["messages"] + ) + + del data["tool_ids"] + + # If docs field is present, generate RAG completions if "docs" in data: data = {**data} data["messages"], citations = rag_messages( @@ -210,7 +320,6 @@ class RAGMiddleware(BaseHTTPMiddleware): # Replace the request body with the modified one request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length request.headers.__dict__["_list"] = [ (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), @@ -253,7 +362,7 @@ class RAGMiddleware(BaseHTTPMiddleware): yield data -app.add_middleware(RAGMiddleware) +app.add_middleware(ChatCompletionMiddleware) def filter_pipeline(payload, user): @@ -515,6 +624,7 @@ async def get_task_config(user=Depends(get_verified_user)): "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -524,6 +634,7 @@ class TaskConfigForm(BaseModel): TITLE_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str @app.post("/api/task/config/update") @@ -539,6 +650,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD ) + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + ) return { "TASK_MODEL": app.state.config.TASK_MODEL, @@ -546,6 +660,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -659,6 +774,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) return await generate_openai_chat_completion(payload, user=user) +@app.post("/api/task/tools/completions") +async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)): + print("get_tools_function_calling") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + + return await get_function_call_response( + form_data["prompt"], form_data["tool_id"], template, model_id, user + ) + + @app.post("/api/chat/completions") async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): model_id = form_data["model"] diff --git a/backend/utils/task.py b/backend/utils/task.py index 2239de7df..615febcdc 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -110,3 +110,8 @@ def search_query_generation_template( ), ) return template + + +def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: + template = template.replace("{{TOOLS}}", tools_specs) + return template From 3d6f5f418dd8c7d4ac929ed0846cd77165e4c292 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 11 Jun 2024 00:18:45 -0700 Subject: [PATCH 18/21] feat: tools full integration --- backend/main.py | 82 +++++++++++-------- src/lib/components/chat/Chat.svelte | 4 + src/lib/components/chat/MessageInput.svelte | 13 ++- .../chat/MessageInput/InputMenu.svelte | 15 +++- 4 files changed, 75 insertions(+), 39 deletions(-) diff --git a/backend/main.py b/backend/main.py index 3f72c5710..fa9563e13 100644 --- a/backend/main.py +++ b/backend/main.py @@ -185,39 +185,48 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u model = app.state.MODELS[task_model_id] response = None - if model["owned_by"] == "ollama": - response = await generate_ollama_chat_completion( - OpenAIChatCompletionForm(**payload), user=user - ) - else: - response = await generate_openai_chat_completion(payload, user=user) + try: + if model["owned_by"] == "ollama": + response = await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + response = await generate_openai_chat_completion(payload, user=user) - print(response) - content = response["choices"][0]["message"]["content"] + content = None + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] - # Parse the function response - if content != "": - result = json.loads(content) - print(result) + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() - # Call the function - if "name" in result: - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module + # Parse the function response + if content is not None: + result = json.loads(content) + print(result) - function = getattr(toolkit_module, result["name"]) - function_result = None - try: - function_result = function(**result["parameters"]) - except Exception as e: - print(e) + # Call the function + if "name" in result: + if tool_id in webui_app.state.TOOLS: + toolkit_module = webui_app.state.TOOLS[tool_id] + else: + toolkit_module = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = toolkit_module - # Add the function result to the system prompt - if function_result: - return function_result + function = getattr(toolkit_module, result["name"]) + function_result = None + try: + function_result = function(**result["parameters"]) + except Exception as e: + print(e) + + # Add the function result to the system prompt + if function_result: + return function_result + except Exception as e: + print(f"Error: {e}") return None @@ -285,15 +294,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): print(response) if response: - context += f"\n{response}" + context = ("\n" if context != "" else "") + response - system_prompt = rag_template( - rag_app.state.config.RAG_TEMPLATE, context, prompt - ) + if context != "": + system_prompt = rag_template( + rag_app.state.config.RAG_TEMPLATE, context, prompt + ) - data["messages"] = add_or_update_system_message( - system_prompt, data["messages"] - ) + print(system_prompt) + + data["messages"] = add_or_update_system_message( + f"\n{system_prompt}", data["messages"] + ) del data["tool_ids"] diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 1a0b0d894..3c4c75967 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -73,6 +73,7 @@ let selectedModels = ['']; let atSelectedModel: Model | undefined; + let selectedToolIds = []; let webSearchEnabled = false; let chat = null; @@ -687,6 +688,7 @@ }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, docs: docs.length > 0 ? docs : undefined, citations: docs.length > 0, chat_id: $chatId @@ -948,6 +950,7 @@ top_p: $settings?.params?.top_p ?? undefined, frequency_penalty: $settings?.params?.frequency_penalty ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, docs: docs.length > 0 ? docs : undefined, citations: docs.length > 0, chat_id: $chatId @@ -1274,6 +1277,7 @@ bind:files bind:prompt bind:autoScroll + bind:selectedToolIds bind:webSearchEnabled bind:atSelectedModel {selectedModels} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index b3ceb3e91..c5dc780ab 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -8,7 +8,8 @@ showSidebar, models, config, - showCallOverlay + showCallOverlay, + tools } from '$lib/stores'; import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils'; @@ -57,6 +58,7 @@ let chatInputPlaceholder = ''; export let files = []; + export let selectedToolIds = []; export let webSearchEnabled = false; @@ -653,6 +655,15 @@
{ + a[e.id] = { + name: e.name, + enabled: false + }; + + return a; + }, {})} uploadFilesHandler={() => { filesInputElement.click(); }} diff --git a/src/lib/components/chat/MessageInput/InputMenu.svelte b/src/lib/components/chat/MessageInput/InputMenu.svelte index 811e4d27d..5d43d4648 100644 --- a/src/lib/components/chat/MessageInput/InputMenu.svelte +++ b/src/lib/components/chat/MessageInput/InputMenu.svelte @@ -14,6 +14,8 @@ const i18n = getContext('i18n'); export let uploadFilesHandler: Function; + + export let selectedToolIds: string[] = []; export let webSearchEnabled: boolean; export let tools = {}; @@ -44,16 +46,23 @@ transition={flyAndScale} > {#if Object.keys(tools).length > 0} - {#each Object.keys(tools) as tool} + {#each Object.keys(tools) as toolId}
-
{tool}
+
{tools[toolId].name}
- + { + selectedToolIds = e.detail + ? [...selectedToolIds, toolId] + : selectedToolIds.filter((id) => id !== toolId); + }} + />
{/each}
From 049b3136e8d75b50bc881f91babb1726561caad0 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 11 Jun 2024 00:24:25 -0700 Subject: [PATCH 19/21] refac --- backend/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/main.py b/backend/main.py index fa9563e13..4376da288 100644 --- a/backend/main.py +++ b/backend/main.py @@ -284,6 +284,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): context = "" for tool_id in data["tool_ids"]: + print(tool_id) response = await get_function_call_response( prompt=prompt, tool_id=tool_id, @@ -291,10 +292,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): task_model_id=task_model_id, user=user, ) - print(response) if response: - context = ("\n" if context != "" else "") + response + context += ("\n" if context != "" else "") + response if context != "": system_prompt = rag_template( @@ -304,7 +304,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): print(system_prompt) data["messages"] = add_or_update_system_message( - f"\n{system_prompt}", data["messages"] + f"\n{system_prompt}", data["messages"] ) del data["tool_ids"] From 5237439e297ff4b256085f8cdeacdcdb204c38b3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 11 Jun 2024 00:32:16 -0700 Subject: [PATCH 20/21] feat: tool desc --- src/lib/components/chat/MessageInput.svelte | 1 + src/lib/components/chat/MessageInput/InputMenu.svelte | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index c5dc780ab..e21387607 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -659,6 +659,7 @@ tools={$tools.reduce((a, e, i, arr) => { a[e.id] = { name: e.name, + description: e.meta.description, enabled: false }; diff --git a/src/lib/components/chat/MessageInput/InputMenu.svelte b/src/lib/components/chat/MessageInput/InputMenu.svelte index 5d43d4648..f0bebd8c0 100644 --- a/src/lib/components/chat/MessageInput/InputMenu.svelte +++ b/src/lib/components/chat/MessageInput/InputMenu.svelte @@ -52,7 +52,10 @@ >
-
{tools[toolId].name}
+ + +
{tools[toolId].name}
+
Date: Tue, 11 Jun 2024 00:37:31 -0700 Subject: [PATCH 21/21] refac --- src/lib/components/chat/Chat.svelte | 14 +++++++++++++- src/lib/components/chat/MessageInput.svelte | 13 +++---------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 3c4c75967..af38d1665 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -24,7 +24,8 @@ banners, user, socket, - showCallOverlay + showCallOverlay, + tools } from '$lib/stores'; import { convertMessagesToHistory, @@ -1280,6 +1281,17 @@ bind:selectedToolIds bind:webSearchEnabled bind:atSelectedModel + availableTools={$user.role === 'admin' + ? $tools.reduce((a, e, i, arr) => { + a[e.id] = { + name: e.name, + description: e.meta.description, + enabled: false + }; + + return a; + }, {}) + : {}} {selectedModels} {messages} {submitPrompt} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index e21387607..871025e63 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -58,8 +58,9 @@ let chatInputPlaceholder = ''; export let files = []; - export let selectedToolIds = []; + export let availableTools = {}; + export let selectedToolIds = []; export let webSearchEnabled = false; export let prompt = ''; @@ -656,15 +657,7 @@ { - a[e.id] = { - name: e.name, - description: e.meta.description, - enabled: false - }; - - return a; - }, {})} + tools={availableTools} uploadFilesHandler={() => { filesInputElement.click(); }}