diff --git a/backend/main.py b/backend/main.py index e5b7d174a..873116dd7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -378,9 +378,8 @@ async def chat_completion_inlets_handler(body, model, extra_params): print(f"Error: {e}") raise e - if skip_files: - if "files" in body: - del body["files"] + if skip_files and "files" in body: + del body["files"] return body, {} @@ -431,12 +430,17 @@ def get_configured_tools( ) for spec in toolkit.specs: + # TODO: Fix hack for OpenAI API + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" name = spec["name"] callable = getattr(module, name) # convert to function that takes only model params and inserts custom params custom_callable = get_tool_with_custom_params(callable, extra_params) + # TODO: This needs to be a pydantic model tool_dict = { "spec": spec, "citation": has_citation, @@ -444,6 +448,7 @@ def get_configured_tools( "toolkit_id": tool_id, "callable": custom_callable, } + # TODO: if collision, prepend toolkit name if name in tools: log.warning(f"Tool {name} already exists in another toolkit!") log.warning(f"Collision between {toolkit} and {tool_id}.") @@ -533,9 +538,9 @@ async def chat_completion_tools_handler( return body, {"contexts": contexts, "citations": citations} -async def chat_completion_files_handler(body): +async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: contexts = [] - citations = None + citations = [] if files := body.pop("files", None): contexts, citations = get_rag_context( @@ -550,10 +555,7 @@ async def chat_completion_files_handler(body): log.debug(f"rag_contexts: {contexts}, citations: {citations}") - return body, { - **({"contexts": contexts} if contexts is not None else {}), - **({"citations": citations} if citations is not None else {}), - } + return body, {"contexts": contexts, "citations": citations} def is_chat_completion_request(request): @@ -618,16 +620,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: - print(e) - pass + log.exception(e) try: body, flags = await chat_completion_files_handler(body) contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: - print(e) - pass + log.exception(e) # If context is not empty, insert it into the messages if len(contexts) > 0: