This commit is contained in:
Timothy J. Baek 2024-08-22 16:02:29 +02:00
parent 99db82a161
commit 63ba8145b9
2 changed files with 40 additions and 37 deletions

View File

@ -26,7 +26,7 @@ from apps.webui.models.files import (
FileModel, FileModel,
FileModelResponse, FileModelResponse,
) )
from utils.utils import get_verified_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from importlib import util from importlib import util
@ -50,7 +50,7 @@ router = APIRouter()
@router.post("/") @router.post("/")
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): def upload_file(file: UploadFile = File(...), user=Depends(get_current_user)):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
try: try:
unsanitized_filename = file.filename unsanitized_filename = file.filename
@ -105,7 +105,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
@router.get("/", response_model=list[FileModel]) @router.get("/", response_model=list[FileModel])
async def list_files(user=Depends(get_verified_user)): async def list_files(user=Depends(get_current_user)):
files = Files.get_files() files = Files.get_files()
return files return files
@ -153,7 +153,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
@router.get("/{id}", response_model=Optional[FileModel]) @router.get("/{id}", response_model=Optional[FileModel])
async def get_file_by_id(id: str, user=Depends(get_verified_user)): async def get_file_by_id(id: str, user=Depends(get_current_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file: if file:
@ -171,7 +171,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/content", response_model=Optional[FileModel]) @router.get("/{id}/content", response_model=Optional[FileModel])
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(id: str, user=Depends(get_current_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file: if file:
@ -194,7 +194,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel]) @router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(id: str, user=Depends(get_current_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file: if file:
@ -222,7 +222,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.delete("/{id}") @router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user)): async def delete_file_by_id(id: str, user=Depends(get_current_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file: if file:

View File

@ -299,24 +299,26 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(inlet) sig = inspect.signature(inlet)
params = {"body": body} params = {"body": body} | {
k: v
for k, v in {
**extra_params,
"__model__": model,
"__id__": filter_id,
}.items()
if k in sig.parameters
}
# Extra parameters to be passed to the function if "__user__" in params and hasattr(function_module, "UserValves"):
custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
try: try:
uid = custom_params["__user__"]["id"] params["__user__"]["valves"] = function_module.UserValves(
custom_params["__user__"]["valves"] = function_module.UserValves( **Functions.get_user_valves_by_id_and_user_id(
**Functions.get_user_valves_by_id_and_user_id(filter_id, uid) filter_id, params["__user__"]["id"]
)
) )
except Exception as e: except Exception as e:
print(e) print(e)
# Add extra params in contained in function signature
for key, value in custom_params.items():
if key in sig.parameters:
params[key] = value
if inspect.iscoroutinefunction(inlet): if inspect.iscoroutinefunction(inlet):
body = await inlet(**params) body = await inlet(**params)
else: else:
@ -372,7 +374,9 @@ async def chat_completion_tools_handler(
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
metadata = body.get("metadata", {}) metadata = body.get("metadata", {})
tool_ids = metadata.get("tool_ids", None) tool_ids = metadata.get("tool_ids", None)
log.debug(f"{tool_ids=}")
if not tool_ids: if not tool_ids:
return body, {} return body, {}
@ -381,16 +385,17 @@ async def chat_completion_tools_handler(
citations = [] citations = []
task_model_id = get_task_model_id(body["model"]) task_model_id = get_task_model_id(body["model"])
tools = get_tools(
log.debug(f"{tool_ids=}") webui_app,
tool_ids,
custom_params = { user,
**extra_params, {
"__model__": app.state.MODELS[task_model_id], **extra_params,
"__messages__": body["messages"], "__model__": app.state.MODELS[task_model_id],
"__files__": metadata.get("files", []), "__messages__": body["messages"],
} "__files__": metadata.get("files", []),
tools = get_tools(webui_app, tool_ids, user, custom_params) },
)
log.info(f"{tools=}") log.info(f"{tools=}")
specs = [tool["spec"] for tool in tools.values()] specs = [tool["spec"] for tool in tools.values()]
@ -530,17 +535,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
} }
body["metadata"] = metadata body["metadata"] = metadata
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
extra_params = { extra_params = {
"__user__": __user__,
"__event_emitter__": get_event_emitter(metadata), "__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata), "__event_call__": get_event_call(metadata),
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
} }
# Initialize data_items to store additional data to be sent to the client # Initialize data_items to store additional data to be sent to the client