diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index c97202f..c711741 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -21,6 +21,7 @@ from typing import ( Union, Mapping, IO, + Callable, ) from urllib.parse import unquote, urlparse from zipfile import ZipFile, ZIP_BZIP2 @@ -54,6 +55,7 @@ from apiserver.database.model.task.task import ( from apiserver.database.utils import get_options from apiserver.utilities import json from apiserver.utilities.dicts import nested_get, nested_set, nested_delete +from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper class PrePopulate: @@ -744,6 +746,19 @@ class PrePopulate: module = importlib.import_module(module_name) return getattr(module, class_name) + @staticmethod + def _upgrade_model_data(model_data: dict) -> dict: + metadata_key = "metadata" + metadata = model_data.get(metadata_key) + if isinstance(metadata, list): + metadata = { + ParameterKeyEscaper.escape(item["key"]): item + for item in metadata + if isinstance(item, dict) and "key" in item + } + model_data[metadata_key] = metadata + return model_data + @staticmethod def _upgrade_task_data(task_data: dict) -> dict: """ @@ -828,9 +843,14 @@ class PrePopulate: print(f"Writing {cls_.__name__.lower()}s into database") tasks = [] override_project_count = 0 + data_upgrade_funcs: Mapping[Type, Callable] = { + cls.task_cls: cls._upgrade_task_data, + cls.model_cls: cls._upgrade_model_data, + } for item in cls.json_lines(f): - if cls_ == cls.task_cls: - item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item))) + upgrade_func = data_upgrade_funcs.get(cls_) + if upgrade_func: + item = json.dumps(upgrade_func(json.loads(item))) doc = cls_.from_json(item, created=True) if hasattr(doc, "user"):