mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
wip: user groups frontend
This commit is contained in:
@@ -13,9 +13,7 @@ import requests
|
||||
from open_webui.apps.webui.models.models import Models
|
||||
from open_webui.config import (
|
||||
CORS_ALLOW_ORIGIN,
|
||||
ENABLE_MODEL_FILTER,
|
||||
ENABLE_OLLAMA_API,
|
||||
MODEL_FILTER_LIST,
|
||||
OLLAMA_BASE_URLS,
|
||||
OLLAMA_API_CONFIGS,
|
||||
UPLOAD_DIR,
|
||||
@@ -66,9 +64,6 @@ app.add_middleware(
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
||||
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||
app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
|
||||
@@ -339,16 +334,18 @@ async def get_ollama_tags(
|
||||
if url_idx is None:
|
||||
models = await get_all_models()
|
||||
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["models"] = list(
|
||||
filter(
|
||||
lambda model: model["name"]
|
||||
in app.state.config.MODEL_FILTER_LIST,
|
||||
models["models"],
|
||||
)
|
||||
)
|
||||
return models
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user":
|
||||
# models["models"] = list(
|
||||
# filter(
|
||||
# lambda model: model["name"]
|
||||
# in app.state.config.MODEL_FILTER_LIST,
|
||||
# models["models"],
|
||||
# )
|
||||
# )
|
||||
# return models
|
||||
|
||||
return models
|
||||
else:
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
@@ -922,12 +919,14 @@ async def generate_chat_completion(
|
||||
|
||||
model_id = form_data.model
|
||||
|
||||
if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if not bypass_filter:
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
||||
# raise HTTPException(
|
||||
# status_code=403,
|
||||
# detail="Model not found",
|
||||
# )
|
||||
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
@@ -1008,12 +1007,13 @@ async def generate_openai_chat_completion(
|
||||
|
||||
model_id = completion_form.model
|
||||
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
||||
# raise HTTPException(
|
||||
# status_code=403,
|
||||
# detail="Model not found",
|
||||
# )
|
||||
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
@@ -1054,15 +1054,16 @@ async def get_openai_models(
|
||||
if url_idx is None:
|
||||
models = await get_all_models()
|
||||
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["models"] = list(
|
||||
filter(
|
||||
lambda model: model["name"]
|
||||
in app.state.config.MODEL_FILTER_LIST,
|
||||
models["models"],
|
||||
)
|
||||
)
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user":
|
||||
# models["models"] = list(
|
||||
# filter(
|
||||
# lambda model: model["name"]
|
||||
# in app.state.config.MODEL_FILTER_LIST,
|
||||
# models["models"],
|
||||
# )
|
||||
# )
|
||||
|
||||
return {
|
||||
"data": [
|
||||
|
||||
@@ -11,9 +11,7 @@ from open_webui.apps.webui.models.models import Models
|
||||
from open_webui.config import (
|
||||
CACHE_DIR,
|
||||
CORS_ALLOW_ORIGIN,
|
||||
ENABLE_MODEL_FILTER,
|
||||
ENABLE_OPENAI_API,
|
||||
MODEL_FILTER_LIST,
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
OPENAI_API_CONFIGS,
|
||||
@@ -61,9 +59,6 @@ app.add_middleware(
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
||||
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
||||
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
|
||||
@@ -372,15 +367,18 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
|
||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
||||
if url_idx is None:
|
||||
models = await get_all_models()
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["data"] = list(
|
||||
filter(
|
||||
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
||||
models["data"],
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user":
|
||||
# models["data"] = list(
|
||||
# filter(
|
||||
# lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
||||
# models["data"],
|
||||
# )
|
||||
# )
|
||||
# return models
|
||||
|
||||
return models
|
||||
else:
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||
@@ -492,11 +490,10 @@ async def verify_connection(
|
||||
|
||||
|
||||
@app.post("/chat/completions")
|
||||
@app.post("/chat/completions/{url_idx}")
|
||||
async def generate_chat_completion(
|
||||
form_data: dict,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: Optional[bool] = False,
|
||||
):
|
||||
idx = 0
|
||||
payload = {**form_data}
|
||||
@@ -505,6 +502,16 @@ async def generate_chat_completion(
|
||||
del payload["metadata"]
|
||||
|
||||
model_id = form_data.get("model")
|
||||
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if not bypass_filter:
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
||||
# raise HTTPException(
|
||||
# status_code=403,
|
||||
# detail="Model not found",
|
||||
# )
|
||||
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
|
||||
@@ -183,7 +183,10 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
|
||||
docs_url="/docs" if ENV == "dev" else None,
|
||||
openapi_url="/openapi.json" if ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.state.config = AppConfig()
|
||||
@@ -1081,15 +1084,16 @@ async def get_models(user=Depends(get_verified_user)):
|
||||
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
|
||||
]
|
||||
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models = list(
|
||||
filter(
|
||||
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
||||
models,
|
||||
)
|
||||
)
|
||||
return {"data": models}
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user":
|
||||
# models = list(
|
||||
# filter(
|
||||
# lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
||||
# models,
|
||||
# )
|
||||
# )
|
||||
# return {"data": models}
|
||||
|
||||
return {"data": models}
|
||||
|
||||
@@ -1106,12 +1110,14 @@ async def generate_chat_completions(
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Model not found",
|
||||
)
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if not bypass_filter:
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_403_FORBIDDEN,
|
||||
# detail="Model not found",
|
||||
# )
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
@@ -1161,14 +1167,16 @@ async def generate_chat_completions(
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(form_data, user=user)
|
||||
if model["owned_by"] == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
form_data = GenerateChatCompletionForm(**form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
form_data=form_data, user=user, bypass_filter=True
|
||||
form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||
)
|
||||
if form_data.stream:
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
@@ -1179,7 +1187,9 @@ async def generate_chat_completions(
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_openai_chat_completion(form_data, user=user)
|
||||
return await generate_openai_chat_completion(
|
||||
form_data, user=user, bypass_filter=bypass_filter
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
@@ -2297,32 +2307,6 @@ async def get_app_config(request: Request):
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/config/model/filter")
|
||||
async def get_model_filter_config(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"enabled": app.state.config.ENABLE_MODEL_FILTER,
|
||||
"models": app.state.config.MODEL_FILTER_LIST,
|
||||
}
|
||||
|
||||
|
||||
class ModelFilterConfigForm(BaseModel):
|
||||
enabled: bool
|
||||
models: list[str]
|
||||
|
||||
|
||||
@app.post("/api/config/model/filter")
|
||||
async def update_model_filter_config(
|
||||
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
||||
app.state.config.MODEL_FILTER_LIST = form_data.models
|
||||
|
||||
return {
|
||||
"enabled": app.state.config.ENABLE_MODEL_FILTER,
|
||||
"models": app.state.config.MODEL_FILTER_LIST,
|
||||
}
|
||||
|
||||
|
||||
# TODO: webhook endpoint should be under config endpoints
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user