diff --git a/backend/main.py b/backend/main.py index e058bc58d..d141437b2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -218,25 +218,6 @@ origins = ["*"] ################################## -async def get_body_and_model_and_user(request): - # Read the original request body - body = await request.body() - body_str = body.decode("utf-8") - body = json.loads(body_str) if body_str else {} - - model_id = body["model"] - if model_id not in app.state.MODELS: - raise Exception("Model not found") - model = app.state.MODELS[model_id] - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) - - return body, model, user - - def get_task_model_id(default_model_id): # Set the task model task_model_id = default_model_id @@ -283,26 +264,6 @@ def get_filter_function_ids(model): return filter_ids -def get_tools_function_calling_payload(messages, task_model_id, content): - user_message = get_last_user_message(messages) - history = "\n".join( - f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" - for message in messages[::-1][:4] - ) - - prompt = f"History:\n{history}\nQuery: {user_message}" - - return { - "model": task_model_id, - "messages": [ - {"role": "system", "content": content}, - {"role": "user", "content": f"Query: {prompt}"}, - ], - "stream": False, - "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, - } - - async def chat_completion_filter_functions_handler(body, model, extra_params): skip_files = None @@ -369,12 +330,32 @@ async def chat_completion_filter_functions_handler(body, model, extra_params): return body, {} +def get_tools_function_calling_payload(messages, task_model_id, content): + user_message = get_last_user_message(messages) + history = "\n".join( + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ) + + prompt = f"History:\n{history}\nQuery: {user_message}" + + return { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, + } + + def apply_extra_params_to_tool_function( - function: Callable, custom_params: dict + function: Callable, extra_params: dict ) -> Callable[..., Awaitable]: sig = inspect.signature(function) extra_params = { - key: value for key, value in custom_params.items() if key in sig.parameters + key: value for key, value in extra_params.items() if key in sig.parameters } is_coroutine = inspect.iscoroutinefunction(function) @@ -511,27 +492,27 @@ async def chat_completion_tools_handler( return body, {} result = json.loads(content) - tool_name = result.get("name", None) - if tool_name not in tools: + + tool_function_name = result.get("name", None) + if tool_function_name not in tools: return body, {} - tool_params = result.get("parameters", {}) - toolkit_id = tools[tool_name]["toolkit_id"] + tool_function_params = result.get("parameters", {}) try: - tool_output = await tools[tool_name]["callable"](**tool_params) + tool_output = await tools[tool_function_name]["callable"](**tool_function_params) except Exception as e: tool_output = str(e) - if tools[tool_name]["citation"]: + if tools[tool_function_name]["citation"]: citations.append( { - "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"}, + "source": {"name": f"TOOL:{tools[tool_function_name]["toolkit_id"]}/{tool_function_name}"}, "document": [tool_output], - "metadata": [{"source": tool_name}], + "metadata": [{"source": tool_function_name}], } ) - if tools[tool_name]["file_handler"]: + if tools[tool_function_name]["file_handler"]: skip_files = True if isinstance(tool_output, str): @@ -576,6 +557,25 @@ def is_chat_completion_request(request): ) +async def get_body_and_model_and_user(request): + # Read the original request body + body = await request.body() + body_str = body.decode("utf-8") + body = json.loads(body_str) if body_str else {} + + model_id = body["model"] + if model_id not in app.state.MODELS: + raise Exception("Model not found") + model = app.state.MODELS[model_id] + + user = get_current_user( + request, + get_http_authorization_cred(request.headers.get("Authorization")), + ) + + return body, model, user + + class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if not is_chat_completion_request(request):