mirror of
https://github.com/open-webui/open-webui
synced 2025-05-24 14:54:33 +00:00
refac: middleware
This commit is contained in:
parent
2e61ea7cc9
commit
e7da506add
@ -889,9 +889,10 @@ async def chat_completion(
|
|||||||
}
|
}
|
||||||
form_data["metadata"] = metadata
|
form_data["metadata"] = metadata
|
||||||
|
|
||||||
form_data, events = await process_chat_payload(
|
form_data, metadata, events = await process_chat_payload(
|
||||||
request, form_data, metadata, user, model
|
request, form_data, metadata, user, model
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@ -900,6 +901,7 @@ async def chat_completion(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = await chat_completion_handler(request, form_data, user)
|
response = await chat_completion_handler(request, form_data, user)
|
||||||
|
|
||||||
return await process_chat_response(
|
return await process_chat_response(
|
||||||
request, response, form_data, user, events, metadata, tasks
|
request, response, form_data, user, events, metadata, tasks
|
||||||
)
|
)
|
||||||
|
@ -183,7 +183,7 @@ async def chat_completion_filter_functions_handler(request, body, model, extra_p
|
|||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(
|
async def chat_completion_tools_handler(
|
||||||
request: Request, body: dict, user: UserModel, models, extra_params: dict
|
request: Request, body: dict, user: UserModel, models, tools
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
async def get_content_from_response(response) -> Optional[str]:
|
async def get_content_from_response(response) -> Optional[str]:
|
||||||
content = None
|
content = None
|
||||||
@ -218,35 +218,15 @@ async def chat_completion_tools_handler(
|
|||||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||||
}
|
}
|
||||||
|
|
||||||
# If tool_ids field is present, call the functions
|
|
||||||
metadata = body.get("metadata", {})
|
|
||||||
|
|
||||||
tool_ids = metadata.get("tool_ids", None)
|
|
||||||
log.debug(f"{tool_ids=}")
|
|
||||||
if not tool_ids:
|
|
||||||
return body, {}
|
|
||||||
|
|
||||||
skip_files = False
|
|
||||||
sources = []
|
|
||||||
|
|
||||||
task_model_id = get_task_model_id(
|
task_model_id = get_task_model_id(
|
||||||
body["model"],
|
body["model"],
|
||||||
request.app.state.config.TASK_MODEL,
|
request.app.state.config.TASK_MODEL,
|
||||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||||
models,
|
models,
|
||||||
)
|
)
|
||||||
tools = get_tools(
|
|
||||||
request,
|
skip_files = False
|
||||||
tool_ids,
|
sources = []
|
||||||
user,
|
|
||||||
{
|
|
||||||
**extra_params,
|
|
||||||
"__model__": models[task_model_id],
|
|
||||||
"__messages__": body["messages"],
|
|
||||||
"__files__": metadata.get("files", []),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
log.info(f"{tools=}")
|
|
||||||
|
|
||||||
specs = [tool["spec"] for tool in tools.values()]
|
specs = [tool["spec"] for tool in tools.values()]
|
||||||
tools_specs = json.dumps(specs)
|
tools_specs = json.dumps(specs)
|
||||||
@ -281,6 +261,8 @@ async def chat_completion_tools_handler(
|
|||||||
result = json.loads(content)
|
result = json.loads(content)
|
||||||
|
|
||||||
async def tool_call_handler(tool_call):
|
async def tool_call_handler(tool_call):
|
||||||
|
nonlocal skip_files
|
||||||
|
|
||||||
log.debug(f"{tool_call=}")
|
log.debug(f"{tool_call=}")
|
||||||
|
|
||||||
tool_function_name = tool_call.get("name", None)
|
tool_function_name = tool_call.get("name", None)
|
||||||
@ -725,6 +707,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|||||||
# Initialize events to store additional event to be sent to the client
|
# Initialize events to store additional event to be sent to the client
|
||||||
# Initialize contexts and citation
|
# Initialize contexts and citation
|
||||||
models = request.app.state.MODELS
|
models = request.app.state.MODELS
|
||||||
|
task_model_id = get_task_model_id(
|
||||||
|
form_data["model"],
|
||||||
|
request.app.state.config.TASK_MODEL,
|
||||||
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||||
|
models,
|
||||||
|
)
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
sources = []
|
sources = []
|
||||||
@ -809,15 +797,41 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|||||||
}
|
}
|
||||||
form_data["metadata"] = metadata
|
form_data["metadata"] = metadata
|
||||||
|
|
||||||
if not form_data["metadata"].get("function_calling") == "native":
|
tool_ids = metadata.get("tool_ids", None)
|
||||||
# If the function calling is not native, then call the tools function calling handler
|
log.debug(f"{tool_ids=}")
|
||||||
try:
|
|
||||||
form_data, flags = await chat_completion_tools_handler(
|
if tool_ids:
|
||||||
request, form_data, user, models, extra_params
|
# If tool_ids field is present, then get the tools
|
||||||
)
|
tools = get_tools(
|
||||||
sources.extend(flags.get("sources", []))
|
request,
|
||||||
except Exception as e:
|
tool_ids,
|
||||||
log.exception(e)
|
user,
|
||||||
|
{
|
||||||
|
**extra_params,
|
||||||
|
"__model__": models[task_model_id],
|
||||||
|
"__messages__": form_data["messages"],
|
||||||
|
"__files__": metadata.get("files", []),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
log.info(f"{tools=}")
|
||||||
|
|
||||||
|
if metadata.get("function_calling") == "native":
|
||||||
|
# If the function calling is native, then call the tools function calling handler
|
||||||
|
metadata["tools"] = tools
|
||||||
|
form_data["tools"] = [
|
||||||
|
{"type": "function", "function": tool.get("spec", {})}
|
||||||
|
for tool in tools.values()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# If the function calling is not native, then call the tools function calling handler
|
||||||
|
try:
|
||||||
|
form_data, flags = await chat_completion_tools_handler(
|
||||||
|
request, form_data, user, models, tools
|
||||||
|
)
|
||||||
|
sources.extend(flags.get("sources", []))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
form_data, flags = await chat_completion_files_handler(request, form_data, user)
|
form_data, flags = await chat_completion_files_handler(request, form_data, user)
|
||||||
@ -833,11 +847,11 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|||||||
|
|
||||||
if "document" in source:
|
if "document" in source:
|
||||||
for doc_idx, doc_context in enumerate(source["document"]):
|
for doc_idx, doc_context in enumerate(source["document"]):
|
||||||
metadata = source.get("metadata")
|
doc_metadata = source.get("metadata")
|
||||||
doc_source_id = None
|
doc_source_id = None
|
||||||
|
|
||||||
if metadata:
|
if doc_metadata:
|
||||||
doc_source_id = metadata[doc_idx].get("source", source_id)
|
doc_source_id = doc_metadata[doc_idx].get("source", source_id)
|
||||||
|
|
||||||
if source_id:
|
if source_id:
|
||||||
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||||
@ -894,12 +908,15 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return form_data, events
|
return form_data, metadata, events
|
||||||
|
|
||||||
|
|
||||||
async def process_chat_response(
|
async def process_chat_response(
|
||||||
request, response, form_data, user, events, metadata, tasks
|
request, response, form_data, user, events, metadata, tasks
|
||||||
):
|
):
|
||||||
|
|
||||||
|
print("metadata", metadata)
|
||||||
|
|
||||||
async def background_tasks_handler():
|
async def background_tasks_handler():
|
||||||
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
|
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
|
||||||
message = message_map.get(metadata["message_id"]) if message_map else None
|
message = message_map.get(metadata["message_id"]) if message_map else None
|
||||||
|
@ -322,7 +322,7 @@ export const generateOpenAIChatCompletion = async (
|
|||||||
return res.json();
|
return res.json();
|
||||||
})
|
})
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
error = `${err?.detail ?? 'Network Problem'}`;
|
error = `${err?.detail ?? err}`;
|
||||||
return null;
|
return null;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user