mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Improve pre-populate on server startup (including sync lock)
This commit is contained in:
parent
21f2ea8b17
commit
901ec37290
@ -10,6 +10,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- /opt/trains/logs:/var/log/trains
|
- /opt/trains/logs:/var/log/trains
|
||||||
- /opt/trains/config:/opt/trains/config
|
- /opt/trains/config:/opt/trains/config
|
||||||
|
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||||
depends_on:
|
depends_on:
|
||||||
- redis
|
- redis
|
||||||
- mongo
|
- mongo
|
||||||
@ -23,8 +24,9 @@ services:
|
|||||||
TRAINS_REDIS_SERVICE_HOST: redis
|
TRAINS_REDIS_SERVICE_HOST: redis
|
||||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
|
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
|
||||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
TRAINS__apiserver__pre_populate__enabled: "true"
|
||||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
|
||||||
|
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
|
||||||
ports:
|
ports:
|
||||||
- "8008:8008"
|
- "8008:8008"
|
||||||
networks:
|
networks:
|
||||||
|
@ -552,7 +552,7 @@ class EventBLL(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||||
next_scroll_id = es_res["_scroll_id"]
|
next_scroll_id = es_res.get("_scroll_id")
|
||||||
total_events = es_res["hits"]["total"]
|
total_events = es_res["hits"]["total"]
|
||||||
|
|
||||||
return TaskEventsResult(
|
return TaskEventsResult(
|
||||||
|
@ -32,7 +32,6 @@ from database.model.task.task import (
|
|||||||
)
|
)
|
||||||
from database.utils import get_company_or_none_constraint, id as create_id
|
from database.utils import get_company_or_none_constraint, id as create_id
|
||||||
from service_repo import APICall
|
from service_repo import APICall
|
||||||
from services.utils import validate_tags
|
|
||||||
from timing_context import TimingContext
|
from timing_context import TimingContext
|
||||||
from utilities.dicts import deep_merge
|
from utilities.dicts import deep_merge
|
||||||
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
||||||
@ -182,7 +181,6 @@ class TaskBLL(object):
|
|||||||
execution_overrides: Optional[dict] = None,
|
execution_overrides: Optional[dict] = None,
|
||||||
validate_references: bool = False,
|
validate_references: bool = False,
|
||||||
) -> Task:
|
) -> Task:
|
||||||
validate_tags(tags, system_tags)
|
|
||||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||||
execution_model_overriden = False
|
execution_model_overriden = False
|
||||||
|
@ -26,6 +26,13 @@
|
|||||||
check_max_version: false
|
check_max_version: false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pre_populate {
|
||||||
|
enabled: false
|
||||||
|
zip_files: ["/path/to/export.zip"]
|
||||||
|
fail_on_error: false
|
||||||
|
artifacts_path: "/mnt/fileserver"
|
||||||
|
}
|
||||||
|
|
||||||
mongo {
|
mongo {
|
||||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||||
# but not declared in a data model
|
# but not declared in a data model
|
||||||
@ -34,12 +41,6 @@
|
|||||||
aggregate {
|
aggregate {
|
||||||
allow_disk_use: true
|
allow_disk_use: true
|
||||||
}
|
}
|
||||||
|
|
||||||
pre_populate {
|
|
||||||
enabled: false
|
|
||||||
zip_file: "/path/to/export.zip"
|
|
||||||
fail_on_error: false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auth {
|
auth {
|
||||||
|
@ -85,6 +85,7 @@ class Artifact(EmbeddedDocument):
|
|||||||
|
|
||||||
|
|
||||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||||
|
meta = {"strict": strict}
|
||||||
test_split = IntField(default=0)
|
test_split = IntField(default=0)
|
||||||
parameters = SafeDictField(default=dict)
|
parameters = SafeDictField(default=dict)
|
||||||
model = StringField(reference_field="Model")
|
model = StringField(reference_field="Model")
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from config import config
|
from config import config
|
||||||
|
from config.info import get_default_company
|
||||||
from database.model.auth import Role
|
from database.model.auth import Role
|
||||||
from service_repo.auth.fixed_user import FixedUser
|
from service_repo.auth.fixed_user import FixedUser
|
||||||
from .migration import _apply_migrations
|
from .migration import _apply_migrations
|
||||||
@ -11,7 +13,48 @@ from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
|
|||||||
log = config.logger(__package__)
|
log = config.logger(__package__)
|
||||||
|
|
||||||
|
|
||||||
def init_mongo_data():
|
def _pre_populate(company_id: str, zip_file: str):
|
||||||
|
if not zip_file or not Path(zip_file).is_file():
|
||||||
|
msg = f"Invalid pre-populate zip file: {zip_file}"
|
||||||
|
if config.get("apiserver.pre_populate.fail_on_error", False):
|
||||||
|
log.error(msg)
|
||||||
|
raise ValueError(msg)
|
||||||
|
else:
|
||||||
|
log.warning(msg)
|
||||||
|
else:
|
||||||
|
log.info(f"Pre-populating using {zip_file}")
|
||||||
|
|
||||||
|
user_id = _ensure_backend_user(
|
||||||
|
"__allegroai__", company_id, "Allegro.ai"
|
||||||
|
)
|
||||||
|
|
||||||
|
PrePopulate.import_from_zip(
|
||||||
|
zip_file,
|
||||||
|
company_id="",
|
||||||
|
user_id=user_id,
|
||||||
|
artifacts_path=config.get(
|
||||||
|
"apiserver.pre_populate.artifacts_path", None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]:
|
||||||
|
if isinstance(zip_files, str):
|
||||||
|
zip_files = [zip_files]
|
||||||
|
for p in map(Path, zip_files):
|
||||||
|
if p.is_file():
|
||||||
|
yield p
|
||||||
|
if p.is_dir():
|
||||||
|
yield from p.glob("*.zip")
|
||||||
|
log.warning(f"Invalid pre-populate entry {str(p)}, skipping")
|
||||||
|
|
||||||
|
|
||||||
|
def pre_populate_data():
|
||||||
|
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
|
||||||
|
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
|
||||||
|
|
||||||
|
|
||||||
|
def init_mongo_data() -> bool:
|
||||||
try:
|
try:
|
||||||
empty_dbs = _apply_migrations(log)
|
empty_dbs = _apply_migrations(log)
|
||||||
|
|
||||||
@ -21,23 +64,6 @@ def init_mongo_data():
|
|||||||
|
|
||||||
_ensure_default_queue(company_id)
|
_ensure_default_queue(company_id)
|
||||||
|
|
||||||
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
|
|
||||||
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
|
|
||||||
if not zip_file or not Path(zip_file).is_file():
|
|
||||||
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
|
|
||||||
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
|
|
||||||
log.error(msg)
|
|
||||||
raise ValueError(msg)
|
|
||||||
else:
|
|
||||||
log.warning(msg)
|
|
||||||
else:
|
|
||||||
|
|
||||||
user_id = _ensure_backend_user(
|
|
||||||
"__allegroai__", company_id, "Allegro.ai"
|
|
||||||
)
|
|
||||||
|
|
||||||
PrePopulate.import_from_zip(zip_file, user_id=user_id)
|
|
||||||
|
|
||||||
fixed_mode = FixedUser.enabled()
|
fixed_mode = FixedUser.enabled()
|
||||||
|
|
||||||
for user, credentials in config.get("secure.credentials", {}).items():
|
for user, credentials in config.get("secure.credentials", {}).items():
|
||||||
@ -61,5 +87,7 @@ def init_mongo_data():
|
|||||||
ensure_fixed_user(user, company_id, log=log)
|
ensure_fixed_user(user, company_id, log=log)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
||||||
|
|
||||||
|
return empty_dbs
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
log.exception("Failed initializing mongodb")
|
log.exception("Failed initializing mongodb")
|
||||||
|
@ -1,31 +1,199 @@
|
|||||||
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
from io import BytesIO
|
||||||
|
from itertools import chain
|
||||||
|
from operator import attrgetter
|
||||||
from os.path import splitext
|
from os.path import splitext
|
||||||
from typing import List, Optional, Any, Type, Set, Dict
|
from pathlib import Path
|
||||||
|
from typing import Optional, Any, Type, Set, Dict, Sequence, Tuple, BinaryIO, Union
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
from zipfile import ZipFile, ZIP_BZIP2
|
from zipfile import ZipFile, ZIP_BZIP2
|
||||||
|
|
||||||
|
import attr
|
||||||
import mongoengine
|
import mongoengine
|
||||||
from tqdm import tqdm
|
from boltons.iterutils import chunked_iter
|
||||||
|
from furl import furl
|
||||||
|
from mongoengine import Q
|
||||||
|
|
||||||
|
from bll.event import EventBLL
|
||||||
|
from database.model import EntityVisibility
|
||||||
|
from database.model.model import Model
|
||||||
|
from database.model.project import Project
|
||||||
|
from database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||||
|
from database.utils import get_options
|
||||||
|
from utilities import json
|
||||||
|
|
||||||
|
|
||||||
class PrePopulate:
|
class PrePopulate:
|
||||||
@classmethod
|
event_bll = EventBLL()
|
||||||
def export_to_zip(
|
events_file_suffix = "_events"
|
||||||
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
|
export_tag_prefix = "Exported:"
|
||||||
):
|
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
|
||||||
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
|
|
||||||
cls._export(zfile, experiments, projects)
|
class JsonLinesWriter:
|
||||||
|
def __init__(self, file: BinaryIO):
|
||||||
|
self.file = file
|
||||||
|
self.empty = True
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._write("[")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
self._write("\n]")
|
||||||
|
|
||||||
|
def _write(self, data: str):
|
||||||
|
self.file.write(data.encode("utf-8"))
|
||||||
|
|
||||||
|
def write(self, line: str):
|
||||||
|
if not self.empty:
|
||||||
|
self._write(",")
|
||||||
|
self._write("\n" + line)
|
||||||
|
self.empty = False
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class _MapData:
|
||||||
|
files: Sequence[str] = None
|
||||||
|
entities: Dict[str, datetime] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_last_update_time(entity) -> datetime:
|
||||||
|
return getattr(entity, "last_update", None) or getattr(entity, "created")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def import_from_zip(cls, filename: str, user_id: str = None):
|
def _check_for_update(
|
||||||
|
cls, map_file: Path, entities: dict
|
||||||
|
) -> Tuple[bool, Sequence[str]]:
|
||||||
|
if not map_file.is_file():
|
||||||
|
return True, []
|
||||||
|
|
||||||
|
files = []
|
||||||
|
try:
|
||||||
|
map_data = cls._MapData(**json.loads(map_file.read_text()))
|
||||||
|
files = map_data.files
|
||||||
|
for file in files:
|
||||||
|
if not Path(file).is_file():
|
||||||
|
return True, files
|
||||||
|
|
||||||
|
new_times = {
|
||||||
|
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
|
||||||
|
for item in chain.from_iterable(entities.values())
|
||||||
|
}
|
||||||
|
old_times = map_data.entities
|
||||||
|
|
||||||
|
if set(new_times.keys()) != set(old_times.keys()):
|
||||||
|
return True, files
|
||||||
|
|
||||||
|
for id_, new_timestamp in new_times.items():
|
||||||
|
if new_timestamp != old_times[id_]:
|
||||||
|
return True, files
|
||||||
|
except Exception as ex:
|
||||||
|
print("Error reading map file. " + str(ex))
|
||||||
|
return True, files
|
||||||
|
|
||||||
|
return False, files
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _write_update_file(
|
||||||
|
cls, map_file: Path, entities: dict, created_files: Sequence[str]
|
||||||
|
):
|
||||||
|
map_data = cls._MapData(
|
||||||
|
files=created_files,
|
||||||
|
entities={
|
||||||
|
entity.id: cls._get_last_update_time(entity)
|
||||||
|
for entity in chain.from_iterable(entities.values())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
map_file.write_text(json.dumps(attr.asdict(map_data)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
|
||||||
|
def is_fileserver_link(a: str) -> bool:
|
||||||
|
a = a.lower()
|
||||||
|
if a.startswith("https://files."):
|
||||||
|
return True
|
||||||
|
if a.startswith("http"):
|
||||||
|
parsed = urlparse(a)
|
||||||
|
if parsed.scheme in {"http", "https"} and parsed.port == 8081:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
|
||||||
|
print(
|
||||||
|
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
|
||||||
|
)
|
||||||
|
|
||||||
|
return fileserver_links
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def export_to_zip(
|
||||||
|
cls,
|
||||||
|
filename: str,
|
||||||
|
experiments: Sequence[str] = None,
|
||||||
|
projects: Sequence[str] = None,
|
||||||
|
artifacts_path: str = None,
|
||||||
|
task_statuses: Sequence[str] = None,
|
||||||
|
tag_exported_entities: bool = False,
|
||||||
|
) -> Sequence[str]:
|
||||||
|
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
|
||||||
|
raise ValueError("Invalid task statuses")
|
||||||
|
|
||||||
|
file = Path(filename)
|
||||||
|
entities = cls._resolve_entities(
|
||||||
|
experiments=experiments, projects=projects, task_statuses=task_statuses
|
||||||
|
)
|
||||||
|
|
||||||
|
map_file = file.with_suffix(".map")
|
||||||
|
updated, old_files = cls._check_for_update(map_file, entities)
|
||||||
|
if not updated:
|
||||||
|
print(f"There are no updates from the last export")
|
||||||
|
return old_files
|
||||||
|
for old in old_files:
|
||||||
|
old_path = Path(old)
|
||||||
|
if old_path.is_file():
|
||||||
|
old_path.unlink()
|
||||||
|
|
||||||
|
zip_args = dict(mode="w", compression=ZIP_BZIP2)
|
||||||
|
with ZipFile(file, **zip_args) as zfile:
|
||||||
|
artifacts, hash_ = cls._export(
|
||||||
|
zfile, entities, tag_entities=tag_exported_entities
|
||||||
|
)
|
||||||
|
file_with_hash = file.with_name(f"{file.stem}_{hash_}{file.suffix}")
|
||||||
|
file.replace(file_with_hash)
|
||||||
|
created_files = [str(file_with_hash)]
|
||||||
|
|
||||||
|
artifacts = cls._filter_artifacts(artifacts)
|
||||||
|
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
|
||||||
|
artifacts_file = file_with_hash.with_suffix(".artifacts")
|
||||||
|
with ZipFile(artifacts_file, **zip_args) as zfile:
|
||||||
|
cls._export_artifacts(zfile, artifacts, artifacts_path)
|
||||||
|
created_files.append(str(artifacts_file))
|
||||||
|
|
||||||
|
cls._write_update_file(map_file, entities, created_files)
|
||||||
|
|
||||||
|
return created_files
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_from_zip(
|
||||||
|
cls, filename: str, company_id: str, user_id: str, artifacts_path: str
|
||||||
|
):
|
||||||
with ZipFile(filename) as zfile:
|
with ZipFile(filename) as zfile:
|
||||||
cls._import(zfile, user_id)
|
cls._import(zfile, company_id, user_id)
|
||||||
|
|
||||||
|
if artifacts_path and os.path.isdir(artifacts_path):
|
||||||
|
artifacts_file = Path(filename).with_suffix(".artifacts")
|
||||||
|
if artifacts_file.is_file():
|
||||||
|
print(f"Unzipping artifacts into {artifacts_path}")
|
||||||
|
with ZipFile(artifacts_file) as zfile:
|
||||||
|
zfile.extractall(artifacts_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_type(
|
def _resolve_type(
|
||||||
cls: Type[mongoengine.Document], ids: Optional[List[str]]
|
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
|
||||||
) -> List[Any]:
|
) -> Sequence[Any]:
|
||||||
ids = set(ids)
|
ids = set(ids)
|
||||||
items = list(cls.objects(id__in=list(ids)))
|
items = list(cls.objects(id__in=list(ids)))
|
||||||
resolved = {i.id for i in items}
|
resolved = {i.id for i in items}
|
||||||
@ -43,20 +211,24 @@ class PrePopulate:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_entities(
|
def _resolve_entities(
|
||||||
cls, experiments: List[str] = None, projects: List[str] = None
|
cls,
|
||||||
|
experiments: Sequence[str] = None,
|
||||||
|
projects: Sequence[str] = None,
|
||||||
|
task_statuses: Sequence[str] = None,
|
||||||
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
||||||
from database.model.project import Project
|
|
||||||
from database.model.task.task import Task
|
|
||||||
|
|
||||||
entities = defaultdict(set)
|
entities = defaultdict(set)
|
||||||
|
|
||||||
if projects:
|
if projects:
|
||||||
print("Reading projects...")
|
print("Reading projects...")
|
||||||
entities[Project].update(cls._resolve_type(Project, projects))
|
entities[Project].update(cls._resolve_type(Project, projects))
|
||||||
print("--> Reading project experiments...")
|
print("--> Reading project experiments...")
|
||||||
objs = Task.objects(
|
query = Q(
|
||||||
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
|
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
|
||||||
|
system_tags__nin=[EntityVisibility.archived.value],
|
||||||
)
|
)
|
||||||
|
if task_statuses:
|
||||||
|
query &= Q(status__in=list(set(task_statuses)))
|
||||||
|
objs = Task.objects(query)
|
||||||
entities[Task].update(o for o in objs if o.id not in (experiments or []))
|
entities[Task].update(o for o in objs if o.id not in (experiments or []))
|
||||||
|
|
||||||
if experiments:
|
if experiments:
|
||||||
@ -69,85 +241,256 @@ class PrePopulate:
|
|||||||
project_ids = {p.id for p in entities[Project]}
|
project_ids = {p.id for p in entities[Project]}
|
||||||
entities[Project].update(o for o in objs if o.id not in project_ids)
|
entities[Project].update(o for o in objs if o.id not in project_ids)
|
||||||
|
|
||||||
|
model_ids = {
|
||||||
|
model_id
|
||||||
|
for task in entities[Task]
|
||||||
|
for model_id in (task.output.model, task.execution.model)
|
||||||
|
if model_id
|
||||||
|
}
|
||||||
|
if model_ids:
|
||||||
|
print("Reading models...")
|
||||||
|
entities[Model] = set(Model.objects(id__in=list(model_ids)))
|
||||||
|
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _cleanup_task(cls, task):
|
def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
|
||||||
from database.model.task.task import TaskStatus
|
if not tags:
|
||||||
|
return tags
|
||||||
|
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
|
||||||
|
|
||||||
task.completed = None
|
@classmethod
|
||||||
task.started = None
|
def _cleanup_model(cls, model: Model):
|
||||||
if task.execution:
|
model.company = ""
|
||||||
task.execution.model = None
|
model.user = ""
|
||||||
task.execution.model_desc = None
|
model.tags = cls._filter_out_export_tags(model.tags)
|
||||||
task.execution.model_labels = None
|
|
||||||
if task.output:
|
|
||||||
task.output.model = None
|
|
||||||
|
|
||||||
task.status = TaskStatus.created
|
@classmethod
|
||||||
|
def _cleanup_task(cls, task: Task):
|
||||||
task.comment = "Auto generated by Allegro.ai"
|
task.comment = "Auto generated by Allegro.ai"
|
||||||
task.created = datetime.utcnow()
|
|
||||||
task.last_iteration = 0
|
|
||||||
task.last_update = task.created
|
|
||||||
task.status_changed = task.created
|
|
||||||
task.status_message = ""
|
task.status_message = ""
|
||||||
task.status_reason = ""
|
task.status_reason = ""
|
||||||
task.user = ""
|
task.user = ""
|
||||||
|
task.company = ""
|
||||||
|
task.tags = cls._filter_out_export_tags(task.tags)
|
||||||
|
if task.output:
|
||||||
|
task.output.destination = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _cleanup_project(cls, project: Project):
|
||||||
|
project.user = ""
|
||||||
|
project.company = ""
|
||||||
|
project.tags = cls._filter_out_export_tags(project.tags)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _cleanup_entity(cls, entity_cls, entity):
|
def _cleanup_entity(cls, entity_cls, entity):
|
||||||
from database.model.task.task import Task
|
|
||||||
if entity_cls == Task:
|
if entity_cls == Task:
|
||||||
cls._cleanup_task(entity)
|
cls._cleanup_task(entity)
|
||||||
|
elif entity_cls == Model:
|
||||||
|
cls._cleanup_model(entity)
|
||||||
|
elif entity == Project:
|
||||||
|
cls._cleanup_project(entity)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
|
||||||
|
try:
|
||||||
|
for item in items:
|
||||||
|
item.update(upsert=False, tags=sorted(item.tags + [tag]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _export_task_events(
|
||||||
|
cls, task: Task, base_filename: str, writer: ZipFile, hash_
|
||||||
|
) -> Sequence[str]:
|
||||||
|
artifacts = []
|
||||||
|
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
|
||||||
|
print(f"Writing task events into {writer.filename}:{filename}")
|
||||||
|
with BytesIO() as f:
|
||||||
|
with cls.JsonLinesWriter(f) as w:
|
||||||
|
scroll_id = None
|
||||||
|
while True:
|
||||||
|
res = cls.event_bll.get_task_events(
|
||||||
|
task.company, task.id, scroll_id=scroll_id
|
||||||
|
)
|
||||||
|
if not res.events:
|
||||||
|
break
|
||||||
|
scroll_id = res.next_scroll_id
|
||||||
|
for event in res.events:
|
||||||
|
if event.get("type") == "training_debug_image":
|
||||||
|
url = cls._get_fixed_url(event.get("url"))
|
||||||
|
if url:
|
||||||
|
event["url"] = url
|
||||||
|
artifacts.append(url)
|
||||||
|
w.write(json.dumps(event))
|
||||||
|
data = f.getvalue()
|
||||||
|
hash_.update(data)
|
||||||
|
writer.writestr(filename, data)
|
||||||
|
|
||||||
|
return artifacts
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
|
||||||
|
if not (url and url.lower().startswith("s3://")):
|
||||||
|
return url
|
||||||
|
try:
|
||||||
|
fixed = furl(url)
|
||||||
|
fixed.scheme = "https"
|
||||||
|
fixed.host += ".s3.amazonaws.com"
|
||||||
|
return fixed.url
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Failed processing link {url}. " + str(ex))
|
||||||
|
return url
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _export_entity_related_data(
|
||||||
|
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
|
||||||
|
):
|
||||||
|
if entity_cls == Task:
|
||||||
|
return [
|
||||||
|
*cls._get_task_output_artifacts(entity),
|
||||||
|
*cls._export_task_events(entity, base_filename, writer, hash_),
|
||||||
|
]
|
||||||
|
|
||||||
|
if entity_cls == Model:
|
||||||
|
entity.uri = cls._get_fixed_url(entity.uri)
|
||||||
|
return [entity.uri] if entity.uri else []
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
|
||||||
|
if not task.execution.artifacts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
for a in task.execution.artifacts:
|
||||||
|
if a.mode == ArtifactModes.output:
|
||||||
|
a.uri = cls._get_fixed_url(a.uri)
|
||||||
|
|
||||||
|
return [
|
||||||
|
a.uri
|
||||||
|
for a in task.execution.artifacts
|
||||||
|
if a.mode == ArtifactModes.output and a.uri
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _export_artifacts(
|
||||||
|
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
|
||||||
|
):
|
||||||
|
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
|
||||||
|
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
|
||||||
|
for path in unique_paths:
|
||||||
|
path = path.lstrip("/")
|
||||||
|
full_path = os.path.join(artifacts_path, path)
|
||||||
|
if os.path.isfile(full_path):
|
||||||
|
writer.write(full_path, path)
|
||||||
|
else:
|
||||||
|
print(f"Artifact {full_path} not found")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _export(
|
def _export(
|
||||||
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
|
cls, writer: ZipFile, entities: dict, tag_entities: bool = False
|
||||||
):
|
) -> Tuple[Sequence[str], str]:
|
||||||
entities = cls._resolve_entities(experiments, projects)
|
"""
|
||||||
|
Export the requested experiments, projects and models and return the list of artifact files
|
||||||
for cls_, items in entities.items():
|
Always do the export on sorted items since the order of items influence hash
|
||||||
|
"""
|
||||||
|
artifacts = []
|
||||||
|
now = datetime.utcnow()
|
||||||
|
hash_ = hashlib.md5()
|
||||||
|
for cls_ in sorted(entities, key=attrgetter("__name__")):
|
||||||
|
items = sorted(entities[cls_], key=attrgetter("id"))
|
||||||
if not items:
|
if not items:
|
||||||
continue
|
continue
|
||||||
filename = f"{cls_.__module__}.{cls_.__name__}.json"
|
base_filename = f"{cls_.__module__}.{cls_.__name__}"
|
||||||
|
for item in items:
|
||||||
|
artifacts.extend(
|
||||||
|
cls._export_entity_related_data(
|
||||||
|
cls_, item, base_filename, writer, hash_
|
||||||
|
)
|
||||||
|
)
|
||||||
|
filename = base_filename + ".json"
|
||||||
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
||||||
with writer.open(filename, "w") as f:
|
with BytesIO() as f:
|
||||||
f.write("[\n".encode("utf-8"))
|
with cls.JsonLinesWriter(f) as w:
|
||||||
last = len(items) - 1
|
for item in items:
|
||||||
for i, item in enumerate(items):
|
|
||||||
cls._cleanup_entity(cls_, item)
|
cls._cleanup_entity(cls_, item)
|
||||||
f.write(item.to_json().encode("utf-8"))
|
w.write(item.to_json())
|
||||||
if i != last:
|
data = f.getvalue()
|
||||||
f.write(",".encode("utf-8"))
|
hash_.update(data)
|
||||||
f.write("\n".encode("utf-8"))
|
writer.writestr(filename, data)
|
||||||
f.write("]\n".encode("utf-8"))
|
|
||||||
|
if tag_entities:
|
||||||
|
cls._add_tag(items, now.strftime(cls.export_tag))
|
||||||
|
|
||||||
|
return artifacts, hash_.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _import(reader: ZipFile, user_id: str = None):
|
def json_lines(file: BinaryIO):
|
||||||
for file_info in reader.filelist:
|
for line in file:
|
||||||
|
clean = (
|
||||||
|
line.decode("utf-8")
|
||||||
|
.rstrip("\r\n")
|
||||||
|
.strip()
|
||||||
|
.lstrip("[")
|
||||||
|
.rstrip(",]")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
if not clean:
|
||||||
|
continue
|
||||||
|
yield clean
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _import(cls, reader: ZipFile, company_id: str = "", user_id: str = None):
|
||||||
|
"""
|
||||||
|
Import entities and events from the zip file
|
||||||
|
Start from entities since event import will require the tasks already in DB
|
||||||
|
"""
|
||||||
|
event_file_ending = cls.events_file_suffix + ".json"
|
||||||
|
entity_files = (
|
||||||
|
fi
|
||||||
|
for fi in reader.filelist
|
||||||
|
if not fi.orig_filename.endswith(event_file_ending)
|
||||||
|
)
|
||||||
|
event_files = (
|
||||||
|
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
|
||||||
|
)
|
||||||
|
for files, reader_func in (
|
||||||
|
(entity_files, cls._import_entity),
|
||||||
|
(event_files, cls._import_events),
|
||||||
|
):
|
||||||
|
for file_info in files:
|
||||||
|
with reader.open(file_info) as f:
|
||||||
full_name = splitext(file_info.orig_filename)[0]
|
full_name = splitext(file_info.orig_filename)[0]
|
||||||
print(f"Reading {reader.filename}:{full_name}...")
|
print(f"Reading {reader.filename}:{full_name}...")
|
||||||
|
reader_func(f, full_name, company_id, user_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _import_entity(cls, f: BinaryIO, full_name: str, company_id: str, user_id: str):
|
||||||
module_name, _, class_name = full_name.rpartition(".")
|
module_name, _, class_name = full_name.rpartition(".")
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||||
|
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||||
with reader.open(file_info) as f:
|
for item in cls.json_lines(f):
|
||||||
for item in tqdm(
|
doc = cls_.from_json(item, created=True)
|
||||||
f.readlines(),
|
if hasattr(doc, "user"):
|
||||||
desc=f"Writing {cls_.__name__.lower()}s into database",
|
|
||||||
unit="doc",
|
|
||||||
):
|
|
||||||
item = (
|
|
||||||
item.decode("utf-8")
|
|
||||||
.strip()
|
|
||||||
.lstrip("[")
|
|
||||||
.rstrip("]")
|
|
||||||
.rstrip(",")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
if not item:
|
|
||||||
continue
|
|
||||||
doc = cls_.from_json(item)
|
|
||||||
if user_id is not None and hasattr(doc, "user"):
|
|
||||||
doc.user = user_id
|
doc.user = user_id
|
||||||
doc.save(force_insert=True)
|
if hasattr(doc, "company"):
|
||||||
|
doc.company = company_id
|
||||||
|
if isinstance(doc, Project):
|
||||||
|
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
|
||||||
|
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||||
|
)
|
||||||
|
doc.save()
|
||||||
|
if isinstance(doc, Task):
|
||||||
|
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
|
||||||
|
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
|
||||||
|
print(f"Writing events for task {task_id} into database")
|
||||||
|
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
||||||
|
events = [json.loads(item) for item in events_chunk]
|
||||||
|
cls.event_bll.add_events(
|
||||||
|
company_id, events=events, worker="", allow_locked_tasks=True
|
||||||
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import atexit
|
import atexit
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
from hashlib import md5
|
||||||
|
|
||||||
from flask import Flask, request, Response
|
from flask import Flask, request, Response
|
||||||
from flask_compress import Compress
|
from flask_compress import Compress
|
||||||
@ -11,10 +12,11 @@ from apierrors.base import BaseError
|
|||||||
from bll.statistics.stats_reporter import StatisticsReporter
|
from bll.statistics.stats_reporter import StatisticsReporter
|
||||||
from config import config
|
from config import config
|
||||||
from elastic.initialize import init_es_data
|
from elastic.initialize import init_es_data
|
||||||
from mongo.initialize import init_mongo_data
|
from mongo.initialize import init_mongo_data, pre_populate_data
|
||||||
from service_repo import ServiceRepo, APICall
|
from service_repo import ServiceRepo, APICall
|
||||||
from service_repo.auth import AuthType
|
from service_repo.auth import AuthType
|
||||||
from service_repo.errors import PathParsingError
|
from service_repo.errors import PathParsingError
|
||||||
|
from sync import distributed_lock
|
||||||
from timing_context import TimingContext
|
from timing_context import TimingContext
|
||||||
from updates import check_updates_thread
|
from updates import check_updates_thread
|
||||||
from utilities import json
|
from utilities import json
|
||||||
@ -33,8 +35,16 @@ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
|
|||||||
|
|
||||||
database.initialize()
|
database.initialize()
|
||||||
|
|
||||||
|
# build a key that uniquely identifies specific mongo instance
|
||||||
|
hosts_string = ";".join(sorted(database.get_hosts()))
|
||||||
|
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
|
||||||
|
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)):
|
||||||
|
print(key)
|
||||||
init_es_data()
|
init_es_data()
|
||||||
init_mongo_data()
|
empty_db = init_mongo_data()
|
||||||
|
if empty_db and config.get("apiserver.pre_populate.enabled", False):
|
||||||
|
pre_populate_data()
|
||||||
|
|
||||||
|
|
||||||
ServiceRepo.load("services")
|
ServiceRepo.load("services")
|
||||||
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||||
|
@ -185,6 +185,7 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None
|
|||||||
def get_all_ex(call: APICall):
|
def get_all_ex(call: APICall):
|
||||||
include_stats = call.data.get("include_stats")
|
include_stats = call.data.get("include_stats")
|
||||||
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
||||||
|
allow_public = not call.data.get("non_public", False)
|
||||||
|
|
||||||
if stats_for_state:
|
if stats_for_state:
|
||||||
try:
|
try:
|
||||||
@ -200,7 +201,7 @@ def get_all_ex(call: APICall):
|
|||||||
company=call.identity.company,
|
company=call.identity.company,
|
||||||
query_dict=call.data,
|
query_dict=call.data,
|
||||||
query_options=get_all_query_options,
|
query_options=get_all_query_options,
|
||||||
allow_public=True,
|
allow_public=allow_public,
|
||||||
)
|
)
|
||||||
conform_output_tags(call, projects)
|
conform_output_tags(call, projects)
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ from database.model.task.task import (
|
|||||||
)
|
)
|
||||||
from database.utils import get_fields, parse_from_call
|
from database.utils import get_fields, parse_from_call
|
||||||
from service_repo import APICall, endpoint
|
from service_repo import APICall, endpoint
|
||||||
from services.utils import conform_tag_fields, conform_output_tags
|
from services.utils import conform_tag_fields, conform_output_tags, validate_tags
|
||||||
from timing_context import TimingContext
|
from timing_context import TimingContext
|
||||||
from utilities import safe_get
|
from utilities import safe_get
|
||||||
|
|
||||||
@ -377,6 +377,7 @@ def create(call: APICall, company_id, req_model: CreateRequest):
|
|||||||
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
|
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
|
||||||
)
|
)
|
||||||
def clone_task(call: APICall, company_id, request: CloneRequest):
|
def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||||
|
validate_tags(request.new_task_tags, request.new_task_system_tags)
|
||||||
task = task_bll.clone_task(
|
task = task_bll.clone_task(
|
||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
user_id=call.identity.user,
|
user_id=call.identity.user,
|
||||||
|
28
server/sync.py
Normal file
28
server/sync.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
from redis_manager import redman
|
||||||
|
|
||||||
|
_redis = redman.connection("apiserver")
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def distributed_lock(name: str, timeout: int, max_wait: int = 0):
|
||||||
|
"""
|
||||||
|
Context manager that acquires a distributed lock on enter
|
||||||
|
and releases it on exit. The has a ttl equal to timeout seconds
|
||||||
|
If the lock can not be acquired for wait seconds (defaults to timeout * 2)
|
||||||
|
then the exception is thrown
|
||||||
|
"""
|
||||||
|
lock_name = f"dist_lock_{name}"
|
||||||
|
start = time.time()
|
||||||
|
max_wait = max_wait or timeout * 2
|
||||||
|
while not _redis.set(lock_name, value="", ex=timeout, nx=True):
|
||||||
|
sleep(1)
|
||||||
|
if time.time() - start > max_wait:
|
||||||
|
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds")
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_redis.delete(lock_name)
|
Loading…
Reference in New Issue
Block a user