diff --git a/backend/open_webui/apps/webui/models/models.py b/backend/open_webui/apps/webui/models/models.py index 77b7c5f67..6434cfb16 100644 --- a/backend/open_webui/apps/webui/models/models.py +++ b/backend/open_webui/apps/webui/models/models.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, ConfigDict from sqlalchemy import or_, and_, func from sqlalchemy.dialects import postgresql, sqlite -from sqlalchemy import BigInteger, Column, Text, JSON +from sqlalchemy import BigInteger, Column, Text, JSON, Boolean from open_webui.utils.utils import has_access @@ -95,6 +95,8 @@ class Model(Base): # } # } + is_active = Column(Boolean, default=True) + updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -110,6 +112,7 @@ class ModelModel(BaseModel): access_control: Optional[dict] = None + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -131,6 +134,8 @@ class ModelResponse(BaseModel): meta: ModelMeta access_control: Optional[dict] = None + + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -141,6 +146,8 @@ class ModelForm(BaseModel): name: str meta: ModelMeta params: ModelParams + access_control: Optional[dict] = None + is_active: bool = True class ModelsTable: @@ -200,6 +207,23 @@ class ModelsTable: except Exception: return None + def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: + with get_db() as db: + try: + is_active = db.query(Model).filter_by(id=id).first().is_active + + db.query(Model).filter_by(id=id).update( + { + "is_active": not is_active, + "updated_at": int(time.time()), + } + ) + db.commit() + + return self.get_model_by_id(id) + except Exception: + return None + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: with get_db() as db: diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py index 86b8515fd..7ba8d8190 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/apps/webui/routers/models.py @@ -79,6 +79,41 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): ) +############################ +# ToggelModelById +############################ + + +@router.post("/id/{id}/toggle", response_model=Optional[ModelResponse]) +async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if model: + if ( + user.role == "admin" + or model.user_id == user.id + or has_access(user.id, "write", model.access_control) + ): + model = Models.toggle_model_by_id(id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # UpdateModelById ############################ diff --git a/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py index f349e3593..bdab303a7 100644 --- a/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py +++ b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py @@ -35,6 +35,17 @@ def upgrade(): sa.Column("access_control", sa.JSON(), nullable=True), ) + # Add 'is_active' column to 'model' table + op.add_column( + "model", + sa.Column( + "is_active", + sa.Boolean(), + nullable=False, + server_default=sa.sql.expression.true(), + ), + ) + # Add 'access_control' column to 'knowledge' table op.add_column( "knowledge", @@ -60,6 +71,9 @@ def downgrade(): # Drop 'access_control' column from 'model' table op.drop_column("model", "access_control") + # Drop 'is_active' column from 'model' table + op.drop_column("model", "is_active") + # Drop 'access_control' column from 'knowledge' table op.drop_column("knowledge", "access_control") diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 14014ce99..86aeb2d89 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -100,6 +100,43 @@ export const getModelById = async (token: string, id: string) => { return res; }; + +export const toggleModelById = async (token: string, id: string) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/id/${id}/toggle`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + export const updateModelById = async (token: string, id: string, model: object) => { let error = null; diff --git a/src/lib/components/workspace/Knowledge/Collection.svelte b/src/lib/components/workspace/Knowledge/Collection.svelte index cf69f7f37..d2f0318bf 100644 --- a/src/lib/components/workspace/Knowledge/Collection.svelte +++ b/src/lib/components/workspace/Knowledge/Collection.svelte @@ -39,6 +39,7 @@ import Drawer from '$lib/components/common/Drawer.svelte'; import ChevronLeft from '$lib/components/icons/ChevronLeft.svelte'; import MenuLines from '$lib/components/icons/MenuLines.svelte'; + import AccessControl from '../common/AccessControl.svelte'; let largeScreen = true; @@ -687,7 +688,17 @@ /> {:else} -