From a2a25fb5711e399dfe038e0a6d0df8450e023a01 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 18 Nov 2024 05:37:04 -0800 Subject: [PATCH] enh: models display author name --- backend/open_webui/apps/socket/main.py | 2 ++ .../open_webui/apps/webui/models/models.py | 31 +++++++++---------- backend/open_webui/apps/webui/models/users.py | 8 +++++ .../open_webui/apps/webui/routers/models.py | 3 +- src/lib/components/workspace/Models.svelte | 9 ++++-- 5 files changed, 33 insertions(+), 20 deletions(-) diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index fca268a6b..5c284f18d 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -1,3 +1,5 @@ +# TODO: move socket to webui app + import asyncio import socketio import logging diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/apps/webui/models/models.py index 46591bd95..52ed8fe43 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/apps/webui/models/models.py @@ -5,7 +5,7 @@ from typing import Optional from open_webui.apps.webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS -from open_webui.apps.webui.models.groups import Groups +from open_webui.apps.webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict @@ -124,20 +124,12 @@ class ModelModel(BaseModel): #################### -class ModelResponse(BaseModel): - id: str - user_id: str - base_model_id: Optional[str] = None +class ModelUserResponse(ModelModel): + user: Optional[UserResponse] = None - name: str - params: ModelParams - meta: ModelMeta - access_control: Optional[dict] = None - - is_active: bool - updated_at: int # timestamp in epoch - created_at: int # timestamp in epoch +class ModelResponse(ModelModel): + pass class ModelForm(BaseModel): @@ -181,10 +173,15 @@ class ModelsTable: with get_db() as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] - def get_models(self) -> list[ModelModel]: + def get_models(self) -> list[ModelUserResponse]: with get_db() as db: return [ - ModelModel.model_validate(model) + ModelUserResponse.model_validate( + { + **ModelModel.model_validate(model).model_dump(), + "user": Users.get_user_by_id(model.user_id).model_dump(), + } + ) for model in db.query(Model).filter(Model.base_model_id != None).all() ] @@ -197,8 +194,8 @@ class ModelsTable: def get_models_by_user_id( self, user_id: str, permission: str = "write" - ) -> list[ModelModel]: - models = self.get_all_models() + ) -> list[ModelUserResponse]: + models = self.get_models() return [ model for model in models diff --git a/backend/open_webui/apps/webui/models/users.py b/backend/open_webui/apps/webui/models/users.py index 328618a67..5bbcc3099 100644 --- a/backend/open_webui/apps/webui/models/users.py +++ b/backend/open_webui/apps/webui/models/users.py @@ -62,6 +62,14 @@ class UserModel(BaseModel): #################### +class UserResponse(BaseModel): + id: str + name: str + email: str + role: str + profile_image_url: str + + class UserRoleUpdateForm(BaseModel): id: str role: str diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py index 634630622..ad5221cf8 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/apps/webui/routers/models.py @@ -4,6 +4,7 @@ from open_webui.apps.webui.models.models import ( ModelForm, ModelModel, ModelResponse, + ModelUserResponse, Models, ) from open_webui.constants import ERROR_MESSAGES @@ -22,7 +23,7 @@ router = APIRouter() ########################### -@router.get("/", response_model=list[ModelResponse]) +@router.get("/", response_model=list[ModelUserResponse]) async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): if user.role == "admin": return Models.get_models() diff --git a/src/lib/components/workspace/Models.svelte b/src/lib/components/workspace/Models.svelte index 3c77abb70..d48eadb94 100644 --- a/src/lib/components/workspace/Models.svelte +++ b/src/lib/components/workspace/Models.svelte @@ -284,8 +284,13 @@ >
{model.name}
-
- {model?.meta?.description ?? model.id} + +
+ +
+ By {model?.user?.name ?? model?.user?.email} +
+