refac: better migration script

This commit is contained in:
Timothy J. Baek 2024-05-24 19:26:27 -07:00
parent 8bca17ee1d
commit e316abcfc8
3 changed files with 81 additions and 2 deletions

View File

@ -30,6 +30,8 @@ import peewee as pw
from peewee_migrate import Migrator from peewee_migrate import Migrator
import json import json
from utils.misc import parse_ollama_modelfile
with suppress(ImportError): with suppress(ImportError):
import playhouse.postgres_ext as pw_pext import playhouse.postgres_ext as pw_pext
@ -64,13 +66,16 @@ def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
} }
) )
info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
# Insert the processed data into the 'model' table # Insert the processed data into the 'model' table
Model.create( Model.create(
id=modelfile.tag_name, id=modelfile.tag_name,
user_id=modelfile.user_id, user_id=modelfile.user_id,
base_model_id=info.get("base_model_id"),
name=modelfile.modelfile.get("title"), name=modelfile.modelfile.get("title"),
meta=meta, meta=meta,
params="{}", params=json.dumps(info.get("params", {})),
created_at=modelfile.timestamp, created_at=modelfile.timestamp,
updated_at=modelfile.timestamp, updated_at=modelfile.timestamp,
) )

View File

@ -1,5 +1,6 @@
from pathlib import Path from pathlib import Path
import hashlib import hashlib
import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional from typing import Optional
@ -110,3 +111,76 @@ def parse_duration(duration: str) -> Optional[timedelta]:
total_duration += timedelta(weeks=number) total_duration += timedelta(weeks=number)
return total_duration return total_duration
def parse_ollama_modelfile(model_text):
parameters_meta = {
"mirostat": int,
"mirostat_eta": float,
"mirostat_tau": float,
"num_ctx": int,
"repeat_last_n": int,
"repeat_penalty": float,
"temperature": float,
"seed": int,
"stop": str,
"tfs_z": float,
"num_predict": int,
"top_k": int,
"top_p": float,
}
data = {"base_model_id": None, "params": {}}
# Parse base model
base_model_match = re.search(
r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
)
if base_model_match:
data["base_model_id"] = base_model_match.group(1)
# Parse template
template_match = re.search(
r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
if template_match:
data["params"] = {"template": template_match.group(1).strip()}
# Parse stops
stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
if stops:
data["params"]["stop"] = stops
# Parse other parameters from the provided list
for param, param_type in parameters_meta.items():
param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
if param_match:
value = param_match.group(1)
if param_type == int:
value = int(value)
elif param_type == float:
value = float(value)
data["params"][param] = value
# Parse adapter
adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
if adapter_match:
data["params"]["adapter"] = adapter_match.group(1)
# Parse system description
system_desc_match = re.search(
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
if system_desc_match:
data["params"]["system"] = system_desc_match.group(1).strip()
# Parse messages
messages = []
message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
for role, content in message_matches:
messages.append({"role": role, "content": content})
if messages:
data["params"]["messages"] = messages
return data

View File

@ -139,7 +139,7 @@
</div> </div>
<div class=" flex-1 self-center"> <div class=" flex-1 self-center">
<div class=" font-bold capitalize">{model.name}</div> <div class=" font-bold capitalize line-clamp-1">{model.name}</div>
<div class=" text-sm overflow-hidden text-ellipsis line-clamp-1"> <div class=" text-sm overflow-hidden text-ellipsis line-clamp-1">
{model?.info?.meta?.description ?? model.id} {model?.info?.meta?.description ?? model.id}
</div> </div>