From a27175d67252aa01c7be503525c124e8de12abf6 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 10 Jun 2024 23:40:27 -0700 Subject: [PATCH] feat: fc integration --- backend/apps/webui/routers/tools.py | 26 +--- backend/apps/webui/utils.py | 23 ++++ backend/config.py | 20 ++- backend/main.py | 181 +++++++++++++++++++++++++--- backend/utils/task.py | 5 + 5 files changed, 215 insertions(+), 40 deletions(-) create mode 100644 backend/apps/webui/utils.py diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 60e8319e6..c2a815ee7 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -7,6 +7,7 @@ from pydantic import BaseModel import json from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse +from apps.webui.utils import load_toolkit_module_by_id from utils.utils import get_current_user, get_admin_user from utils.tools import get_tools_specs @@ -17,32 +18,13 @@ import os from config import DATA_DIR + TOOLS_DIR = f"{DATA_DIR}/tools" os.makedirs(TOOLS_DIR, exist_ok=True) router = APIRouter() - -def load_toolkit_module_from_path(tools_id, tools_path): - spec = util.spec_from_file_location(tools_id, tools_path) - module = util.module_from_spec(spec) - - try: - spec.loader.exec_module(module) - print(f"Loaded module: {module.__name__}") - if hasattr(module, "Tools"): - return module.Tools() - else: - raise Exception("No Tools class found") - except Exception as e: - print(f"Error loading module: {tools_id}") - - # Move the file to the error folder - os.rename(tools_path, f"{tools_path}.error") - raise e - - ############################ # GetToolkits ############################ @@ -89,7 +71,7 @@ async def create_new_toolkit( with open(toolkit_path, "w") as tool_file: tool_file.write(form_data.content) - toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path) + toolkit_module = load_toolkit_module_by_id(form_data.id) TOOLS = request.app.state.TOOLS TOOLS[form_data.id] = toolkit_module @@ -149,7 +131,7 @@ async def update_toolkit_by_id( with open(toolkit_path, "w") as tool_file: tool_file.write(form_data.content) - toolkit_module = load_toolkit_module_from_path(id, toolkit_path) + toolkit_module = load_toolkit_module_by_id(id) TOOLS = request.app.state.TOOLS TOOLS[id] = toolkit_module diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py new file mode 100644 index 000000000..19a8615bc --- /dev/null +++ b/backend/apps/webui/utils.py @@ -0,0 +1,23 @@ +from importlib import util +import os + +from config import TOOLS_DIR + + +def load_toolkit_module_by_id(toolkit_id): + toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py") + spec = util.spec_from_file_location(toolkit_id, toolkit_path) + module = util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + print(f"Loaded module: {module.__name__}") + if hasattr(module, "Tools"): + return module.Tools() + else: + raise Exception("No Tools class found") + except Exception as e: + print(f"Error loading module: {toolkit_id}") + # Move the file to the error folder + os.rename(toolkit_path, f"{toolkit_path}.error") + raise e diff --git a/backend/config.py b/backend/config.py index df52a4b69..32e64347e 100644 --- a/backend/config.py +++ b/backend/config.py @@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs") Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) +#################################### +# Tools DIR +#################################### + +TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") +Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # LITELLM_CONFIG #################################### @@ -669,7 +677,6 @@ Question: ), ) - SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", "task.search.prompt_length_threshold", @@ -679,6 +686,17 @@ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( ), ) +TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", + "task.tools.prompt_template", + os.environ.get( + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", + """Tools: {{TOOLS}} +If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""", + ), +) + + #################################### # WEBUI_SECRET_KEY #################################### diff --git a/backend/main.py b/backend/main.py index 99b409983..3f72c5710 100644 --- a/backend/main.py +++ b/backend/main.py @@ -47,15 +47,24 @@ from pydantic import BaseModel from typing import List, Optional from apps.webui.models.models import Models, ModelModel +from apps.webui.models.tools import Tools +from apps.webui.utils import load_toolkit_module_by_id + + from utils.utils import ( get_admin_user, get_verified_user, get_current_user, get_http_authorization_cred, ) -from utils.task import title_generation_template, search_query_generation_template +from utils.task import ( + title_generation_template, + search_query_generation_template, + tools_function_calling_generation_template, +) +from utils.misc import get_last_user_message, add_or_update_system_message -from apps.rag.utils import rag_messages +from apps.rag.utils import rag_messages, rag_template from config import ( CONFIG_DATA, @@ -82,6 +91,7 @@ from config import ( TITLE_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, AppConfig, ) from constants import ERROR_MESSAGES @@ -148,24 +158,71 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD ) +app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +) app.state.MODELS = {} origins = ["*"] -# Custom middleware to add security headers -# class SecurityHeadersMiddleware(BaseHTTPMiddleware): -# async def dispatch(self, request: Request, call_next): -# response: Response = await call_next(request) -# response.headers["Cross-Origin-Opener-Policy"] = "same-origin" -# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" -# return response + +async def get_function_call_response(prompt, tool_id, template, task_model_id, user): + tool = Tools.get_tool_by_id(tool_id) + tools_specs = json.dumps(tool.specs, indent=2) + content = tools_function_calling_generation_template(template, tools_specs) + + payload = { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + } + + payload = filter_pipeline(payload, user) + model = app.state.MODELS[task_model_id] + + response = None + if model["owned_by"] == "ollama": + response = await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + response = await generate_openai_chat_completion(payload, user=user) + + print(response) + content = response["choices"][0]["message"]["content"] + + # Parse the function response + if content != "": + result = json.loads(content) + print(result) + + # Call the function + if "name" in result: + if tool_id in webui_app.state.TOOLS: + toolkit_module = webui_app.state.TOOLS[tool_id] + else: + toolkit_module = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = toolkit_module + + function = getattr(toolkit_module, result["name"]) + function_result = None + try: + function_result = function(**result["parameters"]) + except Exception as e: + print(e) + + # Add the function result to the system prompt + if function_result: + return function_result + + return None -# app.add_middleware(SecurityHeadersMiddleware) - - -class RAGMiddleware(BaseHTTPMiddleware): +class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): return_citations = False @@ -182,12 +239,65 @@ class RAGMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} + # Remove the citations from the body return_citations = data.get("citations", False) if "citations" in data: del data["citations"] - # Example: Add a new key-value pair or modify existing ones - # data["modified"] = True # Example modification + # Set the task model + task_model_id = data["model"] + if task_model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[task_model_id]["owned_by"] == "ollama": + if ( + app.state.config.TASK_MODEL + and app.state.config.TASK_MODEL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL + else: + if ( + app.state.config.TASK_MODEL_EXTERNAL + and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + + if "tool_ids" in data: + user = get_current_user( + get_http_authorization_cred(request.headers.get("Authorization")) + ) + prompt = get_last_user_message(data["messages"]) + context = "" + + for tool_id in data["tool_ids"]: + response = await get_function_call_response( + prompt=prompt, + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + ) + print(response) + + if response: + context += f"\n{response}" + + system_prompt = rag_template( + rag_app.state.config.RAG_TEMPLATE, context, prompt + ) + + data["messages"] = add_or_update_system_message( + system_prompt, data["messages"] + ) + + del data["tool_ids"] + + # If docs field is present, generate RAG completions if "docs" in data: data = {**data} data["messages"], citations = rag_messages( @@ -210,7 +320,6 @@ class RAGMiddleware(BaseHTTPMiddleware): # Replace the request body with the modified one request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length request.headers.__dict__["_list"] = [ (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), @@ -253,7 +362,7 @@ class RAGMiddleware(BaseHTTPMiddleware): yield data -app.add_middleware(RAGMiddleware) +app.add_middleware(ChatCompletionMiddleware) def filter_pipeline(payload, user): @@ -515,6 +624,7 @@ async def get_task_config(user=Depends(get_verified_user)): "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -524,6 +634,7 @@ class TaskConfigForm(BaseModel): TITLE_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str @app.post("/api/task/config/update") @@ -539,6 +650,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD ) + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + ) return { "TASK_MODEL": app.state.config.TASK_MODEL, @@ -546,6 +660,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, } @@ -659,6 +774,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) return await generate_openai_chat_completion(payload, user=user) +@app.post("/api/task/tools/completions") +async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)): + print("get_tools_function_calling") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + + return await get_function_call_response( + form_data["prompt"], form_data["tool_id"], template, model_id, user + ) + + @app.post("/api/chat/completions") async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): model_id = form_data["model"] diff --git a/backend/utils/task.py b/backend/utils/task.py index 2239de7df..615febcdc 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -110,3 +110,8 @@ def search_query_generation_template( ), ) return template + + +def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: + template = template.replace("{{TOOLS}}", tools_specs) + return template