mirror of
https://github.com/open-webui/open-webui
synced 2024-11-07 00:59:52 +00:00
refac: chat completion middleware
This commit is contained in:
parent
b62d2a9b28
commit
c7a9b5ccfa
@ -294,14 +294,16 @@ def get_rag_context(
|
||||
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
context_string = ""
|
||||
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
for context in relevant_contexts:
|
||||
try:
|
||||
if "documents" in context:
|
||||
context_string += "\n\n".join(
|
||||
[text for text in context["documents"][0] if text is not None]
|
||||
contexts.append(
|
||||
"\n\n".join(
|
||||
[text for text in context["documents"][0] if text is not None]
|
||||
)
|
||||
)
|
||||
|
||||
if "metadatas" in context:
|
||||
@ -315,9 +317,7 @@ def get_rag_context(
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
context_string = context_string.strip()
|
||||
|
||||
return context_string, citations
|
||||
return contexts, citations
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
|
501
backend/main.py
501
backend/main.py
@ -213,7 +213,7 @@ origins = ["*"]
|
||||
|
||||
|
||||
async def get_function_call_response(
|
||||
messages, files, tool_id, template, task_model_id, user
|
||||
messages, files, tool_id, template, task_model_id, user, model
|
||||
):
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
tools_specs = json.dumps(tool.specs, indent=2)
|
||||
@ -373,233 +373,308 @@ async def get_function_call_response(
|
||||
return None, None, False
|
||||
|
||||
|
||||
def get_task_model_id(default_model_id):
|
||||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
# Check if the user has a custom task model and use that model
|
||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
||||
if (
|
||||
app.state.config.TASK_MODEL
|
||||
and app.state.config.TASK_MODEL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
else:
|
||||
if (
|
||||
app.state.config.TASK_MODEL_EXTERNAL
|
||||
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
|
||||
return task_model_id
|
||||
|
||||
|
||||
def get_filter_function_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
filter_ids.sort(key=get_priority)
|
||||
return filter_ids
|
||||
|
||||
|
||||
async def chat_completion_functions_handler(body, model, user):
|
||||
skip_files = None
|
||||
|
||||
filter_ids = get_filter_function_ids(model)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = (
|
||||
load_function_module_by_id(filter_id)
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(
|
||||
function_module, "Valves"
|
||||
):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "inlet"):
|
||||
inlet = function_module.inlet
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(inlet)
|
||||
params = {"body": body}
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if "__id__" in sig.parameters:
|
||||
params = {
|
||||
**params,
|
||||
"__id__": filter_id,
|
||||
}
|
||||
|
||||
if "__model__" in sig.parameters:
|
||||
params = {
|
||||
**params,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
body = await inlet(**params)
|
||||
else:
|
||||
body = inlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
raise e
|
||||
|
||||
if skip_files:
|
||||
if "files" in body:
|
||||
del body["files"]
|
||||
|
||||
return body, {}
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(body, model, user):
|
||||
skip_files = None
|
||||
|
||||
contexts = []
|
||||
citations = None
|
||||
|
||||
task_model_id = get_task_model_id(body["model"])
|
||||
|
||||
# If tool_ids field is present, call the functions
|
||||
if "tool_ids" in body:
|
||||
print(body["tool_ids"])
|
||||
for tool_id in body["tool_ids"]:
|
||||
print(tool_id)
|
||||
try:
|
||||
response, citation, file_handler = await get_function_call_response(
|
||||
messages=body["messages"],
|
||||
files=body.get("files", []),
|
||||
tool_id=tool_id,
|
||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
task_model_id=task_model_id,
|
||||
user=user,
|
||||
model=model,
|
||||
)
|
||||
|
||||
print(file_handler)
|
||||
if isinstance(response, str):
|
||||
contexts.append(response)
|
||||
|
||||
if citation:
|
||||
if citations is None:
|
||||
citations = [citation]
|
||||
else:
|
||||
citations.append(citation)
|
||||
|
||||
if file_handler:
|
||||
skip_files = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
del body["tool_ids"]
|
||||
print(f"tool_contexts: {contexts}")
|
||||
|
||||
if skip_files:
|
||||
if "files" in body:
|
||||
del body["files"]
|
||||
|
||||
return body, {
|
||||
**({"contexts": contexts} if contexts is not None else {}),
|
||||
**({"citations": citations} if citations is not None else {}),
|
||||
}
|
||||
|
||||
|
||||
async def chat_completion_files_handler(body):
|
||||
contexts = []
|
||||
citations = None
|
||||
|
||||
if "files" in body:
|
||||
files = body["files"]
|
||||
del body["files"]
|
||||
|
||||
contexts, citations = get_rag_context(
|
||||
files=files,
|
||||
messages=body["messages"],
|
||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||
k=rag_app.state.config.TOP_K,
|
||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
|
||||
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
||||
|
||||
return body, {
|
||||
**({"contexts": contexts} if contexts is not None else {}),
|
||||
**({"citations": citations} if citations is not None else {}),
|
||||
}
|
||||
|
||||
|
||||
async def get_body_and_model_and_user(request):
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
body_str = body.decode("utf-8")
|
||||
body = json.loads(body_str) if body_str else {}
|
||||
|
||||
model_id = body["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise "Model not found"
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
)
|
||||
|
||||
return body, model, user
|
||||
|
||||
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
data_items = []
|
||||
|
||||
show_citations = False
|
||||
citations = []
|
||||
|
||||
if request.method == "POST" and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
):
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
body_str = body.decode("utf-8")
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
)
|
||||
# Flag to skip RAG completions if file_handler is present in tools/functions
|
||||
skip_files = False
|
||||
if data.get("citations"):
|
||||
show_citations = True
|
||||
del data["citations"]
|
||||
|
||||
model_id = data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
try:
|
||||
body, model, user = await get_body_and_model_and_user(request)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
return (function.valves if function.valves else {}).get(
|
||||
"priority", 0
|
||||
)
|
||||
return 0
|
||||
# Extract chat_id and message_id from the request body
|
||||
chat_id = None
|
||||
if "chat_id" in body:
|
||||
chat_id = body["chat_id"]
|
||||
del body["chat_id"]
|
||||
message_id = None
|
||||
if "id" in body:
|
||||
message_id = body["id"]
|
||||
del body["id"]
|
||||
|
||||
filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
# Initialize data_items to store additional data to be sent to the client
|
||||
data_items = []
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type(
|
||||
"filter", active_only=True
|
||||
# Initialize context, and citations
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
print(body)
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_functions_handler(body, model, user)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
filter_ids.sort(key=get_priority)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = (
|
||||
load_function_module_by_id(filter_id)
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
try:
|
||||
body, flags = await chat_completion_tools_handler(body, model, user)
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(
|
||||
function_module, "Valves"
|
||||
):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
try:
|
||||
body, flags = await chat_completion_files_handler(body)
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "inlet"):
|
||||
inlet = function_module.inlet
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(inlet)
|
||||
params = {"body": data}
|
||||
# If context is not empty, insert it into the messages
|
||||
if len(contexts) > 0:
|
||||
context_string = "/n".join(contexts).strip()
|
||||
prompt = get_last_user_message(body["messages"])
|
||||
body["messages"] = add_or_update_system_message(
|
||||
rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
),
|
||||
body["messages"],
|
||||
)
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if "__id__" in sig.parameters:
|
||||
params = {
|
||||
**params,
|
||||
"__id__": filter_id,
|
||||
}
|
||||
|
||||
if "__model__" in sig.parameters:
|
||||
params = {
|
||||
**params,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
data = await inlet(**params)
|
||||
else:
|
||||
data = inlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
# Set the task model
|
||||
task_model_id = data["model"]
|
||||
# Check if the user has a custom task model and use that model
|
||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
||||
if (
|
||||
app.state.config.TASK_MODEL
|
||||
and app.state.config.TASK_MODEL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
else:
|
||||
if (
|
||||
app.state.config.TASK_MODEL_EXTERNAL
|
||||
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
|
||||
prompt = get_last_user_message(data["messages"])
|
||||
context = ""
|
||||
|
||||
# If tool_ids field is present, call the functions
|
||||
if "tool_ids" in data:
|
||||
print(data["tool_ids"])
|
||||
for tool_id in data["tool_ids"]:
|
||||
print(tool_id)
|
||||
try:
|
||||
response, citation, file_handler = (
|
||||
await get_function_call_response(
|
||||
messages=data["messages"],
|
||||
files=data.get("files", []),
|
||||
tool_id=tool_id,
|
||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
task_model_id=task_model_id,
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
|
||||
print(file_handler)
|
||||
if isinstance(response, str):
|
||||
context += ("\n" if context != "" else "") + response
|
||||
|
||||
if citation:
|
||||
citations.append(citation)
|
||||
show_citations = True
|
||||
|
||||
if file_handler:
|
||||
skip_files = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
del data["tool_ids"]
|
||||
|
||||
print(f"tool_context: {context}")
|
||||
|
||||
# If files field is present, generate RAG completions
|
||||
# If skip_files is True, skip the RAG completions
|
||||
if "files" in data:
|
||||
if not skip_files:
|
||||
data = {**data}
|
||||
rag_context, rag_citations = get_rag_context(
|
||||
files=data["files"],
|
||||
messages=data["messages"],
|
||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||
k=rag_app.state.config.TOP_K,
|
||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
if rag_context:
|
||||
context += ("\n" if context != "" else "") + rag_context
|
||||
|
||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||
|
||||
if rag_citations:
|
||||
citations.extend(rag_citations)
|
||||
|
||||
del data["files"]
|
||||
|
||||
if show_citations and len(citations) > 0:
|
||||
# If there are citations, add them to the data_items
|
||||
if len(citations) > 0:
|
||||
data_items.append({"citations": citations})
|
||||
|
||||
if context != "":
|
||||
system_prompt = rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
||||
)
|
||||
print(system_prompt)
|
||||
data["messages"] = add_or_update_system_message(
|
||||
system_prompt, data["messages"]
|
||||
)
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
modified_body_bytes = json.dumps(body).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
@ -721,9 +796,6 @@ def filter_pipeline(payload, user):
|
||||
pass
|
||||
|
||||
if "pipeline" not in app.state.MODELS[model_id]:
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
if "title" in payload:
|
||||
del payload["title"]
|
||||
|
||||
@ -1225,6 +1297,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@ -1285,6 +1360,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@ -1349,6 +1427,9 @@ Message: """{{prompt}}"""
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
|
@ -665,6 +665,7 @@
|
||||
await tick();
|
||||
|
||||
const [res, controller] = await generateChatCompletion(localStorage.token, {
|
||||
stream: true,
|
||||
model: model.id,
|
||||
messages: messagesBody,
|
||||
options: {
|
||||
@ -682,8 +683,8 @@
|
||||
keep_alive: $settings.keepAlive ?? undefined,
|
||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||
files: files.length > 0 ? files : undefined,
|
||||
citations: files.length > 0 ? true : undefined,
|
||||
chat_id: $chatId
|
||||
chat_id: $chatId,
|
||||
id: responseMessageId
|
||||
});
|
||||
|
||||
if (res && res.ok) {
|
||||
@ -912,8 +913,8 @@
|
||||
const [res, controller] = await generateOpenAIChatCompletion(
|
||||
localStorage.token,
|
||||
{
|
||||
model: model.id,
|
||||
stream: true,
|
||||
model: model.id,
|
||||
stream_options:
|
||||
model.info?.meta?.capabilities?.usage ?? false
|
||||
? {
|
||||
@ -983,9 +984,8 @@
|
||||
max_tokens: $settings?.params?.max_tokens ?? undefined,
|
||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||
files: files.length > 0 ? files : undefined,
|
||||
citations: files.length > 0 ? true : undefined,
|
||||
|
||||
chat_id: $chatId
|
||||
chat_id: $chatId,
|
||||
id: responseMessageId
|
||||
},
|
||||
`${WEBUI_BASE_URL}/api`
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user