diff --git a/backend/apps/webui/internal/migrations/012_add_tools.py b/backend/apps/webui/internal/migrations/012_add_tools.py new file mode 100644 index 000000000..4a68eea55 --- /dev/null +++ b/backend/apps/webui/internal/migrations/012_add_tools.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Tool(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + + name = pw.TextField() + content = pw.TextField() + specs = pw.TextField() + + meta = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "tool" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("tool") diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 6ec9bbace..62a0a7a7b 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -6,6 +6,7 @@ from apps.webui.routers import ( users, chats, documents, + tools, models, prompts, configs, @@ -26,8 +27,8 @@ from config import ( WEBUI_AUTH_TRUSTED_EMAIL_HEADER, JWT_EXPIRES_IN, WEBUI_BANNERS, - AppConfig, ENABLE_COMMUNITY_SHARING, + AppConfig, ) app = FastAPI() @@ -38,6 +39,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN +app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS @@ -54,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.MODELS = {} -app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER +app.state.TOOLS = {} app.add_middleware( @@ -70,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) +app.include_router(tools.router, prefix="/tools", tags=["tools"]) app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py new file mode 100644 index 000000000..99463878b --- /dev/null +++ b/backend/apps/webui/models/tools.py @@ -0,0 +1,131 @@ +from pydantic import BaseModel +from peewee import * +from playhouse.shortcuts import model_to_dict +from typing import List, Union, Optional +import time +import logging +from apps.webui.internal.db import DB, JSONField + +import json + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# Tools DB Schema +#################### + + +class Tool(Model): + id = CharField(unique=True) + user_id = CharField() + name = TextField() + content = TextField() + specs = JSONField() + meta = JSONField() + updated_at = BigIntegerField() + created_at = BigIntegerField() + + class Meta: + database = DB + + +class ToolMeta(BaseModel): + description: Optional[str] = None + + +class ToolModel(BaseModel): + id: str + user_id: str + name: str + content: str + specs: List[dict] + meta: ToolMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class ToolResponse(BaseModel): + id: str + user_id: str + name: str + meta: ToolMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +class ToolForm(BaseModel): + id: str + name: str + content: str + meta: ToolMeta + + +class ToolsTable: + def __init__(self, db): + self.db = db + self.db.create_tables([Tool]) + + def insert_new_tool( + self, user_id: str, form_data: ToolForm, specs: List[dict] + ) -> Optional[ToolModel]: + tool = ToolModel( + **{ + **form_data.model_dump(), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + + try: + result = Tool.create(**tool.model_dump()) + if result: + return tool + else: + return None + except: + return None + + def get_tool_by_id(self, id: str) -> Optional[ToolModel]: + try: + tool = Tool.get(Tool.id == id) + return ToolModel(**model_to_dict(tool)) + except: + return None + + def get_tools(self) -> List[ToolModel]: + return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] + + def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: + try: + query = Tool.update( + **updated, + updated_at=int(time.time()), + ).where(Tool.id == id) + query.execute() + + tool = Tool.get(Tool.id == id) + return ToolModel(**model_to_dict(tool)) + except: + return None + + def delete_tool_by_id(self, id: str) -> bool: + try: + query = Tool.delete().where((Tool.id == id)) + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + + +Tools = ToolsTable(DB) diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py new file mode 100644 index 000000000..c2a815ee7 --- /dev/null +++ b/backend/apps/webui/routers/tools.py @@ -0,0 +1,177 @@ +from fastapi import Depends, FastAPI, HTTPException, status, Request +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + +from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse +from apps.webui.utils import load_toolkit_module_by_id + +from utils.utils import get_current_user, get_admin_user +from utils.tools import get_tools_specs +from constants import ERROR_MESSAGES + +from importlib import util +import os + +from config import DATA_DIR + + +TOOLS_DIR = f"{DATA_DIR}/tools" +os.makedirs(TOOLS_DIR, exist_ok=True) + + +router = APIRouter() + +############################ +# GetToolkits +############################ + + +@router.get("/", response_model=List[ToolResponse]) +async def get_toolkits(user=Depends(get_current_user)): + toolkits = [toolkit for toolkit in Tools.get_tools()] + return toolkits + + +############################ +# ExportToolKits +############################ + + +@router.get("/export", response_model=List[ToolModel]) +async def get_toolkits(user=Depends(get_admin_user)): + toolkits = [toolkit for toolkit in Tools.get_tools()] + return toolkits + + +############################ +# CreateNewToolKit +############################ + + +@router.post("/create", response_model=Optional[ToolResponse]) +async def create_new_toolkit( + request: Request, form_data: ToolForm, user=Depends(get_admin_user) +): + if not form_data.id.isidentifier(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only alphanumeric characters and underscores are allowed in the id", + ) + + form_data.id = form_data.id.lower() + + toolkit = Tools.get_tool_by_id(form_data.id) + if toolkit == None: + toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py") + try: + with open(toolkit_path, "w") as tool_file: + tool_file.write(form_data.content) + + toolkit_module = load_toolkit_module_by_id(form_data.id) + + TOOLS = request.app.state.TOOLS + TOOLS[form_data.id] = toolkit_module + + specs = get_tools_specs(TOOLS[form_data.id]) + toolkit = Tools.insert_new_tool(user.id, form_data, specs) + + if toolkit: + return toolkit + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.FILE_EXISTS, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ID_TAKEN, + ) + + +############################ +# GetToolkitById +############################ + + +@router.get("/id/{id}", response_model=Optional[ToolModel]) +async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): + toolkit = Tools.get_tool_by_id(id) + + if toolkit: + return toolkit + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateToolkitById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[ToolModel]) +async def update_toolkit_by_id( + request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user) +): + toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") + + try: + with open(toolkit_path, "w") as tool_file: + tool_file.write(form_data.content) + + toolkit_module = load_toolkit_module_by_id(id) + + TOOLS = request.app.state.TOOLS + TOOLS[id] = toolkit_module + + specs = get_tools_specs(TOOLS[id]) + + updated = { + **form_data.model_dump(exclude={"id"}), + "specs": specs, + } + + print(updated) + toolkit = Tools.update_tool_by_id(id, updated) + + if toolkit: + return toolkit + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"), + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteToolkitById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): + result = Tools.delete_tool_by_id(id) + + if result: + TOOLS = request.app.state.TOOLS + del TOOLS[id] + + return result diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py new file mode 100644 index 000000000..19a8615bc --- /dev/null +++ b/backend/apps/webui/utils.py @@ -0,0 +1,23 @@ +from importlib import util +import os + +from config import TOOLS_DIR + + +def load_toolkit_module_by_id(toolkit_id): + toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py") + spec = util.spec_from_file_location(toolkit_id, toolkit_path) + module = util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + print(f"Loaded module: {module.__name__}") + if hasattr(module, "Tools"): + return module.Tools() + else: + raise Exception("No Tools class found") + except Exception as e: + print(f"Error loading module: {toolkit_id}") + # Move the file to the error folder + os.rename(toolkit_path, f"{toolkit_path}.error") + raise e diff --git a/backend/config.py b/backend/config.py index df52a4b69..32e64347e 100644 --- a/backend/config.py +++ b/backend/config.py @@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs") Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) +#################################### +# Tools DIR +#################################### + +TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") +Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # LITELLM_CONFIG #################################### @@ -669,7 +677,6 @@ Question: ), ) - SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", "task.search.prompt_length_threshold", @@ -679,6 +686,17 @@ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( ), ) +TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", + "task.tools.prompt_template", + os.environ.get( + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", + """Tools: {{TOOLS}} +If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""", + ), +) + + #################################### # WEBUI_SECRET_KEY #################################### diff --git a/backend/constants.py b/backend/constants.py index 0740fa49d..f1eed43d3 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum): COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." + ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string." MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." diff --git a/backend/main.py b/backend/main.py index 99b409983..4376da288 100644 --- a/backend/main.py +++ b/backend/main.py @@ -47,15 +47,24 @@ from pydantic import BaseModel from typing import List, Optional from apps.webui.models.models import Models, ModelModel +from apps.webui.models.tools import Tools +from apps.webui.utils import load_toolkit_module_by_id + + from utils.utils import ( get_admin_user, get_verified_user, get_current_user, get_http_authorization_cred, ) -from utils.task import title_generation_template, search_query_generation_template +from utils.task import ( + title_generation_template, + search_query_generation_template, + tools_function_calling_generation_template, +) +from utils.misc import get_last_user_message, add_or_update_system_message -from apps.rag.utils import rag_messages +from apps.rag.utils import rag_messages, rag_template from config import ( CONFIG_DATA, @@ -82,6 +91,7 @@ from config import ( TITLE_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, AppConfig, ) from constants import ERROR_MESSAGES @@ -148,24 +158,80 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD ) +app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +) app.state.MODELS = {} origins = ["*"] -# Custom middleware to add security headers -# class SecurityHeadersMiddleware(BaseHTTPMiddleware): -# async def dispatch(self, request: Request, call_next): -# response: Response = await call_next(request) -# response.headers["Cross-Origin-Opener-Policy"] = "same-origin" -# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" -# return response + +async def get_function_call_response(prompt, tool_id, template, task_model_id, user): + tool = Tools.get_tool_by_id(tool_id) + tools_specs = json.dumps(tool.specs, indent=2) + content = tools_function_calling_generation_template(template, tools_specs) + + payload = { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + } + + payload = filter_pipeline(payload, user) + model = app.state.MODELS[task_model_id] + + response = None + try: + if model["owned_by"] == "ollama": + response = await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + response = await generate_openai_chat_completion(payload, user=user) + + content = None + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + + # Parse the function response + if content is not None: + result = json.loads(content) + print(result) + + # Call the function + if "name" in result: + if tool_id in webui_app.state.TOOLS: + toolkit_module = webui_app.state.TOOLS[tool_id] + else: + toolkit_module = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = toolkit_module + + function = getattr(toolkit_module, result["name"]) + function_result = None + try: + function_result = function(**result["parameters"]) + except Exception as e: + print(e) + + # Add the function result to the system prompt + if function_result: + return function_result + except Exception as e: + print(f"Error: {e}") + + return None -# app.add_middleware(SecurityHeadersMiddleware) - - -class RAGMiddleware(BaseHTTPMiddleware): +class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): return_citations = False @@ -182,12 +248,68 @@ class RAGMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} + # Remove the citations from the body return_citations = data.get("citations", False) if "citations" in data: del data["citations"] - # Example: Add a new key-value pair or modify existing ones - # data["modified"] = True # Example modification + # Set the task model + task_model_id = data["model"] + if task_model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[task_model_id]["owned_by"] == "ollama": + if ( + app.state.config.TASK_MODEL + and app.state.config.TASK_MODEL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL + else: + if ( + app.state.config.TASK_MODEL_EXTERNAL + and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + + if "tool_ids" in data: + user = get_current_user( + get_http_authorization_cred(request.headers.get("Authorization")) + ) + prompt = get_last_user_message(data["messages"]) + context = "" + + for tool_id in data["tool_ids"]: + print(tool_id) + response = await get_function_call_response( + prompt=prompt, + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + ) + + if response: + context += ("\n" if context != "" else "") + response + + if context != "": + system_prompt = rag_template( + rag_app.state.config.RAG_TEMPLATE, context, prompt + ) + + print(system_prompt) + + data["messages"] = add_or_update_system_message( + f"\n{system_prompt}", data["messages"] + ) + + del data["tool_ids"] + + # If docs field is present, generate RAG completions if "docs" in data: data = {**data} data["messages"], citations = rag_messages( @@ -210,7 +332,6 @@ class RAGMiddleware(BaseHTTPMiddleware): # Replace the request body with the modified one request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length request.headers.__dict__["_list"] = [ (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), @@ -253,7 +374,7 @@ class RAGMiddleware(BaseHTTPMiddleware): yield data -app.add_middleware(RAGMiddleware) +app.add_middleware(ChatCompletionMiddleware) def filter_pipeline(payload, user): @@ -515,6 +636,7 @@ async def get_task_config(user=Depends(get_verified_user)): "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -524,6 +646,7 @@ class TaskConfigForm(BaseModel): TITLE_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str @app.post("/api/task/config/update") @@ -539,6 +662,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD ) + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + ) return { "TASK_MODEL": app.state.config.TASK_MODEL, @@ -546,6 +672,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -659,6 +786,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) return await generate_openai_chat_completion(payload, user=user) +@app.post("/api/task/tools/completions") +async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)): + print("get_tools_function_calling") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + + return await get_function_call_response( + form_data["prompt"], form_data["tool_id"], template, model_id, user + ) + + @app.post("/api/chat/completions") async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): model_id = form_data["model"] diff --git a/backend/utils/task.py b/backend/utils/task.py index 2239de7df..615febcdc 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -110,3 +110,8 @@ def search_query_generation_template( ), ) return template + + +def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: + template = template.replace("{{TOOLS}}", tools_specs) + return template diff --git a/backend/utils/tools.py b/backend/utils/tools.py new file mode 100644 index 000000000..91359bad9 --- /dev/null +++ b/backend/utils/tools.py @@ -0,0 +1,73 @@ +import inspect +from typing import get_type_hints, List, Dict, Any + + +def doc_to_dict(docstring): + lines = docstring.split("\n") + description = lines[1].strip() + param_dict = {} + + for line in lines: + if ":param" in line: + line = line.replace(":param", "").strip() + param, desc = line.split(":", 1) + param_dict[param.strip()] = desc.strip() + ret_dict = {"description": description, "params": param_dict} + return ret_dict + + +def get_tools_specs(tools) -> List[dict]: + function_list = [ + {"name": func, "function": getattr(tools, func)} + for func in dir(tools) + if callable(getattr(tools, func)) and not func.startswith("__") + ] + + specs = [] + for function_item in function_list: + function_name = function_item["name"] + function = function_item["function"] + + function_doc = doc_to_dict(function.__doc__ or function_name) + specs.append( + { + "name": function_name, + # TODO: multi-line desc? + "description": function_doc.get("description", function_name), + "parameters": { + "type": "object", + "properties": { + param_name: { + "type": param_annotation.__name__.lower(), + **( + { + "enum": ( + param_annotation.__args__ + if hasattr(param_annotation, "__args__") + else None + ) + } + if hasattr(param_annotation, "__args__") + else {} + ), + "description": function_doc.get("params", {}).get( + param_name, param_name + ), + } + for param_name, param_annotation in get_type_hints( + function + ).items() + if param_name != "return" + }, + "required": [ + name + for name, param in inspect.signature( + function + ).parameters.items() + if param.default is param.empty + ], + }, + } + ) + + return specs diff --git a/src/lib/apis/tools/index.ts b/src/lib/apis/tools/index.ts new file mode 100644 index 000000000..47a535cdf --- /dev/null +++ b/src/lib/apis/tools/index.ts @@ -0,0 +1,193 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewTool = async (token: string, tool: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...tool + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getTools = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const exportTools = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/export`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getToolById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateToolById = async (token: string, id: string, tool: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...tool + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteToolById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 1a0b0d894..af38d1665 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -24,7 +24,8 @@ banners, user, socket, - showCallOverlay + showCallOverlay, + tools } from '$lib/stores'; import { convertMessagesToHistory, @@ -73,6 +74,7 @@ let selectedModels = ['']; let atSelectedModel: Model | undefined; + let selectedToolIds = []; let webSearchEnabled = false; let chat = null; @@ -687,6 +689,7 @@ }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, docs: docs.length > 0 ? docs : undefined, citations: docs.length > 0, chat_id: $chatId @@ -948,6 +951,7 @@ top_p: $settings?.params?.top_p ?? undefined, frequency_penalty: $settings?.params?.frequency_penalty ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, docs: docs.length > 0 ? docs : undefined, citations: docs.length > 0, chat_id: $chatId @@ -1274,8 +1278,20 @@ bind:files bind:prompt bind:autoScroll + bind:selectedToolIds bind:webSearchEnabled bind:atSelectedModel + availableTools={$user.role === 'admin' + ? $tools.reduce((a, e, i, arr) => { + a[e.id] = { + name: e.name, + description: e.meta.description, + enabled: false + }; + + return a; + }, {}) + : {}} {selectedModels} {messages} {submitPrompt} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index b3ceb3e91..871025e63 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -8,7 +8,8 @@ showSidebar, models, config, - showCallOverlay + showCallOverlay, + tools } from '$lib/stores'; import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils'; @@ -58,6 +59,8 @@ export let files = []; + export let availableTools = {}; + export let selectedToolIds = []; export let webSearchEnabled = false; export let prompt = ''; @@ -653,6 +656,8 @@