feat: pipelines filter outlet

This commit is contained in:
Timothy J. Baek 2024-05-30 02:04:29 -07:00
parent d9ceb31674
commit ef8d84296e
3 changed files with 156 additions and 4 deletions

View File

@ -141,7 +141,8 @@ class RAGMiddleware(BaseHTTPMiddleware):
return_citations = False
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
"/ollama/api/chat" in request.url.path
or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
@ -229,7 +230,8 @@ app.add_middleware(RAGMiddleware)
class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
"/ollama/api/chat" in request.url.path
or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
@ -308,6 +310,9 @@ class PipelineMiddleware(BaseHTTPMiddleware):
else:
pass
if "chat_id" in data:
del data["chat_id"]
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
@ -464,6 +469,69 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models}
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
return data
@app.get("/api/pipelines/list")
async def get_pipelines_list(user=Depends(get_admin_user)):
responses = await get_openai_models(raw=True)

View File

@ -49,6 +49,45 @@ export const getModels = async (token: string = '') => {
return models;
};
type ChatCompletedForm = {
model: string;
messages: string[];
chat_id: string;
};
export const chatCompleted = async (token: string, body: ChatCompletedForm) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/chat/completed`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify(body)
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
return null;
});
if (error) {
throw error;
}
return res;
};
export const getPipelinesList = async (token: string = '') => {
let error = null;

View File

@ -48,6 +48,7 @@
import { runWebSearch } from '$lib/apis/rag';
import Banner from '../common/Banner.svelte';
import { getUserSettings } from '$lib/apis/users';
import { chatCompleted } from '$lib/apis';
const i18n: Writable<i18nType> = getContext('i18n');
@ -576,7 +577,8 @@
format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined,
docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0
citations: docs.length > 0,
chat_id: $chatId
});
if (res && res.ok) {
@ -596,6 +598,27 @@
if (stopResponseFlag) {
controller.abort('User: Stop Response');
await cancelOllamaRequest(localStorage.token, currentRequestId);
} else {
const res = await chatCompleted(localStorage.token, {
model: model,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
timestamp: m.timestamp
})),
chat_id: $chatId
}).catch((error) => {
console.error(error);
return null;
});
if (res !== null) {
// Update chat history with the new messages
for (const message of res.messages) {
history.messages[message.id] = { ...history.messages[message.id], ...message };
}
}
}
currentRequestId = null;
@ -829,7 +852,8 @@
frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
max_tokens: $settings?.params?.max_tokens ?? undefined,
docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0
citations: docs.length > 0,
chat_id: $chatId
},
`${OPENAI_API_BASE_URL}`
);
@ -855,6 +879,27 @@
if (stopResponseFlag) {
controller.abort('User: Stop Response');
} else {
const res = await chatCompleted(localStorage.token, {
model: model,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
timestamp: m.timestamp
})),
chat_id: $chatId
}).catch((error) => {
console.error(error);
return null;
});
if (res !== null) {
// Update chat history with the new messages
for (const message of res.messages) {
history.messages[message.id] = { ...history.messages[message.id], ...message };
}
}
}
break;