From c50b678dce2c06bec9e677f0a49efaf8e5880e3d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 18 Nov 2024 06:19:34 -0800 Subject: [PATCH] enh: tools user info --- .../open_webui/apps/webui/models/prompts.py | 18 ++++-- backend/open_webui/apps/webui/models/tools.py | 20 +++++-- .../open_webui/apps/webui/routers/prompts.py | 9 ++- .../open_webui/apps/webui/routers/tools.py | 12 +++- src/lib/components/workspace/Knowledge.svelte | 8 +-- src/lib/components/workspace/Models.svelte | 12 ++-- src/lib/components/workspace/Prompts.svelte | 28 ++++++++-- src/lib/components/workspace/Tools.svelte | 55 ++++++++++++------- 8 files changed, 113 insertions(+), 49 deletions(-) diff --git a/backend/open_webui/apps/webui/models/prompts.py b/backend/open_webui/apps/webui/models/prompts.py index ea4a229f7..fe2ea87da 100644 --- a/backend/open_webui/apps/webui/models/prompts.py +++ b/backend/open_webui/apps/webui/models/prompts.py @@ -2,7 +2,7 @@ import time from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db -from open_webui.apps.webui.models.groups import Groups +from open_webui.apps.webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -57,6 +57,10 @@ class PromptModel(BaseModel): #################### +class PromptUserResponse(PromptModel): + user: Optional[UserResponse] = None + + class PromptForm(BaseModel): command: str title: str @@ -97,15 +101,21 @@ class PromptsTable: except Exception: return None - def get_prompts(self) -> list[PromptModel]: + def get_prompts(self) -> list[PromptUserResponse]: with get_db() as db: return [ - PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() + PromptUserResponse.model_validate( + { + **PromptModel.model_validate(prompt).model_dump(), + "user": Users.get_user_by_id(prompt.user_id).model_dump(), + } + ) + for prompt in db.query(Prompt).all() ] def get_prompts_by_user_id( self, user_id: str, permission: str = "write" - ) -> list[PromptModel]: + ) -> list[PromptUserResponse]: prompts = self.get_prompts() return [ diff --git a/backend/open_webui/apps/webui/models/tools.py b/backend/open_webui/apps/webui/models/tools.py index 63570bee6..0044de218 100644 --- a/backend/open_webui/apps/webui/models/tools.py +++ b/backend/open_webui/apps/webui/models/tools.py @@ -3,7 +3,7 @@ import time from typing import Optional from open_webui.apps.webui.internal.db import Base, JSONField, get_db -from open_webui.apps.webui.models.users import Users +from open_webui.apps.webui.models.users import Users, UserResponse from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -86,6 +86,10 @@ class ToolResponse(BaseModel): created_at: int # timestamp in epoch +class ToolUserResponse(ToolResponse): + user: Optional[UserResponse] = None + + class ToolForm(BaseModel): id: str name: str @@ -134,13 +138,21 @@ class ToolsTable: except Exception: return None - def get_tools(self) -> list[ToolModel]: + def get_tools(self) -> list[ToolUserResponse]: with get_db() as db: - return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + return [ + ToolUserResponse.model_validate( + { + **ToolModel.model_validate(tool).model_dump(), + "user": Users.get_user_by_id(tool.user_id).model_dump(), + } + ) + for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all() + ] def get_tools_by_user_id( self, user_id: str, permission: str = "write" - ) -> list[ToolModel]: + ) -> list[ToolUserResponse]: tools = self.get_tools() return [ diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/apps/webui/routers/prompts.py index e3aab4043..7cacde606 100644 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ b/backend/open_webui/apps/webui/routers/prompts.py @@ -1,6 +1,11 @@ from typing import Optional -from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompts +from open_webui.apps.webui.models.prompts import ( + PromptForm, + PromptUserResponse, + PromptModel, + Prompts, +) from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, status, Request from open_webui.utils.utils import get_admin_user, get_verified_user @@ -23,7 +28,7 @@ async def get_prompts(user=Depends(get_verified_user)): return prompts -@router.get("/list", response_model=list[PromptModel]) +@router.get("/list", response_model=list[PromptUserResponse]) async def get_prompt_list(user=Depends(get_verified_user)): if user.role == "admin": prompts = Prompts.get_prompts() diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index fb6292f2f..883c34405 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -2,7 +2,13 @@ import os from pathlib import Path from typing import Optional -from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools +from open_webui.apps.webui.models.tools import ( + ToolForm, + ToolModel, + ToolResponse, + ToolUserResponse, + Tools, +) from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports from open_webui.config import CACHE_DIR, DATA_DIR from open_webui.constants import ERROR_MESSAGES @@ -19,7 +25,7 @@ router = APIRouter() ############################ -@router.get("/", response_model=list[ToolResponse]) +@router.get("/", response_model=list[ToolUserResponse]) async def get_tools(user=Depends(get_verified_user)): if user.role == "admin": tools = Tools.get_tools() @@ -33,7 +39,7 @@ async def get_tools(user=Depends(get_verified_user)): ############################ -@router.get("/list", response_model=list[ToolResponse]) +@router.get("/list", response_model=list[ToolUserResponse]) async def get_tool_list(user=Depends(get_verified_user)): if user.role == "admin": tools = Tools.get_tools() diff --git a/src/lib/components/workspace/Knowledge.svelte b/src/lib/components/workspace/Knowledge.svelte index b2fc94614..77723532c 100644 --- a/src/lib/components/workspace/Knowledge.svelte +++ b/src/lib/components/workspace/Knowledge.svelte @@ -84,7 +84,7 @@ }} /> -
+
{$i18n.t('Knowledge')} @@ -121,10 +121,10 @@
-
+
{#each filteredItems as item}
-
+
-
+
{$i18n.t('Models')} @@ -230,14 +230,14 @@
-
+
{#each filteredModels as model}
-
+
@@ -261,7 +261,7 @@ className=" w-fit" placement="top-start" > -
{model.name}
+
{model.name}
@@ -278,7 +278,7 @@
-
+
{$i18n.t('By {{name}}', { diff --git a/src/lib/components/workspace/Prompts.svelte b/src/lib/components/workspace/Prompts.svelte index 8996b33ac..534b21af3 100644 --- a/src/lib/components/workspace/Prompts.svelte +++ b/src/lib/components/workspace/Prompts.svelte @@ -21,6 +21,8 @@ import Plus from '../icons/Plus.svelte'; import ChevronRight from '../icons/ChevronRight.svelte'; import Spinner from '../common/Spinner.svelte'; + import Tooltip from '../common/Tooltip.svelte'; + import { capitalizeFirstLetter } from '$lib/utils'; const i18n = getContext('i18n'); let promptsImportInputElement: HTMLInputElement; @@ -103,7 +105,7 @@
-
+
{$i18n.t('Prompts')} @@ -137,19 +139,33 @@
-
+
{#each filteredItems as prompt}
-
diff --git a/src/lib/components/workspace/Tools.svelte b/src/lib/components/workspace/Tools.svelte index 358b06f67..9252ea010 100644 --- a/src/lib/components/workspace/Tools.svelte +++ b/src/lib/components/workspace/Tools.svelte @@ -30,6 +30,7 @@ import Plus from '../icons/Plus.svelte'; import ChevronRight from '../icons/ChevronRight.svelte'; import Spinner from '../common/Spinner.svelte'; + import { capitalizeFirstLetter } from '$lib/utils'; const i18n = getContext('i18n'); @@ -172,7 +173,7 @@ {#if loaded} -