From 3befadb29f61df4c363697829cc1be0d820a0d50 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 11:58:18 +0100 Subject: [PATCH 01/16] remove unnecessary nesting, remove unused endpoint --- backend/main.py | 96 +++++++++++++++++-------------------------------- 1 file changed, 32 insertions(+), 64 deletions(-) diff --git a/backend/main.py b/backend/main.py index d8ce5f5d7..3fc6b8db5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -532,39 +532,42 @@ async def chat_completion_tools_handler(body, user, __event_emitter__, __event_c task_model_id = get_task_model_id(body["model"]) # If tool_ids field is present, call the functions - if "tool_ids" in body: - print(body["tool_ids"]) - for tool_id in body["tool_ids"]: - print(tool_id) - try: - response, citation, file_handler = await get_function_call_response( - messages=body["messages"], - files=body.get("files", []), - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - __event_emitter__=__event_emitter__, - __event_call__=__event_call__, - ) + if "tool_ids" not in body: + return body, {} - print(file_handler) - if isinstance(response, str): - contexts.append(response) + print(body["tool_ids"]) + for tool_id in body["tool_ids"]: + print(tool_id) + try: + response, citation, file_handler = await get_function_call_response( + messages=body["messages"], + files=body.get("files", []), + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + __event_emitter__=__event_emitter__, + __event_call__=__event_call__, + ) - if citation: - if citations is None: - citations = [citation] - else: - citations.append(citation) + print(file_handler) + if isinstance(response, str): + contexts.append(response) - if file_handler: - skip_files = True + if citation: + if citations is None: + citations = [citation] + else: + citations.append(citation) - except Exception as e: - print(f"Error: {e}") - del body["tool_ids"] - print(f"tool_contexts: {contexts}") + if file_handler: + skip_files = True + + except Exception as e: + print(f"Error: {e}") + + del body["tool_ids"] + print(f"tool_contexts: {contexts}") if skip_files: if "files" in body: @@ -1610,41 +1613,6 @@ Message: """{{prompt}}""" return await generate_chat_completions(form_data=payload, user=user) -@app.post("/api/task/tools/completions") -async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)): - print("get_tools_function_calling") - - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - # Check if the user has a custom task model - # If the user has a custom task model, use that model - model_id = get_task_model_id(model_id) - - print(model_id) - template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - - try: - context, _, _ = await get_function_call_response( - form_data["messages"], - form_data.get("files", []), - form_data["tool_id"], - template, - model_id, - user, - ) - return context - except Exception as e: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - - ################################## # # Pipelines Endpoints From 589efcdc5fe61754112312ef275fa6f164362efc Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 12:03:47 +0100 Subject: [PATCH 02/16] is_chat_completion_request helper, remove nesting --- backend/main.py | 292 +++++++++++++++++++++++------------------------- 1 file changed, 142 insertions(+), 150 deletions(-) diff --git a/backend/main.py b/backend/main.py index 3fc6b8db5..b1cd298a2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -605,129 +605,126 @@ async def chat_completion_files_handler(body): } +def is_chat_completion_request(request): + return request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) + + class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ): - log.debug(f"request.url.path: {request.url.path}") + if not is_chat_completion_request(request): + return await call_next(request) + log.debug(f"request.url.path: {request.url.path}") - try: - body, model, user = await get_body_and_model_and_user(request) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, + try: + body, model, user = await get_body_and_model_and_user(request) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + metadata = { + "chat_id": body.pop("chat_id", None), + "message_id": body.pop("id", None), + "session_id": body.pop("session_id", None), + "valves": body.pop("valves", None), + } + + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + + # Initialize data_items to store additional data to be sent to the client + data_items = [] + + # Initialize context, and citations + contexts = [] + citations = [] + + try: + body, flags = await chat_completion_functions_handler( + body, model, user, __event_emitter__, __event_call__ + ) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + try: + body, flags = await chat_completion_tools_handler( + body, user, __event_emitter__, __event_call__ + ) + + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + print(e) + pass + + try: + body, flags = await chat_completion_files_handler(body) + + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + print(e) + pass + + # If context is not empty, insert it into the messages + if len(contexts) > 0: + context_string = "/n".join(contexts).strip() + prompt = get_last_user_message(body["messages"]) + + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + body["messages"] = prepend_to_first_user_message_content( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], ) - - metadata = { - "chat_id": body.pop("chat_id", None), - "message_id": body.pop("id", None), - "session_id": body.pop("session_id", None), - "valves": body.pop("valves", None), - } - - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - - # Initialize data_items to store additional data to be sent to the client - data_items = [] - - # Initialize context, and citations - contexts = [] - citations = [] - - try: - body, flags = await chat_completion_functions_handler( - body, model, user, __event_emitter__, __event_call__ - ) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - try: - body, flags = await chat_completion_tools_handler( - body, user, __event_emitter__, __event_call__ - ) - - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass - - try: - body, flags = await chat_completion_files_handler(body) - - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass - - # If context is not empty, insert it into the messages - if len(contexts) > 0: - context_string = "/n".join(contexts).strip() - prompt = get_last_user_message(body["messages"]) - - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": - body["messages"] = prepend_to_first_user_message_content( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - else: - body["messages"] = add_or_update_system_message( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - - # If there are citations, add them to the data_items - if len(citations) > 0: - data_items.append({"citations": citations}) - - body["metadata"] = metadata - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] - - response = await call_next(request) - if isinstance(response, StreamingResponse): - # If it's a streaming response, inject it as SSE event or NDJSON line - content_type = response.headers.get("Content-Type") - if "text/event-stream" in content_type: - return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, data_items), - ) - if "application/x-ndjson" in content_type: - return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, data_items), - ) - - return response else: - return response + body["messages"] = add_or_update_system_message( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) + + # If there are citations, add them to the data_items + if len(citations) > 0: + data_items.append({"citations": citations}) + + body["metadata"] = metadata + modified_body_bytes = json.dumps(body).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] - # If it's not a chat completion request, just pass it through response = await call_next(request) - return response + if isinstance(response, StreamingResponse): + # If it's a streaming response, inject it as SSE event or NDJSON line + content_type = response.headers.get("Content-Type") + if "text/event-stream" in content_type: + return StreamingResponse( + self.openai_stream_wrapper(response.body_iterator, data_items), + ) + if "application/x-ndjson" in content_type: + return StreamingResponse( + self.ollama_stream_wrapper(response.body_iterator, data_items), + ) + + return response + else: + return response async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} @@ -820,44 +817,39 @@ def filter_pipeline(payload, user): class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if request.method == "POST" and ( - "/ollama/api/chat" in request.url.path - or "/chat/completions" in request.url.path - ): - log.debug(f"request.url.path: {request.url.path}") + if not is_chat_completion_request(request): + return await call_next(request) - # Read the original request body - body = await request.body() - # Decode body to string - body_str = body.decode("utf-8") - # Parse string to JSON - data = json.loads(body_str) if body_str else {} + log.debug(f"request.url.path: {request.url.path}") - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + user = get_current_user( + request, + get_http_authorization_cred(request.headers.get("Authorization")), + ) + + try: + data = filter_pipeline(data, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, ) - try: - data = filter_pipeline(data, user) - except Exception as e: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - - modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] + modified_body_bytes = json.dumps(data).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] response = await call_next(request) return response From 23f1bee7bd3af6bc935934e5d31240724928cf81 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 12:11:41 +0100 Subject: [PATCH 03/16] cleanup --- backend/main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/backend/main.py b/backend/main.py index b1cd298a2..d48f3445b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -677,7 +677,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if len(contexts) > 0: context_string = "/n".join(contexts).strip() prompt = get_last_user_message(body["messages"]) - + if prompt is None: + raise Exception("No user message found") # Workaround for Ollama 2.0+ system prompt issue # TODO: replace with add_or_update_system_message if model["owned_by"] == "ollama": @@ -722,9 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): self.ollama_stream_wrapper(response.body_iterator, data_items), ) - return response - else: - return response + return response async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} From 60003c976aadfd44f8ff25850872301e564888e6 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 12:33:27 +0100 Subject: [PATCH 04/16] rename to chat_completions_inlet_handler for clarity --- backend/main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/backend/main.py b/backend/main.py index d48f3445b..b110e9631 100644 --- a/backend/main.py +++ b/backend/main.py @@ -437,7 +437,7 @@ async def get_function_call_response( return None, None, False -async def chat_completion_functions_handler( +async def chat_completion_inlets_handler( body, model, user, __event_emitter__, __event_call__ ): skip_files = None @@ -637,14 +637,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): __event_call__ = get_event_call(metadata) # Initialize data_items to store additional data to be sent to the client + # Initalize contexts and citation data_items = [] - - # Initialize context, and citations contexts = [] citations = [] try: - body, flags = await chat_completion_functions_handler( + body, flags = await chat_completion_inlets_handler( body, model, user, __event_emitter__, __event_call__ ) except Exception as e: From 556141cdd84db9483afeb3c68d4dc795847b5fc8 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 12:50:09 +0100 Subject: [PATCH 05/16] refactor task --- backend/main.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/backend/main.py b/backend/main.py index b110e9631..bd7b6c8f6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1087,12 +1087,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u model = app.state.MODELS[model_id] # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. - task = None - if "task" in form_data: - task = form_data["task"] - del form_data["task"] - - if task: + if task := form_data.pop("task", None): if "metadata" in form_data: form_data["metadata"]["task"] = task else: From 0c9119d6199f61c623ea06e0899a2f36c7ecc09d Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 13:04:01 +0100 Subject: [PATCH 06/16] move task to metadata --- backend/main.py | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/backend/main.py b/backend/main.py index bd7b6c8f6..0099aabb8 100644 --- a/backend/main.py +++ b/backend/main.py @@ -317,7 +317,7 @@ async def get_function_call_response( {"role": "user", "content": f"Query: {prompt}"}, ], "stream": False, - "task": str(TASKS.FUNCTION_CALLING), + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, } try: @@ -788,19 +788,21 @@ def filter_pipeline(payload, user): url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) + if key == "": + continue - r.raise_for_status() - payload = r.json() + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() except Exception as e: # Handle connection error here print(f"Connection error: {e}") @@ -1086,13 +1088,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ) model = app.state.MODELS[model_id] - # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. - if task := form_data.pop("task", None): - if "metadata" in form_data: - form_data["metadata"]["task"] = task - else: - form_data["metadata"] = {"task": task} - if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": @@ -1469,7 +1464,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 50, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.TITLE_GENERATION), + "metadata": {"task": str(TASKS.TITLE_GENERATION)}, } log.debug(payload) @@ -1522,7 +1517,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, - "task": str(TASKS.QUERY_GENERATION), + "metadata": {"task": str(TASKS.QUERY_GENERATION)}, } print(payload) @@ -1579,7 +1574,7 @@ Message: """{{prompt}}""" "stream": False, "max_tokens": 4, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.EMOJI_GENERATION), + "metadata": {"task": str(TASKS.EMOJI_GENERATION)}, } log.debug(payload) From 9fb70969d729af13960ddf9b4be5df753299d57e Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 13:40:04 +0100 Subject: [PATCH 07/16] factor out get_content_from_response --- backend/main.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/backend/main.py b/backend/main.py index 0099aabb8..44fdc6298 100644 --- a/backend/main.py +++ b/backend/main.py @@ -282,6 +282,21 @@ def get_filter_function_ids(model): return filter_ids +async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + async def get_function_call_response( messages, files, @@ -293,6 +308,9 @@ async def get_function_call_response( __event_call__=None, ): tool = Tools.get_tool_by_id(tool_id) + if tool is None: + return None, None, False + tools_specs = json.dumps(tool.specs, indent=2) content = tools_function_calling_generation_template(template, tools_specs) @@ -327,21 +345,9 @@ async def get_function_call_response( model = app.state.MODELS[task_model_id] - response = None try: response = await generate_chat_completions(form_data=payload, user=user) - content = None - - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] + content = await get_content_from_response(response) if content is None: return None, None, False @@ -351,8 +357,6 @@ async def get_function_call_response( result = json.loads(content) print(result) - citation = None - if "name" not in result: return None, None, False @@ -375,6 +379,7 @@ async def get_function_call_response( function = getattr(toolkit_module, result["name"]) function_result = None + citation = None try: # Get the signature of the function sig = inspect.signature(function) @@ -1091,7 +1096,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": - print("generate_ollama_chat_completion") return await generate_ollama_chat_completion(form_data, user=user) else: return await generate_openai_chat_completion(form_data, user=user) From a68b918cbbc84e4b1ecc806c00c7d425c60d31b8 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 14:25:20 +0100 Subject: [PATCH 08/16] refactor get_function_call_response --- backend/main.py | 150 +++++++++++++++++++++++------------------------- 1 file changed, 73 insertions(+), 77 deletions(-) diff --git a/backend/main.py b/backend/main.py index 44fdc6298..512f3d006 100644 --- a/backend/main.py +++ b/backend/main.py @@ -297,6 +297,30 @@ async def get_content_from_response(response) -> Optional[str]: return content +async def call_tool_from_completion( + result: dict, extra_params: dict, toolkit_module +) -> Optional[str]: + if "name" not in result: + return None + + tool = getattr(toolkit_module, result["name"]) + try: + # Get the signature of the function + sig = inspect.signature(tool) + params = result["parameters"] + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if inspect.iscoroutinefunction(tool): + return await tool(**params) + else: + return tool(**params) + except Exception as e: + print(f"Error: {e}") + return None + + async def get_function_call_response( messages, files, @@ -306,7 +330,7 @@ async def get_function_call_response( user, __event_emitter__=None, __event_call__=None, -): +) -> tuple[Optional[str], Optional[dict], bool]: tool = Tools.get_tool_by_id(tool_id) if tool is None: return None, None, False @@ -343,7 +367,43 @@ async def get_function_call_response( except Exception as e: raise e - model = app.state.MODELS[task_model_id] + if tool_id in webui_app.state.TOOLS: + toolkit_module = webui_app.state.TOOLS[tool_id] + else: + toolkit_module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = toolkit_module + + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(toolkit_module, "UserValves"): + __user__["valves"] = toolkit_module.UserValves( + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + except Exception as e: + print(e) + + extra_params = { + "__model__": app.state.MODELS[task_model_id], + "__id__": tool_id, + "__messages__": messages, + "__files__": files, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__user__": __user__, + } + + file_handler = hasattr(toolkit_module, "file_handler") + + if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) + toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) try: response = await generate_chat_completions(form_data=payload, user=user) @@ -353,85 +413,21 @@ async def get_function_call_response( return None, None, False # Parse the function response - print(f"content: {content}") + log.debug(f"content: {content}") result = json.loads(content) - print(result) - if "name" not in result: - return None, None, False + function_result = await call_tool_from_completion( + result, extra_params, toolkit_module + ) - # Call the function - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module - - file_handler = False - # check if toolkit_module has file_handler self variable - if hasattr(toolkit_module, "file_handler"): - file_handler = True - print("file_handler: ", file_handler) - - if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) - toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) - - function = getattr(toolkit_module, result["name"]) - function_result = None - citation = None - try: - # Get the signature of the function - sig = inspect.signature(function) - params = result["parameters"] - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": tool_id, - "__messages__": messages, - "__files__": files, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, + if hasattr(toolkit_module, "citation") and toolkit_module.citation: + citation = { + "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, + "document": [function_result], + "metadata": [{"source": result["name"]}], } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - # Call the function with the '__user__' parameter included - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(toolkit_module, "UserValves"): - __user__["valves"] = toolkit_module.UserValves( - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(function): - function_result = await function(**params) - else: - function_result = function(**params) - - if hasattr(toolkit_module, "citation") and toolkit_module.citation: - citation = { - "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, - "document": [function_result], - "metadata": [{"source": result["name"]}], - } - except Exception as e: - print(e) + else: + citation = None # Add the function result to the system prompt if function_result is not None: From ff9d899f9c5d8318cacc17c23b1df64a64213eea Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sun, 11 Aug 2024 08:31:40 +0100 Subject: [PATCH 09/16] fix more LSP errors --- backend/main.py | 130 +++++++++++++++++++----------------------------- 1 file changed, 51 insertions(+), 79 deletions(-) diff --git a/backend/main.py b/backend/main.py index 512f3d006..0aa6bf167 100644 --- a/backend/main.py +++ b/backend/main.py @@ -261,6 +261,7 @@ def get_filter_function_ids(model): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -322,14 +323,7 @@ async def call_tool_from_completion( async def get_function_call_response( - messages, - files, - tool_id, - template, - task_model_id, - user, - __event_emitter__=None, - __event_call__=None, + messages, files, tool_id, template, task_model_id, user, extra_params ) -> tuple[Optional[str], Optional[dict], bool]: tool = Tools.get_tool_by_id(tool_id) if tool is None: @@ -373,32 +367,22 @@ async def get_function_call_response( toolkit_module, _ = load_toolkit_module_by_id(tool_id) webui_app.state.TOOLS[tool_id] = toolkit_module - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, + custom_params = { + **extra_params, + "__model__": app.state.MODELS[task_model_id], + "__id__": tool_id, + "__messages__": messages, + "__files__": files, } - try: if hasattr(toolkit_module, "UserValves"): - __user__["valves"] = toolkit_module.UserValves( + custom_params["__user__"]["valves"] = toolkit_module.UserValves( **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) ) except Exception as e: print(e) - extra_params = { - "__model__": app.state.MODELS[task_model_id], - "__id__": tool_id, - "__messages__": messages, - "__files__": files, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__user__": __user__, - } - file_handler = hasattr(toolkit_module, "file_handler") if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): @@ -417,7 +401,7 @@ async def get_function_call_response( result = json.loads(content) function_result = await call_tool_from_completion( - result, extra_params, toolkit_module + result, custom_params, toolkit_module ) if hasattr(toolkit_module, "citation") and toolkit_module.citation: @@ -438,9 +422,7 @@ async def get_function_call_response( return None, None, False -async def chat_completion_inlets_handler( - body, model, user, __event_emitter__, __event_call__ -): +async def chat_completion_inlets_handler(body, model, extra_params): skip_files = None filter_ids = get_filter_function_ids(model) @@ -476,38 +458,18 @@ async def chat_completion_inlets_handler( params = {"body": body} # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } + custom_params = {**extra_params, "__model__": model, "__id__": filter_id} + if hasattr(function_module, "UserValves") and "__user__" in sig.parameters: + uid = custom_params["__user__"]["id"] + custom_params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(filter_id, uid) + ) # Add extra params in contained in function signature - for key, value in extra_params.items(): + for key, value in custom_params.items(): if key in sig.parameters: params[key] = value - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - if inspect.iscoroutinefunction(inlet): body = await inlet(**params) else: @@ -524,7 +486,7 @@ async def chat_completion_inlets_handler( return body, {} -async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__): +async def chat_completion_tools_handler(body, user, extra_params): skip_files = None contexts = [] @@ -547,8 +509,7 @@ async def chat_completion_tools_handler(body, user, __event_emitter__, __event_c template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, task_model_id=task_model_id, user=user, - __event_emitter__=__event_emitter__, - __event_call__=__event_call__, + extra_params=extra_params, ) print(file_handler) @@ -584,10 +545,7 @@ async def chat_completion_files_handler(body): contexts = [] citations = None - if "files" in body: - files = body["files"] - del body["files"] - + if files := body.pop("files", None): contexts, citations = get_rag_context( files=files, messages=body["messages"], @@ -634,8 +592,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): "valves": body.pop("valves", None), } - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + extra_params = { + "__user__": __user__, + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + } # Initialize data_items to store additional data to be sent to the client # Initalize contexts and citation @@ -645,7 +613,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: body, flags = await chat_completion_inlets_handler( - body, model, user, __event_emitter__, __event_call__ + body, model, extra_params ) except Exception as e: return JSONResponse( @@ -654,10 +622,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) try: - body, flags = await chat_completion_tools_handler( - body, user, __event_emitter__, __event_call__ - ) - + body, flags = await chat_completion_tools_handler(body, user, extra_params) contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: @@ -666,7 +631,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: body, flags = await chat_completion_files_handler(body) - contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: @@ -713,7 +677,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): response = await call_next(request) if isinstance(response, StreamingResponse): # If it's a streaming response, inject it as SSE event or NDJSON line - content_type = response.headers.get("Content-Type") + content_type = response.headers["Content-Type"] if "text/event-stream" in content_type: return StreamingResponse( self.openai_stream_wrapper(response.body_iterator, data_items), @@ -832,7 +796,7 @@ class PipelineMiddleware(BaseHTTPMiddleware): user = get_current_user( request, - get_http_authorization_cred(request.headers.get("Authorization")), + get_http_authorization_cred(request.headers["Authorization"]), ) try: @@ -1015,6 +979,8 @@ async def get_all_models(): model["actions"] = [] for action_id in action_ids: action = Functions.get_function_by_id(action_id) + if action is None: + raise Exception(f"Action not found: {action_id}") if action_id in webui_app.state.FUNCTIONS: function_module = webui_app.state.FUNCTIONS[action_id] @@ -1022,6 +988,10 @@ async def get_all_models(): function_module, _, _ = load_function_module_by_id(action_id) webui_app.state.FUNCTIONS[action_id] = function_module + icon_url = None + if action.meta.manifest is not None: + icon_url = action.meta.manifest.get("icon_url", None) + if hasattr(function_module, "actions"): actions = function_module.actions model["actions"].extend( @@ -1032,9 +1002,7 @@ async def get_all_models(): "name", f"{action.name} ({_action['id']})" ), "description": action.meta.description, - "icon_url": _action.get( - "icon_url", action.meta.manifest.get("icon_url", None) - ), + "icon_url": _action.get("icon_url", icon_url), } for _action in actions ] @@ -1045,7 +1013,7 @@ async def get_all_models(): "id": action_id, "name": action.name, "description": action.meta.description, - "icon_url": action.meta.manifest.get("icon_url", None), + "icon_url": icon_url, } ) @@ -1175,6 +1143,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -1631,7 +1600,7 @@ async def upload_pipeline( ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file - if not file.filename.endswith(".py"): + if not (file.filename and file.filename.endswith(".py")): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Only Python (.py) files are allowed.", @@ -2080,7 +2049,10 @@ async def oauth_login(provider: str, request: Request): redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( "oauth_callback", provider=provider ) - return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) + client = oauth.create_client(provider) + if client is None: + raise HTTPException(404) + return await client.authorize_redirect(request, redirect_uri) # OAuth login logic is as follows: From e86688284add6a5a6c37584f666019d4d15745b5 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sun, 11 Aug 2024 09:05:22 +0100 Subject: [PATCH 10/16] factor out get_function_calling_payload --- backend/main.py | 45 +++++++++++++++++++++------------------------ 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/backend/main.py b/backend/main.py index 0aa6bf167..fbfbf2f2d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -322,6 +322,26 @@ async def call_tool_from_completion( return None +def get_function_calling_payload(messages, task_model_id, content): + user_message = get_last_user_message(messages) + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ) + + prompt = f"History:\n{history}\nQuery: {user_message}" + + return { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, + } + + async def get_function_call_response( messages, files, tool_id, template, task_model_id, user, extra_params ) -> tuple[Optional[str], Optional[dict], bool]: @@ -331,30 +351,7 @@ async def get_function_call_response( tools_specs = json.dumps(tool.specs, indent=2) content = tools_function_calling_generation_template(template, tools_specs) - - user_message = get_last_user_message(messages) - prompt = ( - "History:\n" - + "\n".join( - [ - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ] - ) - + f"\nQuery: {user_message}" - ) - - print(prompt) - - payload = { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": f"Query: {prompt}"}, - ], - "stream": False, - "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, - } + payload = get_function_calling_payload(messages, task_model_id, content) try: payload = filter_pipeline(payload, user) From 2efcda837cf3ef6934f2666e841a67e8231bd76c Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sun, 11 Aug 2024 09:07:12 +0100 Subject: [PATCH 11/16] add try: except back --- backend/main.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/backend/main.py b/backend/main.py index fbfbf2f2d..14d5a604a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -457,10 +457,13 @@ async def chat_completion_inlets_handler(body, model, extra_params): # Extra parameters to be passed to the function custom_params = {**extra_params, "__model__": model, "__id__": filter_id} if hasattr(function_module, "UserValves") and "__user__" in sig.parameters: - uid = custom_params["__user__"]["id"] - custom_params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id(filter_id, uid) - ) + try: + uid = custom_params["__user__"]["id"] + custom_params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(filter_id, uid) + ) + except Exception as e: + print(e) # Add extra params in contained in function signature for key, value in custom_params.items(): From 790bdcf9fcd64c0d033b263647d6e619384a9e41 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sun, 11 Aug 2024 14:56:16 +0100 Subject: [PATCH 12/16] rename tool calling helpers to use 'tool' instead of 'function' --- backend/main.py | 12 ++++++------ backend/utils/task.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/main.py b/backend/main.py index 14d5a604a..c992f175f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -72,7 +72,7 @@ from utils.utils import ( from utils.task import ( title_generation_template, search_query_generation_template, - tools_function_calling_generation_template, + tool_calling_generation_template, ) from utils.misc import ( get_last_user_message, @@ -322,7 +322,7 @@ async def call_tool_from_completion( return None -def get_function_calling_payload(messages, task_model_id, content): +def get_tool_calling_payload(messages, task_model_id, content): user_message = get_last_user_message(messages) history = "\n".join( f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" @@ -342,7 +342,7 @@ def get_function_calling_payload(messages, task_model_id, content): } -async def get_function_call_response( +async def get_tool_call_response( messages, files, tool_id, template, task_model_id, user, extra_params ) -> tuple[Optional[str], Optional[dict], bool]: tool = Tools.get_tool_by_id(tool_id) @@ -350,8 +350,8 @@ async def get_function_call_response( return None, None, False tools_specs = json.dumps(tool.specs, indent=2) - content = tools_function_calling_generation_template(template, tools_specs) - payload = get_function_calling_payload(messages, task_model_id, content) + content = tool_calling_generation_template(template, tools_specs) + payload = get_tool_calling_payload(messages, task_model_id, content) try: payload = filter_pipeline(payload, user) @@ -502,7 +502,7 @@ async def chat_completion_tools_handler(body, user, extra_params): for tool_id in body["tool_ids"]: print(tool_id) try: - response, citation, file_handler = await get_function_call_response( + response, citation, file_handler = await get_tool_call_response( messages=body["messages"], files=body.get("files", []), tool_id=tool_id, diff --git a/backend/utils/task.py b/backend/utils/task.py index 1b2276c9c..37c174d3d 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -121,6 +121,6 @@ def search_query_generation_template( return template -def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: +def tool_calling_generation_template(template: str, tools_specs: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) return template From d598d4bb9397effab7df147ca3ef0913a787c028 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sun, 11 Aug 2024 15:16:57 +0100 Subject: [PATCH 13/16] typing and tweaks --- backend/apps/webui/models/users.py | 13 +++------- backend/main.py | 41 ++++++++++++++++++------------ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 36dfa4f85..b6e85e2ca 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -1,12 +1,10 @@ -from pydantic import BaseModel, ConfigDict, parse_obj_as -from typing import Union, Optional +from pydantic import BaseModel, ConfigDict +from typing import Optional import time from sqlalchemy import String, Column, BigInteger, Text -from utils.misc import get_gravatar_url - -from apps.webui.internal.db import Base, JSONField, Session, get_db +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.chats import Chats #################### @@ -78,7 +76,6 @@ class UserUpdateForm(BaseModel): class UsersTable: - def insert_new_user( self, id: str, @@ -122,7 +119,6 @@ class UsersTable: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(api_key=api_key).first() return UserModel.model_validate(user) except Exception: @@ -131,7 +127,6 @@ class UsersTable: def get_user_by_email(self, email: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) except Exception: @@ -140,7 +135,6 @@ class UsersTable: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: with get_db() as db: - user = db.query(User).filter_by(oauth_sub=sub).first() return UserModel.model_validate(user) except Exception: @@ -195,7 +189,6 @@ class UsersTable: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: with get_db() as db: - db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) diff --git a/backend/main.py b/backend/main.py index c992f175f..0a1768fd6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -57,7 +57,7 @@ from apps.webui.models.auths import Auths from apps.webui.models.models import Models from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions -from apps.webui.models.users import Users +from apps.webui.models.users import Users, User from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id @@ -322,7 +322,7 @@ async def call_tool_from_completion( return None -def get_tool_calling_payload(messages, task_model_id, content): +def get_tool_call_payload(messages, task_model_id, content): user_message = get_last_user_message(messages) history = "\n".join( f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" @@ -345,13 +345,19 @@ def get_tool_calling_payload(messages, task_model_id, content): async def get_tool_call_response( messages, files, tool_id, template, task_model_id, user, extra_params ) -> tuple[Optional[str], Optional[dict], bool]: + """ + return: tuple of (function_result, citation, file_handler) where + - function_result: Optional[str] is the result of the tool call if successful + - citation: Optional[dict] is the citation object if the tool has citation + - file_handler: bool, True if tool handles files + """ tool = Tools.get_tool_by_id(tool_id) if tool is None: return None, None, False tools_specs = json.dumps(tool.specs, indent=2) content = tool_calling_generation_template(template, tools_specs) - payload = get_tool_calling_payload(messages, task_model_id, content) + payload = get_tool_call_payload(messages, task_model_id, content) try: payload = filter_pipeline(payload, user) @@ -486,7 +492,9 @@ async def chat_completion_inlets_handler(body, model, extra_params): return body, {} -async def chat_completion_tools_handler(body, user, extra_params): +async def chat_completion_tools_handler( + body: dict, user: User, extra_params: dict +) -> tuple[dict, dict]: skip_files = None contexts = [] @@ -498,21 +506,22 @@ async def chat_completion_tools_handler(body, user, extra_params): if "tool_ids" not in body: return body, {} - print(body["tool_ids"]) + log.debug(f"tool_ids: {body['tool_ids']}") + kwargs = { + "messages": body["messages"], + "files": body.get("files", []), + "template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + "task_model_id": task_model_id, + "user": user, + "extra_params": extra_params, + } for tool_id in body["tool_ids"]: - print(tool_id) + log.debug(f"{tool_id=}") try: response, citation, file_handler = await get_tool_call_response( - messages=body["messages"], - files=body.get("files", []), - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - extra_params=extra_params, + tool_id=tool_id, **kwargs ) - print(file_handler) if isinstance(response, str): contexts.append(response) @@ -526,10 +535,10 @@ async def chat_completion_tools_handler(body, user, extra_params): skip_files = True except Exception as e: - print(f"Error: {e}") + log.exception(f"Error: {e}") del body["tool_ids"] - print(f"tool_contexts: {contexts}") + log.debug(f"tool_contexts: {contexts}") if skip_files: if "files" in body: From 6df6170c4493020baf937f18b609281523dec582 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 12 Aug 2024 14:48:57 +0100 Subject: [PATCH 14/16] add get_configured_tools --- backend/main.py | 99 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 81 insertions(+), 18 deletions(-) diff --git a/backend/main.py b/backend/main.py index 0a1768fd6..50b83f437 100644 --- a/backend/main.py +++ b/backend/main.py @@ -51,13 +51,13 @@ from apps.webui.internal.db import Session from pydantic import BaseModel -from typing import Optional +from typing import Optional, Callable, Awaitable from apps.webui.models.auths import Auths from apps.webui.models.models import Models from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions -from apps.webui.models.users import Users, User +from apps.webui.models.users import Users, UserModel from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id @@ -356,6 +356,7 @@ async def get_tool_call_response( return None, None, False tools_specs = json.dumps(tool.specs, indent=2) + log.debug(f"{tool.specs=}") content = tool_calling_generation_template(template, tools_specs) payload = get_tool_call_payload(messages, task_model_id, content) @@ -492,14 +493,81 @@ async def chat_completion_inlets_handler(body, model, extra_params): return body, {} +def get_tool_with_custom_params( + tool: Callable, custom_params: dict +) -> Callable[..., Awaitable]: + sig = inspect.signature(tool) + extra_params = { + key: value for key, value in custom_params.items() if key in sig.parameters + } + is_coroutine = inspect.iscoroutinefunction(tool) + + async def new_tool(**kwargs): + extra_kwargs = kwargs | extra_params + if is_coroutine: + return await tool(**extra_kwargs) + return tool(**extra_kwargs) + + return new_tool + + +def get_configured_tools( + tool_ids: list[str], extra_params: dict, user: UserModel +) -> dict[str, dict]: + tools = {} + for tool_id in tool_ids: + toolkit = Tools.get_tool_by_id(tool_id) + if toolkit is None: + continue + + module = webui_app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = module + + more_params = {"__id__": tool_id} + custom_params = more_params | extra_params + has_citation = hasattr(module, "citation") and module.citation + handles_files = hasattr(module, "file_handler") and module.file_handler + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + + if hasattr(module, "UserValves"): + custom_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in toolkit.specs: + name = spec["name"] + callable = getattr(module, name) + # convert to function that takes only model params and inserts custom params + custom_callable = get_tool_with_custom_params(callable, custom_params) + + tool_dict = { + "spec": spec, + "citation": has_citation, + "file_handler": handles_files, + "toolkit_module": module, + "callable": custom_callable, + } + if name in tools: + log.warning(f"Tool {name} already exists in another toolkit!") + mod_name = tools[name]["toolkit_module"].__name__ + log.warning(f"Collision between {toolkit} and {mod_name}.") + log.warning(f"Discarding {toolkit}.{name}") + else: + tools[name] = tool_dict + + return tools + + async def chat_completion_tools_handler( - body: dict, user: User, extra_params: dict + body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: - skip_files = None - + skip_files = False contexts = [] - citations = None - + citations = [] task_model_id = get_task_model_id(body["model"]) # If tool_ids field is present, call the functions @@ -507,6 +575,7 @@ async def chat_completion_tools_handler( return body, {} log.debug(f"tool_ids: {body['tool_ids']}") + log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}") kwargs = { "messages": body["messages"], "files": body.get("files", []), @@ -515,6 +584,7 @@ async def chat_completion_tools_handler( "user": user, "extra_params": extra_params, } + for tool_id in body["tool_ids"]: log.debug(f"{tool_id=}") try: @@ -526,10 +596,7 @@ async def chat_completion_tools_handler( contexts.append(response) if citation: - if citations is None: - citations = [citation] - else: - citations.append(citation) + citations.append(citation) if file_handler: skip_files = True @@ -540,14 +607,10 @@ async def chat_completion_tools_handler( del body["tool_ids"] log.debug(f"tool_contexts: {contexts}") - if skip_files: - if "files" in body: - del body["files"] + if skip_files and "files" in body: + del body["files"] - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), - } + return body, {"contexts": contexts, "citations": citations} async def chat_completion_files_handler(body): From fdc89cbceeca5286d83dd5717c20052a73e7ab3a Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 12 Aug 2024 15:53:47 +0100 Subject: [PATCH 15/16] tool calling refactor --- backend/main.py | 204 +++++++++++++++--------------------------------- 1 file changed, 62 insertions(+), 142 deletions(-) diff --git a/backend/main.py b/backend/main.py index 50b83f437..e5b7d174a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -298,30 +298,6 @@ async def get_content_from_response(response) -> Optional[str]: return content -async def call_tool_from_completion( - result: dict, extra_params: dict, toolkit_module -) -> Optional[str]: - if "name" not in result: - return None - - tool = getattr(toolkit_module, result["name"]) - try: - # Get the signature of the function - sig = inspect.signature(tool) - params = result["parameters"] - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if inspect.iscoroutinefunction(tool): - return await tool(**params) - else: - return tool(**params) - except Exception as e: - print(f"Error: {e}") - return None - - def get_tool_call_payload(messages, task_model_id, content): user_message = get_last_user_message(messages) history = "\n".join( @@ -342,90 +318,6 @@ def get_tool_call_payload(messages, task_model_id, content): } -async def get_tool_call_response( - messages, files, tool_id, template, task_model_id, user, extra_params -) -> tuple[Optional[str], Optional[dict], bool]: - """ - return: tuple of (function_result, citation, file_handler) where - - function_result: Optional[str] is the result of the tool call if successful - - citation: Optional[dict] is the citation object if the tool has citation - - file_handler: bool, True if tool handles files - """ - tool = Tools.get_tool_by_id(tool_id) - if tool is None: - return None, None, False - - tools_specs = json.dumps(tool.specs, indent=2) - log.debug(f"{tool.specs=}") - content = tool_calling_generation_template(template, tools_specs) - payload = get_tool_call_payload(messages, task_model_id, content) - - try: - payload = filter_pipeline(payload, user) - except Exception as e: - raise e - - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module - - custom_params = { - **extra_params, - "__model__": app.state.MODELS[task_model_id], - "__id__": tool_id, - "__messages__": messages, - "__files__": files, - } - try: - if hasattr(toolkit_module, "UserValves"): - custom_params["__user__"]["valves"] = toolkit_module.UserValves( - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - - except Exception as e: - print(e) - - file_handler = hasattr(toolkit_module, "file_handler") - - if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) - toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) - - try: - response = await generate_chat_completions(form_data=payload, user=user) - content = await get_content_from_response(response) - - if content is None: - return None, None, False - - # Parse the function response - log.debug(f"content: {content}") - result = json.loads(content) - - function_result = await call_tool_from_completion( - result, custom_params, toolkit_module - ) - - if hasattr(toolkit_module, "citation") and toolkit_module.citation: - citation = { - "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, - "document": [function_result], - "metadata": [{"source": result["name"]}], - } - else: - citation = None - - # Add the function result to the system prompt - if function_result is not None: - return function_result, citation, file_handler - except Exception as e: - print(f"Error: {e}") - - return None, None, False - - async def chat_completion_inlets_handler(body, model, extra_params): skip_files = None @@ -511,6 +403,7 @@ def get_tool_with_custom_params( return new_tool +# Mutation on extra_params def get_configured_tools( tool_ids: list[str], extra_params: dict, user: UserModel ) -> dict[str, dict]: @@ -525,8 +418,7 @@ def get_configured_tools( module, _ = load_toolkit_module_by_id(tool_id) webui_app.state.TOOLS[tool_id] = module - more_params = {"__id__": tool_id} - custom_params = more_params | extra_params + extra_params["__id__"] = tool_id has_citation = hasattr(module, "citation") and module.citation handles_files = hasattr(module, "file_handler") and module.file_handler if hasattr(module, "valves") and hasattr(module, "Valves"): @@ -534,27 +426,27 @@ def get_configured_tools( module.valves = module.Valves(**valves) if hasattr(module, "UserValves"): - custom_params["__user__"]["valves"] = module.UserValves( # type: ignore + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) ) for spec in toolkit.specs: name = spec["name"] callable = getattr(module, name) + # convert to function that takes only model params and inserts custom params - custom_callable = get_tool_with_custom_params(callable, custom_params) + custom_callable = get_tool_with_custom_params(callable, extra_params) tool_dict = { "spec": spec, "citation": has_citation, "file_handler": handles_files, - "toolkit_module": module, + "toolkit_id": tool_id, "callable": custom_callable, } if name in tools: log.warning(f"Tool {name} already exists in another toolkit!") - mod_name = tools[name]["toolkit_module"].__name__ - log.warning(f"Collision between {toolkit} and {mod_name}.") + log.warning(f"Collision between {toolkit} and {tool_id}.") log.warning(f"Discarding {toolkit}.{name}") else: tools[name] = tool_dict @@ -571,40 +463,68 @@ async def chat_completion_tools_handler( task_model_id = get_task_model_id(body["model"]) # If tool_ids field is present, call the functions - if "tool_ids" not in body: + tool_ids = body.pop("tool_ids", None) + if not tool_ids: return body, {} - log.debug(f"tool_ids: {body['tool_ids']}") - log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}") - kwargs = { - "messages": body["messages"], - "files": body.get("files", []), - "template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - "task_model_id": task_model_id, - "user": user, - "extra_params": extra_params, + log.debug(f"{tool_ids=}") + custom_params = { + **extra_params, + "__model__": app.state.MODELS[task_model_id], + "__messages__": body["messages"], + "__files__": body.get("files", []), } + configured_tools = get_configured_tools(tool_ids, custom_params, user) - for tool_id in body["tool_ids"]: - log.debug(f"{tool_id=}") + log.info(f"{configured_tools=}") + + specs = [tool["spec"] for tool in configured_tools.values()] + tools_specs = json.dumps(specs) + template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + content = tool_calling_generation_template(template, tools_specs) + payload = get_tool_call_payload(body["messages"], task_model_id, content) + try: + payload = filter_pipeline(payload, user) + except Exception as e: + raise e + + try: + response = await generate_chat_completions(form_data=payload, user=user) + log.debug(f"{response=}") + content = await get_content_from_response(response) + log.debug(f"{content=}") + if content is None: + return body, {} + + result = json.loads(content) + tool_name = result.get("name", None) + if tool_name not in configured_tools: + return body, {} + + tool_params = result.get("parameters", {}) + toolkit_id = configured_tools[tool_name]["toolkit_id"] try: - response, citation, file_handler = await get_tool_call_response( - tool_id=tool_id, **kwargs - ) - - if isinstance(response, str): - contexts.append(response) - - if citation: - citations.append(citation) - - if file_handler: - skip_files = True - + tool_output = await configured_tools[tool_name]["callable"](**tool_params) except Exception as e: - log.exception(f"Error: {e}") + tool_output = str(e) + if configured_tools[tool_name]["citation"]: + citations.append( + { + "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"}, + "document": [tool_output], + "metadata": [{"source": tool_name}], + } + ) + if configured_tools[tool_name]["file_handler"]: + skip_files = True + + if isinstance(tool_output, str): + contexts.append(tool_output) + + except Exception as e: + print(f"Error: {e}") + content = None - del body["tool_ids"] log.debug(f"tool_contexts: {contexts}") if skip_files and "files" in body: From 4042219b3e3a28f65116baf2d90fb8d399e03f38 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 14 Aug 2024 20:40:10 +0100 Subject: [PATCH 16/16] minor refac --- backend/main.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/backend/main.py b/backend/main.py index e5b7d174a..873116dd7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -378,9 +378,8 @@ async def chat_completion_inlets_handler(body, model, extra_params): print(f"Error: {e}") raise e - if skip_files: - if "files" in body: - del body["files"] + if skip_files and "files" in body: + del body["files"] return body, {} @@ -431,12 +430,17 @@ def get_configured_tools( ) for spec in toolkit.specs: + # TODO: Fix hack for OpenAI API + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" name = spec["name"] callable = getattr(module, name) # convert to function that takes only model params and inserts custom params custom_callable = get_tool_with_custom_params(callable, extra_params) + # TODO: This needs to be a pydantic model tool_dict = { "spec": spec, "citation": has_citation, @@ -444,6 +448,7 @@ def get_configured_tools( "toolkit_id": tool_id, "callable": custom_callable, } + # TODO: if collision, prepend toolkit name if name in tools: log.warning(f"Tool {name} already exists in another toolkit!") log.warning(f"Collision between {toolkit} and {tool_id}.") @@ -533,9 +538,9 @@ async def chat_completion_tools_handler( return body, {"contexts": contexts, "citations": citations} -async def chat_completion_files_handler(body): +async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: contexts = [] - citations = None + citations = [] if files := body.pop("files", None): contexts, citations = get_rag_context( @@ -550,10 +555,7 @@ async def chat_completion_files_handler(body): log.debug(f"rag_contexts: {contexts}, citations: {citations}") - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), - } + return body, {"contexts": contexts, "citations": citations} def is_chat_completion_request(request): @@ -618,16 +620,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: - print(e) - pass + log.exception(e) try: body, flags = await chat_completion_files_handler(body) contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: - print(e) - pass + log.exception(e) # If context is not empty, insert it into the messages if len(contexts) > 0: