From a2eadb30f5e847ccfb2b9f3e68ed9caa494a0d6a Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 2 Oct 2024 20:42:10 -0700 Subject: [PATCH] refac --- backend/open_webui/apps/webui/models/files.py | 7 + .../open_webui/apps/webui/models/knowledge.py | 33 ++++- .../open_webui/apps/webui/routers/files.py | 3 + .../apps/webui/routers/knowledge.py | 20 ++- src/lib/apis/knowledge/index.ts | 16 +-- .../workspace/Knowledge/Files.svelte | 7 + .../workspace/Knowledge/Item.svelte | 130 ++++++++++++++++-- 7 files changed, 187 insertions(+), 29 deletions(-) diff --git a/backend/open_webui/apps/webui/models/files.py b/backend/open_webui/apps/webui/models/files.py index 69abf6f1a..ec79a2d9f 100644 --- a/backend/open_webui/apps/webui/models/files.py +++ b/backend/open_webui/apps/webui/models/files.py @@ -106,6 +106,13 @@ class FilesTable: with get_db() as db: return [FileModel.model_validate(file) for file in db.query(File).all()] + def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: + with get_db() as db: + return [ + FileModel.model_validate(file) + for file in db.query(File).filter(File.id.in_(ids)).all() + ] + def get_files_by_user_id(self, user_id: str) -> list[FileModel]: with get_db() as db: return [ diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/apps/webui/models/knowledge.py index d5d329f00..828bc3c2a 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/apps/webui/models/knowledge.py @@ -71,6 +71,12 @@ class KnowledgeForm(BaseModel): data: Optional[dict] = None +class KnowledgeUpdateForm(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + data: Optional[dict] = None + + class KnowledgeTable: def insert_new_knowledge( self, user_id: str, form_data: KnowledgeForm @@ -116,18 +122,37 @@ class KnowledgeTable: return None def update_knowledge_by_id( - self, id: str, form_data: KnowledgeForm + self, id: str, form_data: KnowledgeUpdateForm, overwrite: bool = False ) -> Optional[KnowledgeModel]: try: with get_db() as db: db.query(Knowledge).filter_by(id=id).update( { - "name": form_data.name, - "updated_id": int(time.time()), + **({"name": form_data.name} if form_data.name else {}), + **( + {"description": form_data.description} + if form_data.description + else {} + ), + **( + { + "data": ( + form_data.data + if overwrite + else { + **(self.get_knowledge_by_id(id=id)).data, + **form_data.data, + } + ) + } + if form_data.data + else {} + ), + "updated_at": int(time.time()), } ) db.commit() - return self.get_knowledge_by_id(id=form_data.id) + return self.get_knowledge_by_id(id=id) except Exception as e: log.exception(e) return None diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py index b39e4f542..1204dea9d 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/apps/webui/routers/files.py @@ -6,9 +6,12 @@ from pathlib import Path from typing import Optional from open_webui.apps.webui.models.files import FileForm, FileModel, Files +from open_webui.apps.webui.models.knowledge import Knowledges from open_webui.config import UPLOAD_DIR from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS + + from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi.responses import FileResponse, StreamingResponse from open_webui.utils.utils import get_admin_user, get_verified_user diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index 0f97fb18c..78faae6e1 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -6,10 +6,12 @@ from fastapi import APIRouter, Depends, HTTPException, status from open_webui.apps.webui.models.knowledge import ( Knowledges, - KnowledgeModel, + KnowledgeUpdateForm, KnowledgeForm, KnowledgeResponse, ) +from open_webui.apps.webui.models.files import Files, FileModel + from open_webui.constants import ERROR_MESSAGES from open_webui.utils.utils import get_admin_user, get_verified_user @@ -66,12 +68,22 @@ async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_ ############################ -@router.get("/{id}", response_model=Optional[KnowledgeResponse]) +class KnowledgeFilesResponse(KnowledgeResponse): + files: list[FileModel] + + +@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse]) async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): knowledge = Knowledges.get_knowledge_by_id(id=id) if knowledge: - return knowledge + file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] + files = Files.get_files_by_ids(file_ids) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=files, + ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -87,7 +99,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/update", response_model=Optional[KnowledgeResponse]) async def update_knowledge_by_id( id: str, - form_data: KnowledgeForm, + form_data: KnowledgeUpdateForm, user=Depends(get_admin_user), ): knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) diff --git a/src/lib/apis/knowledge/index.ts b/src/lib/apis/knowledge/index.ts index 600c3bbac..511924a04 100644 --- a/src/lib/apis/knowledge/index.ts +++ b/src/lib/apis/knowledge/index.ts @@ -95,13 +95,13 @@ export const getKnowledgeById = async (token: string, id: string) => { return res; }; -type KnowledgeForm = { - name: string; - description: string; - data: object; +type KnowledgeUpdateForm = { + name?: string; + description?: string; + data?: object; }; -export const updateKnowledgeById = async (token: string, id: string, form: KnowledgeForm) => { +export const updateKnowledgeById = async (token: string, id: string, form: KnowledgeUpdateForm) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/${id}/update`, { @@ -112,9 +112,9 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl authorization: `Bearer ${token}` }, body: JSON.stringify({ - name: form.name, - description: form.description, - data: form.data + name: form?.name ? form.name : undefined, + description: form?.description ? form.description : undefined, + data: form?.data ? form.data : undefined }) }) .then(async (res) => { diff --git a/src/lib/components/workspace/Knowledge/Files.svelte b/src/lib/components/workspace/Knowledge/Files.svelte index e69de29bb..73109670a 100644 --- a/src/lib/components/workspace/Knowledge/Files.svelte +++ b/src/lib/components/workspace/Knowledge/Files.svelte @@ -0,0 +1,7 @@ + + +
+ {JSON.stringify(files)} +
diff --git a/src/lib/components/workspace/Knowledge/Item.svelte b/src/lib/components/workspace/Knowledge/Item.svelte index 52c92ae41..ef1d2abb5 100644 --- a/src/lib/components/workspace/Knowledge/Item.svelte +++ b/src/lib/components/workspace/Knowledge/Item.svelte @@ -1,31 +1,80 @@ @@ -92,11 +190,14 @@
-
+
{ + changeDebounceHandler(); + }} />
@@ -112,6 +213,9 @@ type="text" class="w-full font-medium text-gray-500 text-sm bg-transparent outline-none" bind:value={knowledge.description} + on:input={() => { + changeDebounceHandler(); + }} />
@@ -119,7 +223,7 @@