mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 21:42:58 +00:00
feat: fc integration
This commit is contained in:
parent
ff1cd306d8
commit
a27175d672
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
from utils.tools import get_tools_specs
|
||||
@ -17,32 +18,13 @@ import os
|
||||
|
||||
from config import DATA_DIR
|
||||
|
||||
|
||||
TOOLS_DIR = f"{DATA_DIR}/tools"
|
||||
os.makedirs(TOOLS_DIR, exist_ok=True)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def load_toolkit_module_from_path(tools_id, tools_path):
|
||||
spec = util.spec_from_file_location(tools_id, tools_path)
|
||||
module = util.module_from_spec(spec)
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
print(f"Loaded module: {module.__name__}")
|
||||
if hasattr(module, "Tools"):
|
||||
return module.Tools()
|
||||
else:
|
||||
raise Exception("No Tools class found")
|
||||
except Exception as e:
|
||||
print(f"Error loading module: {tools_id}")
|
||||
|
||||
# Move the file to the error folder
|
||||
os.rename(tools_path, f"{tools_path}.error")
|
||||
raise e
|
||||
|
||||
|
||||
############################
|
||||
# GetToolkits
|
||||
############################
|
||||
@ -89,7 +71,7 @@ async def create_new_toolkit(
|
||||
with open(toolkit_path, "w") as tool_file:
|
||||
tool_file.write(form_data.content)
|
||||
|
||||
toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path)
|
||||
toolkit_module = load_toolkit_module_by_id(form_data.id)
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[form_data.id] = toolkit_module
|
||||
@ -149,7 +131,7 @@ async def update_toolkit_by_id(
|
||||
with open(toolkit_path, "w") as tool_file:
|
||||
tool_file.write(form_data.content)
|
||||
|
||||
toolkit_module = load_toolkit_module_from_path(id, toolkit_path)
|
||||
toolkit_module = load_toolkit_module_by_id(id)
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[id] = toolkit_module
|
||||
|
23
backend/apps/webui/utils.py
Normal file
23
backend/apps/webui/utils.py
Normal file
@ -0,0 +1,23 @@
|
||||
from importlib import util
|
||||
import os
|
||||
|
||||
from config import TOOLS_DIR
|
||||
|
||||
|
||||
def load_toolkit_module_by_id(toolkit_id):
|
||||
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
|
||||
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
|
||||
module = util.module_from_spec(spec)
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
print(f"Loaded module: {module.__name__}")
|
||||
if hasattr(module, "Tools"):
|
||||
return module.Tools()
|
||||
else:
|
||||
raise Exception("No Tools class found")
|
||||
except Exception as e:
|
||||
print(f"Error loading module: {toolkit_id}")
|
||||
# Move the file to the error folder
|
||||
os.rename(toolkit_path, f"{toolkit_path}.error")
|
||||
raise e
|
@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
|
||||
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# Tools DIR
|
||||
####################################
|
||||
|
||||
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
|
||||
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# LITELLM_CONFIG
|
||||
####################################
|
||||
@ -669,7 +677,6 @@ Question:
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
|
||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
|
||||
"task.search.prompt_length_threshold",
|
||||
@ -679,6 +686,17 @@ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
|
||||
),
|
||||
)
|
||||
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
||||
"task.tools.prompt_template",
|
||||
os.environ.get(
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
||||
"""Tools: {{TOOLS}}
|
||||
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# WEBUI_SECRET_KEY
|
||||
####################################
|
||||
|
181
backend/main.py
181
backend/main.py
@ -47,15 +47,24 @@ from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from apps.webui.models.models import Models, ModelModel
|
||||
from apps.webui.models.tools import Tools
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
|
||||
from utils.utils import (
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
get_current_user,
|
||||
get_http_authorization_cred,
|
||||
)
|
||||
from utils.task import title_generation_template, search_query_generation_template
|
||||
from utils.task import (
|
||||
title_generation_template,
|
||||
search_query_generation_template,
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from utils.misc import get_last_user_message, add_or_update_system_message
|
||||
|
||||
from apps.rag.utils import rag_messages
|
||||
from apps.rag.utils import rag_messages, rag_template
|
||||
|
||||
from config import (
|
||||
CONFIG_DATA,
|
||||
@ -82,6 +91,7 @@ from config import (
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
AppConfig,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
@ -148,24 +158,71 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
||||
)
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
app.state.MODELS = {}
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
# Custom middleware to add security headers
|
||||
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
# async def dispatch(self, request: Request, call_next):
|
||||
# response: Response = await call_next(request)
|
||||
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
|
||||
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
|
||||
# return response
|
||||
|
||||
async def get_function_call_response(prompt, tool_id, template, task_model_id, user):
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
tools_specs = json.dumps(tool.specs, indent=2)
|
||||
content = tools_function_calling_generation_template(template, tools_specs)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": content},
|
||||
{"role": "user", "content": f"Query: {prompt}"},
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
payload = filter_pipeline(payload, user)
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
response = None
|
||||
if model["owned_by"] == "ollama":
|
||||
response = await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
else:
|
||||
response = await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
print(response)
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
|
||||
# Parse the function response
|
||||
if content != "":
|
||||
result = json.loads(content)
|
||||
print(result)
|
||||
|
||||
# Call the function
|
||||
if "name" in result:
|
||||
if tool_id in webui_app.state.TOOLS:
|
||||
toolkit_module = webui_app.state.TOOLS[tool_id]
|
||||
else:
|
||||
toolkit_module = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||
|
||||
function = getattr(toolkit_module, result["name"])
|
||||
function_result = None
|
||||
try:
|
||||
function_result = function(**result["parameters"])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# Add the function result to the system prompt
|
||||
if function_result:
|
||||
return function_result
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
class RAGMiddleware(BaseHTTPMiddleware):
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
return_citations = False
|
||||
|
||||
@ -182,12 +239,65 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
# Remove the citations from the body
|
||||
return_citations = data.get("citations", False)
|
||||
if "citations" in data:
|
||||
del data["citations"]
|
||||
|
||||
# Example: Add a new key-value pair or modify existing ones
|
||||
# data["modified"] = True # Example modification
|
||||
# Set the task model
|
||||
task_model_id = data["model"]
|
||||
if task_model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
||||
if (
|
||||
app.state.config.TASK_MODEL
|
||||
and app.state.config.TASK_MODEL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
else:
|
||||
if (
|
||||
app.state.config.TASK_MODEL_EXTERNAL
|
||||
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
|
||||
if "tool_ids" in data:
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||
)
|
||||
prompt = get_last_user_message(data["messages"])
|
||||
context = ""
|
||||
|
||||
for tool_id in data["tool_ids"]:
|
||||
response = await get_function_call_response(
|
||||
prompt=prompt,
|
||||
tool_id=tool_id,
|
||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
task_model_id=task_model_id,
|
||||
user=user,
|
||||
)
|
||||
print(response)
|
||||
|
||||
if response:
|
||||
context += f"\n{response}"
|
||||
|
||||
system_prompt = rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
||||
)
|
||||
|
||||
data["messages"] = add_or_update_system_message(
|
||||
system_prompt, data["messages"]
|
||||
)
|
||||
|
||||
del data["tool_ids"]
|
||||
|
||||
# If docs field is present, generate RAG completions
|
||||
if "docs" in data:
|
||||
data = {**data}
|
||||
data["messages"], citations = rag_messages(
|
||||
@ -210,7 +320,6 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
@ -253,7 +362,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
yield data
|
||||
|
||||
|
||||
app.add_middleware(RAGMiddleware)
|
||||
app.add_middleware(ChatCompletionMiddleware)
|
||||
|
||||
|
||||
def filter_pipeline(payload, user):
|
||||
@ -515,6 +624,7 @@ async def get_task_config(user=Depends(get_verified_user)):
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
@ -524,6 +634,7 @@ class TaskConfigForm(BaseModel):
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
||||
|
||||
|
||||
@app.post("/api/task/config/update")
|
||||
@ -539,6 +650,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
||||
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
||||
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
||||
)
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
return {
|
||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||
@ -546,6 +660,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
@ -659,6 +774,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
return await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/tools/completions")
|
||||
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("get_tools_function_calling")
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
||||
if app.state.config.TASK_MODEL:
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
else:
|
||||
if app.state.config.TASK_MODEL_EXTERNAL:
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
|
||||
print(model_id)
|
||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
|
||||
return await get_function_call_response(
|
||||
form_data["prompt"], form_data["tool_id"], template, model_id, user
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||
model_id = form_data["model"]
|
||||
|
@ -110,3 +110,8 @@ def search_query_generation_template(
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
|
||||
template = template.replace("{{TOOLS}}", tools_specs)
|
||||
return template
|
||||
|
Loading…
Reference in New Issue
Block a user