mirror of
https://github.com/open-webui/open-webui
synced 2025-04-04 04:51:27 +00:00
commit
c0a06f7db4
@ -196,7 +196,11 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = filter_pipeline(payload, user)
|
try:
|
||||||
|
payload = filter_pipeline(payload, user)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
model = app.state.MODELS[task_model_id]
|
model = app.state.MODELS[task_model_id]
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
@ -326,16 +330,19 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
print(data["tool_ids"])
|
print(data["tool_ids"])
|
||||||
for tool_id in data["tool_ids"]:
|
for tool_id in data["tool_ids"]:
|
||||||
print(tool_id)
|
print(tool_id)
|
||||||
response = await get_function_call_response(
|
try:
|
||||||
messages=data["messages"],
|
response = await get_function_call_response(
|
||||||
tool_id=tool_id,
|
messages=data["messages"],
|
||||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
tool_id=tool_id,
|
||||||
task_model_id=task_model_id,
|
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
user=user,
|
task_model_id=task_model_id,
|
||||||
)
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
context += ("\n" if context != "" else "") + response
|
context += ("\n" if context != "" else "") + response
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
del data["tool_ids"]
|
del data["tool_ids"]
|
||||||
|
|
||||||
print(f"tool_context: {context}")
|
print(f"tool_context: {context}")
|
||||||
@ -472,13 +479,10 @@ def filter_pipeline(payload, user):
|
|||||||
if r is not None:
|
if r is not None:
|
||||||
try:
|
try:
|
||||||
res = r.json()
|
res = r.json()
|
||||||
if "detail" in res:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=r.status_code,
|
|
||||||
content=res,
|
|
||||||
)
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
if "detail" in res:
|
||||||
|
raise Exception(r.status_code, res["detail"])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@ -489,6 +493,7 @@ def filter_pipeline(payload, user):
|
|||||||
|
|
||||||
if "title" in payload:
|
if "title" in payload:
|
||||||
del payload["title"]
|
del payload["title"]
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
@ -510,7 +515,14 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|||||||
user = get_current_user(
|
user = get_current_user(
|
||||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||||
)
|
)
|
||||||
data = filter_pipeline(data, user)
|
|
||||||
|
try:
|
||||||
|
data = filter_pipeline(data, user)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=e.args[0],
|
||||||
|
content={"detail": e.args[1]},
|
||||||
|
)
|
||||||
|
|
||||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||||
# Replace the request body with the modified one
|
# Replace the request body with the modified one
|
||||||
@ -762,7 +774,14 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
}
|
}
|
||||||
|
|
||||||
print(payload)
|
print(payload)
|
||||||
payload = filter_pipeline(payload, user)
|
|
||||||
|
try:
|
||||||
|
payload = filter_pipeline(payload, user)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=e.args[0],
|
||||||
|
content={"detail": e.args[1]},
|
||||||
|
)
|
||||||
|
|
||||||
if model["owned_by"] == "ollama":
|
if model["owned_by"] == "ollama":
|
||||||
return await generate_ollama_chat_completion(
|
return await generate_ollama_chat_completion(
|
||||||
@ -819,7 +838,14 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|||||||
}
|
}
|
||||||
|
|
||||||
print(payload)
|
print(payload)
|
||||||
payload = filter_pipeline(payload, user)
|
|
||||||
|
try:
|
||||||
|
payload = filter_pipeline(payload, user)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=e.args[0],
|
||||||
|
content={"detail": e.args[1]},
|
||||||
|
)
|
||||||
|
|
||||||
if model["owned_by"] == "ollama":
|
if model["owned_by"] == "ollama":
|
||||||
return await generate_ollama_chat_completion(
|
return await generate_ollama_chat_completion(
|
||||||
@ -856,9 +882,16 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|||||||
print(model_id)
|
print(model_id)
|
||||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
|
|
||||||
return await get_function_call_response(
|
try:
|
||||||
form_data["messages"], form_data["tool_id"], template, model_id, user
|
context = await get_function_call_response(
|
||||||
)
|
form_data["messages"], form_data["tool_id"], template, model_id, user
|
||||||
|
)
|
||||||
|
return context
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=e.args[0],
|
||||||
|
content={"detail": e.args[1]},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/chat/completions")
|
@app.post("/api/chat/completions")
|
||||||
|
Loading…
Reference in New Issue
Block a user