mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 18:54:20 +00:00
Improve pre-populate on server startup (including sync lock)
This commit is contained in:
parent
21f2ea8b17
commit
901ec37290
@ -10,6 +10,7 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
@ -23,8 +24,9 @@ services:
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
TRAINS__apiserver__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
|
||||
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
|
@ -552,7 +552,7 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
return TaskEventsResult(
|
||||
|
@ -32,7 +32,6 @@ from database.model.task.task import (
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from services.utils import validate_tags
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import deep_merge
|
||||
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
||||
@ -182,7 +181,6 @@ class TaskBLL(object):
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
validate_tags(tags, system_tags)
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
|
@ -26,6 +26,13 @@
|
||||
check_max_version: false
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_files: ["/path/to/export.zip"]
|
||||
fail_on_error: false
|
||||
artifacts_path: "/mnt/fileserver"
|
||||
}
|
||||
|
||||
mongo {
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
@ -34,12 +41,6 @@
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_file: "/path/to/export.zip"
|
||||
fail_on_error: false
|
||||
}
|
||||
}
|
||||
|
||||
auth {
|
||||
|
@ -85,6 +85,7 @@ class Artifact(EmbeddedDocument):
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field="Model")
|
||||
|
@ -1,6 +1,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model.auth import Role
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
from .migration import _apply_migrations
|
||||
@ -11,7 +13,48 @@ from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
|
||||
log = config.logger(__package__)
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
def _pre_populate(company_id: str, zip_file: str):
|
||||
if not zip_file or not Path(zip_file).is_file():
|
||||
msg = f"Invalid pre-populate zip file: {zip_file}"
|
||||
if config.get("apiserver.pre_populate.fail_on_error", False):
|
||||
log.error(msg)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
log.info(f"Pre-populating using {zip_file}")
|
||||
|
||||
user_id = _ensure_backend_user(
|
||||
"__allegroai__", company_id, "Allegro.ai"
|
||||
)
|
||||
|
||||
PrePopulate.import_from_zip(
|
||||
zip_file,
|
||||
company_id="",
|
||||
user_id=user_id,
|
||||
artifacts_path=config.get(
|
||||
"apiserver.pre_populate.artifacts_path", None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]:
|
||||
if isinstance(zip_files, str):
|
||||
zip_files = [zip_files]
|
||||
for p in map(Path, zip_files):
|
||||
if p.is_file():
|
||||
yield p
|
||||
if p.is_dir():
|
||||
yield from p.glob("*.zip")
|
||||
log.warning(f"Invalid pre-populate entry {str(p)}, skipping")
|
||||
|
||||
|
||||
def pre_populate_data():
|
||||
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
|
||||
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
|
||||
|
||||
|
||||
def init_mongo_data() -> bool:
|
||||
try:
|
||||
empty_dbs = _apply_migrations(log)
|
||||
|
||||
@ -21,23 +64,6 @@ def init_mongo_data():
|
||||
|
||||
_ensure_default_queue(company_id)
|
||||
|
||||
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
|
||||
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
|
||||
if not zip_file or not Path(zip_file).is_file():
|
||||
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
|
||||
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
|
||||
log.error(msg)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
|
||||
user_id = _ensure_backend_user(
|
||||
"__allegroai__", company_id, "Allegro.ai"
|
||||
)
|
||||
|
||||
PrePopulate.import_from_zip(zip_file, user_id=user_id)
|
||||
|
||||
fixed_mode = FixedUser.enabled()
|
||||
|
||||
for user, credentials in config.get("secure.credentials", {}).items():
|
||||
@ -61,5 +87,7 @@ def init_mongo_data():
|
||||
ensure_fixed_user(user, company_id, log=log)
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
||||
|
||||
return empty_dbs
|
||||
except Exception as ex:
|
||||
log.exception("Failed initializing mongodb")
|
||||
|
@ -1,31 +1,199 @@
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from io import BytesIO
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from os.path import splitext
|
||||
from typing import List, Optional, Any, Type, Set, Dict
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, Type, Set, Dict, Sequence, Tuple, BinaryIO, Union
|
||||
from urllib.parse import unquote, urlparse
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import attr
|
||||
import mongoengine
|
||||
from tqdm import tqdm
|
||||
from boltons.iterutils import chunked_iter
|
||||
from furl import furl
|
||||
from mongoengine import Q
|
||||
|
||||
from bll.event import EventBLL
|
||||
from database.model import EntityVisibility
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
from database.utils import get_options
|
||||
from utilities import json
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@classmethod
|
||||
def export_to_zip(
|
||||
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
|
||||
cls._export(zfile, experiments, projects)
|
||||
event_bll = EventBLL()
|
||||
events_file_suffix = "_events"
|
||||
export_tag_prefix = "Exported:"
|
||||
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
|
||||
|
||||
class JsonLinesWriter:
|
||||
def __init__(self, file: BinaryIO):
|
||||
self.file = file
|
||||
self.empty = True
|
||||
|
||||
def __enter__(self):
|
||||
self._write("[")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self._write("\n]")
|
||||
|
||||
def _write(self, data: str):
|
||||
self.file.write(data.encode("utf-8"))
|
||||
|
||||
def write(self, line: str):
|
||||
if not self.empty:
|
||||
self._write(",")
|
||||
self._write("\n" + line)
|
||||
self.empty = False
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class _MapData:
|
||||
files: Sequence[str] = None
|
||||
entities: Dict[str, datetime] = None
|
||||
|
||||
@staticmethod
|
||||
def _get_last_update_time(entity) -> datetime:
|
||||
return getattr(entity, "last_update", None) or getattr(entity, "created")
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(cls, filename: str, user_id: str = None):
|
||||
def _check_for_update(
|
||||
cls, map_file: Path, entities: dict
|
||||
) -> Tuple[bool, Sequence[str]]:
|
||||
if not map_file.is_file():
|
||||
return True, []
|
||||
|
||||
files = []
|
||||
try:
|
||||
map_data = cls._MapData(**json.loads(map_file.read_text()))
|
||||
files = map_data.files
|
||||
for file in files:
|
||||
if not Path(file).is_file():
|
||||
return True, files
|
||||
|
||||
new_times = {
|
||||
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
|
||||
for item in chain.from_iterable(entities.values())
|
||||
}
|
||||
old_times = map_data.entities
|
||||
|
||||
if set(new_times.keys()) != set(old_times.keys()):
|
||||
return True, files
|
||||
|
||||
for id_, new_timestamp in new_times.items():
|
||||
if new_timestamp != old_times[id_]:
|
||||
return True, files
|
||||
except Exception as ex:
|
||||
print("Error reading map file. " + str(ex))
|
||||
return True, files
|
||||
|
||||
return False, files
|
||||
|
||||
@classmethod
|
||||
def _write_update_file(
|
||||
cls, map_file: Path, entities: dict, created_files: Sequence[str]
|
||||
):
|
||||
map_data = cls._MapData(
|
||||
files=created_files,
|
||||
entities={
|
||||
entity.id: cls._get_last_update_time(entity)
|
||||
for entity in chain.from_iterable(entities.values())
|
||||
},
|
||||
)
|
||||
map_file.write_text(json.dumps(attr.asdict(map_data)))
|
||||
|
||||
@staticmethod
|
||||
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
|
||||
def is_fileserver_link(a: str) -> bool:
|
||||
a = a.lower()
|
||||
if a.startswith("https://files."):
|
||||
return True
|
||||
if a.startswith("http"):
|
||||
parsed = urlparse(a)
|
||||
if parsed.scheme in {"http", "https"} and parsed.port == 8081:
|
||||
return True
|
||||
return False
|
||||
|
||||
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
|
||||
print(
|
||||
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
|
||||
)
|
||||
|
||||
return fileserver_links
|
||||
|
||||
@classmethod
|
||||
def export_to_zip(
|
||||
cls,
|
||||
filename: str,
|
||||
experiments: Sequence[str] = None,
|
||||
projects: Sequence[str] = None,
|
||||
artifacts_path: str = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
tag_exported_entities: bool = False,
|
||||
) -> Sequence[str]:
|
||||
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
|
||||
raise ValueError("Invalid task statuses")
|
||||
|
||||
file = Path(filename)
|
||||
entities = cls._resolve_entities(
|
||||
experiments=experiments, projects=projects, task_statuses=task_statuses
|
||||
)
|
||||
|
||||
map_file = file.with_suffix(".map")
|
||||
updated, old_files = cls._check_for_update(map_file, entities)
|
||||
if not updated:
|
||||
print(f"There are no updates from the last export")
|
||||
return old_files
|
||||
for old in old_files:
|
||||
old_path = Path(old)
|
||||
if old_path.is_file():
|
||||
old_path.unlink()
|
||||
|
||||
zip_args = dict(mode="w", compression=ZIP_BZIP2)
|
||||
with ZipFile(file, **zip_args) as zfile:
|
||||
artifacts, hash_ = cls._export(
|
||||
zfile, entities, tag_entities=tag_exported_entities
|
||||
)
|
||||
file_with_hash = file.with_name(f"{file.stem}_{hash_}{file.suffix}")
|
||||
file.replace(file_with_hash)
|
||||
created_files = [str(file_with_hash)]
|
||||
|
||||
artifacts = cls._filter_artifacts(artifacts)
|
||||
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = file_with_hash.with_suffix(".artifacts")
|
||||
with ZipFile(artifacts_file, **zip_args) as zfile:
|
||||
cls._export_artifacts(zfile, artifacts, artifacts_path)
|
||||
created_files.append(str(artifacts_file))
|
||||
|
||||
cls._write_update_file(map_file, entities, created_files)
|
||||
|
||||
return created_files
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(
|
||||
cls, filename: str, company_id: str, user_id: str, artifacts_path: str
|
||||
):
|
||||
with ZipFile(filename) as zfile:
|
||||
cls._import(zfile, user_id)
|
||||
cls._import(zfile, company_id, user_id)
|
||||
|
||||
if artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = Path(filename).with_suffix(".artifacts")
|
||||
if artifacts_file.is_file():
|
||||
print(f"Unzipping artifacts into {artifacts_path}")
|
||||
with ZipFile(artifacts_file) as zfile:
|
||||
zfile.extractall(artifacts_path)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_type(
|
||||
cls: Type[mongoengine.Document], ids: Optional[List[str]]
|
||||
) -> List[Any]:
|
||||
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
|
||||
) -> Sequence[Any]:
|
||||
ids = set(ids)
|
||||
items = list(cls.objects(id__in=list(ids)))
|
||||
resolved = {i.id for i in items}
|
||||
@ -43,20 +211,24 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _resolve_entities(
|
||||
cls, experiments: List[str] = None, projects: List[str] = None
|
||||
cls,
|
||||
experiments: Sequence[str] = None,
|
||||
projects: Sequence[str] = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task
|
||||
|
||||
entities = defaultdict(set)
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
entities[Project].update(cls._resolve_type(Project, projects))
|
||||
print("--> Reading project experiments...")
|
||||
objs = Task.objects(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
|
||||
query = Q(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
)
|
||||
if task_statuses:
|
||||
query &= Q(status__in=list(set(task_statuses)))
|
||||
objs = Task.objects(query)
|
||||
entities[Task].update(o for o in objs if o.id not in (experiments or []))
|
||||
|
||||
if experiments:
|
||||
@ -69,85 +241,256 @@ class PrePopulate:
|
||||
project_ids = {p.id for p in entities[Project]}
|
||||
entities[Project].update(o for o in objs if o.id not in project_ids)
|
||||
|
||||
model_ids = {
|
||||
model_id
|
||||
for task in entities[Task]
|
||||
for model_id in (task.output.model, task.execution.model)
|
||||
if model_id
|
||||
}
|
||||
if model_ids:
|
||||
print("Reading models...")
|
||||
entities[Model] = set(Model.objects(id__in=list(model_ids)))
|
||||
|
||||
return entities
|
||||
|
||||
@classmethod
|
||||
def _cleanup_task(cls, task):
|
||||
from database.model.task.task import TaskStatus
|
||||
def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
|
||||
if not tags:
|
||||
return tags
|
||||
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
|
||||
|
||||
task.completed = None
|
||||
task.started = None
|
||||
if task.execution:
|
||||
task.execution.model = None
|
||||
task.execution.model_desc = None
|
||||
task.execution.model_labels = None
|
||||
if task.output:
|
||||
task.output.model = None
|
||||
@classmethod
|
||||
def _cleanup_model(cls, model: Model):
|
||||
model.company = ""
|
||||
model.user = ""
|
||||
model.tags = cls._filter_out_export_tags(model.tags)
|
||||
|
||||
task.status = TaskStatus.created
|
||||
@classmethod
|
||||
def _cleanup_task(cls, task: Task):
|
||||
task.comment = "Auto generated by Allegro.ai"
|
||||
task.created = datetime.utcnow()
|
||||
task.last_iteration = 0
|
||||
task.last_update = task.created
|
||||
task.status_changed = task.created
|
||||
task.status_message = ""
|
||||
task.status_reason = ""
|
||||
task.user = ""
|
||||
task.company = ""
|
||||
task.tags = cls._filter_out_export_tags(task.tags)
|
||||
if task.output:
|
||||
task.output.destination = None
|
||||
|
||||
@classmethod
|
||||
def _cleanup_project(cls, project: Project):
|
||||
project.user = ""
|
||||
project.company = ""
|
||||
project.tags = cls._filter_out_export_tags(project.tags)
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity):
|
||||
from database.model.task.task import Task
|
||||
if entity_cls == Task:
|
||||
cls._cleanup_task(entity)
|
||||
elif entity_cls == Model:
|
||||
cls._cleanup_model(entity)
|
||||
elif entity == Project:
|
||||
cls._cleanup_project(entity)
|
||||
|
||||
@classmethod
|
||||
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
|
||||
try:
|
||||
for item in items:
|
||||
item.update(upsert=False, tags=sorted(item.tags + [tag]))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _export_task_events(
|
||||
cls, task: Task, base_filename: str, writer: ZipFile, hash_
|
||||
) -> Sequence[str]:
|
||||
artifacts = []
|
||||
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
|
||||
print(f"Writing task events into {writer.filename}:{filename}")
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
scroll_id = None
|
||||
while True:
|
||||
res = cls.event_bll.get_task_events(
|
||||
task.company, task.id, scroll_id=scroll_id
|
||||
)
|
||||
if not res.events:
|
||||
break
|
||||
scroll_id = res.next_scroll_id
|
||||
for event in res.events:
|
||||
if event.get("type") == "training_debug_image":
|
||||
url = cls._get_fixed_url(event.get("url"))
|
||||
if url:
|
||||
event["url"] = url
|
||||
artifacts.append(url)
|
||||
w.write(json.dumps(event))
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
writer.writestr(filename, data)
|
||||
|
||||
return artifacts
|
||||
|
||||
@staticmethod
|
||||
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
|
||||
if not (url and url.lower().startswith("s3://")):
|
||||
return url
|
||||
try:
|
||||
fixed = furl(url)
|
||||
fixed.scheme = "https"
|
||||
fixed.host += ".s3.amazonaws.com"
|
||||
return fixed.url
|
||||
except Exception as ex:
|
||||
print(f"Failed processing link {url}. " + str(ex))
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def _export_entity_related_data(
|
||||
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
|
||||
):
|
||||
if entity_cls == Task:
|
||||
return [
|
||||
*cls._get_task_output_artifacts(entity),
|
||||
*cls._export_task_events(entity, base_filename, writer, hash_),
|
||||
]
|
||||
|
||||
if entity_cls == Model:
|
||||
entity.uri = cls._get_fixed_url(entity.uri)
|
||||
return [entity.uri] if entity.uri else []
|
||||
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
|
||||
if not task.execution.artifacts:
|
||||
return []
|
||||
|
||||
for a in task.execution.artifacts:
|
||||
if a.mode == ArtifactModes.output:
|
||||
a.uri = cls._get_fixed_url(a.uri)
|
||||
|
||||
return [
|
||||
a.uri
|
||||
for a in task.execution.artifacts
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _export_artifacts(
|
||||
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
|
||||
):
|
||||
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
|
||||
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
|
||||
for path in unique_paths:
|
||||
path = path.lstrip("/")
|
||||
full_path = os.path.join(artifacts_path, path)
|
||||
if os.path.isfile(full_path):
|
||||
writer.write(full_path, path)
|
||||
else:
|
||||
print(f"Artifact {full_path} not found")
|
||||
|
||||
@classmethod
|
||||
def _export(
|
||||
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
entities = cls._resolve_entities(experiments, projects)
|
||||
|
||||
for cls_, items in entities.items():
|
||||
cls, writer: ZipFile, entities: dict, tag_entities: bool = False
|
||||
) -> Tuple[Sequence[str], str]:
|
||||
"""
|
||||
Export the requested experiments, projects and models and return the list of artifact files
|
||||
Always do the export on sorted items since the order of items influence hash
|
||||
"""
|
||||
artifacts = []
|
||||
now = datetime.utcnow()
|
||||
hash_ = hashlib.md5()
|
||||
for cls_ in sorted(entities, key=attrgetter("__name__")):
|
||||
items = sorted(entities[cls_], key=attrgetter("id"))
|
||||
if not items:
|
||||
continue
|
||||
filename = f"{cls_.__module__}.{cls_.__name__}.json"
|
||||
base_filename = f"{cls_.__module__}.{cls_.__name__}"
|
||||
for item in items:
|
||||
artifacts.extend(
|
||||
cls._export_entity_related_data(
|
||||
cls_, item, base_filename, writer, hash_
|
||||
)
|
||||
)
|
||||
filename = base_filename + ".json"
|
||||
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
||||
with writer.open(filename, "w") as f:
|
||||
f.write("[\n".encode("utf-8"))
|
||||
last = len(items) - 1
|
||||
for i, item in enumerate(items):
|
||||
cls._cleanup_entity(cls_, item)
|
||||
f.write(item.to_json().encode("utf-8"))
|
||||
if i != last:
|
||||
f.write(",".encode("utf-8"))
|
||||
f.write("\n".encode("utf-8"))
|
||||
f.write("]\n".encode("utf-8"))
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
for item in items:
|
||||
cls._cleanup_entity(cls_, item)
|
||||
w.write(item.to_json())
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
writer.writestr(filename, data)
|
||||
|
||||
if tag_entities:
|
||||
cls._add_tag(items, now.strftime(cls.export_tag))
|
||||
|
||||
return artifacts, hash_.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _import(reader: ZipFile, user_id: str = None):
|
||||
for file_info in reader.filelist:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
def json_lines(file: BinaryIO):
|
||||
for line in file:
|
||||
clean = (
|
||||
line.decode("utf-8")
|
||||
.rstrip("\r\n")
|
||||
.strip()
|
||||
.lstrip("[")
|
||||
.rstrip(",]")
|
||||
.strip()
|
||||
)
|
||||
if not clean:
|
||||
continue
|
||||
yield clean
|
||||
|
||||
with reader.open(file_info) as f:
|
||||
for item in tqdm(
|
||||
f.readlines(),
|
||||
desc=f"Writing {cls_.__name__.lower()}s into database",
|
||||
unit="doc",
|
||||
):
|
||||
item = (
|
||||
item.decode("utf-8")
|
||||
.strip()
|
||||
.lstrip("[")
|
||||
.rstrip("]")
|
||||
.rstrip(",")
|
||||
.strip()
|
||||
)
|
||||
if not item:
|
||||
continue
|
||||
doc = cls_.from_json(item)
|
||||
if user_id is not None and hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
doc.save(force_insert=True)
|
||||
@classmethod
|
||||
def _import(cls, reader: ZipFile, company_id: str = "", user_id: str = None):
|
||||
"""
|
||||
Import entities and events from the zip file
|
||||
Start from entities since event import will require the tasks already in DB
|
||||
"""
|
||||
event_file_ending = cls.events_file_suffix + ".json"
|
||||
entity_files = (
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if not fi.orig_filename.endswith(event_file_ending)
|
||||
)
|
||||
event_files = (
|
||||
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
|
||||
)
|
||||
for files, reader_func in (
|
||||
(entity_files, cls._import_entity),
|
||||
(event_files, cls._import_events),
|
||||
):
|
||||
for file_info in files:
|
||||
with reader.open(file_info) as f:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
reader_func(f, full_name, company_id, user_id)
|
||||
|
||||
@classmethod
|
||||
def _import_entity(cls, f: BinaryIO, full_name: str, company_id: str, user_id: str):
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
for item in cls.json_lines(f):
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
if hasattr(doc, "company"):
|
||||
doc.company = company_id
|
||||
if isinstance(doc, Project):
|
||||
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
|
||||
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||
)
|
||||
doc.save()
|
||||
if isinstance(doc, Task):
|
||||
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
|
||||
|
||||
@classmethod
|
||||
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
|
||||
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
|
||||
print(f"Writing events for task {task_id} into database")
|
||||
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
||||
events = [json.loads(item) for item in events_chunk]
|
||||
cls.event_bll.add_events(
|
||||
company_id, events=events, worker="", allow_locked_tasks=True
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import atexit
|
||||
from argparse import ArgumentParser
|
||||
from hashlib import md5
|
||||
|
||||
from flask import Flask, request, Response
|
||||
from flask_compress import Compress
|
||||
@ -11,10 +12,11 @@ from apierrors.base import BaseError
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config
|
||||
from elastic.initialize import init_es_data
|
||||
from mongo.initialize import init_mongo_data
|
||||
from mongo.initialize import init_mongo_data, pre_populate_data
|
||||
from service_repo import ServiceRepo, APICall
|
||||
from service_repo.auth import AuthType
|
||||
from service_repo.errors import PathParsingError
|
||||
from sync import distributed_lock
|
||||
from timing_context import TimingContext
|
||||
from updates import check_updates_thread
|
||||
from utilities import json
|
||||
@ -33,8 +35,16 @@ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
|
||||
|
||||
database.initialize()
|
||||
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
# build a key that uniquely identifies specific mongo instance
|
||||
hosts_string = ";".join(sorted(database.get_hosts()))
|
||||
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
|
||||
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)):
|
||||
print(key)
|
||||
init_es_data()
|
||||
empty_db = init_mongo_data()
|
||||
if empty_db and config.get("apiserver.pre_populate.enabled", False):
|
||||
pre_populate_data()
|
||||
|
||||
|
||||
ServiceRepo.load("services")
|
||||
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||
|
@ -185,6 +185,7 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None
|
||||
def get_all_ex(call: APICall):
|
||||
include_stats = call.data.get("include_stats")
|
||||
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
||||
allow_public = not call.data.get("non_public", False)
|
||||
|
||||
if stats_for_state:
|
||||
try:
|
||||
@ -200,7 +201,7 @@ def get_all_ex(call: APICall):
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=True,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
|
||||
|
@ -56,7 +56,7 @@ from database.model.task.task import (
|
||||
)
|
||||
from database.utils import get_fields, parse_from_call
|
||||
from service_repo import APICall, endpoint
|
||||
from services.utils import conform_tag_fields, conform_output_tags
|
||||
from services.utils import conform_tag_fields, conform_output_tags, validate_tags
|
||||
from timing_context import TimingContext
|
||||
from utilities import safe_get
|
||||
|
||||
@ -377,6 +377,7 @@ def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
|
||||
)
|
||||
def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
validate_tags(request.new_task_tags, request.new_task_system_tags)
|
||||
task = task_bll.clone_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
|
28
server/sync.py
Normal file
28
server/sync.py
Normal file
@ -0,0 +1,28 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from time import sleep
|
||||
|
||||
from redis_manager import redman
|
||||
|
||||
_redis = redman.connection("apiserver")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def distributed_lock(name: str, timeout: int, max_wait: int = 0):
|
||||
"""
|
||||
Context manager that acquires a distributed lock on enter
|
||||
and releases it on exit. The has a ttl equal to timeout seconds
|
||||
If the lock can not be acquired for wait seconds (defaults to timeout * 2)
|
||||
then the exception is thrown
|
||||
"""
|
||||
lock_name = f"dist_lock_{name}"
|
||||
start = time.time()
|
||||
max_wait = max_wait or timeout * 2
|
||||
while not _redis.set(lock_name, value="", ex=timeout, nx=True):
|
||||
sleep(1)
|
||||
if time.time() - start > max_wait:
|
||||
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_redis.delete(lock_name)
|
Loading…
Reference in New Issue
Block a user