From 120b1857b21ea25f7a2dddc5e6e288f3ce0a6ebe Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 23 Jun 2024 18:05:33 -0700 Subject: [PATCH] enh: valves --- .../016_add_valves_and_is_active.py | 50 +++++++++++++++++++ backend/apps/webui/models/functions.py | 8 +++ backend/apps/webui/models/tools.py | 26 ++++++++++ src/routes/(app)/workspace/+layout.svelte | 2 +- 4 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py diff --git a/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py new file mode 100644 index 000000000..e3af521b7 --- /dev/null +++ b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py @@ -0,0 +1,50 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields("tool", valves=pw.TextField(null=True)) + migrator.add_fields("function", valves=pw.TextField(null=True)) + migrator.add_fields("function", is_active=pw.BooleanField(default=False)) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("tool", "valves") + migrator.remove_fields("function", "valves") + migrator.remove_fields("function", "is_active") diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index cd6320f95..779eef9ef 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -28,6 +28,8 @@ class Function(Model): type = TextField() content = TextField() meta = JSONField() + valves = JSONField() + is_active = BooleanField(default=False) updated_at = BigIntegerField() created_at = BigIntegerField() @@ -46,6 +48,7 @@ class FunctionModel(BaseModel): type: str content: str meta: FunctionMeta + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -61,6 +64,7 @@ class FunctionResponse(BaseModel): type: str name: str meta: FunctionMeta + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -72,6 +76,10 @@ class FunctionForm(BaseModel): meta: FunctionMeta +class FunctionValves(BaseModel): + valves: Optional[dict] = None + + class FunctionsTable: def __init__(self, db): self.db = db diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index ab322ac14..0f5755e39 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -28,6 +28,7 @@ class Tool(Model): content = TextField() specs = JSONField() meta = JSONField() + valves = JSONField() updated_at = BigIntegerField() created_at = BigIntegerField() @@ -71,6 +72,10 @@ class ToolForm(BaseModel): meta: ToolMeta +class ToolValves(BaseModel): + valves: Optional[dict] = None + + class ToolsTable: def __init__(self, db): self.db = db @@ -109,6 +114,27 @@ class ToolsTable: def get_tools(self) -> List[ToolModel]: return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] + def get_tool_valves_by_id(self, id: str) -> Optional[ToolValves]: + try: + tool = Tool.get(Tool.id == id) + return ToolValves(**model_to_dict(tool)) + except Exception as e: + print(f"An error occurred: {e}") + return None + + def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: + try: + query = Tool.update( + **{"valves": valves}, + updated_at=int(time.time()), + ).where(Tool.id == id) + query.execute() + + tool = Tool.get(Tool.id == id) + return ToolValves(**model_to_dict(tool)) + except: + return None + def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: diff --git a/src/routes/(app)/workspace/+layout.svelte b/src/routes/(app)/workspace/+layout.svelte index 46e0f63c4..794bd7ed5 100644 --- a/src/routes/(app)/workspace/+layout.svelte +++ b/src/routes/(app)/workspace/+layout.svelte @@ -9,7 +9,7 @@ const i18n = getContext('i18n'); onMount(async () => { - functions.set(await getFunctions(localStorage.token)); + // functions.set(await getFunctions(localStorage.token)); });