Merge pull request #3103 from open-webui/dev

fix: filter pipeline
This commit is contained in:
Timothy Jaeryang Baek 2024-06-12 13:35:43 -07:00 committed by GitHub
commit c0a06f7db4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -196,7 +196,11 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
"stream": False,
}
payload = filter_pipeline(payload, user)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
model = app.state.MODELS[task_model_id]
response = None
@ -326,16 +330,19 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
response = await get_function_call_response(
messages=data["messages"],
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
try:
response = await get_function_call_response(
messages=data["messages"],
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
if response:
context += ("\n" if context != "" else "") + response
if response:
context += ("\n" if context != "" else "") + response
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
@ -472,13 +479,10 @@ def filter_pipeline(payload, user):
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
if "detail" in res:
raise Exception(r.status_code, res["detail"])
else:
pass
@ -489,6 +493,7 @@ def filter_pipeline(payload, user):
if "title" in payload:
del payload["title"]
return payload
@ -510,7 +515,14 @@ class PipelineMiddleware(BaseHTTPMiddleware):
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
data = filter_pipeline(data, user)
try:
data = filter_pipeline(data, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
@ -762,7 +774,14 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
}
print(payload)
payload = filter_pipeline(payload, user)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
@ -819,7 +838,14 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
}
print(payload)
payload = filter_pipeline(payload, user)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
@ -856,9 +882,16 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
return await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user
)
try:
context = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user
)
return context
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
@app.post("/api/chat/completions")