feat: fc integration

This commit is contained in:
Timothy J. Baek 2024-06-10 23:40:27 -07:00
parent ff1cd306d8
commit a27175d672
5 changed files with 215 additions and 40 deletions

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel
import json import json
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from utils.tools import get_tools_specs from utils.tools import get_tools_specs
@ -17,32 +18,13 @@ import os
from config import DATA_DIR from config import DATA_DIR
TOOLS_DIR = f"{DATA_DIR}/tools" TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True) os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter() router = APIRouter()
def load_toolkit_module_from_path(tools_id, tools_path):
spec = util.spec_from_file_location(tools_id, tools_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {tools_id}")
# Move the file to the error folder
os.rename(tools_path, f"{tools_path}.error")
raise e
############################ ############################
# GetToolkits # GetToolkits
############################ ############################
@ -89,7 +71,7 @@ async def create_new_toolkit(
with open(toolkit_path, "w") as tool_file: with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content) tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path) toolkit_module = load_toolkit_module_by_id(form_data.id)
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module TOOLS[form_data.id] = toolkit_module
@ -149,7 +131,7 @@ async def update_toolkit_by_id(
with open(toolkit_path, "w") as tool_file: with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content) tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_from_path(id, toolkit_path) toolkit_module = load_toolkit_module_by_id(id)
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module TOOLS[id] = toolkit_module

View File

@ -0,0 +1,23 @@
from importlib import util
import os
from config import TOOLS_DIR
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {toolkit_id}")
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e

View File

@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# LITELLM_CONFIG # LITELLM_CONFIG
#################################### ####################################
@ -669,7 +677,6 @@ Question:
), ),
) )
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold", "task.search.prompt_length_threshold",
@ -679,6 +686,17 @@ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
), ),
) )
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",
os.environ.get(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"""Tools: {{TOOLS}}
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
),
)
#################################### ####################################
# WEBUI_SECRET_KEY # WEBUI_SECRET_KEY
#################################### ####################################

View File

@ -47,15 +47,24 @@ from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import ( from utils.utils import (
get_admin_user, get_admin_user,
get_verified_user, get_verified_user,
get_current_user, get_current_user,
get_http_authorization_cred, get_http_authorization_cred,
) )
from utils.task import title_generation_template, search_query_generation_template from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from apps.rag.utils import rag_messages from apps.rag.utils import rag_messages, rag_template
from config import ( from config import (
CONFIG_DATA, CONFIG_DATA,
@ -82,6 +91,7 @@ from config import (
TITLE_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
AppConfig, AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -148,24 +158,71 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
) )
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
app.state.MODELS = {} app.state.MODELS = {}
origins = ["*"] origins = ["*"]
# Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware): async def get_function_call_response(prompt, tool_id, template, task_model_id, user):
# async def dispatch(self, request: Request, call_next): tool = Tools.get_tool_by_id(tool_id)
# response: Response = await call_next(request) tools_specs = json.dumps(tool.specs, indent=2)
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin" content = tools_function_calling_generation_template(template, tools_specs)
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
# return response payload = {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
}
payload = filter_pipeline(payload, user)
model = app.state.MODELS[task_model_id]
response = None
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
print(response)
content = response["choices"][0]["message"]["content"]
# Parse the function response
if content != "":
result = json.loads(content)
print(result)
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
function = getattr(toolkit_module, result["name"])
function_result = None
try:
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
return None
# app.add_middleware(SecurityHeadersMiddleware) class ChatCompletionMiddleware(BaseHTTPMiddleware):
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
return_citations = False return_citations = False
@ -182,12 +239,65 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON # Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
# Remove the citations from the body
return_citations = data.get("citations", False) return_citations = data.get("citations", False)
if "citations" in data: if "citations" in data:
del data["citations"] del data["citations"]
# Example: Add a new key-value pair or modify existing ones # Set the task model
# data["modified"] = True # Example modification task_model_id = data["model"]
if task_model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if "tool_ids" in data:
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
prompt = get_last_user_message(data["messages"])
context = ""
for tool_id in data["tool_ids"]:
response = await get_function_call_response(
prompt=prompt,
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
print(response)
if response:
context += f"\n{response}"
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
data["messages"] = add_or_update_system_message(
system_prompt, data["messages"]
)
del data["tool_ids"]
# If docs field is present, generate RAG completions
if "docs" in data: if "docs" in data:
data = {**data} data = {**data}
data["messages"], citations = rag_messages( data["messages"], citations = rag_messages(
@ -210,7 +320,6 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [ request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")), (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
@ -253,7 +362,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
yield data yield data
app.add_middleware(RAGMiddleware) app.add_middleware(ChatCompletionMiddleware)
def filter_pipeline(payload, user): def filter_pipeline(payload, user):
@ -515,6 +624,7 @@ async def get_task_config(user=Depends(get_verified_user)):
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
} }
@ -524,6 +634,7 @@ class TaskConfigForm(BaseModel):
TITLE_GENERATION_PROMPT_TEMPLATE: str TITLE_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@app.post("/api/task/config/update") @app.post("/api/task/config/update")
@ -539,6 +650,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
) )
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return { return {
"TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL": app.state.config.TASK_MODEL,
@ -546,6 +660,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
} }
@ -659,6 +774,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
return await generate_openai_chat_completion(payload, user=user) return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
return await get_function_call_response(
form_data["prompt"], form_data["tool_id"], template, model_id, user
)
@app.post("/api/chat/completions") @app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"] model_id = form_data["model"]

View File

@ -110,3 +110,8 @@ def search_query_generation_template(
), ),
) )
return template return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
return template