mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/dev' into feat/oauth
This commit is contained in:
@@ -77,6 +77,7 @@ from apps.rag.search.serpstack import search_serpstack
|
||||
from apps.rag.search.serply import search_serply
|
||||
from apps.rag.search.duckduckgo import search_duckduckgo
|
||||
from apps.rag.search.tavily import search_tavily
|
||||
from apps.rag.search.jina_search import search_jina
|
||||
|
||||
from utils.misc import (
|
||||
calculate_sha256,
|
||||
@@ -856,6 +857,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
)
|
||||
else:
|
||||
raise Exception("No TAVILY_API_KEY found in environment variables")
|
||||
elif engine == "jina":
|
||||
return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
|
||||
|
||||
41
backend/apps/rag/search/jina_search.py
Normal file
41
backend/apps/rag/search/jina_search.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import logging
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_jina(query: str, count: int) -> list[SearchResult]:
|
||||
"""
|
||||
Search using Jina's Search API and return the results as a list of SearchResult objects.
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return
|
||||
|
||||
Returns:
|
||||
List[SearchResult]: A list of search results
|
||||
"""
|
||||
jina_search_endpoint = "https://s.jina.ai/"
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
url = str(URL(jina_search_endpoint + query))
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data["data"][:count]:
|
||||
results.append(
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("content"),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields("tool", valves=pw.TextField(null=True))
|
||||
migrator.add_fields("function", valves=pw.TextField(null=True))
|
||||
migrator.add_fields("function", is_active=pw.BooleanField(default=False))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("tool", "valves")
|
||||
migrator.remove_fields("function", "valves")
|
||||
migrator.remove_fields("function", "is_active")
|
||||
@@ -105,13 +105,15 @@ async def get_status():
|
||||
|
||||
|
||||
async def get_pipe_models():
|
||||
pipes = Functions.get_functions_by_type("pipe")
|
||||
pipes = Functions.get_functions_by_type("pipe", active_only=True)
|
||||
pipe_models = []
|
||||
|
||||
for pipe in pipes:
|
||||
# Check if function is already loaded
|
||||
if pipe.id not in app.state.FUNCTIONS:
|
||||
function_module, function_type = load_function_module_by_id(pipe.id)
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
pipe.id
|
||||
)
|
||||
app.state.FUNCTIONS[pipe.id] = function_module
|
||||
else:
|
||||
function_module = app.state.FUNCTIONS[pipe.id]
|
||||
@@ -132,7 +134,9 @@ async def get_pipe_models():
|
||||
manifold_pipe_name = p["name"]
|
||||
|
||||
if hasattr(function_module, "name"):
|
||||
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
|
||||
manifold_pipe_name = (
|
||||
f"{function_module.name}{manifold_pipe_name}"
|
||||
)
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
|
||||
@@ -5,8 +5,11 @@ from typing import List, Union, Optional
|
||||
import time
|
||||
import logging
|
||||
from apps.webui.internal.db import DB, JSONField
|
||||
from apps.webui.models.users import Users
|
||||
|
||||
import json
|
||||
import copy
|
||||
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
@@ -25,6 +28,8 @@ class Function(Model):
|
||||
type = TextField()
|
||||
content = TextField()
|
||||
meta = JSONField()
|
||||
valves = JSONField()
|
||||
is_active = BooleanField(default=False)
|
||||
updated_at = BigIntegerField()
|
||||
created_at = BigIntegerField()
|
||||
|
||||
@@ -34,6 +39,7 @@ class Function(Model):
|
||||
|
||||
class FunctionMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
manifest: Optional[dict] = {}
|
||||
|
||||
|
||||
class FunctionModel(BaseModel):
|
||||
@@ -43,6 +49,7 @@ class FunctionModel(BaseModel):
|
||||
type: str
|
||||
content: str
|
||||
meta: FunctionMeta
|
||||
is_active: bool = False
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
@@ -58,6 +65,7 @@ class FunctionResponse(BaseModel):
|
||||
type: str
|
||||
name: str
|
||||
meta: FunctionMeta
|
||||
is_active: bool
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
@@ -69,6 +77,10 @@ class FunctionForm(BaseModel):
|
||||
meta: FunctionMeta
|
||||
|
||||
|
||||
class FunctionValves(BaseModel):
|
||||
valves: Optional[dict] = None
|
||||
|
||||
|
||||
class FunctionsTable:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
@@ -104,16 +116,98 @@ class FunctionsTable:
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_functions(self) -> List[FunctionModel]:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function)) for function in Function.select()
|
||||
]
|
||||
def get_functions(self, active_only=False) -> List[FunctionModel]:
|
||||
if active_only:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function))
|
||||
for function in Function.select().where(Function.is_active == True)
|
||||
]
|
||||
else:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function))
|
||||
for function in Function.select()
|
||||
]
|
||||
|
||||
def get_functions_by_type(self, type: str) -> List[FunctionModel]:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function))
|
||||
for function in Function.select().where(Function.type == type)
|
||||
]
|
||||
def get_functions_by_type(
|
||||
self, type: str, active_only=False
|
||||
) -> List[FunctionModel]:
|
||||
if active_only:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function))
|
||||
for function in Function.select().where(
|
||||
Function.type == type, Function.is_active == True
|
||||
)
|
||||
]
|
||||
else:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function))
|
||||
for function in Function.select().where(Function.type == type)
|
||||
]
|
||||
|
||||
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
try:
|
||||
function = Function.get(Function.id == id)
|
||||
return function.valves if function.valves else {}
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
def update_function_valves_by_id(
|
||||
self, id: str, valves: dict
|
||||
) -> Optional[FunctionValves]:
|
||||
try:
|
||||
query = Function.update(
|
||||
**{"valves": valves},
|
||||
updated_at=int(time.time()),
|
||||
).where(Function.id == id)
|
||||
query.execute()
|
||||
|
||||
function = Function.get(Function.id == id)
|
||||
return FunctionValves(**model_to_dict(function))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user_settings = user.settings.model_dump()
|
||||
|
||||
# Check if user has "functions" and "valves" settings
|
||||
if "functions" not in user_settings:
|
||||
user_settings["functions"] = {}
|
||||
if "valves" not in user_settings["functions"]:
|
||||
user_settings["functions"]["valves"] = {}
|
||||
|
||||
return user_settings["functions"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, valves: dict
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user_settings = user.settings.model_dump()
|
||||
|
||||
# Check if user has "functions" and "valves" settings
|
||||
if "functions" not in user_settings:
|
||||
user_settings["functions"] = {}
|
||||
if "valves" not in user_settings["functions"]:
|
||||
user_settings["functions"]["valves"] = {}
|
||||
|
||||
user_settings["functions"]["valves"][id] = valves
|
||||
|
||||
# Update the user settings in the database
|
||||
query = Users.update_user_by_id(user_id, {"settings": user_settings})
|
||||
query.execute()
|
||||
|
||||
return user_settings["functions"]["valves"][id]
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||
try:
|
||||
@@ -128,6 +222,19 @@ class FunctionsTable:
|
||||
except:
|
||||
return None
|
||||
|
||||
def deactivate_all_functions(self) -> Optional[bool]:
|
||||
try:
|
||||
query = Function.update(
|
||||
**{"is_active": False},
|
||||
updated_at=int(time.time()),
|
||||
)
|
||||
|
||||
query.execute()
|
||||
|
||||
return True
|
||||
except:
|
||||
return None
|
||||
|
||||
def delete_function_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
query = Function.delete().where((Function.id == id))
|
||||
|
||||
@@ -5,8 +5,11 @@ from typing import List, Union, Optional
|
||||
import time
|
||||
import logging
|
||||
from apps.webui.internal.db import DB, JSONField
|
||||
from apps.webui.models.users import Users
|
||||
|
||||
import json
|
||||
import copy
|
||||
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
@@ -25,6 +28,7 @@ class Tool(Model):
|
||||
content = TextField()
|
||||
specs = JSONField()
|
||||
meta = JSONField()
|
||||
valves = JSONField()
|
||||
updated_at = BigIntegerField()
|
||||
created_at = BigIntegerField()
|
||||
|
||||
@@ -34,6 +38,7 @@ class Tool(Model):
|
||||
|
||||
class ToolMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
manifest: Optional[dict] = {}
|
||||
|
||||
|
||||
class ToolModel(BaseModel):
|
||||
@@ -68,6 +73,10 @@ class ToolForm(BaseModel):
|
||||
meta: ToolMeta
|
||||
|
||||
|
||||
class ToolValves(BaseModel):
|
||||
valves: Optional[dict] = None
|
||||
|
||||
|
||||
class ToolsTable:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
@@ -106,6 +115,69 @@ class ToolsTable:
|
||||
def get_tools(self) -> List[ToolModel]:
|
||||
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
|
||||
|
||||
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
try:
|
||||
tool = Tool.get(Tool.id == id)
|
||||
return tool.valves if tool.valves else {}
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||
try:
|
||||
query = Tool.update(
|
||||
**{"valves": valves},
|
||||
updated_at=int(time.time()),
|
||||
).where(Tool.id == id)
|
||||
query.execute()
|
||||
|
||||
tool = Tool.get(Tool.id == id)
|
||||
return ToolValves(**model_to_dict(tool))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user_settings = user.settings.model_dump()
|
||||
|
||||
# Check if user has "tools" and "valves" settings
|
||||
if "tools" not in user_settings:
|
||||
user_settings["tools"] = {}
|
||||
if "valves" not in user_settings["tools"]:
|
||||
user_settings["tools"]["valves"] = {}
|
||||
|
||||
return user_settings["tools"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, valves: dict
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user_settings = user.settings.model_dump()
|
||||
|
||||
# Check if user has "tools" and "valves" settings
|
||||
if "tools" not in user_settings:
|
||||
user_settings["tools"] = {}
|
||||
if "valves" not in user_settings["tools"]:
|
||||
user_settings["tools"]["valves"] = {}
|
||||
|
||||
user_settings["tools"]["valves"][id] = valves
|
||||
|
||||
# Update the user settings in the database
|
||||
query = Users.update_user_by_id(user_id, {"settings": user_settings})
|
||||
query.execute()
|
||||
|
||||
return user_settings["tools"]["valves"][id]
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||
try:
|
||||
query = Tool.update(
|
||||
|
||||
@@ -194,6 +194,29 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file:
|
||||
file_path = Path(file.meta["path"])
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Delete File By Id
|
||||
############################
|
||||
|
||||
@@ -69,7 +69,10 @@ async def create_new_function(
|
||||
with open(function_path, "w") as function_file:
|
||||
function_file.write(form_data.content)
|
||||
|
||||
function_module, function_type = load_function_module_by_id(form_data.id)
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
form_data.id
|
||||
)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
FUNCTIONS = request.app.state.FUNCTIONS
|
||||
FUNCTIONS[form_data.id] = function_module
|
||||
@@ -117,13 +120,40 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# ToggleFunctionById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
|
||||
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function = Functions.update_function_by_id(
|
||||
id, {"is_active": not function.is_active}
|
||||
)
|
||||
|
||||
if function:
|
||||
return function
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateFunctionById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
|
||||
async def update_toolkit_by_id(
|
||||
async def update_function_by_id(
|
||||
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
|
||||
):
|
||||
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
|
||||
@@ -132,7 +162,8 @@ async def update_toolkit_by_id(
|
||||
with open(function_path, "w") as function_file:
|
||||
function_file.write(form_data.content)
|
||||
|
||||
function_module, function_type = load_function_module_by_id(id)
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
FUNCTIONS = request.app.state.FUNCTIONS
|
||||
FUNCTIONS[id] = function_module
|
||||
@@ -178,3 +209,188 @@ async def delete_function_by_id(
|
||||
os.remove(function_path)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
############################
|
||||
# GetFunctionValves
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
try:
|
||||
valves = Functions.get_function_valves_by_id(id)
|
||||
return valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetFunctionValvesSpec
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
||||
async def get_function_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
return Valves.schema()
|
||||
return None
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateFunctionValves
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
||||
async def update_function_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
|
||||
try:
|
||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||
valves = Valves(**form_data)
|
||||
Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# FunctionUserValves
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
||||
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
try:
|
||||
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
|
||||
return user_valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
||||
async def get_function_user_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
return UserValves.schema()
|
||||
return None
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
||||
async def update_function_user_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
|
||||
try:
|
||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||
user_valves = UserValves(**form_data)
|
||||
Functions.update_user_valves_by_id_and_user_id(
|
||||
id, user.id, user_valves.model_dump()
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
@@ -6,10 +6,12 @@ from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
|
||||
from apps.webui.models.users import Users
|
||||
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
from utils.utils import get_admin_user, get_verified_user
|
||||
from utils.tools import get_tools_specs
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
@@ -32,7 +34,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=List[ToolResponse])
|
||||
async def get_toolkits(user=Depends(get_current_user)):
|
||||
async def get_toolkits(user=Depends(get_verified_user)):
|
||||
toolkits = [toolkit for toolkit in Tools.get_tools()]
|
||||
return toolkits
|
||||
|
||||
@@ -72,7 +74,8 @@ async def create_new_toolkit(
|
||||
with open(toolkit_path, "w") as tool_file:
|
||||
tool_file.write(form_data.content)
|
||||
|
||||
toolkit_module = load_toolkit_module_by_id(form_data.id)
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(form_data.id)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[form_data.id] = toolkit_module
|
||||
@@ -136,7 +139,8 @@ async def update_toolkit_by_id(
|
||||
with open(toolkit_path, "w") as tool_file:
|
||||
tool_file.write(form_data.content)
|
||||
|
||||
toolkit_module = load_toolkit_module_by_id(id)
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[id] = toolkit_module
|
||||
@@ -185,3 +189,187 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
|
||||
os.remove(toolkit_path)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
############################
|
||||
# GetToolValves
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||
async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
if toolkit:
|
||||
try:
|
||||
valves = Tools.get_tool_valves_by_id(id)
|
||||
return valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetToolValvesSpec
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
|
||||
async def get_toolkit_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
if toolkit:
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "Valves"):
|
||||
Valves = toolkit_module.Valves
|
||||
return Valves.schema()
|
||||
return None
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateToolValves
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
|
||||
async def update_toolkit_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
|
||||
):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
if toolkit:
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "Valves"):
|
||||
Valves = toolkit_module.Valves
|
||||
|
||||
try:
|
||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||
valves = Valves(**form_data)
|
||||
Tools.update_tool_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# ToolUserValves
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
||||
async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
if toolkit:
|
||||
try:
|
||||
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
|
||||
return user_valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
|
||||
async def get_toolkit_user_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
if toolkit:
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "UserValves"):
|
||||
UserValves = toolkit_module.UserValves
|
||||
return UserValves.schema()
|
||||
return None
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
|
||||
async def update_toolkit_user_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
|
||||
if toolkit:
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "UserValves"):
|
||||
UserValves = toolkit_module.UserValves
|
||||
|
||||
try:
|
||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||
user_valves = UserValves(**form_data)
|
||||
Tools.update_user_valves_by_id_and_user_id(
|
||||
id, user.id, user_valves.model_dump()
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
@@ -1,19 +1,56 @@
|
||||
from importlib import util
|
||||
import os
|
||||
import re
|
||||
|
||||
from config import TOOLS_DIR, FUNCTIONS_DIR
|
||||
|
||||
|
||||
def extract_frontmatter(file_path):
|
||||
"""
|
||||
Extract frontmatter as a dictionary from the specified file path.
|
||||
"""
|
||||
frontmatter = {}
|
||||
frontmatter_started = False
|
||||
frontmatter_ended = False
|
||||
frontmatter_pattern = re.compile(r"^\s*([a-z_]+):\s*(.*)\s*$", re.IGNORECASE)
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
if '"""' in line:
|
||||
if not frontmatter_started:
|
||||
frontmatter_started = True
|
||||
continue # skip the line with the opening triple quotes
|
||||
else:
|
||||
frontmatter_ended = True
|
||||
break
|
||||
|
||||
if frontmatter_started and not frontmatter_ended:
|
||||
match = frontmatter_pattern.match(line)
|
||||
if match:
|
||||
key, value = match.groups()
|
||||
frontmatter[key.strip()] = value.strip()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file {file_path} does not exist.")
|
||||
return {}
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return {}
|
||||
|
||||
return frontmatter
|
||||
|
||||
|
||||
def load_toolkit_module_by_id(toolkit_id):
|
||||
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
|
||||
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
|
||||
module = util.module_from_spec(spec)
|
||||
frontmatter = extract_frontmatter(toolkit_path)
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
print(f"Loaded module: {module.__name__}")
|
||||
if hasattr(module, "Tools"):
|
||||
return module.Tools()
|
||||
return module.Tools(), frontmatter
|
||||
else:
|
||||
raise Exception("No Tools class found")
|
||||
except Exception as e:
|
||||
@@ -28,14 +65,15 @@ def load_function_module_by_id(function_id):
|
||||
|
||||
spec = util.spec_from_file_location(function_id, function_path)
|
||||
module = util.module_from_spec(spec)
|
||||
frontmatter = extract_frontmatter(function_path)
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
print(f"Loaded module: {module.__name__}")
|
||||
if hasattr(module, "Pipe"):
|
||||
return module.Pipe(), "pipe"
|
||||
return module.Pipe(), "pipe", frontmatter
|
||||
elif hasattr(module, "Filter"):
|
||||
return module.Filter(), "filter"
|
||||
return module.Filter(), "filter", frontmatter
|
||||
else:
|
||||
raise Exception("No Function class found")
|
||||
except Exception as e:
|
||||
|
||||
@@ -167,6 +167,12 @@ for version in soup.find_all("h2"):
|
||||
CHANGELOG = changelog_json
|
||||
|
||||
|
||||
####################################
|
||||
# SAFE_MODE
|
||||
####################################
|
||||
|
||||
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
|
||||
|
||||
####################################
|
||||
# WEBUI_BUILD_HASH
|
||||
####################################
|
||||
|
||||
370
backend/main.py
370
backend/main.py
@@ -62,9 +62,7 @@ from apps.webui.models.functions import Functions
|
||||
from apps.webui.models.users import Users
|
||||
|
||||
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
from utils.misc import parse_duration
|
||||
from utils.utils import (
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
@@ -82,6 +80,7 @@ from utils.misc import (
|
||||
get_last_user_message,
|
||||
add_or_update_system_message,
|
||||
stream_message_template,
|
||||
parse_duration,
|
||||
)
|
||||
|
||||
from apps.rag.utils import get_rag_context, rag_template
|
||||
@@ -113,6 +112,7 @@ from config import (
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
SAFE_MODE,
|
||||
OAUTH_PROVIDERS,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
||||
@@ -124,6 +124,11 @@ from config import (
|
||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from utils.webhook import post_webhook
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
Functions.deactivate_all_functions()
|
||||
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
@@ -271,7 +276,7 @@ async def get_function_call_response(
|
||||
if tool_id in webui_app.state.TOOLS:
|
||||
toolkit_module = webui_app.state.TOOLS[tool_id]
|
||||
else:
|
||||
toolkit_module = load_toolkit_module_by_id(tool_id)
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||
|
||||
file_handler = False
|
||||
@@ -280,6 +285,14 @@ async def get_function_call_response(
|
||||
file_handler = True
|
||||
print("file_handler: ", file_handler)
|
||||
|
||||
if hasattr(toolkit_module, "valves") and hasattr(
|
||||
toolkit_module, "Valves"
|
||||
):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id)
|
||||
toolkit_module.valves = toolkit_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
function = getattr(toolkit_module, result["name"])
|
||||
function_result = None
|
||||
try:
|
||||
@@ -289,16 +302,24 @@ async def get_function_call_response(
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
# Call the function with the '__user__' parameter included
|
||||
params = {
|
||||
**params,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(toolkit_module, "UserValves"):
|
||||
__user__["valves"] = toolkit_module.UserValves(
|
||||
**Tools.get_user_valves_by_id_and_user_id(
|
||||
tool_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
if "__messages__" in sig.parameters:
|
||||
# Call the function with the '__messages__' parameter included
|
||||
params = {
|
||||
@@ -386,54 +407,94 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
return (function.valves if function.valves else {}).get(
|
||||
"priority", 0
|
||||
)
|
||||
return 0
|
||||
|
||||
filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type(
|
||||
"filter", active_only=True
|
||||
)
|
||||
]
|
||||
# Check if the model has any filters
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(
|
||||
filter_id
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
filter_ids.sort(key=get_priority)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = (
|
||||
load_function_module_by_id(filter_id)
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "inlet"):
|
||||
inlet = function_module.inlet
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
data = await inlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = inlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
if hasattr(function_module, "valves") and hasattr(
|
||||
function_module, "Valves"
|
||||
):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
try:
|
||||
if hasattr(function_module, "inlet"):
|
||||
inlet = function_module.inlet
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(inlet)
|
||||
params = {"body": data}
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if "__id__" in sig.parameters:
|
||||
params = {
|
||||
**params,
|
||||
"__id__": filter_id,
|
||||
}
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
data = await inlet(**params)
|
||||
else:
|
||||
data = inlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
# Set the task model
|
||||
task_model_id = data["model"]
|
||||
@@ -857,12 +918,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
|
||||
pipe = model.get("pipe")
|
||||
if pipe:
|
||||
form_data["user"] = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
async def job():
|
||||
pipe_id = form_data["model"]
|
||||
@@ -870,14 +925,62 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
|
||||
print(pipe_id)
|
||||
|
||||
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
|
||||
# Check if function is already loaded
|
||||
if pipe_id not in webui_app.state.FUNCTIONS:
|
||||
function_module, function_type, frontmatter = (
|
||||
load_function_module_by_id(pipe_id)
|
||||
)
|
||||
webui_app.state.FUNCTIONS[pipe_id] = function_module
|
||||
else:
|
||||
function_module = webui_app.state.FUNCTIONS[pipe_id]
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(
|
||||
function_module, "Valves"
|
||||
):
|
||||
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
pipe = function_module.pipe
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(pipe)
|
||||
params = {"body": form_data}
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
pipe_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if form_data["stream"]:
|
||||
|
||||
async def stream_content():
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(body=form_data)
|
||||
else:
|
||||
res = pipe(body=form_data)
|
||||
try:
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(**params)
|
||||
else:
|
||||
res = pipe(**params)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
|
||||
return
|
||||
|
||||
if isinstance(res, str):
|
||||
message = stream_message_template(form_data["model"], res)
|
||||
@@ -922,10 +1025,20 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
stream_content(), media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
|
||||
try:
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(**params)
|
||||
else:
|
||||
res = pipe(**params)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return {"error": {"detail": str(e)}}
|
||||
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(body=form_data)
|
||||
res = await pipe(**params)
|
||||
else:
|
||||
res = pipe(body=form_data)
|
||||
res = pipe(**params)
|
||||
|
||||
if isinstance(res, dict):
|
||||
return res
|
||||
@@ -1008,7 +1121,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": {"id": user.id, "name": user.name, "role": user.role},
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
},
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
@@ -1033,49 +1151,88 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
else:
|
||||
pass
|
||||
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
# Check if the model has any filters
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(
|
||||
filter_id
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "outlet"):
|
||||
outlet = function_module.outlet
|
||||
if inspect.iscoroutinefunction(outlet):
|
||||
data = await outlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = outlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
# Sort filter_ids by priority, using the get_priority function
|
||||
filter_ids.sort(key=get_priority)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = (
|
||||
load_function_module_by_id(filter_id)
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(
|
||||
function_module, "Valves"
|
||||
):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "outlet"):
|
||||
outlet = function_module.outlet
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(outlet)
|
||||
params = {"body": data}
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if "__id__" in sig.parameters:
|
||||
params = {
|
||||
**params,
|
||||
"__id__": filter_id,
|
||||
}
|
||||
|
||||
if inspect.iscoroutinefunction(outlet):
|
||||
data = await outlet(**params)
|
||||
else:
|
||||
data = outlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -1989,7 +2146,6 @@ async def get_manifest_json():
|
||||
"start_url": "/",
|
||||
"display": "standalone",
|
||||
"background_color": "#343541",
|
||||
"theme_color": "#343541",
|
||||
"orientation": "portrait-primary",
|
||||
"icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
|
||||
}
|
||||
|
||||
@@ -17,7 +17,9 @@ peewee-migrate==1.12.2
|
||||
psycopg2-binary==2.9.9
|
||||
PyMySQL==1.1.1
|
||||
bcrypt==4.1.3
|
||||
|
||||
SQLAlchemy
|
||||
pymongo
|
||||
redis
|
||||
boto3==1.34.110
|
||||
|
||||
argon2-cffi==23.1.0
|
||||
|
||||
@@ -20,7 +20,9 @@ def get_tools_specs(tools) -> List[dict]:
|
||||
function_list = [
|
||||
{"name": func, "function": getattr(tools, func)}
|
||||
for func in dir(tools)
|
||||
if callable(getattr(tools, func)) and not func.startswith("__")
|
||||
if callable(getattr(tools, func))
|
||||
and not func.startswith("__")
|
||||
and not inspect.isclass(getattr(tools, func))
|
||||
]
|
||||
|
||||
specs = []
|
||||
@@ -65,6 +67,7 @@ def get_tools_specs(tools) -> List[dict]:
|
||||
function
|
||||
).parameters.items()
|
||||
if param.default is param.empty
|
||||
and not (name.startswith("__") and name.endswith("__"))
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user