refac: middleware

This commit is contained in:
Timothy Jaeryang Baek 2025-02-04 21:01:53 -08:00
parent 2e61ea7cc9
commit e7da506add
3 changed files with 58 additions and 39 deletions

View File

@ -889,9 +889,10 @@ async def chat_completion(
} }
form_data["metadata"] = metadata form_data["metadata"] = metadata
form_data, events = await process_chat_payload( form_data, metadata, events = await process_chat_payload(
request, form_data, metadata, user, model request, form_data, metadata, user, model
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -900,6 +901,7 @@ async def chat_completion(
try: try:
response = await chat_completion_handler(request, form_data, user) response = await chat_completion_handler(request, form_data, user)
return await process_chat_response( return await process_chat_response(
request, response, form_data, user, events, metadata, tasks request, response, form_data, user, events, metadata, tasks
) )

View File

@ -183,7 +183,7 @@ async def chat_completion_filter_functions_handler(request, body, model, extra_p
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, extra_params: dict request: Request, body: dict, user: UserModel, models, tools
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]: async def get_content_from_response(response) -> Optional[str]:
content = None content = None
@ -218,35 +218,15 @@ async def chat_completion_tools_handler(
"metadata": {"task": str(TASKS.FUNCTION_CALLING)}, "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
} }
# If tool_ids field is present, call the functions
metadata = body.get("metadata", {})
tool_ids = metadata.get("tool_ids", None)
log.debug(f"{tool_ids=}")
if not tool_ids:
return body, {}
skip_files = False
sources = []
task_model_id = get_task_model_id( task_model_id = get_task_model_id(
body["model"], body["model"],
request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL, request.app.state.config.TASK_MODEL_EXTERNAL,
models, models,
) )
tools = get_tools(
request, skip_files = False
tool_ids, sources = []
user,
{
**extra_params,
"__model__": models[task_model_id],
"__messages__": body["messages"],
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
specs = [tool["spec"] for tool in tools.values()] specs = [tool["spec"] for tool in tools.values()]
tools_specs = json.dumps(specs) tools_specs = json.dumps(specs)
@ -281,6 +261,8 @@ async def chat_completion_tools_handler(
result = json.loads(content) result = json.loads(content)
async def tool_call_handler(tool_call): async def tool_call_handler(tool_call):
nonlocal skip_files
log.debug(f"{tool_call=}") log.debug(f"{tool_call=}")
tool_function_name = tool_call.get("name", None) tool_function_name = tool_call.get("name", None)
@ -725,6 +707,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
# Initialize events to store additional event to be sent to the client # Initialize events to store additional event to be sent to the client
# Initialize contexts and citation # Initialize contexts and citation
models = request.app.state.MODELS models = request.app.state.MODELS
task_model_id = get_task_model_id(
form_data["model"],
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
events = [] events = []
sources = [] sources = []
@ -809,15 +797,41 @@ async def process_chat_payload(request, form_data, metadata, user, model):
} }
form_data["metadata"] = metadata form_data["metadata"] = metadata
if not form_data["metadata"].get("function_calling") == "native": tool_ids = metadata.get("tool_ids", None)
# If the function calling is not native, then call the tools function calling handler log.debug(f"{tool_ids=}")
try:
form_data, flags = await chat_completion_tools_handler( if tool_ids:
request, form_data, user, models, extra_params # If tool_ids field is present, then get the tools
) tools = get_tools(
sources.extend(flags.get("sources", [])) request,
except Exception as e: tool_ids,
log.exception(e) user,
{
**extra_params,
"__model__": models[task_model_id],
"__messages__": form_data["messages"],
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
if metadata.get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, user, models, tools
)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
try: try:
form_data, flags = await chat_completion_files_handler(request, form_data, user) form_data, flags = await chat_completion_files_handler(request, form_data, user)
@ -833,11 +847,11 @@ async def process_chat_payload(request, form_data, metadata, user, model):
if "document" in source: if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]): for doc_idx, doc_context in enumerate(source["document"]):
metadata = source.get("metadata") doc_metadata = source.get("metadata")
doc_source_id = None doc_source_id = None
if metadata: if doc_metadata:
doc_source_id = metadata[doc_idx].get("source", source_id) doc_source_id = doc_metadata[doc_idx].get("source", source_id)
if source_id: if source_id:
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n" context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
@ -894,12 +908,15 @@ async def process_chat_payload(request, form_data, metadata, user, model):
} }
) )
return form_data, events return form_data, metadata, events
async def process_chat_response( async def process_chat_response(
request, response, form_data, user, events, metadata, tasks request, response, form_data, user, events, metadata, tasks
): ):
print("metadata", metadata)
async def background_tasks_handler(): async def background_tasks_handler():
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
message = message_map.get(metadata["message_id"]) if message_map else None message = message_map.get(metadata["message_id"]) if message_map else None

View File

@ -322,7 +322,7 @@ export const generateOpenAIChatCompletion = async (
return res.json(); return res.json();
}) })
.catch((err) => { .catch((err) => {
error = `${err?.detail ?? 'Network Problem'}`; error = `${err?.detail ?? err}`;
return null; return null;
}); });