mirror of
https://github.com/open-webui/open-webui
synced 2025-03-16 02:17:33 +00:00
refac
This commit is contained in:
parent
99db82a161
commit
63ba8145b9
@ -26,7 +26,7 @@ from apps.webui.models.files import (
|
|||||||
FileModel,
|
FileModel,
|
||||||
FileModelResponse,
|
FileModelResponse,
|
||||||
)
|
)
|
||||||
from utils.utils import get_verified_user, get_admin_user
|
from utils.utils import get_current_user, get_admin_user
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
from importlib import util
|
from importlib import util
|
||||||
@ -50,7 +50,7 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/")
|
@router.post("/")
|
||||||
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
|
def upload_file(file: UploadFile = File(...), user=Depends(get_current_user)):
|
||||||
log.info(f"file.content_type: {file.content_type}")
|
log.info(f"file.content_type: {file.content_type}")
|
||||||
try:
|
try:
|
||||||
unsanitized_filename = file.filename
|
unsanitized_filename = file.filename
|
||||||
@ -105,7 +105,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[FileModel])
|
@router.get("/", response_model=list[FileModel])
|
||||||
async def list_files(user=Depends(get_verified_user)):
|
async def list_files(user=Depends(get_current_user)):
|
||||||
files = Files.get_files()
|
files = Files.get_files()
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@ -153,7 +153,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{id}", response_model=Optional[FileModel])
|
@router.get("/{id}", response_model=Optional[FileModel])
|
||||||
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_file_by_id(id: str, user=Depends(get_current_user)):
|
||||||
file = Files.get_file_by_id(id)
|
file = Files.get_file_by_id(id)
|
||||||
|
|
||||||
if file:
|
if file:
|
||||||
@ -171,7 +171,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{id}/content", response_model=Optional[FileModel])
|
@router.get("/{id}/content", response_model=Optional[FileModel])
|
||||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_file_content_by_id(id: str, user=Depends(get_current_user)):
|
||||||
file = Files.get_file_by_id(id)
|
file = Files.get_file_by_id(id)
|
||||||
|
|
||||||
if file:
|
if file:
|
||||||
@ -194,7 +194,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
|
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
|
||||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_file_content_by_id(id: str, user=Depends(get_current_user)):
|
||||||
file = Files.get_file_by_id(id)
|
file = Files.get_file_by_id(id)
|
||||||
|
|
||||||
if file:
|
if file:
|
||||||
@ -222,7 +222,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{id}")
|
@router.delete("/{id}")
|
||||||
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
async def delete_file_by_id(id: str, user=Depends(get_current_user)):
|
||||||
file = Files.get_file_by_id(id)
|
file = Files.get_file_by_id(id)
|
||||||
|
|
||||||
if file:
|
if file:
|
||||||
|
@ -299,24 +299,26 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|||||||
|
|
||||||
# Get the signature of the function
|
# Get the signature of the function
|
||||||
sig = inspect.signature(inlet)
|
sig = inspect.signature(inlet)
|
||||||
params = {"body": body}
|
params = {"body": body} | {
|
||||||
|
k: v
|
||||||
|
for k, v in {
|
||||||
|
**extra_params,
|
||||||
|
"__model__": model,
|
||||||
|
"__id__": filter_id,
|
||||||
|
}.items()
|
||||||
|
if k in sig.parameters
|
||||||
|
}
|
||||||
|
|
||||||
# Extra parameters to be passed to the function
|
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||||
custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
|
|
||||||
if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
|
|
||||||
try:
|
try:
|
||||||
uid = custom_params["__user__"]["id"]
|
params["__user__"]["valves"] = function_module.UserValves(
|
||||||
custom_params["__user__"]["valves"] = function_module.UserValves(
|
**Functions.get_user_valves_by_id_and_user_id(
|
||||||
**Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
|
filter_id, params["__user__"]["id"]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
# Add extra params in contained in function signature
|
|
||||||
for key, value in custom_params.items():
|
|
||||||
if key in sig.parameters:
|
|
||||||
params[key] = value
|
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(inlet):
|
if inspect.iscoroutinefunction(inlet):
|
||||||
body = await inlet(**params)
|
body = await inlet(**params)
|
||||||
else:
|
else:
|
||||||
@ -372,7 +374,9 @@ async def chat_completion_tools_handler(
|
|||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
# If tool_ids field is present, call the functions
|
# If tool_ids field is present, call the functions
|
||||||
metadata = body.get("metadata", {})
|
metadata = body.get("metadata", {})
|
||||||
|
|
||||||
tool_ids = metadata.get("tool_ids", None)
|
tool_ids = metadata.get("tool_ids", None)
|
||||||
|
log.debug(f"{tool_ids=}")
|
||||||
if not tool_ids:
|
if not tool_ids:
|
||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
@ -381,16 +385,17 @@ async def chat_completion_tools_handler(
|
|||||||
citations = []
|
citations = []
|
||||||
|
|
||||||
task_model_id = get_task_model_id(body["model"])
|
task_model_id = get_task_model_id(body["model"])
|
||||||
|
tools = get_tools(
|
||||||
log.debug(f"{tool_ids=}")
|
webui_app,
|
||||||
|
tool_ids,
|
||||||
custom_params = {
|
user,
|
||||||
**extra_params,
|
{
|
||||||
"__model__": app.state.MODELS[task_model_id],
|
**extra_params,
|
||||||
"__messages__": body["messages"],
|
"__model__": app.state.MODELS[task_model_id],
|
||||||
"__files__": metadata.get("files", []),
|
"__messages__": body["messages"],
|
||||||
}
|
"__files__": metadata.get("files", []),
|
||||||
tools = get_tools(webui_app, tool_ids, user, custom_params)
|
},
|
||||||
|
)
|
||||||
log.info(f"{tools=}")
|
log.info(f"{tools=}")
|
||||||
|
|
||||||
specs = [tool["spec"] for tool in tools.values()]
|
specs = [tool["spec"] for tool in tools.values()]
|
||||||
@ -530,17 +535,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
}
|
}
|
||||||
body["metadata"] = metadata
|
body["metadata"] = metadata
|
||||||
|
|
||||||
__user__ = {
|
|
||||||
"id": user.id,
|
|
||||||
"email": user.email,
|
|
||||||
"name": user.name,
|
|
||||||
"role": user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
extra_params = {
|
extra_params = {
|
||||||
"__user__": __user__,
|
|
||||||
"__event_emitter__": get_event_emitter(metadata),
|
"__event_emitter__": get_event_emitter(metadata),
|
||||||
"__event_call__": get_event_call(metadata),
|
"__event_call__": get_event_call(metadata),
|
||||||
|
"__user__": {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"name": user.name,
|
||||||
|
"role": user.role,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize data_items to store additional data to be sent to the client
|
# Initialize data_items to store additional data to be sent to the client
|
||||||
|
Loading…
Reference in New Issue
Block a user