mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 21:42:58 +00:00
feat: fc integration
This commit is contained in:
parent
ff1cd306d8
commit
a27175d672
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
|
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
|
||||||
|
from apps.webui.utils import load_toolkit_module_by_id
|
||||||
|
|
||||||
from utils.utils import get_current_user, get_admin_user
|
from utils.utils import get_current_user, get_admin_user
|
||||||
from utils.tools import get_tools_specs
|
from utils.tools import get_tools_specs
|
||||||
@ -17,32 +18,13 @@ import os
|
|||||||
|
|
||||||
from config import DATA_DIR
|
from config import DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
TOOLS_DIR = f"{DATA_DIR}/tools"
|
TOOLS_DIR = f"{DATA_DIR}/tools"
|
||||||
os.makedirs(TOOLS_DIR, exist_ok=True)
|
os.makedirs(TOOLS_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def load_toolkit_module_from_path(tools_id, tools_path):
|
|
||||||
spec = util.spec_from_file_location(tools_id, tools_path)
|
|
||||||
module = util.module_from_spec(spec)
|
|
||||||
|
|
||||||
try:
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
print(f"Loaded module: {module.__name__}")
|
|
||||||
if hasattr(module, "Tools"):
|
|
||||||
return module.Tools()
|
|
||||||
else:
|
|
||||||
raise Exception("No Tools class found")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading module: {tools_id}")
|
|
||||||
|
|
||||||
# Move the file to the error folder
|
|
||||||
os.rename(tools_path, f"{tools_path}.error")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetToolkits
|
# GetToolkits
|
||||||
############################
|
############################
|
||||||
@ -89,7 +71,7 @@ async def create_new_toolkit(
|
|||||||
with open(toolkit_path, "w") as tool_file:
|
with open(toolkit_path, "w") as tool_file:
|
||||||
tool_file.write(form_data.content)
|
tool_file.write(form_data.content)
|
||||||
|
|
||||||
toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path)
|
toolkit_module = load_toolkit_module_by_id(form_data.id)
|
||||||
|
|
||||||
TOOLS = request.app.state.TOOLS
|
TOOLS = request.app.state.TOOLS
|
||||||
TOOLS[form_data.id] = toolkit_module
|
TOOLS[form_data.id] = toolkit_module
|
||||||
@ -149,7 +131,7 @@ async def update_toolkit_by_id(
|
|||||||
with open(toolkit_path, "w") as tool_file:
|
with open(toolkit_path, "w") as tool_file:
|
||||||
tool_file.write(form_data.content)
|
tool_file.write(form_data.content)
|
||||||
|
|
||||||
toolkit_module = load_toolkit_module_from_path(id, toolkit_path)
|
toolkit_module = load_toolkit_module_by_id(id)
|
||||||
|
|
||||||
TOOLS = request.app.state.TOOLS
|
TOOLS = request.app.state.TOOLS
|
||||||
TOOLS[id] = toolkit_module
|
TOOLS[id] = toolkit_module
|
||||||
|
23
backend/apps/webui/utils.py
Normal file
23
backend/apps/webui/utils.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from importlib import util
|
||||||
|
import os
|
||||||
|
|
||||||
|
from config import TOOLS_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def load_toolkit_module_by_id(toolkit_id):
|
||||||
|
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
|
||||||
|
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
|
||||||
|
module = util.module_from_spec(spec)
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
print(f"Loaded module: {module.__name__}")
|
||||||
|
if hasattr(module, "Tools"):
|
||||||
|
return module.Tools()
|
||||||
|
else:
|
||||||
|
raise Exception("No Tools class found")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading module: {toolkit_id}")
|
||||||
|
# Move the file to the error folder
|
||||||
|
os.rename(toolkit_path, f"{toolkit_path}.error")
|
||||||
|
raise e
|
@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
|
|||||||
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
|
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# Tools DIR
|
||||||
|
####################################
|
||||||
|
|
||||||
|
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
|
||||||
|
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# LITELLM_CONFIG
|
# LITELLM_CONFIG
|
||||||
####################################
|
####################################
|
||||||
@ -669,7 +677,6 @@ Question:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
|
||||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
|
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
|
||||||
"task.search.prompt_length_threshold",
|
"task.search.prompt_length_threshold",
|
||||||
@ -679,6 +686,17 @@ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
||||||
|
"task.tools.prompt_template",
|
||||||
|
os.environ.get(
|
||||||
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
||||||
|
"""Tools: {{TOOLS}}
|
||||||
|
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBUI_SECRET_KEY
|
# WEBUI_SECRET_KEY
|
||||||
####################################
|
####################################
|
||||||
|
181
backend/main.py
181
backend/main.py
@ -47,15 +47,24 @@ from pydantic import BaseModel
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from apps.webui.models.models import Models, ModelModel
|
from apps.webui.models.models import Models, ModelModel
|
||||||
|
from apps.webui.models.tools import Tools
|
||||||
|
from apps.webui.utils import load_toolkit_module_by_id
|
||||||
|
|
||||||
|
|
||||||
from utils.utils import (
|
from utils.utils import (
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
get_verified_user,
|
get_verified_user,
|
||||||
get_current_user,
|
get_current_user,
|
||||||
get_http_authorization_cred,
|
get_http_authorization_cred,
|
||||||
)
|
)
|
||||||
from utils.task import title_generation_template, search_query_generation_template
|
from utils.task import (
|
||||||
|
title_generation_template,
|
||||||
|
search_query_generation_template,
|
||||||
|
tools_function_calling_generation_template,
|
||||||
|
)
|
||||||
|
from utils.misc import get_last_user_message, add_or_update_system_message
|
||||||
|
|
||||||
from apps.rag.utils import rag_messages
|
from apps.rag.utils import rag_messages, rag_template
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
CONFIG_DATA,
|
CONFIG_DATA,
|
||||||
@ -82,6 +91,7 @@ from config import (
|
|||||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||||
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
AppConfig,
|
AppConfig,
|
||||||
)
|
)
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
@ -148,24 +158,71 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|||||||
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
||||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
||||||
)
|
)
|
||||||
|
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||||
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
app.state.MODELS = {}
|
app.state.MODELS = {}
|
||||||
|
|
||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
|
||||||
# Custom middleware to add security headers
|
|
||||||
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
async def get_function_call_response(prompt, tool_id, template, task_model_id, user):
|
||||||
# async def dispatch(self, request: Request, call_next):
|
tool = Tools.get_tool_by_id(tool_id)
|
||||||
# response: Response = await call_next(request)
|
tools_specs = json.dumps(tool.specs, indent=2)
|
||||||
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
|
content = tools_function_calling_generation_template(template, tools_specs)
|
||||||
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
|
|
||||||
# return response
|
payload = {
|
||||||
|
"model": task_model_id,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": content},
|
||||||
|
{"role": "user", "content": f"Query: {prompt}"},
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = filter_pipeline(payload, user)
|
||||||
|
model = app.state.MODELS[task_model_id]
|
||||||
|
|
||||||
|
response = None
|
||||||
|
if model["owned_by"] == "ollama":
|
||||||
|
response = await generate_ollama_chat_completion(
|
||||||
|
OpenAIChatCompletionForm(**payload), user=user
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await generate_openai_chat_completion(payload, user=user)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
# Parse the function response
|
||||||
|
if content != "":
|
||||||
|
result = json.loads(content)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
if "name" in result:
|
||||||
|
if tool_id in webui_app.state.TOOLS:
|
||||||
|
toolkit_module = webui_app.state.TOOLS[tool_id]
|
||||||
|
else:
|
||||||
|
toolkit_module = load_toolkit_module_by_id(tool_id)
|
||||||
|
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||||
|
|
||||||
|
function = getattr(toolkit_module, result["name"])
|
||||||
|
function_result = None
|
||||||
|
try:
|
||||||
|
function_result = function(**result["parameters"])
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
# Add the function result to the system prompt
|
||||||
|
if function_result:
|
||||||
|
return function_result
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# app.add_middleware(SecurityHeadersMiddleware)
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||||
|
|
||||||
|
|
||||||
class RAGMiddleware(BaseHTTPMiddleware):
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
return_citations = False
|
return_citations = False
|
||||||
|
|
||||||
@ -182,12 +239,65 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|||||||
# Parse string to JSON
|
# Parse string to JSON
|
||||||
data = json.loads(body_str) if body_str else {}
|
data = json.loads(body_str) if body_str else {}
|
||||||
|
|
||||||
|
# Remove the citations from the body
|
||||||
return_citations = data.get("citations", False)
|
return_citations = data.get("citations", False)
|
||||||
if "citations" in data:
|
if "citations" in data:
|
||||||
del data["citations"]
|
del data["citations"]
|
||||||
|
|
||||||
# Example: Add a new key-value pair or modify existing ones
|
# Set the task model
|
||||||
# data["modified"] = True # Example modification
|
task_model_id = data["model"]
|
||||||
|
if task_model_id not in app.state.MODELS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the user has a custom task model
|
||||||
|
# If the user has a custom task model, use that model
|
||||||
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
||||||
|
if (
|
||||||
|
app.state.config.TASK_MODEL
|
||||||
|
and app.state.config.TASK_MODEL in app.state.MODELS
|
||||||
|
):
|
||||||
|
task_model_id = app.state.config.TASK_MODEL
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
||||||
|
):
|
||||||
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
|
||||||
|
if "tool_ids" in data:
|
||||||
|
user = get_current_user(
|
||||||
|
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||||
|
)
|
||||||
|
prompt = get_last_user_message(data["messages"])
|
||||||
|
context = ""
|
||||||
|
|
||||||
|
for tool_id in data["tool_ids"]:
|
||||||
|
response = await get_function_call_response(
|
||||||
|
prompt=prompt,
|
||||||
|
tool_id=tool_id,
|
||||||
|
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
|
task_model_id=task_model_id,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
if response:
|
||||||
|
context += f"\n{response}"
|
||||||
|
|
||||||
|
system_prompt = rag_template(
|
||||||
|
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
data["messages"] = add_or_update_system_message(
|
||||||
|
system_prompt, data["messages"]
|
||||||
|
)
|
||||||
|
|
||||||
|
del data["tool_ids"]
|
||||||
|
|
||||||
|
# If docs field is present, generate RAG completions
|
||||||
if "docs" in data:
|
if "docs" in data:
|
||||||
data = {**data}
|
data = {**data}
|
||||||
data["messages"], citations = rag_messages(
|
data["messages"], citations = rag_messages(
|
||||||
@ -210,7 +320,6 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# Replace the request body with the modified one
|
# Replace the request body with the modified one
|
||||||
request._body = modified_body_bytes
|
request._body = modified_body_bytes
|
||||||
|
|
||||||
# Set custom header to ensure content-length matches new body length
|
# Set custom header to ensure content-length matches new body length
|
||||||
request.headers.__dict__["_list"] = [
|
request.headers.__dict__["_list"] = [
|
||||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||||
@ -253,7 +362,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|||||||
yield data
|
yield data
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(RAGMiddleware)
|
app.add_middleware(ChatCompletionMiddleware)
|
||||||
|
|
||||||
|
|
||||||
def filter_pipeline(payload, user):
|
def filter_pipeline(payload, user):
|
||||||
@ -515,6 +624,7 @@ async def get_task_config(user=Depends(get_verified_user)):
|
|||||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||||
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -524,6 +634,7 @@ class TaskConfigForm(BaseModel):
|
|||||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
|
||||||
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/task/config/update")
|
@app.post("/api/task/config/update")
|
||||||
@ -539,6 +650,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|||||||
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
||||||
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
||||||
)
|
)
|
||||||
|
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||||
|
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||||
@ -546,6 +660,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|||||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||||
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -659,6 +774,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|||||||
return await generate_openai_chat_completion(payload, user=user)
|
return await generate_openai_chat_completion(payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/task/tools/completions")
|
||||||
|
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
|
||||||
|
print("get_tools_function_calling")
|
||||||
|
|
||||||
|
model_id = form_data["model"]
|
||||||
|
if model_id not in app.state.MODELS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the user has a custom task model
|
||||||
|
# If the user has a custom task model, use that model
|
||||||
|
if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
||||||
|
if app.state.config.TASK_MODEL:
|
||||||
|
task_model_id = app.state.config.TASK_MODEL
|
||||||
|
if task_model_id in app.state.MODELS:
|
||||||
|
model_id = task_model_id
|
||||||
|
else:
|
||||||
|
if app.state.config.TASK_MODEL_EXTERNAL:
|
||||||
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
if task_model_id in app.state.MODELS:
|
||||||
|
model_id = task_model_id
|
||||||
|
|
||||||
|
print(model_id)
|
||||||
|
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
return await get_function_call_response(
|
||||||
|
form_data["prompt"], form_data["tool_id"], template, model_id, user
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/chat/completions")
|
@app.post("/api/chat/completions")
|
||||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
@ -110,3 +110,8 @@ def search_query_generation_template(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
|
||||||
|
template = template.replace("{{TOOLS}}", tools_specs)
|
||||||
|
return template
|
||||||
|
Loading…
Reference in New Issue
Block a user