feat: migrate modelfiles to models

This commit is contained in:
Timothy J. Baek 2024-05-23 23:47:01 -07:00
parent 3be0fa63ee
commit 17e4be49c0
2 changed files with 134 additions and 5 deletions

View File

@ -0,0 +1,125 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
import json
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Fetch data from 'modelfile' table and insert into 'model' table
migrate_modelfile_to_model(migrator, database)
# Drop the 'modelfile' table
migrator.remove_model("modelfile")
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
ModelFile = migrator.orm["modelfile"]
Model = migrator.orm["model"]
modelfiles = ModelFile.select()
for modelfile in modelfiles:
# Extract and transform data in Python
modelfile.modelfile = json.loads(modelfile.modelfile)
meta = json.dumps(
{
"description": modelfile.modelfile.get("desc"),
"profile_image_url": modelfile.modelfile.get("imageUrl"),
"ollama": {"modelfile": modelfile.modelfile.get("content")},
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
"categories": modelfile.modelfile.get("categories"),
"user": {**modelfile.modelfile.get("user", {}), "community": "true"},
}
)
# Insert the processed data into the 'model' table
Model.create(
id=modelfile.tag_name,
user_id=modelfile.user_id,
name=modelfile.modelfile.get("title"),
meta=meta,
params="{}",
created_at=modelfile.timestamp,
updated_at=modelfile.timestamp,
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
recreate_modelfile_table(migrator, database)
move_data_back_to_modelfile(migrator, database)
migrator.remove_model("model")
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
query = """
CREATE TABLE IF NOT EXISTS modelfile (
user_id TEXT,
tag_name TEXT,
modelfile JSON,
timestamp BIGINT
)
"""
migrator.sql(query)
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
Model = migrator.orm["model"]
Modelfile = migrator.orm["modelfile"]
models = Model.select()
for model in models:
# Extract and transform data in Python
meta = json.loads(model.meta)
modelfile_data = {
"title": model.name,
"desc": meta.get("description"),
"imageUrl": meta.get("profile_image_url"),
"content": meta.get("ollama", {}).get("modelfile"),
"suggestionPrompts": meta.get("suggestion_prompts"),
"categories": meta.get("categories"),
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
}
# Insert the processed data back into the 'modelfile' table
Modelfile.create(
user_id=model.user_id,
tag_name=model.id,
modelfile=modelfile_data,
timestamp=model.created_at,
)

View File

@ -4,7 +4,7 @@ from typing import Optional
import peewee as pw import peewee as pw
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from apps.web.internal.db import DB, JSONField from apps.web.internal.db import DB, JSONField
@ -22,29 +22,34 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
# ModelParams is a model for the data stored in the params field of the Model table # ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference # It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel): class ModelParams(BaseModel):
model_config = ConfigDict(extra="allow")
pass pass
# ModelMeta is a model for the data stored in the meta field of the Model table # ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference # It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel): class ModelMeta(BaseModel):
description: str description: Optional[str] = None
""" """
User-facing description of the model. User-facing description of the model.
""" """
vision_capable: bool vision_capable: Optional[bool] = None
""" """
A flag indicating if the model is capable of vision and thus image inputs A flag indicating if the model is capable of vision and thus image inputs
""" """
model_config = ConfigDict(extra="allow")
pass
class Model(pw.Model): class Model(pw.Model):
id = pw.TextField(unique=True) id = pw.TextField(unique=True)
""" """
The model's id as used in the API. If set to an existing model, it will override the model. The model's id as used in the API. If set to an existing model, it will override the model.
""" """
user_id = pw.TextField() user_id = pw.TextField()
base_model_id = pw.TextField(null=True) base_model_id = pw.TextField(null=True)
@ -89,7 +94,6 @@ class ModelModel(BaseModel):
class ModelsTable: class ModelsTable:
def __init__( def __init__(
self, self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase, db: pw.SqliteDatabase | pw.PostgresqlDatabase,