fix more LSP errors

This commit is contained in:
Michael Poluektov 2024-08-11 08:31:40 +01:00
parent a68b918cbb
commit ff9d899f9c

View File

@ -261,6 +261,7 @@ def get_filter_function_ids(model):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"): if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel
return (function.valves if function.valves else {}).get("priority", 0) return (function.valves if function.valves else {}).get("priority", 0)
return 0 return 0
@ -322,14 +323,7 @@ async def call_tool_from_completion(
async def get_function_call_response( async def get_function_call_response(
messages, messages, files, tool_id, template, task_model_id, user, extra_params
files,
tool_id,
template,
task_model_id,
user,
__event_emitter__=None,
__event_call__=None,
) -> tuple[Optional[str], Optional[dict], bool]: ) -> tuple[Optional[str], Optional[dict], bool]:
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
if tool is None: if tool is None:
@ -373,32 +367,22 @@ async def get_function_call_response(
toolkit_module, _ = load_toolkit_module_by_id(tool_id) toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module webui_app.state.TOOLS[tool_id] = toolkit_module
__user__ = { custom_params = {
"id": user.id, **extra_params,
"email": user.email, "__model__": app.state.MODELS[task_model_id],
"name": user.name, "__id__": tool_id,
"role": user.role, "__messages__": messages,
"__files__": files,
} }
try: try:
if hasattr(toolkit_module, "UserValves"): if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves( custom_params["__user__"]["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
) )
except Exception as e: except Exception as e:
print(e) print(e)
extra_params = {
"__model__": app.state.MODELS[task_model_id],
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__user__": __user__,
}
file_handler = hasattr(toolkit_module, "file_handler") file_handler = hasattr(toolkit_module, "file_handler")
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
@ -417,7 +401,7 @@ async def get_function_call_response(
result = json.loads(content) result = json.loads(content)
function_result = await call_tool_from_completion( function_result = await call_tool_from_completion(
result, extra_params, toolkit_module result, custom_params, toolkit_module
) )
if hasattr(toolkit_module, "citation") and toolkit_module.citation: if hasattr(toolkit_module, "citation") and toolkit_module.citation:
@ -438,9 +422,7 @@ async def get_function_call_response(
return None, None, False return None, None, False
async def chat_completion_inlets_handler( async def chat_completion_inlets_handler(body, model, extra_params):
body, model, user, __event_emitter__, __event_call__
):
skip_files = None skip_files = None
filter_ids = get_filter_function_ids(model) filter_ids = get_filter_function_ids(model)
@ -476,38 +458,18 @@ async def chat_completion_inlets_handler(
params = {"body": body} params = {"body": body}
# Extra parameters to be passed to the function # Extra parameters to be passed to the function
extra_params = { custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
"__model__": model, if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
"__id__": filter_id, uid = custom_params["__user__"]["id"]
"__event_emitter__": __event_emitter__, custom_params["__user__"]["valves"] = function_module.UserValves(
"__event_call__": __event_call__, **Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
} )
# Add extra params in contained in function signature # Add extra params in contained in function signature
for key, value in extra_params.items(): for key, value in custom_params.items():
if key in sig.parameters: if key in sig.parameters:
params[key] = value params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(inlet): if inspect.iscoroutinefunction(inlet):
body = await inlet(**params) body = await inlet(**params)
else: else:
@ -524,7 +486,7 @@ async def chat_completion_inlets_handler(
return body, {} return body, {}
async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__): async def chat_completion_tools_handler(body, user, extra_params):
skip_files = None skip_files = None
contexts = [] contexts = []
@ -547,8 +509,7 @@ async def chat_completion_tools_handler(body, user, __event_emitter__, __event_c
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id, task_model_id=task_model_id,
user=user, user=user,
__event_emitter__=__event_emitter__, extra_params=extra_params,
__event_call__=__event_call__,
) )
print(file_handler) print(file_handler)
@ -584,10 +545,7 @@ async def chat_completion_files_handler(body):
contexts = [] contexts = []
citations = None citations = None
if "files" in body: if files := body.pop("files", None):
files = body["files"]
del body["files"]
contexts, citations = get_rag_context( contexts, citations = get_rag_context(
files=files, files=files,
messages=body["messages"], messages=body["messages"],
@ -634,8 +592,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
"valves": body.pop("valves", None), "valves": body.pop("valves", None),
} }
__event_emitter__ = get_event_emitter(metadata) __user__ = {
__event_call__ = get_event_call(metadata) "id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
extra_params = {
"__user__": __user__,
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
}
# Initialize data_items to store additional data to be sent to the client # Initialize data_items to store additional data to be sent to the client
# Initalize contexts and citation # Initalize contexts and citation
@ -645,7 +613,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try: try:
body, flags = await chat_completion_inlets_handler( body, flags = await chat_completion_inlets_handler(
body, model, user, __event_emitter__, __event_call__ body, model, extra_params
) )
except Exception as e: except Exception as e:
return JSONResponse( return JSONResponse(
@ -654,10 +622,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
) )
try: try:
body, flags = await chat_completion_tools_handler( body, flags = await chat_completion_tools_handler(body, user, extra_params)
body, user, __event_emitter__, __event_call__
)
contexts.extend(flags.get("contexts", [])) contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", [])) citations.extend(flags.get("citations", []))
except Exception as e: except Exception as e:
@ -666,7 +631,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try: try:
body, flags = await chat_completion_files_handler(body) body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", [])) contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", [])) citations.extend(flags.get("citations", []))
except Exception as e: except Exception as e:
@ -713,7 +677,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
response = await call_next(request) response = await call_next(request)
if isinstance(response, StreamingResponse): if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line # If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type") content_type = response.headers["Content-Type"]
if "text/event-stream" in content_type: if "text/event-stream" in content_type:
return StreamingResponse( return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, data_items), self.openai_stream_wrapper(response.body_iterator, data_items),
@ -832,7 +796,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
user = get_current_user( user = get_current_user(
request, request,
get_http_authorization_cred(request.headers.get("Authorization")), get_http_authorization_cred(request.headers["Authorization"]),
) )
try: try:
@ -1015,6 +979,8 @@ async def get_all_models():
model["actions"] = [] model["actions"] = []
for action_id in action_ids: for action_id in action_ids:
action = Functions.get_function_by_id(action_id) action = Functions.get_function_by_id(action_id)
if action is None:
raise Exception(f"Action not found: {action_id}")
if action_id in webui_app.state.FUNCTIONS: if action_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id] function_module = webui_app.state.FUNCTIONS[action_id]
@ -1022,6 +988,10 @@ async def get_all_models():
function_module, _, _ = load_function_module_by_id(action_id) function_module, _, _ = load_function_module_by_id(action_id)
webui_app.state.FUNCTIONS[action_id] = function_module webui_app.state.FUNCTIONS[action_id] = function_module
icon_url = None
if action.meta.manifest is not None:
icon_url = action.meta.manifest.get("icon_url", None)
if hasattr(function_module, "actions"): if hasattr(function_module, "actions"):
actions = function_module.actions actions = function_module.actions
model["actions"].extend( model["actions"].extend(
@ -1032,9 +1002,7 @@ async def get_all_models():
"name", f"{action.name} ({_action['id']})" "name", f"{action.name} ({_action['id']})"
), ),
"description": action.meta.description, "description": action.meta.description,
"icon_url": _action.get( "icon_url": _action.get("icon_url", icon_url),
"icon_url", action.meta.manifest.get("icon_url", None)
),
} }
for _action in actions for _action in actions
] ]
@ -1045,7 +1013,7 @@ async def get_all_models():
"id": action_id, "id": action_id,
"name": action.name, "name": action.name,
"description": action.meta.description, "description": action.meta.description,
"icon_url": action.meta.manifest.get("icon_url", None), "icon_url": icon_url,
} }
) )
@ -1175,6 +1143,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"): if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0) return (function.valves if function.valves else {}).get("priority", 0)
return 0 return 0
@ -1631,7 +1600,7 @@ async def upload_pipeline(
): ):
print("upload_pipeline", urlIdx, file.filename) print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file # Check if the uploaded file is a python file
if not file.filename.endswith(".py"): if not (file.filename and file.filename.endswith(".py")):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.", detail="Only Python (.py) files are allowed.",
@ -2080,7 +2049,10 @@ async def oauth_login(provider: str, request: Request):
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider "oauth_callback", provider=provider
) )
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) client = oauth.create_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
# OAuth login logic is as follows: # OAuth login logic is as follows: