Code cleanup

This commit is contained in:
allegroai 2021-05-03 17:48:24 +03:00
parent f4d5168a20
commit 174f692edf

View File

@ -25,7 +25,6 @@ from typing import (
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
import dpath
import mongoengine
from boltons.iterutils import chunked_iter, first
from furl import furl
@ -47,7 +46,6 @@ from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
from apiserver.database.utils import get_options
from apiserver.tools import safe_get
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
@ -344,31 +342,10 @@ class PrePopulate:
return upadated
@staticmethod
def _upgrade_task_data(task_data: dict):
for old_param_field, new_param_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy = safe_get(task_data, old_param_field)
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not safe_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
dpath.new(task_data, new_path, new_param)
dpath.delete(task_data, old_param_field)
@classmethod
def _upgrade_tasks(cls, f: IO[bytes]) -> bytes:
"""
Build content array that contains fixed tasks from the passed file
Build content array that contains upgraded tasks from the passed file
For each task the old execution.parameters and model.design are
converted to the new structure.
The fix is done on Task objects (not the dictionary) so that
@ -759,15 +736,35 @@ class PrePopulate:
module = importlib.import_module(module_name)
return getattr(module, class_name)
@classmethod
def _upgrade_task_data(cls, task_data: dict) -> dict:
@staticmethod
def _upgrade_task_data(task_data: dict) -> dict:
"""
Migrate from execution/parameters and model_desc to hyperparams and configuration fiields
Upgrade artifacts list to dict
Migrate from execution.model and output.model to the new models field
Move docker_cmd contents into the container field
:param task_data:
:return:
:param task_data: Upgraded in place
:return: The upgraded task data
"""
for old_param_field, new_param_field, default_section in (
("execution.parameters", "hyperparams", hyperparams_default_section),
("execution.model_desc", "configuration", None),
):
legacy_path = old_param_field.split(".")
legacy = nested_get(task_data, legacy_path)
if legacy:
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not nested_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
nested_set(task_data, path=new_path, value=new_param)
nested_delete(task_data, legacy_path)
artifacts_path = ("execution", "artifacts")
artifacts = nested_get(task_data, artifacts_path)
if isinstance(artifacts, list):
@ -801,15 +798,10 @@ class PrePopulate:
task_data["models"] = models
docker_cmd_path = ("execution", "docker_cmd")
container_path = ("execution", "container")
docker_cmd = nested_get(task_data, docker_cmd_path)
if docker_cmd and not nested_get(task_data, container_path):
if docker_cmd and not task_data.get("container"):
image, _, arguments = docker_cmd.partition(" ")
nested_set(
task_data,
path=container_path,
value={"image": image, "arguments": arguments},
)
task_data["container"] = {"image": image, "arguments": arguments}
nested_delete(task_data, docker_cmd_path)
return task_data
@ -830,7 +822,6 @@ class PrePopulate:
for item in cls.json_lines(f):
if cls_ == cls.task_cls:
item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item)))
print(item)
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):