mirror of
https://github.com/open-webui/open-webui
synced 2024-11-24 21:13:59 +00:00
feat: pipelines filter outlet
This commit is contained in:
parent
d9ceb31674
commit
ef8d84296e
@ -141,7 +141,8 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
return_citations = False
|
||||
|
||||
if request.method == "POST" and (
|
||||
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
||||
"/ollama/api/chat" in request.url.path
|
||||
or "/chat/completions" in request.url.path
|
||||
):
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
@ -229,7 +230,8 @@ app.add_middleware(RAGMiddleware)
|
||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.method == "POST" and (
|
||||
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
||||
"/ollama/api/chat" in request.url.path
|
||||
or "/chat/completions" in request.url.path
|
||||
):
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
@ -308,6 +310,9 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
else:
|
||||
pass
|
||||
|
||||
if "chat_id" in data:
|
||||
del data["chat_id"]
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
@ -464,6 +469,69 @@ async def get_models(user=Depends(get_verified_user)):
|
||||
return {"data": models}
|
||||
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
)
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": {"id": user.id, "name": user.name, "role": user.role},
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
return JSONResponse(
|
||||
status_code=r.status_code,
|
||||
content=res,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@app.get("/api/pipelines/list")
|
||||
async def get_pipelines_list(user=Depends(get_admin_user)):
|
||||
responses = await get_openai_models(raw=True)
|
||||
|
@ -49,6 +49,45 @@ export const getModels = async (token: string = '') => {
|
||||
return models;
|
||||
};
|
||||
|
||||
type ChatCompletedForm = {
|
||||
model: string;
|
||||
messages: string[];
|
||||
chat_id: string;
|
||||
};
|
||||
|
||||
export const chatCompleted = async (token: string, body: ChatCompletedForm) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_BASE_URL}/api/chat/completed`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
...(token && { authorization: `Bearer ${token}` })
|
||||
},
|
||||
body: JSON.stringify(body)
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
if ('detail' in err) {
|
||||
error = err.detail;
|
||||
} else {
|
||||
error = err;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getPipelinesList = async (token: string = '') => {
|
||||
let error = null;
|
||||
|
||||
|
@ -48,6 +48,7 @@
|
||||
import { runWebSearch } from '$lib/apis/rag';
|
||||
import Banner from '../common/Banner.svelte';
|
||||
import { getUserSettings } from '$lib/apis/users';
|
||||
import { chatCompleted } from '$lib/apis';
|
||||
|
||||
const i18n: Writable<i18nType> = getContext('i18n');
|
||||
|
||||
@ -576,7 +577,8 @@
|
||||
format: $settings.requestFormat ?? undefined,
|
||||
keep_alive: $settings.keepAlive ?? undefined,
|
||||
docs: docs.length > 0 ? docs : undefined,
|
||||
citations: docs.length > 0
|
||||
citations: docs.length > 0,
|
||||
chat_id: $chatId
|
||||
});
|
||||
|
||||
if (res && res.ok) {
|
||||
@ -596,6 +598,27 @@
|
||||
if (stopResponseFlag) {
|
||||
controller.abort('User: Stop Response');
|
||||
await cancelOllamaRequest(localStorage.token, currentRequestId);
|
||||
} else {
|
||||
const res = await chatCompleted(localStorage.token, {
|
||||
model: model,
|
||||
messages: messages.map((m) => ({
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
timestamp: m.timestamp
|
||||
})),
|
||||
chat_id: $chatId
|
||||
}).catch((error) => {
|
||||
console.error(error);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (res !== null) {
|
||||
// Update chat history with the new messages
|
||||
for (const message of res.messages) {
|
||||
history.messages[message.id] = { ...history.messages[message.id], ...message };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentRequestId = null;
|
||||
@ -829,7 +852,8 @@
|
||||
frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
|
||||
max_tokens: $settings?.params?.max_tokens ?? undefined,
|
||||
docs: docs.length > 0 ? docs : undefined,
|
||||
citations: docs.length > 0
|
||||
citations: docs.length > 0,
|
||||
chat_id: $chatId
|
||||
},
|
||||
`${OPENAI_API_BASE_URL}`
|
||||
);
|
||||
@ -855,6 +879,27 @@
|
||||
|
||||
if (stopResponseFlag) {
|
||||
controller.abort('User: Stop Response');
|
||||
} else {
|
||||
const res = await chatCompleted(localStorage.token, {
|
||||
model: model,
|
||||
messages: messages.map((m) => ({
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
timestamp: m.timestamp
|
||||
})),
|
||||
chat_id: $chatId
|
||||
}).catch((error) => {
|
||||
console.error(error);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (res !== null) {
|
||||
// Update chat history with the new messages
|
||||
for (const message of res.messages) {
|
||||
history.messages[message.id] = { ...history.messages[message.id], ...message };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
|
Loading…
Reference in New Issue
Block a user