mirror of
https://github.com/open-webui/open-webui
synced 2025-05-02 03:56:09 +00:00
fix more LSP errors
This commit is contained in:
parent
a68b918cbb
commit
ff9d899f9c
130
backend/main.py
130
backend/main.py
@ -261,6 +261,7 @@ def get_filter_function_ids(model):
|
|||||||
def get_priority(function_id):
|
def get_priority(function_id):
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = Functions.get_function_by_id(function_id)
|
||||||
if function is not None and hasattr(function, "valves"):
|
if function is not None and hasattr(function, "valves"):
|
||||||
|
# TODO: Fix FunctionModel
|
||||||
return (function.valves if function.valves else {}).get("priority", 0)
|
return (function.valves if function.valves else {}).get("priority", 0)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -322,14 +323,7 @@ async def call_tool_from_completion(
|
|||||||
|
|
||||||
|
|
||||||
async def get_function_call_response(
|
async def get_function_call_response(
|
||||||
messages,
|
messages, files, tool_id, template, task_model_id, user, extra_params
|
||||||
files,
|
|
||||||
tool_id,
|
|
||||||
template,
|
|
||||||
task_model_id,
|
|
||||||
user,
|
|
||||||
__event_emitter__=None,
|
|
||||||
__event_call__=None,
|
|
||||||
) -> tuple[Optional[str], Optional[dict], bool]:
|
) -> tuple[Optional[str], Optional[dict], bool]:
|
||||||
tool = Tools.get_tool_by_id(tool_id)
|
tool = Tools.get_tool_by_id(tool_id)
|
||||||
if tool is None:
|
if tool is None:
|
||||||
@ -373,32 +367,22 @@ async def get_function_call_response(
|
|||||||
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
|
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
|
||||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||||
|
|
||||||
__user__ = {
|
custom_params = {
|
||||||
"id": user.id,
|
**extra_params,
|
||||||
"email": user.email,
|
"__model__": app.state.MODELS[task_model_id],
|
||||||
"name": user.name,
|
"__id__": tool_id,
|
||||||
"role": user.role,
|
"__messages__": messages,
|
||||||
|
"__files__": files,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(toolkit_module, "UserValves"):
|
if hasattr(toolkit_module, "UserValves"):
|
||||||
__user__["valves"] = toolkit_module.UserValves(
|
custom_params["__user__"]["valves"] = toolkit_module.UserValves(
|
||||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
extra_params = {
|
|
||||||
"__model__": app.state.MODELS[task_model_id],
|
|
||||||
"__id__": tool_id,
|
|
||||||
"__messages__": messages,
|
|
||||||
"__files__": files,
|
|
||||||
"__event_emitter__": __event_emitter__,
|
|
||||||
"__event_call__": __event_call__,
|
|
||||||
"__user__": __user__,
|
|
||||||
}
|
|
||||||
|
|
||||||
file_handler = hasattr(toolkit_module, "file_handler")
|
file_handler = hasattr(toolkit_module, "file_handler")
|
||||||
|
|
||||||
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
|
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
|
||||||
@ -417,7 +401,7 @@ async def get_function_call_response(
|
|||||||
result = json.loads(content)
|
result = json.loads(content)
|
||||||
|
|
||||||
function_result = await call_tool_from_completion(
|
function_result = await call_tool_from_completion(
|
||||||
result, extra_params, toolkit_module
|
result, custom_params, toolkit_module
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
||||||
@ -438,9 +422,7 @@ async def get_function_call_response(
|
|||||||
return None, None, False
|
return None, None, False
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_inlets_handler(
|
async def chat_completion_inlets_handler(body, model, extra_params):
|
||||||
body, model, user, __event_emitter__, __event_call__
|
|
||||||
):
|
|
||||||
skip_files = None
|
skip_files = None
|
||||||
|
|
||||||
filter_ids = get_filter_function_ids(model)
|
filter_ids = get_filter_function_ids(model)
|
||||||
@ -476,38 +458,18 @@ async def chat_completion_inlets_handler(
|
|||||||
params = {"body": body}
|
params = {"body": body}
|
||||||
|
|
||||||
# Extra parameters to be passed to the function
|
# Extra parameters to be passed to the function
|
||||||
extra_params = {
|
custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
|
||||||
"__model__": model,
|
if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
|
||||||
"__id__": filter_id,
|
uid = custom_params["__user__"]["id"]
|
||||||
"__event_emitter__": __event_emitter__,
|
custom_params["__user__"]["valves"] = function_module.UserValves(
|
||||||
"__event_call__": __event_call__,
|
**Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
|
||||||
}
|
)
|
||||||
|
|
||||||
# Add extra params in contained in function signature
|
# Add extra params in contained in function signature
|
||||||
for key, value in extra_params.items():
|
for key, value in custom_params.items():
|
||||||
if key in sig.parameters:
|
if key in sig.parameters:
|
||||||
params[key] = value
|
params[key] = value
|
||||||
|
|
||||||
if "__user__" in sig.parameters:
|
|
||||||
__user__ = {
|
|
||||||
"id": user.id,
|
|
||||||
"email": user.email,
|
|
||||||
"name": user.name,
|
|
||||||
"role": user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if hasattr(function_module, "UserValves"):
|
|
||||||
__user__["valves"] = function_module.UserValves(
|
|
||||||
**Functions.get_user_valves_by_id_and_user_id(
|
|
||||||
filter_id, user.id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
params = {**params, "__user__": __user__}
|
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(inlet):
|
if inspect.iscoroutinefunction(inlet):
|
||||||
body = await inlet(**params)
|
body = await inlet(**params)
|
||||||
else:
|
else:
|
||||||
@ -524,7 +486,7 @@ async def chat_completion_inlets_handler(
|
|||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
|
async def chat_completion_tools_handler(body, user, extra_params):
|
||||||
skip_files = None
|
skip_files = None
|
||||||
|
|
||||||
contexts = []
|
contexts = []
|
||||||
@ -547,8 +509,7 @@ async def chat_completion_tools_handler(body, user, __event_emitter__, __event_c
|
|||||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
task_model_id=task_model_id,
|
task_model_id=task_model_id,
|
||||||
user=user,
|
user=user,
|
||||||
__event_emitter__=__event_emitter__,
|
extra_params=extra_params,
|
||||||
__event_call__=__event_call__,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(file_handler)
|
print(file_handler)
|
||||||
@ -584,10 +545,7 @@ async def chat_completion_files_handler(body):
|
|||||||
contexts = []
|
contexts = []
|
||||||
citations = None
|
citations = None
|
||||||
|
|
||||||
if "files" in body:
|
if files := body.pop("files", None):
|
||||||
files = body["files"]
|
|
||||||
del body["files"]
|
|
||||||
|
|
||||||
contexts, citations = get_rag_context(
|
contexts, citations = get_rag_context(
|
||||||
files=files,
|
files=files,
|
||||||
messages=body["messages"],
|
messages=body["messages"],
|
||||||
@ -634,8 +592,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
"valves": body.pop("valves", None),
|
"valves": body.pop("valves", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
__event_emitter__ = get_event_emitter(metadata)
|
__user__ = {
|
||||||
__event_call__ = get_event_call(metadata)
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"name": user.name,
|
||||||
|
"role": user.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
extra_params = {
|
||||||
|
"__user__": __user__,
|
||||||
|
"__event_emitter__": get_event_emitter(metadata),
|
||||||
|
"__event_call__": get_event_call(metadata),
|
||||||
|
}
|
||||||
|
|
||||||
# Initialize data_items to store additional data to be sent to the client
|
# Initialize data_items to store additional data to be sent to the client
|
||||||
# Initalize contexts and citation
|
# Initalize contexts and citation
|
||||||
@ -645,7 +613,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
body, flags = await chat_completion_inlets_handler(
|
body, flags = await chat_completion_inlets_handler(
|
||||||
body, model, user, __event_emitter__, __event_call__
|
body, model, extra_params
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -654,10 +622,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, flags = await chat_completion_tools_handler(
|
body, flags = await chat_completion_tools_handler(body, user, extra_params)
|
||||||
body, user, __event_emitter__, __event_call__
|
|
||||||
)
|
|
||||||
|
|
||||||
contexts.extend(flags.get("contexts", []))
|
contexts.extend(flags.get("contexts", []))
|
||||||
citations.extend(flags.get("citations", []))
|
citations.extend(flags.get("citations", []))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -666,7 +631,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
body, flags = await chat_completion_files_handler(body)
|
body, flags = await chat_completion_files_handler(body)
|
||||||
|
|
||||||
contexts.extend(flags.get("contexts", []))
|
contexts.extend(flags.get("contexts", []))
|
||||||
citations.extend(flags.get("citations", []))
|
citations.extend(flags.get("citations", []))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -713,7 +677,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
if isinstance(response, StreamingResponse):
|
if isinstance(response, StreamingResponse):
|
||||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||||
content_type = response.headers.get("Content-Type")
|
content_type = response.headers["Content-Type"]
|
||||||
if "text/event-stream" in content_type:
|
if "text/event-stream" in content_type:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
self.openai_stream_wrapper(response.body_iterator, data_items),
|
self.openai_stream_wrapper(response.body_iterator, data_items),
|
||||||
@ -832,7 +796,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
user = get_current_user(
|
user = get_current_user(
|
||||||
request,
|
request,
|
||||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
get_http_authorization_cred(request.headers["Authorization"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1015,6 +979,8 @@ async def get_all_models():
|
|||||||
model["actions"] = []
|
model["actions"] = []
|
||||||
for action_id in action_ids:
|
for action_id in action_ids:
|
||||||
action = Functions.get_function_by_id(action_id)
|
action = Functions.get_function_by_id(action_id)
|
||||||
|
if action is None:
|
||||||
|
raise Exception(f"Action not found: {action_id}")
|
||||||
|
|
||||||
if action_id in webui_app.state.FUNCTIONS:
|
if action_id in webui_app.state.FUNCTIONS:
|
||||||
function_module = webui_app.state.FUNCTIONS[action_id]
|
function_module = webui_app.state.FUNCTIONS[action_id]
|
||||||
@ -1022,6 +988,10 @@ async def get_all_models():
|
|||||||
function_module, _, _ = load_function_module_by_id(action_id)
|
function_module, _, _ = load_function_module_by_id(action_id)
|
||||||
webui_app.state.FUNCTIONS[action_id] = function_module
|
webui_app.state.FUNCTIONS[action_id] = function_module
|
||||||
|
|
||||||
|
icon_url = None
|
||||||
|
if action.meta.manifest is not None:
|
||||||
|
icon_url = action.meta.manifest.get("icon_url", None)
|
||||||
|
|
||||||
if hasattr(function_module, "actions"):
|
if hasattr(function_module, "actions"):
|
||||||
actions = function_module.actions
|
actions = function_module.actions
|
||||||
model["actions"].extend(
|
model["actions"].extend(
|
||||||
@ -1032,9 +1002,7 @@ async def get_all_models():
|
|||||||
"name", f"{action.name} ({_action['id']})"
|
"name", f"{action.name} ({_action['id']})"
|
||||||
),
|
),
|
||||||
"description": action.meta.description,
|
"description": action.meta.description,
|
||||||
"icon_url": _action.get(
|
"icon_url": _action.get("icon_url", icon_url),
|
||||||
"icon_url", action.meta.manifest.get("icon_url", None)
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
for _action in actions
|
for _action in actions
|
||||||
]
|
]
|
||||||
@ -1045,7 +1013,7 @@ async def get_all_models():
|
|||||||
"id": action_id,
|
"id": action_id,
|
||||||
"name": action.name,
|
"name": action.name,
|
||||||
"description": action.meta.description,
|
"description": action.meta.description,
|
||||||
"icon_url": action.meta.manifest.get("icon_url", None),
|
"icon_url": icon_url,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1175,6 +1143,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
def get_priority(function_id):
|
def get_priority(function_id):
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = Functions.get_function_by_id(function_id)
|
||||||
if function is not None and hasattr(function, "valves"):
|
if function is not None and hasattr(function, "valves"):
|
||||||
|
# TODO: Fix FunctionModel to include vavles
|
||||||
return (function.valves if function.valves else {}).get("priority", 0)
|
return (function.valves if function.valves else {}).get("priority", 0)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -1631,7 +1600,7 @@ async def upload_pipeline(
|
|||||||
):
|
):
|
||||||
print("upload_pipeline", urlIdx, file.filename)
|
print("upload_pipeline", urlIdx, file.filename)
|
||||||
# Check if the uploaded file is a python file
|
# Check if the uploaded file is a python file
|
||||||
if not file.filename.endswith(".py"):
|
if not (file.filename and file.filename.endswith(".py")):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Only Python (.py) files are allowed.",
|
detail="Only Python (.py) files are allowed.",
|
||||||
@ -2080,7 +2049,10 @@ async def oauth_login(provider: str, request: Request):
|
|||||||
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
||||||
"oauth_callback", provider=provider
|
"oauth_callback", provider=provider
|
||||||
)
|
)
|
||||||
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
|
client = oauth.create_client(provider)
|
||||||
|
if client is None:
|
||||||
|
raise HTTPException(404)
|
||||||
|
return await client.authorize_redirect(request, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
# OAuth login logic is as follows:
|
# OAuth login logic is as follows:
|
||||||
|
Loading…
Reference in New Issue
Block a user