feat: pipeline valves user field

This commit is contained in:
Timothy J. Baek 2024-05-27 19:16:07 -07:00
parent cc6d9bb8c0
commit 4fac99c5b3

View File

@ -32,7 +32,12 @@ from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models, ModelModel
from utils.utils import get_admin_user, get_verified_user from utils.utils import (
get_admin_user,
get_verified_user,
get_current_user,
get_http_authorization_cred,
)
from apps.rag.utils import rag_messages from apps.rag.utils import rag_messages
from config import ( from config import (
@ -244,7 +249,6 @@ class PipelineMiddleware(BaseHTTPMiddleware):
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
model_id = data["model"] model_id = data["model"]
valves = [ valves = [
model model
for model in app.state.MODELS.values() for model in app.state.MODELS.values()
@ -258,7 +262,20 @@ class PipelineMiddleware(BaseHTTPMiddleware):
] ]
sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"]) sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"])
user = None
if len(sorted_valves) > 0:
try:
user = get_current_user(
get_http_authorization_cred(
request.headers.get("Authorization")
)
)
user = {"id": user.id, "name": user.name, "role": user.role}
except:
pass
for valve in sorted_valves: for valve in sorted_valves:
try: try:
urlIdx = valve["urlIdx"] urlIdx = valve["urlIdx"]
@ -271,6 +288,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
f"{url}/valve", f"{url}/valve",
headers=headers, headers=headers,
json={ json={
"user": user,
"model": valve["id"], "model": valve["id"],
"body": data, "body": data,
}, },