mirror of
https://github.com/open-webui/open-webui
synced 2025-05-17 20:05:08 +00:00
fix: rag
This commit is contained in:
parent
bcc27e3852
commit
514c7f1520
@ -170,7 +170,9 @@ app.state.MODELS = {}
|
|||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
|
||||||
|
|
||||||
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
|
async def get_function_call_response(
|
||||||
|
messages, files, tool_id, template, task_model_id, user
|
||||||
|
):
|
||||||
tool = Tools.get_tool_by_id(tool_id)
|
tool = Tools.get_tool_by_id(tool_id)
|
||||||
tools_specs = json.dumps(tool.specs, indent=2)
|
tools_specs = json.dumps(tool.specs, indent=2)
|
||||||
content = tools_function_calling_generation_template(template, tools_specs)
|
content = tools_function_calling_generation_template(template, tools_specs)
|
||||||
@ -265,6 +267,13 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
|||||||
"__messages__": messages,
|
"__messages__": messages,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if "__files__" in sig.parameters:
|
||||||
|
# Call the function with the '__files__' parameter included
|
||||||
|
params = {
|
||||||
|
**params,
|
||||||
|
"__files__": files,
|
||||||
|
}
|
||||||
|
|
||||||
function_result = function(**params)
|
function_result = function(**params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
@ -338,6 +347,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
try:
|
try:
|
||||||
response = await get_function_call_response(
|
response = await get_function_call_response(
|
||||||
messages=data["messages"],
|
messages=data["messages"],
|
||||||
|
files=data.get("files", []),
|
||||||
tool_id=tool_id,
|
tool_id=tool_id,
|
||||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
task_model_id=task_model_id,
|
task_model_id=task_model_id,
|
||||||
@ -353,7 +363,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
print(f"tool_context: {context}")
|
print(f"tool_context: {context}")
|
||||||
|
|
||||||
# If docs field is present, generate RAG completions
|
# TODO: Check if tools & functions have files support to skip this step to delegate file processing
|
||||||
|
# If files field is present, generate RAG completions
|
||||||
if "files" in data:
|
if "files" in data:
|
||||||
data = {**data}
|
data = {**data}
|
||||||
rag_context, citations = get_rag_context(
|
rag_context, citations = get_rag_context(
|
||||||
@ -376,15 +387,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
system_prompt = rag_template(
|
system_prompt = rag_template(
|
||||||
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
print(system_prompt)
|
print(system_prompt)
|
||||||
|
|
||||||
data["messages"] = add_or_update_system_message(
|
data["messages"] = add_or_update_system_message(
|
||||||
f"\n{system_prompt}", data["messages"]
|
f"\n{system_prompt}", data["messages"]
|
||||||
)
|
)
|
||||||
|
|
||||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||||
|
|
||||||
# Replace the request body with the modified one
|
# Replace the request body with the modified one
|
||||||
request._body = modified_body_bytes
|
request._body = modified_body_bytes
|
||||||
# Set custom header to ensure content-length matches new body length
|
# Set custom header to ensure content-length matches new body length
|
||||||
@ -961,7 +969,12 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
context = await get_function_call_response(
|
context = await get_function_call_response(
|
||||||
form_data["messages"], form_data["tool_id"], template, model_id, user
|
form_data["messages"],
|
||||||
|
form_data.get("files", []),
|
||||||
|
form_data["tool_id"],
|
||||||
|
template,
|
||||||
|
model_id,
|
||||||
|
user,
|
||||||
)
|
)
|
||||||
return context
|
return context
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -587,22 +587,17 @@
|
|||||||
});
|
});
|
||||||
|
|
||||||
let files = [];
|
let files = [];
|
||||||
|
|
||||||
if (model?.info?.meta?.knowledge ?? false) {
|
if (model?.info?.meta?.knowledge ?? false) {
|
||||||
files = model.info.meta.knowledge;
|
files = model.info.meta.knowledge;
|
||||||
}
|
}
|
||||||
|
const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
|
||||||
files = [
|
files = [
|
||||||
...files,
|
...files,
|
||||||
...messages
|
...(lastUserMessage?.files?.filter((item) =>
|
||||||
.filter((message) => message?.files ?? null)
|
|
||||||
.map((message) =>
|
|
||||||
message.files.filter((item) =>
|
|
||||||
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
|
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
|
||||||
)
|
) ?? [])
|
||||||
)
|
|
||||||
.flat(1)
|
|
||||||
].filter(
|
].filter(
|
||||||
|
// Remove duplicates
|
||||||
(item, index, array) =>
|
(item, index, array) =>
|
||||||
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
|
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
|
||||||
);
|
);
|
||||||
@ -832,22 +827,17 @@
|
|||||||
const responseMessage = history.messages[responseMessageId];
|
const responseMessage = history.messages[responseMessageId];
|
||||||
|
|
||||||
let files = [];
|
let files = [];
|
||||||
|
|
||||||
if (model?.info?.meta?.knowledge ?? false) {
|
if (model?.info?.meta?.knowledge ?? false) {
|
||||||
files = model.info.meta.knowledge;
|
files = model.info.meta.knowledge;
|
||||||
}
|
}
|
||||||
|
const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
|
||||||
files = [
|
files = [
|
||||||
...files,
|
...files,
|
||||||
...messages
|
...(lastUserMessage?.files?.filter((item) =>
|
||||||
.filter((message) => message?.files ?? null)
|
|
||||||
.map((message) =>
|
|
||||||
message.files.filter((item) =>
|
|
||||||
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
|
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
|
||||||
)
|
) ?? [])
|
||||||
)
|
|
||||||
.flat(1)
|
|
||||||
].filter(
|
].filter(
|
||||||
|
// Remove duplicates
|
||||||
(item, index, array) =>
|
(item, index, array) =>
|
||||||
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
|
array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
|
||||||
);
|
);
|
||||||
|
@ -153,6 +153,7 @@
|
|||||||
|
|
||||||
if (res) {
|
if (res) {
|
||||||
fileItem.status = 'processed';
|
fileItem.status = 'processed';
|
||||||
|
fileItem.collection_name = res.collection_name;
|
||||||
files = files;
|
files = files;
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
|
Loading…
Reference in New Issue
Block a user