Merge pull request #3006 from open-webui/tools

feat: tools
This commit is contained in:
Timothy Jaeryang Baek 2024-06-11 00:37:53 -07:00 committed by GitHub
commit e6fe3ada57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1391 additions and 99 deletions

View File

@ -0,0 +1,61 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Tool(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
content = pw.TextField()
specs = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")

View File

@ -6,6 +6,7 @@ from apps.webui.routers import (
users,
chats,
documents,
tools,
models,
prompts,
configs,
@ -26,8 +27,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
AppConfig,
ENABLE_COMMUNITY_SHARING,
AppConfig,
)
app = FastAPI()
@ -38,6 +39,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
@ -54,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.TOOLS = {}
app.add_middleware(
@ -70,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])

View File

@ -0,0 +1,131 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ToolMeta(BaseModel):
description: Optional[str] = None
class ToolModel(BaseModel):
id: str
user_id: str
name: str
content: str
specs: List[dict]
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool.create(**tool.model_dump())
if result:
return tool
else:
return None
except:
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Tools = ToolsTable(DB)

View File

@ -0,0 +1,177 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from config import DATA_DIR
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()
############################
# GetToolkits
############################
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# ExportToolKits
############################
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# CreateNewToolKit
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id)
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_EXISTS,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetToolkitById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolkitById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id)
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
specs = get_tools_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteToolkitById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
del TOOLS[id]
return result

View File

@ -0,0 +1,23 @@
from importlib import util
import os
from config import TOOLS_DIR
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {toolkit_id}")
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e

View File

