diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 36dfa4f85..b6e85e2ca 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -1,12 +1,10 @@ -from pydantic import BaseModel, ConfigDict, parse_obj_as -from typing import Union, Optional +from pydantic import BaseModel, ConfigDict +from typing import Optional import time from sqlalchemy import String, Column, BigInteger, Text -from utils.misc import get_gravatar_url - -from apps.webui.internal.db import Base, JSONField, Session, get_db +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.chats import Chats #################### @@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel): class UsersTable: - def insert_new_user( self, id: str, @@ -122,7 +119,6 @@ class UsersTable: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(api_key=api_key).first() return UserModel.model_validate(user) except Exception: @@ -131,7 +127,6 @@ class UsersTable: def get_user_by_email(self, email: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) except Exception: @@ -140,7 +135,6 @@ class UsersTable: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(oauth_sub=sub).first() return UserModel.model_validate(user) except Exception: @@ -195,7 +189,6 @@ class UsersTable: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: with get_db() as db: - db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) diff --git a/backend/main.py b/backend/main.py index 838556c40..8ddaffcac 100644 --- a/backend/main.py +++ b/backend/main.py @@ -51,13 +51,13 @@ from apps.webui.internal.db import Session from pydantic import BaseModel -from typing import Optional +from typing import Optional, Callable, Awaitable from apps.webui.models.auths import Auths from apps.webui.models.models import Models from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions -from apps.webui.models.users import Users +from apps.webui.models.users import Users, UserModel from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id @@ -72,7 +72,7 @@ from utils.utils import ( from utils.task import ( title_generation_template, search_query_generation_template, - tools_function_calling_generation_template, + tool_calling_generation_template, ) from utils.misc import ( get_last_user_message, @@ -261,6 +261,7 @@ def get_filter_function_ids(model): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -282,164 +283,42 @@ def get_filter_function_ids(model): return filter_ids -async def get_function_call_response( - messages, - files, - tool_id, - template, - task_model_id, - user, - __event_emitter__=None, - __event_call__=None, -): - tool = Tools.get_tool_by_id(tool_id) - tools_specs = json.dumps(tool.specs, indent=2) - content = tools_function_calling_generation_template(template, tools_specs) +async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + +def get_tool_call_payload(messages, task_model_id, content): user_message = get_last_user_message(messages) - prompt = ( - "History:\n" - + "\n".join( - [ - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ] - ) - + f"\nQuery: {user_message}" + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] ) - print(prompt) + prompt = f"History:\n{history}\nQuery: {user_message}" - payload = { + return { "model": task_model_id, "messages": [ {"role": "system", "content": content}, {"role": "user", "content": f"Query: {prompt}"}, ], "stream": False, - "task": str(TASKS.FUNCTION_CALLING), + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, } - try: - payload = filter_pipeline(payload, user) - except Exception as e: - raise e - model = app.state.MODELS[task_model_id] - - response = None - try: - response = await generate_chat_completions(form_data=payload, user=user) - content = None - - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] - - if content is None: - return None, None, False - - # Parse the function response - print(f"content: {content}") - result = json.loads(content) - print(result) - - citation = None - - if "name" not in result: - return None, None, False - - # Call the function - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module - - file_handler = False - # check if toolkit_module has file_handler self variable - if hasattr(toolkit_module, "file_handler"): - file_handler = True - print("file_handler: ", file_handler) - - if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) - toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) - - function = getattr(toolkit_module, result["name"]) - function_result = None - try: - # Get the signature of the function - sig = inspect.signature(function) - params = result["parameters"] - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": tool_id, - "__messages__": messages, - "__files__": files, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - # Call the function with the '__user__' parameter included - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(toolkit_module, "UserValves"): - __user__["valves"] = toolkit_module.UserValves( - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(function): - function_result = await function(**params) - else: - function_result = function(**params) - - if hasattr(toolkit_module, "citation") and toolkit_module.citation: - citation = { - "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, - "document": [function_result], - "metadata": [{"source": result["name"]}], - } - except Exception as e: - print(e) - - # Add the function result to the system prompt - if function_result is not None: - return function_result, citation, file_handler - except Exception as e: - print(f"Error: {e}") - - return None, None, False - - -async def chat_completion_functions_handler( - body, model, user, __event_emitter__, __event_call__ -): +async def chat_completion_inlets_handler(body, model, extra_params): skip_files = None filter_ids = get_filter_function_ids(model) @@ -475,37 +354,20 @@ async def chat_completion_functions_handler( params = {"body": body} # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - + custom_params = {**extra_params, "__model__": model, "__id__": filter_id} + if hasattr(function_module, "UserValves") and "__user__" in sig.parameters: try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) + uid = custom_params["__user__"]["id"] + custom_params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(filter_id, uid) + ) except Exception as e: print(e) - params = {**params, "__user__": __user__} + # Add extra params in contained in function signature + for key, value in custom_params.items(): + if key in sig.parameters: + params[key] = value if inspect.iscoroutinefunction(inlet): body = await inlet(**params) @@ -516,74 +378,171 @@ async def chat_completion_functions_handler( print(f"Error: {e}") raise e - if skip_files: - if "files" in body: - del body["files"] + if skip_files and "files" in body: + del body["files"] return body, {} -async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__): - skip_files = None +def get_tool_with_custom_params( + tool: Callable, custom_params: dict +) -> Callable[..., Awaitable]: + sig = inspect.signature(tool) + extra_params = { + key: value for key, value in custom_params.items() if key in sig.parameters + } + is_coroutine = inspect.iscoroutinefunction(tool) + async def new_tool(**kwargs): + extra_kwargs = kwargs | extra_params + if is_coroutine: + return await tool(**extra_kwargs) + return tool(**extra_kwargs) + + return new_tool + + +# Mutation on extra_params +def get_configured_tools( + tool_ids: list[str], extra_params: dict, user: UserModel +) -> dict[str, dict]: + tools = {} + for tool_id in tool_ids: + toolkit = Tools.get_tool_by_id(tool_id) + if toolkit is None: + continue + + module = webui_app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = module + + extra_params["__id__"] = tool_id + has_citation = hasattr(module, "citation") and module.citation + handles_files = hasattr(module, "file_handler") and module.file_handler + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + + if hasattr(module, "UserValves"): + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in toolkit.specs: + # TODO: Fix hack for OpenAI API + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" + name = spec["name"] + callable = getattr(module, name) + + # convert to function that takes only model params and inserts custom params + custom_callable = get_tool_with_custom_params(callable, extra_params) + + # TODO: This needs to be a pydantic model + tool_dict = { + "spec": spec, + "citation": has_citation, + "file_handler": handles_files, + "toolkit_id": tool_id, + "callable": custom_callable, + } + # TODO: if collision, prepend toolkit name + if name in tools: + log.warning(f"Tool {name} already exists in another toolkit!") + log.warning(f"Collision between {toolkit} and {tool_id}.") + log.warning(f"Discarding {toolkit}.{name}") + else: + tools[name] = tool_dict + + return tools + + +async def chat_completion_tools_handler( + body: dict, user: UserModel, extra_params: dict +) -> tuple[dict, dict]: + skip_files = False contexts = [] - citations = None - + citations = [] task_model_id = get_task_model_id(body["model"]) # If tool_ids field is present, call the functions - if "tool_ids" in body: - print(body["tool_ids"]) - for tool_id in body["tool_ids"]: - print(tool_id) - try: - response, citation, file_handler = await get_function_call_response( - messages=body["messages"], - files=body.get("files", []), - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - __event_emitter__=__event_emitter__, - __event_call__=__event_call__, - ) + tool_ids = body.pop("tool_ids", None) + if not tool_ids: + return body, {} - print(file_handler) - if isinstance(response, str): - contexts.append(response) - - if citation: - if citations is None: - citations = [citation] - else: - citations.append(citation) - - if file_handler: - skip_files = True - - except Exception as e: - print(f"Error: {e}") - del body["tool_ids"] - print(f"tool_contexts: {contexts}") - - if skip_files: - if "files" in body: - del body["files"] - - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), + log.debug(f"{tool_ids=}") + custom_params = { + **extra_params, + "__model__": app.state.MODELS[task_model_id], + "__messages__": body["messages"], + "__files__": body.get("files", []), } + configured_tools = get_configured_tools(tool_ids, custom_params, user) + log.info(f"{configured_tools=}") -async def chat_completion_files_handler(body): - contexts = [] - citations = None + specs = [tool["spec"] for tool in configured_tools.values()] + tools_specs = json.dumps(specs) + template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + content = tool_calling_generation_template(template, tools_specs) + payload = get_tool_call_payload(body["messages"], task_model_id, content) + try: + payload = filter_pipeline(payload, user) + except Exception as e: + raise e - if "files" in body: - files = body["files"] + try: + response = await generate_chat_completions(form_data=payload, user=user) + log.debug(f"{response=}") + content = await get_content_from_response(response) + log.debug(f"{content=}") + if content is None: + return body, {} + + result = json.loads(content) + tool_name = result.get("name", None) + if tool_name not in configured_tools: + return body, {} + + tool_params = result.get("parameters", {}) + toolkit_id = configured_tools[tool_name]["toolkit_id"] + try: + tool_output = await configured_tools[tool_name]["callable"](**tool_params) + except Exception as e: + tool_output = str(e) + if configured_tools[tool_name]["citation"]: + citations.append( + { + "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"}, + "document": [tool_output], + "metadata": [{"source": tool_name}], + } + ) + if configured_tools[tool_name]["file_handler"]: + skip_files = True + + if isinstance(tool_output, str): + contexts.append(tool_output) + + except Exception as e: + print(f"Error: {e}") + content = None + + log.debug(f"tool_contexts: {contexts}") + + if skip_files and "files" in body: del body["files"] + return body, {"contexts": contexts, "citations": citations} + + +async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: + contexts = [] + citations = [] + + if files := body.pop("files", None): contexts, citations = get_rag_context( files=files, messages=body["messages"], @@ -596,134 +555,130 @@ async def chat_completion_files_handler(body): log.debug(f"rag_contexts: {contexts}, citations: {citations}") - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), - } + return body, {"contexts": contexts, "citations": citations} + + +def is_chat_completion_request(request): + return request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ): - log.debug(f"request.url.path: {request.url.path}") + if not is_chat_completion_request(request): + return await call_next(request) + log.debug(f"request.url.path: {request.url.path}") - try: - body, model, user = await get_body_and_model_and_user(request) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, + try: + body, model, user = await get_body_and_model_and_user(request) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + metadata = { + "chat_id": body.pop("chat_id", None), + "message_id": body.pop("id", None), + "session_id": body.pop("session_id", None), + "valves": body.pop("valves", None), + } + + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + extra_params = { + "__user__": __user__, + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + } + + # Initialize data_items to store additional data to be sent to the client + # Initalize contexts and citation + data_items = [] + contexts = [] + citations = [] + + try: + body, flags = await chat_completion_inlets_handler( + body, model, extra_params + ) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + try: + body, flags = await chat_completion_tools_handler(body, user, extra_params) + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + log.exception(e) + + try: + body, flags = await chat_completion_files_handler(body) + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + log.exception(e) + + # If context is not empty, insert it into the messages + if len(contexts) > 0: + context_string = "/n".join(contexts).strip() + prompt = get_last_user_message(body["messages"]) + if prompt is None: + raise Exception("No user message found") + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + body["messages"] = prepend_to_first_user_message_content( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], ) - - metadata = { - "chat_id": body.pop("chat_id", None), - "message_id": body.pop("id", None), - "session_id": body.pop("session_id", None), - "valves": body.pop("valves", None), - } - - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - - # Initialize data_items to store additional data to be sent to the client - data_items = [] - - # Initialize context, and citations - contexts = [] - citations = [] - - try: - body, flags = await chat_completion_functions_handler( - body, model, user, __event_emitter__, __event_call__ - ) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - try: - body, flags = await chat_completion_tools_handler( - body, user, __event_emitter__, __event_call__ - ) - - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass - - try: - body, flags = await chat_completion_files_handler(body) - - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass - - # If context is not empty, insert it into the messages - if len(contexts) > 0: - context_string = "/n".join(contexts).strip() - prompt = get_last_user_message(body["messages"]) - - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": - body["messages"] = prepend_to_first_user_message_content( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - else: - body["messages"] = add_or_update_system_message( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - - # If there are citations, add them to the data_items - if len(citations) > 0: - data_items.append({"citations": citations}) - - body["metadata"] = metadata - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] - - response = await call_next(request) - if isinstance(response, StreamingResponse): - # If it's a streaming response, inject it as SSE event or NDJSON line - content_type = response.headers.get("Content-Type") - if "text/event-stream" in content_type: - return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, data_items), - ) - if "application/x-ndjson" in content_type: - return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, data_items), - ) - - return response else: - return response + body["messages"] = add_or_update_system_message( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) + + # If there are citations, add them to the data_items + if len(citations) > 0: + data_items.append({"citations": citations}) + + body["metadata"] = metadata + modified_body_bytes = json.dumps(body).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] - # If it's not a chat completion request, just pass it through response = await call_next(request) + if isinstance(response, StreamingResponse): + # If it's a streaming response, inject it as SSE event or NDJSON line + content_type = response.headers["Content-Type"] + if "text/event-stream" in content_type: + return StreamingResponse( + self.openai_stream_wrapper(response.body_iterator, data_items), + ) + if "application/x-ndjson" in content_type: + return StreamingResponse( + self.ollama_stream_wrapper(response.body_iterator, data_items), + ) + return response async def _receive(self, body: bytes): @@ -790,19 +745,21 @@ def filter_pipeline(payload, user): url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) + if key == "": + continue - r.raise_for_status() - payload = r.json() + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() except Exception as e: # Handle connection error here print(f"Connection error: {e}") @@ -817,44 +774,39 @@ def filter_pipeline(payload, user): class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if request.method == "POST" and ( - "/ollama/api/chat" in request.url.path - or "/chat/completions" in request.url.path - ): - log.debug(f"request.url.path: {request.url.path}") + if not is_chat_completion_request(request): + return await call_next(request) - # Read the original request body - body = await request.body() - # Decode body to string - body_str = body.decode("utf-8") - # Parse string to JSON - data = json.loads(body_str) if body_str else {} + log.debug(f"request.url.path: {request.url.path}") - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + user = get_current_user( + request, + get_http_authorization_cred(request.headers["Authorization"]), + ) + + try: + data = filter_pipeline(data, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, ) - try: - data = filter_pipeline(data, user) - except Exception as e: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - - modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] + modified_body_bytes = json.dumps(data).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] response = await call_next(request) return response @@ -1019,6 +971,8 @@ async def get_all_models(): model["actions"] = [] for action_id in action_ids: action = Functions.get_function_by_id(action_id) + if action is None: + raise Exception(f"Action not found: {action_id}") if action_id in webui_app.state.FUNCTIONS: function_module = webui_app.state.FUNCTIONS[action_id] @@ -1099,22 +1053,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ) model = app.state.MODELS[model_id] - # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. - task = None - if "task" in form_data: - task = form_data["task"] - del form_data["task"] - - if task: - if "metadata" in form_data: - form_data["metadata"]["task"] = task - else: - form_data["metadata"] = {"task": task} - if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": - print("generate_ollama_chat_completion") return await generate_ollama_chat_completion(form_data, user=user) else: return await generate_openai_chat_completion(form_data, user=user) @@ -1198,6 +1139,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -1487,7 +1429,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 50, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.TITLE_GENERATION), + "metadata": {"task": str(TASKS.TITLE_GENERATION)}, } log.debug(payload) @@ -1540,7 +1482,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, - "task": str(TASKS.QUERY_GENERATION), + "metadata": {"task": str(TASKS.QUERY_GENERATION)}, } print(payload) @@ -1597,7 +1539,7 @@ Message: """{{prompt}}""" "stream": False, "max_tokens": 4, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.EMOJI_GENERATION), + "metadata": {"task": str(TASKS.EMOJI_GENERATION)}, } log.debug(payload) @@ -1616,41 +1558,6 @@ Message: """{{prompt}}""" return await generate_chat_completions(form_data=payload, user=user) -@app.post("/api/task/tools/completions") -async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)): - print("get_tools_function_calling") - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - model_id = get_task_model_id(model_id) - - print(model_id) - template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - - try: - context, _, _ = await get_function_call_response( - form_data["messages"], - form_data.get("files", []), - form_data["tool_id"], - template, - model_id, - user, - ) - return context - except Exception as e: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - - ################################## # # Pipelines Endpoints @@ -1689,7 +1596,7 @@ async def upload_pipeline( ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file - if not file.filename.endswith(".py"): + if not (file.filename and file.filename.endswith(".py")): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Only Python (.py) files are allowed.", @@ -2138,7 +2045,10 @@ async def oauth_login(provider: str, request: Request): redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( "oauth_callback", provider=provider ) - return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) + client = oauth.create_client(provider) + if client is None: + raise HTTPException(404) + return await client.authorize_redirect(request, redirect_uri) # OAuth login logic is as follows: diff --git a/backend/utils/task.py b/backend/utils/task.py index 1b2276c9c..37c174d3d 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -121,6 +121,6 @@ def search_query_generation_template( return template -def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: +def tool_calling_generation_template(template: str, tools_specs: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) return template