@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# LITELLM_CONFIG
####################################
@ -669,7 +677,6 @@ Question:
),
)
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold",
@ -679,6 +686,17 @@ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
),
)
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",
os.environ.get(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"""Tools: {{TOOLS}}
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
),
)
####################################
# WEBUI_SECRET_KEY
####################################

View File

@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum):
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."

View File

@ -47,15 +47,24 @@ from pydantic import BaseModel
from typing import List, Optional
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import (
get_admin_user,
get_verified_user,
get_current_user,
get_http_authorization_cred,
)
from utils.task import title_generation_template, search_query_generation_template
from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from apps.rag.utils import rag_messages
from apps.rag.utils import rag_messages, rag_template
from config import (
CONFIG_DATA,
@ -82,6 +91,7 @@ from config import (
TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
AppConfig,
)
from constants import ERROR_MESSAGES
@ -148,24 +158,80 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
app.state.MODELS = {}
origins = ["*"]
# Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next):
# response: Response = await call_next(request)
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
# return response
async def get_function_call_response(prompt, tool_id, template, task_model_id, user):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
payload = {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
}
payload = filter_pipeline(payload, user)
model = app.state.MODELS[task_model_id]
response = None
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
content = None
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
# Parse the function response
if content is not None:
result = json.loads(content)
print(result)
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
function = getattr(toolkit_module, result["name"])
function_result = None
try:
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
except Exception as e:
print(f"Error: {e}")
return None
# app.add_middleware(SecurityHeadersMiddleware)
class RAGMiddleware(BaseHTTPMiddleware):
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
@ -182,12 +248,68 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
# Remove the citations from the body
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
# Set the task model
task_model_id = data["model"]
if task_model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if "tool_ids" in data:
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
prompt = get_last_user_message(data["messages"])
context = ""
for tool_id in data["tool_ids"]:
print(tool_id)
response = await get_function_call_response(
prompt=prompt,
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
if response:
context += ("\n" if context != "" else "") + response
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
del data["tool_ids"]
# If docs field is present, generate RAG completions
if "docs" in data:
data = {**data}
data["messages"], citations = rag_messages(
@ -210,7 +332,6 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
@ -253,7 +374,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
yield data
app.add_middleware(RAGMiddleware)
app.add_middleware(ChatCompletionMiddleware)
def filter_pipeline(payload, user):
@ -515,6 +636,7 @@ async def get_task_config(user=Depends(get_verified_user)):
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@ -524,6 +646,7 @@ class TaskConfigForm(BaseModel):
TITLE_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@app.post("/api/task/config/update")
@ -539,6 +662,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
@ -546,6 +672,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@ -659,6 +786,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
return await get_function_call_response(
form_data["prompt"], form_data["tool_id"], template, model_id, user
)
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]

View File

@ -110,3 +110,8 @@ def search_query_generation_template(
),
)
return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
return template

73
backend/utils/tools.py Normal file
View File

@ -0,0 +1,73 @@
import inspect
from typing import get_type_hints, List, Dict, Any
def doc_to_dict(docstring):
lines = docstring.split("\n")
description = lines[1].strip()
param_dict = {}
for line in lines:
if ":param" in line:
line = line.replace(":param", "").strip()
param, desc = line.split(":", 1)
param_dict[param.strip()] = desc.strip()
ret_dict = {"description": description, "params": param_dict}
return ret_dict
def get_tools_specs(tools) -> List[dict]:
function_list = [
{"name": func, "function": getattr(tools, func)}
for func in dir(tools)
if callable(getattr(tools, func)) and not func.startswith("__")
]
specs = []
for function_item in function_list:
function_name = function_item["name"]
function = function_item["function"]
function_doc = doc_to_dict(function.__doc__ or function_name)
specs.append(
{
"name": function_name,
# TODO: multi-line desc?
"description": function_doc.get("description", function_name),
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_annotation.__name__.lower(),
**(
{
"enum": (
param_annotation.__args__
if hasattr(param_annotation, "__args__")
else None
)
}
if hasattr(param_annotation, "__args__")
else {}
),
"description": function_doc.get("params", {}).get(
param_name, param_name
),
}
for param_name, param_annotation in get_type_hints(
function
).items()
if param_name != "return"
},
"required": [
name
for name, param in inspect.signature(
function
).parameters.items()
if param.default is param.empty
],
},
}
)
return specs

193
src/lib/apis/tools/index.ts Normal file
View File

@ -0,0 +1,193 @@
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewTool = async (token: string, tool: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/create`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...tool
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getTools = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const exportTools = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/export`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getToolById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateToolById = async (token: string, id: string, tool: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...tool
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteToolById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/delete`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};

View File

@ -24,7 +24,8 @@
banners,
user,
socket,
showCallOverlay
showCallOverlay,
tools
} from '$lib/stores';
import {
convertMessagesToHistory,
@ -73,6 +74,7 @@
let selectedModels = [''];
let atSelectedModel: Model | undefined;
let selectedToolIds = [];
let webSearchEnabled = false;
let chat = null;
@ -687,6 +689,7 @@
},
format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0,
chat_id: $chatId
@ -948,6 +951,7 @@
top_p: $settings?.params?.top_p ?? undefined,
frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0,
chat_id: $chatId
@ -1274,8 +1278,20 @@
bind:files
bind:prompt
bind:autoScroll
bind:selectedToolIds
bind:webSearchEnabled
bind:atSelectedModel
availableTools={$user.role === 'admin'
? $tools.reduce((a, e, i, arr) => {
a[e.id] = {
name: e.name,
description: e.meta.description,
enabled: false
};
return a;
}, {})
: {}}
{selectedModels}
{messages}
{submitPrompt}

View File

@ -8,7 +8,8 @@
showSidebar,
models,
config,
showCallOverlay
showCallOverlay,
tools
} from '$lib/stores';
import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils';
@ -58,6 +59,8 @@
export let files = [];
export let availableTools = {};
export let selectedToolIds = [];
export let webSearchEnabled = false;
export let prompt = '';
@ -653,6 +656,8 @@
<div class=" ml-0.5 self-end mb-1.5 flex space-x-1">
<InputMenu
bind:webSearchEnabled
bind:selectedToolIds
tools={availableTools}
uploadFilesHandler={() => {
filesInputElement.click();
}}

View File

@ -4,22 +4,21 @@
import { getContext } from 'svelte';
import Dropdown from '$lib/components/common/Dropdown.svelte';
import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
import Pencil from '$lib/components/icons/Pencil.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Tags from '$lib/components/chat/Tags.svelte';
import Share from '$lib/components/icons/Share.svelte';
import ArchiveBox from '$lib/components/icons/ArchiveBox.svelte';
import DocumentArrowUpSolid from '$lib/components/icons/DocumentArrowUpSolid.svelte';
import Switch from '$lib/components/common/Switch.svelte';
import GlobeAltSolid from '$lib/components/icons/GlobeAltSolid.svelte';
import { config } from '$lib/stores';
import WrenchSolid from '$lib/components/icons/WrenchSolid.svelte';
const i18n = getContext('i18n');
export let uploadFilesHandler: Function;
export let selectedToolIds: string[] = [];
export let webSearchEnabled: boolean;
export let tools = {};
export let onClose: Function;
let show = false;
@ -46,6 +45,32 @@
align="start"
transition={flyAndScale}
>
{#if Object.keys(tools).length > 0}
{#each Object.keys(tools) as toolId}
<div
class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"
>
<div class="flex-1 flex items-center gap-2">
<WrenchSolid />
<Tooltip content={tools[toolId]?.description ?? ''}>
<div class="flex items-center line-clamp-1">{tools[toolId].name}</div>
</Tooltip>
</div>
<Switch
bind:state={tools[toolId].enabled}
on:change={(e) => {
selectedToolIds = e.detail
? [...selectedToolIds, toolId]
: selectedToolIds.filter((id) => id !== toolId);
}}
/>
</div>
{/each}
<hr class="border-gray-100 dark:border-gray-800 my-1" />
{/if}
{#if $config?.features?.enable_web_search}
<div
class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"

View File

@ -60,7 +60,10 @@
];
onMount(() => {
console.log(value);
if (value === '') {
value = boilerplate;
}
// Check if html class has dark mode
isDarkMode = document.documentElement.classList.contains('dark');
@ -107,27 +110,24 @@
attributeFilter: ['class']
});
// Add a keyboard shortcut to format the code when Ctrl/Cmd + S is pressed
// Override the default browser save functionality
const handleSave = async (e) => {
const keydownHandler = async (e) => {
if ((e.ctrlKey || e.metaKey) && e.key === 's') {
e.preventDefault();
const res = await formatPythonCodeHandler().catch((error) => {
return null;
});
if (res) {
dispatch('save');
}
// Format code when Ctrl + Shift + F is pressed
if ((e.ctrlKey || e.metaKey) && e.shiftKey && e.key === 'f') {
e.preventDefault();
await formatPythonCodeHandler();
}
};
document.addEventListener('keydown', handleSave);
document.addEventListener('keydown', keydownHandler);
return () => {
observer.disconnect();
document.removeEventListener('keydown', handleSave);
document.removeEventListener('keydown', keydownHandler);
};
});
</script>

View File

@ -0,0 +1,11 @@
<script lang="ts">
export let className = 'size-4';
</script>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class={className}>
<path
fill-rule="evenodd"
d="M12 6.75a5.25 5.25 0 0 1 6.775-5.025.75.75 0 0 1 .313 1.248l-3.32 3.319c.063.475.276.934.641 1.299.365.365.824.578 1.3.64l3.318-3.319a.75.75 0 0 1 1.248.313 5.25 5.25 0 0 1-5.472 6.756c-1.018-.086-1.87.1-2.309.634L7.344 21.3A3.298 3.298 0 1 1 2.7 16.657l8.684-7.151c.533-.44.72-1.291.634-2.309A5.342 5.342 0 0 1 12 6.75ZM4.117 19.125a.75.75 0 0 1 .75-.75h.008a.75.75 0 0 1 .75.75v.008a.75.75 0 0 1-.75.75h-.008a.75.75 0 0 1-.75-.75v-.008Z"
clip-rule="evenodd"
/>
</svg>

View File

@ -4,12 +4,23 @@
const { saveAs } = fileSaver;
import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, prompts } from '$lib/stores';
import { WEBUI_NAME, prompts, tools } from '$lib/stores';
import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
import { goto } from '$app/navigation';
import {
createNewTool,
deleteToolById,
exportTools,
getToolById,
getTools
} from '$lib/apis/tools';
const i18n = getContext('i18n');
let toolsImportInputElement: HTMLInputElement;
let importFiles;
let query = '';
</script>
@ -65,3 +76,216 @@
</div>
</div>
<hr class=" dark:border-gray-850 my-2.5" />
<div class="my-3 mb-5">
{#each $tools.filter((t) => query === '' || t.name
.toLowerCase()
.includes(query.toLowerCase()) || t.id.toLowerCase().includes(query.toLowerCase())) as tool}
<button
class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
type="button"
on:click={() => {
goto(`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`);
}}
>
<div class=" flex flex-1 space-x-4 cursor-pointer w-full">
<a
href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
class="flex items-center text-left"
>
<div class=" flex-1 self-center pl-5">
<div class=" font-semibold flex items-center gap-1.5">
<div>
{tool.name}
</div>
<div class=" text-gray-500 text-xs font-medium">{tool.id}</div>
</div>
<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
{tool.meta.description}
</div>
</div>
</a>
</div>
<div class="flex flex-row space-x-1 self-center">
<a
class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button"
href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-4 h-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L6.832 19.82a4.5 4.5 0 01-1.897 1.13l-2.685.8.8-2.685a4.5 4.5 0 011.13-1.897L16.863 4.487zm0 0L19.5 7.125"
/>
</svg>
</a>
<button
class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button"
on:click={async () => {
const _tool = await getToolById(localStorage.token, tool.id).catch((error) => {
toast.error(error);
return null;
});
if (_tool) {
sessionStorage.tool = JSON.stringify({
..._tool,
id: `${_tool.id}_clone`,
name: `${_tool.name} (Clone)`
});
goto('/workspace/tools/create');
}
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-4 h-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M15.75 17.25v3.375c0 .621-.504 1.125-1.125 1.125h-9.75a1.125 1.125 0 0 1-1.125-1.125V7.875c0-.621.504-1.125 1.125-1.125H6.75a9.06 9.06 0 0 1 1.5.124m7.5 10.376h3.375c.621 0 1.125-.504 1.125-1.125V11.25c0-4.46-3.243-8.161-7.5-8.876a9.06 9.06 0 0 0-1.5-.124H9.375c-.621 0-1.125.504-1.125 1.125v3.5m7.5 10.375H9.375a1.125 1.125 0 0 1-1.125-1.125v-9.25m12 6.625v-1.875a3.375 3.375 0 0 0-3.375-3.375h-1.5a1.125 1.125 0 0 1-1.125-1.125v-1.5a3.375 3.375 0 0 0-3.375-3.375H9.75"
/>
</svg>
</button>
<button
class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button"
on:click={async () => {
const res = await deleteToolById(localStorage.token, tool.id).catch((error) => {
toast.error(error);
return null;
});
if (res) {
toast.success('Tool deleted successfully');
tools.set(await getTools(localStorage.token));
}
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-4 h-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0"
/>
</svg>
</button>
</div>
</button>
{/each}
</div>
<div class=" flex justify-end w-full mb-2">
<div class="flex space-x-2">
<input
id="documents-import-input"
bind:this={toolsImportInputElement}
bind:files={importFiles}
type="file"
accept=".json"
hidden
on:change={() => {
console.log(importFiles);
const reader = new FileReader();
reader.onload = async (event) => {
const _tools = JSON.parse(event.target.result);
console.log(_tools);
for (const tool of _tools) {
const res = await createNewTool(localStorage.token, tool).catch((error) => {
toast.error(error);
return null;
});
}
toast.success('Tool imported successfully');
tools.set(await getTools(localStorage.token));
};
reader.readAsText(importFiles[0]);
}}
/>
<button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
on:click={() => {
toolsImportInputElement.click();
}}
>
<div class=" self-center mr-2 font-medium">{$i18n.t('Import Tools')}</div>
<div class=" self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M4 2a1.5 1.5 0 0 0-1.5 1.5v9A1.5 1.5 0 0 0 4 14h8a1.5 1.5 0 0 0 1.5-1.5V6.621a1.5 1.5 0 0 0-.44-1.06L9.94 2.439A1.5 1.5 0 0 0 8.878 2H4Zm4 9.5a.75.75 0 0 1-.75-.75V8.06l-.72.72a.75.75 0 0 1-1.06-1.06l2-2a.75.75 0 0 1 1.06 0l2 2a.75.75 0 1 1-1.06 1.06l-.72-.72v2.69a.75.75 0 0 1-.75.75Z"
clip-rule="evenodd"
/>
</svg>
</div>
</button>
<button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
on:click={async () => {
const _tools = await exportTools(localStorage.token).catch((error) => {
toast.error(error);
return null;
});
if (_tools) {
let blob = new Blob([JSON.stringify(_tools)], {
type: 'application/json'
});
saveAs(blob, `tools-export-${Date.now()}.json`);
}
}}
>
<div class=" self-center mr-2 font-medium">{$i18n.t('Export Tools')}</div>
<div class=" self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M4 2a1.5 1.5 0 0 0-1.5 1.5v9A1.5 1.5 0 0 0 4 14h8a1.5 1.5 0 0 0 1.5-1.5V6.621a1.5 1.5 0 0 0-.44-1.06L9.94 2.439A1.5 1.5 0 0 0 8.878 2H4Zm4 3.5a.75.75 0 0 1 .75.75v2.69l.72-.72a.75.75 0 1 1 1.06 1.06l-2 2a.75.75 0 0 1-1.06 0l-2-2a.75.75 0 0 1 1.06-1.06l.72.72V6.25A.75.75 0 0 1 8 5.5Z"
clip-rule="evenodd"
/>
</svg>
</div>
</button>
</div>
</div>

View File

@ -1,15 +1,15 @@
<script lang="ts">
import CodeEditor from '$lib/components/common/CodeEditor.svelte';
import { createEventDispatcher } from 'svelte';
const dispatch = createEventDispatcher();
export let saveHandler: Function;
export let value = '';
let codeEditor;
let boilerplate = `# Tip: Use Ctrl/Cmd + S to format the code
from datetime import datetime
let boilerplate = `import os
import requests
from datetime import datetime
class Tools:
@ -20,6 +20,20 @@ class Tools:
# Use Sphinx-style docstrings to document your tools, they will be used for generating tools specifications
# Please refer to function_calling_filter_pipeline.py file from pipelines project for an example
def get_environment_variable(self, variable_name: str) -> str:
"""
Get the value of an environment variable.
:param variable_name: The name of the environment variable.
:return: The value of the environment variable or a message if it doesn't exist.
"""
value = os.getenv(variable_name)
if value is not None:
return (
f"The value of the environment variable '{variable_name}' is '{value}'"
)
else:
return f"The environment variable '{variable_name}' does not exist."
def get_current_time(self) -> str:
"""
Get the current time.
@ -45,6 +59,41 @@ class Tools:
print(e)
return "Invalid equation"
def get_current_weather(self, city: str) -> str:
"""
Get the current weather for a given city.
:param city: The name of the city to get the weather for.
:return: The current weather information or an error message.
"""
api_key = os.getenv("OPENWEATHER_API_KEY")
if not api_key:
return (
"API key is not set in the environment variable 'OPENWEATHER_API_KEY'."
)
base_url = "http://api.openweathermap.org/data/2.5/weather"
params = {
"q": city,
"appid": api_key,
"units": "metric", # Optional: Use 'imperial' for Fahrenheit
}
try:
response = requests.get(base_url, params=params)
response.raise_for_status() # Raise HTTPError for bad responses (4xx and 5xx)
data = response.json()
if data.get("cod") != 200:
return f"Error fetching weather data: {data.get('message')}"
weather_description = data["weather"][0]["description"]
temperature = data["main"]["temp"]
humidity = data["main"]["humidity"]
wind_speed = data["wind"]["speed"]
return f"Weather in {city}: {temperature}°C"
except requests.RequestException as e:
return f"Error fetching weather data: {str(e)}"
`;
export const formatHandler = async () => {
@ -60,6 +109,6 @@ class Tools:
{boilerplate}
bind:this={codeEditor}
on:save={() => {
saveHandler();
dispatch('save');
}}
/>

View File

@ -1,22 +1,27 @@
<script>
import { getContext } from 'svelte';
import { getContext, createEventDispatcher, onMount } from 'svelte';
const i18n = getContext('i18n');
import CodeEditor from './CodeEditor.svelte';
import { goto } from '$app/navigation';
const dispatch = createEventDispatcher();
let formElement = null;
let loading = false;
let id = '';
let name = '';
let meta = {
export let edit = false;
export let clone = false;
export let id = '';
export let name = '';
export let meta = {
description: ''
};
export let content = '';
let code = '';
$: if (name) {
$: if (name && !edit && !clone) {
id = name.replace(/\s+/g, '_').toLowerCase();
}
@ -24,8 +29,12 @@
const saveHandler = async () => {
loading = true;
// Call the API to save the toolkit
console.log('saveHandler');
dispatch('save', {
id,
name,
meta,
content
});
};
const submitHandler = async () => {
@ -42,13 +51,20 @@
<div class=" flex flex-col justify-between w-full overflow-y-auto h-full">
<div class="mx-auto w-full md:px-0 h-full">
<div class=" flex flex-col max-h-[100dvh] h-full">
<form
bind:this={formElement}
class=" flex flex-col max-h-[100dvh] h-full"
on:submit|preventDefault={() => {
submitHandler();
}}
>
<div class="mb-2.5">
<button
class="flex space-x-1"
on:click={() => {
goto('/workspace/tools');
}}
type="button"
>
<div class=" self-center">
<svg
@ -80,11 +96,12 @@
/>
<input
class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
class="w-full px-3 py-2 text-sm font-medium disabled:text-gray-300 dark:disabled:text-gray-700 bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Toolkit ID (e.g. my_toolkit)"
bind:value={id}
required
disabled={edit}
/>
</div>
<input
@ -97,20 +114,25 @@
</div>
<div class="mb-2 flex-1 overflow-auto h-0 rounded-lg">
<CodeEditor bind:value={code} bind:this={codeEditor} {saveHandler} />
<CodeEditor
bind:value={content}
bind:this={codeEditor}
on:save={() => {
if (formElement) {
formElement.requestSubmit();
}
}}
/>
</div>
<div class="pb-3 flex justify-end">
<button
class="px-3 py-1.5 text-sm font-medium bg-emerald-600 hover:bg-emerald-700 text-gray-50 transition rounded-lg"
on:click={() => {
submitHandler();
}}
>
{$i18n.t('Save')}
</button>
</div>
</div>
</div>
</form>
</div>
</div>

View File

@ -23,24 +23,11 @@ export const chatId = writable('');
export const chats = writable([]);
export const tags = writable([]);
export const models: Writable<Model[]> = writable([]);
export const modelfiles = writable([]);
export const models: Writable<Model[]> = writable([]);
export const prompts: Writable<Prompt[]> = writable([]);
export const documents = writable([
{
collection_name: 'collection_name',
filename: 'filename',
name: 'name',
title: 'title'
},
{
collection_name: 'collection_name1',
filename: 'filename1',
name: 'name1',
title: 'title1'
}
]);
export const documents: Writable<Document[]> = writable([]);
export const tools = writable([]);
export const banners: Writable<Banner[]> = writable([]);
@ -135,6 +122,13 @@ type Prompt = {
timestamp: number;
};
type Document = {
collection_name: string;
filename: string;
name: string;
title: string;
};
type Config = {
status: boolean;
name: string;

View File

@ -8,12 +8,15 @@
import { goto } from '$app/navigation';
import { getModels as _getModels } from '$lib/apis';
import { getOllamaVersion } from '$lib/apis/ollama';
import { getPrompts } from '$lib/apis/prompts';
import { getDocs } from '$lib/apis/documents';
import { getAllChatTags } from '$lib/apis/chats';
import { getPrompts } from '$lib/apis/prompts';
import { getDocs } from '$lib/apis/documents';
import { getTools } from '$lib/apis/tools';
import { getBanners } from '$lib/apis/configs';
import { getUserSettings } from '$lib/apis/users';
import {
user,
showSettings,
@ -25,33 +28,21 @@
banners,
showChangelog,
config,
showCallOverlay
showCallOverlay,
tools
} from '$lib/stores';
import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants';
import { compareVersion } from '$lib/utils';
import SettingsModal from '$lib/components/chat/SettingsModal.svelte';
import Sidebar from '$lib/components/layout/Sidebar.svelte';
import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte';
import ChangelogModal from '$lib/components/ChangelogModal.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import { getBanners } from '$lib/apis/configs';
import { getUserSettings } from '$lib/apis/users';
import Help from '$lib/components/layout/Help.svelte';
import AccountPending from '$lib/components/layout/Overlay/AccountPending.svelte';
import { error } from '@sveltejs/kit';
import CallOverlay from '$lib/components/chat/MessageInput/CallOverlay.svelte';
const i18n = getContext('i18n');
let ollamaVersion = '';
let loaded = false;
let showShortcutsButtonElement: HTMLButtonElement;
let DB = null;
let localDBChats = [];
let showShortcuts = false;
const getModels = async () => {
return _getModels(localStorage.token);
};
@ -99,6 +90,9 @@
(async () => {
documents.set(await getDocs(localStorage.token));
})(),
(async () => {
tools.set(await getTools(localStorage.token));
})(),
(async () => {
banners.set(await getBanners(localStorage.token));
})(),

View File

@ -1,5 +1,52 @@
<script>
import { goto } from '$app/navigation';
import { createNewTool, getTools } from '$lib/apis/tools';
import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
import { tools } from '$lib/stores';
import { onMount } from 'svelte';
import { toast } from 'svelte-sonner';
let clone = false;
let tool = null;
const saveHandler = async (data) => {
console.log(data);
const res = await createNewTool(localStorage.token, {
id: data.id,
name: data.name,
meta: data.meta,
content: data.content
}).catch((error) => {
toast.error(error);
return null;
});
if (res) {
toast.success('Tool created successfully');
tools.set(await getTools(localStorage.token));
await goto('/workspace/tools');
}
};
onMount(() => {
console.log('mounted');
if (sessionStorage.tool) {
tool = JSON.parse(sessionStorage.tool);
sessionStorage.removeItem('tool');
clone = true;
}
});
</script>
<ToolkitEditor />
<ToolkitEditor
id={tool?.id ?? ''}
name={tool?.name ?? ''}
meta={tool?.meta ?? { description: '' }}
content={tool?.content ?? ''}
{clone}
on:save={(e) => {
saveHandler(e.detail);
}}
/>

View File

@ -1,5 +1,66 @@
<script>
import { goto } from '$app/navigation';
import { page } from '$app/stores';
import { getToolById, getTools, updateToolById } from '$lib/apis/tools';
import Spinner from '$lib/components/common/Spinner.svelte';
import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
import { tools } from '$lib/stores';
import { onMount } from 'svelte';
import { toast } from 'svelte-sonner';
let tool = null;
const saveHandler = async (data) => {
console.log(data);
const res = await updateToolById(localStorage.token, tool.id, {
id: data.id,
name: data.name,
meta: data.meta,
content: data.content
}).catch((error) => {
toast.error(error);
return null;
});
if (res) {
toast.success('Tool updated successfully');
tools.set(await getTools(localStorage.token));
await goto('/workspace/tools');
}
};
onMount(async () => {
console.log('mounted');
const id = $page.url.searchParams.get('id');
if (id) {
tool = await getToolById(localStorage.token, id).catch((error) => {
toast.error(error);
goto('/workspace/tools');
return null;
});
console.log(tool);
}
});
</script>
<ToolkitEditor />
{#if tool}
<ToolkitEditor
edit={true}
id={tool.id}
name={tool.name}
meta={tool.meta}
content={tool.content}
on:save={(e) => {
saveHandler(e.detail);
}}
/>
{:else}
<div class="flex items-center justify-center h-full">
<div class=" pb-16">
<Spinner />
</div>
</div>
{/if}