Improve pre-populate on server startup (including sync lock)

This commit is contained in:
allegroai 2020-07-06 22:05:36 +03:00
parent 21f2ea8b17
commit 901ec37290
11 changed files with 525 additions and 112 deletions

View File

@ -10,6 +10,7 @@ services:
volumes: volumes:
- /opt/trains/logs:/var/log/trains - /opt/trains/logs:/var/log/trains
- /opt/trains/config:/opt/trains/config - /opt/trains/config:/opt/trains/config
- /opt/trains/data/fileserver:/mnt/fileserver
depends_on: depends_on:
- redis - redis
- mongo - mongo
@ -23,8 +24,9 @@ services:
TRAINS_REDIS_SERVICE_HOST: redis TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379 TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux} TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
TRAINS__apiserver__mongo__pre_populate__enabled: "true" TRAINS__apiserver__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip" TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
ports: ports:
- "8008:8008" - "8008:8008"
networks: networks:

View File

@ -552,7 +552,7 @@ class EventBLL(object):
) )
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])] events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
next_scroll_id = es_res["_scroll_id"] next_scroll_id = es_res.get("_scroll_id")
total_events = es_res["hits"]["total"] total_events = es_res["hits"]["total"]
return TaskEventsResult( return TaskEventsResult(

View File

@ -32,7 +32,6 @@ from database.model.task.task import (
) )
from database.utils import get_company_or_none_constraint, id as create_id from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall from service_repo import APICall
from services.utils import validate_tags
from timing_context import TimingContext from timing_context import TimingContext
from utilities.dicts import deep_merge from utilities.dicts import deep_merge
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
@ -182,7 +181,6 @@ class TaskBLL(object):
execution_overrides: Optional[dict] = None, execution_overrides: Optional[dict] = None,
validate_references: bool = False, validate_references: bool = False,
) -> Task: ) -> Task:
validate_tags(tags, system_tags)
task = cls.get_by_id(company_id=company_id, task_id=task_id) task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {} execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False execution_model_overriden = False

View File

@ -26,6 +26,13 @@
check_max_version: false check_max_version: false
} }
pre_populate {
enabled: false
zip_files: ["/path/to/export.zip"]
fail_on_error: false
artifacts_path: "/mnt/fileserver"
}
mongo { mongo {
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data # controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model # but not declared in a data model
@ -34,12 +41,6 @@
aggregate { aggregate {
allow_disk_use: true allow_disk_use: true
} }
pre_populate {
enabled: false
zip_file: "/path/to/export.zip"
fail_on_error: false
}
} }
auth { auth {

View File

@ -85,6 +85,7 @@ class Artifact(EmbeddedDocument):
class Execution(EmbeddedDocument, ProperDictMixin): class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0) test_split = IntField(default=0)
parameters = SafeDictField(default=dict) parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model") model = StringField(reference_field="Model")

View File

@ -1,6 +1,8 @@
from pathlib import Path from pathlib import Path
from typing import Sequence, Union
from config import config from config import config
from config.info import get_default_company
from database.model.auth import Role from database.model.auth import Role
from service_repo.auth.fixed_user import FixedUser from service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations from .migration import _apply_migrations
@ -11,7 +13,48 @@ from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
log = config.logger(__package__) log = config.logger(__package__)
def init_mongo_data(): def _pre_populate(company_id: str, zip_file: str):
if not zip_file or not Path(zip_file).is_file():
msg = f"Invalid pre-populate zip file: {zip_file}"
if config.get("apiserver.pre_populate.fail_on_error", False):
log.error(msg)
raise ValueError(msg)
else:
log.warning(msg)
else:
log.info(f"Pre-populating using {zip_file}")
user_id = _ensure_backend_user(
"__allegroai__", company_id, "Allegro.ai"
)
PrePopulate.import_from_zip(
zip_file,
company_id="",
user_id=user_id,
artifacts_path=config.get(
"apiserver.pre_populate.artifacts_path", None
),
)
def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]:
if isinstance(zip_files, str):
zip_files = [zip_files]
for p in map(Path, zip_files):
if p.is_file():
yield p
if p.is_dir():
yield from p.glob("*.zip")
log.warning(f"Invalid pre-populate entry {str(p)}, skipping")
def pre_populate_data():
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
def init_mongo_data() -> bool:
try: try:
empty_dbs = _apply_migrations(log) empty_dbs = _apply_migrations(log)
@ -21,23 +64,6 @@ def init_mongo_data():
_ensure_default_queue(company_id) _ensure_default_queue(company_id)
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
if not zip_file or not Path(zip_file).is_file():
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
log.error(msg)
raise ValueError(msg)
else:
log.warning(msg)
else:
user_id = _ensure_backend_user(
"__allegroai__", company_id, "Allegro.ai"
)
PrePopulate.import_from_zip(zip_file, user_id=user_id)
fixed_mode = FixedUser.enabled() fixed_mode = FixedUser.enabled()
for user, credentials in config.get("secure.credentials", {}).items(): for user, credentials in config.get("secure.credentials", {}).items():
@ -61,5 +87,7 @@ def init_mongo_data():
ensure_fixed_user(user, company_id, log=log) ensure_fixed_user(user, company_id, log=log)
except Exception as ex: except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}") log.error(f"Failed creating fixed user {user.name}: {ex}")
return empty_dbs
except Exception as ex: except Exception as ex:
log.exception("Failed initializing mongodb") log.exception("Failed initializing mongodb")

View File

@ -1,31 +1,199 @@
import hashlib
import importlib import importlib
import os
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime, timezone
from io import BytesIO
from itertools import chain
from operator import attrgetter
from os.path import splitext from os.path import splitext
from typing import List, Optional, Any, Type, Set, Dict from pathlib import Path
from typing import Optional, Any, Type, Set, Dict, Sequence, Tuple, BinaryIO, Union
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2 from zipfile import ZipFile, ZIP_BZIP2
import attr
import mongoengine import mongoengine
from tqdm import tqdm from boltons.iterutils import chunked_iter
from furl import furl
from mongoengine import Q
from bll.event import EventBLL
from database.model import EntityVisibility
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, ArtifactModes, TaskStatus
from database.utils import get_options
from utilities import json
class PrePopulate: class PrePopulate:
@classmethod event_bll = EventBLL()
def export_to_zip( events_file_suffix = "_events"
cls, filename: str, experiments: List[str] = None, projects: List[str] = None export_tag_prefix = "Exported:"
): export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
cls._export(zfile, experiments, projects) class JsonLinesWriter:
def __init__(self, file: BinaryIO):
self.file = file
self.empty = True
def __enter__(self):
self._write("[")
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self._write("\n]")
def _write(self, data: str):
self.file.write(data.encode("utf-8"))
def write(self, line: str):
if not self.empty:
self._write(",")
self._write("\n" + line)
self.empty = False
@attr.s(auto_attribs=True)
class _MapData:
files: Sequence[str] = None
entities: Dict[str, datetime] = None
@staticmethod
def _get_last_update_time(entity) -> datetime:
return getattr(entity, "last_update", None) or getattr(entity, "created")
@classmethod @classmethod
def import_from_zip(cls, filename: str, user_id: str = None): def _check_for_update(
cls, map_file: Path, entities: dict
) -> Tuple[bool, Sequence[str]]:
if not map_file.is_file():
return True, []
files = []
try:
map_data = cls._MapData(**json.loads(map_file.read_text()))
files = map_data.files
for file in files:
if not Path(file).is_file():
return True, files
new_times = {
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
for item in chain.from_iterable(entities.values())
}
old_times = map_data.entities
if set(new_times.keys()) != set(old_times.keys()):
return True, files
for id_, new_timestamp in new_times.items():
if new_timestamp != old_times[id_]:
return True, files
except Exception as ex:
print("Error reading map file. " + str(ex))
return True, files
return False, files
@classmethod
def _write_update_file(
cls, map_file: Path, entities: dict, created_files: Sequence[str]
):
map_data = cls._MapData(
files=created_files,
entities={
entity.id: cls._get_last_update_time(entity)
for entity in chain.from_iterable(entities.values())
},
)
map_file.write_text(json.dumps(attr.asdict(map_data)))
@staticmethod
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
def is_fileserver_link(a: str) -> bool:
a = a.lower()
if a.startswith("https://files."):
return True
if a.startswith("http"):
parsed = urlparse(a)
if parsed.scheme in {"http", "https"} and parsed.port == 8081:
return True
return False
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
print(
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
)
return fileserver_links
@classmethod
def export_to_zip(
cls,
filename: str,
experiments: Sequence[str] = None,
projects: Sequence[str] = None,
artifacts_path: str = None,
task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False,
) -> Sequence[str]:
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
raise ValueError("Invalid task statuses")
file = Path(filename)
entities = cls._resolve_entities(
experiments=experiments, projects=projects, task_statuses=task_statuses
)
map_file = file.with_suffix(".map")
updated, old_files = cls._check_for_update(map_file, entities)
if not updated:
print(f"There are no updates from the last export")
return old_files
for old in old_files:
old_path = Path(old)
if old_path.is_file():
old_path.unlink()
zip_args = dict(mode="w", compression=ZIP_BZIP2)
with ZipFile(file, **zip_args) as zfile:
artifacts, hash_ = cls._export(
zfile, entities, tag_entities=tag_exported_entities
)
file_with_hash = file.with_name(f"{file.stem}_{hash_}{file.suffix}")
file.replace(file_with_hash)
created_files = [str(file_with_hash)]
artifacts = cls._filter_artifacts(artifacts)
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = file_with_hash.with_suffix(".artifacts")
with ZipFile(artifacts_file, **zip_args) as zfile:
cls._export_artifacts(zfile, artifacts, artifacts_path)
created_files.append(str(artifacts_file))
cls._write_update_file(map_file, entities, created_files)
return created_files
@classmethod
def import_from_zip(
cls, filename: str, company_id: str, user_id: str, artifacts_path: str
):
with ZipFile(filename) as zfile: with ZipFile(filename) as zfile:
cls._import(zfile, user_id) cls._import(zfile, company_id, user_id)
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(".artifacts")
if artifacts_file.is_file():
print(f"Unzipping artifacts into {artifacts_path}")
with ZipFile(artifacts_file) as zfile:
zfile.extractall(artifacts_path)
@staticmethod @staticmethod
def _resolve_type( def _resolve_type(
cls: Type[mongoengine.Document], ids: Optional[List[str]] cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
) -> List[Any]: ) -> Sequence[Any]:
ids = set(ids) ids = set(ids)
items = list(cls.objects(id__in=list(ids))) items = list(cls.objects(id__in=list(ids)))
resolved = {i.id for i in items} resolved = {i.id for i in items}
@ -43,20 +211,24 @@ class PrePopulate:
@classmethod @classmethod
def _resolve_entities( def _resolve_entities(
cls, experiments: List[str] = None, projects: List[str] = None cls,
experiments: Sequence[str] = None,
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]: ) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
from database.model.project import Project
from database.model.task.task import Task
entities = defaultdict(set) entities = defaultdict(set)
if projects: if projects:
print("Reading projects...") print("Reading projects...")
entities[Project].update(cls._resolve_type(Project, projects)) entities[Project].update(cls._resolve_type(Project, projects))
print("--> Reading project experiments...") print("--> Reading project experiments...")
objs = Task.objects( query = Q(
project__in=list(set(filter(None, (p.id for p in entities[Project])))) project__in=list(set(filter(None, (p.id for p in entities[Project])))),
system_tags__nin=[EntityVisibility.archived.value],
) )
if task_statuses:
query &= Q(status__in=list(set(task_statuses)))
objs = Task.objects(query)
entities[Task].update(o for o in objs if o.id not in (experiments or [])) entities[Task].update(o for o in objs if o.id not in (experiments or []))
if experiments: if experiments:
@ -69,85 +241,256 @@ class PrePopulate:
project_ids = {p.id for p in entities[Project]} project_ids = {p.id for p in entities[Project]}
entities[Project].update(o for o in objs if o.id not in project_ids) entities[Project].update(o for o in objs if o.id not in project_ids)
model_ids = {
model_id
for task in entities[Task]
for model_id in (task.output.model, task.execution.model)
if model_id
}
if model_ids:
print("Reading models...")
entities[Model] = set(Model.objects(id__in=list(model_ids)))
return entities return entities
@classmethod @classmethod
def _cleanup_task(cls, task): def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
from database.model.task.task import TaskStatus if not tags:
return tags
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
task.completed = None @classmethod
task.started = None def _cleanup_model(cls, model: Model):
if task.execution: model.company = ""
task.execution.model = None model.user = ""
task.execution.model_desc = None model.tags = cls._filter_out_export_tags(model.tags)
task.execution.model_labels = None
if task.output:
task.output.model = None
task.status = TaskStatus.created @classmethod
def _cleanup_task(cls, task: Task):
task.comment = "Auto generated by Allegro.ai" task.comment = "Auto generated by Allegro.ai"
task.created = datetime.utcnow()
task.last_iteration = 0
task.last_update = task.created
task.status_changed = task.created
task.status_message = "" task.status_message = ""
task.status_reason = "" task.status_reason = ""
task.user = "" task.user = ""
task.company = ""
task.tags = cls._filter_out_export_tags(task.tags)
if task.output:
task.output.destination = None
@classmethod
def _cleanup_project(cls, project: Project):
project.user = ""
project.company = ""
project.tags = cls._filter_out_export_tags(project.tags)
@classmethod @classmethod
def _cleanup_entity(cls, entity_cls, entity): def _cleanup_entity(cls, entity_cls, entity):
from database.model.task.task import Task
if entity_cls == Task: if entity_cls == Task:
cls._cleanup_task(entity) cls._cleanup_task(entity)
elif entity_cls == Model:
cls._cleanup_model(entity)
elif entity == Project:
cls._cleanup_project(entity)
@classmethod
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
try:
for item in items:
item.update(upsert=False, tags=sorted(item.tags + [tag]))
except AttributeError:
pass
@classmethod
def _export_task_events(
cls, task: Task, base_filename: str, writer: ZipFile, hash_
) -> Sequence[str]:
artifacts = []
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
print(f"Writing task events into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
scroll_id = None
while True:
res = cls.event_bll.get_task_events(
task.company, task.id, scroll_id=scroll_id
)
if not res.events:
break
scroll_id = res.next_scroll_id
for event in res.events:
if event.get("type") == "training_debug_image":
url = cls._get_fixed_url(event.get("url"))
if url:
event["url"] = url
artifacts.append(url)
w.write(json.dumps(event))
data = f.getvalue()
hash_.update(data)
writer.writestr(filename, data)
return artifacts
@staticmethod
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
if not (url and url.lower().startswith("s3://")):
return url
try:
fixed = furl(url)
fixed.scheme = "https"
fixed.host += ".s3.amazonaws.com"
return fixed.url
except Exception as ex:
print(f"Failed processing link {url}. " + str(ex))
return url
@classmethod
def _export_entity_related_data(
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
):
if entity_cls == Task:
return [
*cls._get_task_output_artifacts(entity),
*cls._export_task_events(entity, base_filename, writer, hash_),
]
if entity_cls == Model:
entity.uri = cls._get_fixed_url(entity.uri)
return [entity.uri] if entity.uri else []
return []
@classmethod
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
if not task.execution.artifacts:
return []
for a in task.execution.artifacts:
if a.mode == ArtifactModes.output:
a.uri = cls._get_fixed_url(a.uri)
return [
a.uri
for a in task.execution.artifacts
if a.mode == ArtifactModes.output and a.uri
]
@classmethod
def _export_artifacts(
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
):
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
for path in unique_paths:
path = path.lstrip("/")
full_path = os.path.join(artifacts_path, path)
if os.path.isfile(full_path):
writer.write(full_path, path)
else:
print(f"Artifact {full_path} not found")
@classmethod @classmethod
def _export( def _export(
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None cls, writer: ZipFile, entities: dict, tag_entities: bool = False
): ) -> Tuple[Sequence[str], str]:
entities = cls._resolve_entities(experiments, projects) """
Export the requested experiments, projects and models and return the list of artifact files
for cls_, items in entities.items(): Always do the export on sorted items since the order of items influence hash
"""
artifacts = []
now = datetime.utcnow()
hash_ = hashlib.md5()
for cls_ in sorted(entities, key=attrgetter("__name__")):
items = sorted(entities[cls_], key=attrgetter("id"))
if not items: if not items:
continue continue
filename = f"{cls_.__module__}.{cls_.__name__}.json" base_filename = f"{cls_.__module__}.{cls_.__name__}"
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
)
)
filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}") print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with writer.open(filename, "w") as f: with BytesIO() as f:
f.write("[\n".encode("utf-8")) with cls.JsonLinesWriter(f) as w:
last = len(items) - 1 for item in items:
for i, item in enumerate(items): cls._cleanup_entity(cls_, item)
cls._cleanup_entity(cls_, item) w.write(item.to_json())
f.write(item.to_json().encode("utf-8")) data = f.getvalue()
if i != last: hash_.update(data)
f.write(",".encode("utf-8")) writer.writestr(filename, data)
f.write("\n".encode("utf-8"))
f.write("]\n".encode("utf-8")) if tag_entities:
cls._add_tag(items, now.strftime(cls.export_tag))
return artifacts, hash_.hexdigest()
@staticmethod @staticmethod
def _import(reader: ZipFile, user_id: str = None): def json_lines(file: BinaryIO):
for file_info in reader.filelist: for line in file:
full_name = splitext(file_info.orig_filename)[0] clean = (
print(f"Reading {reader.filename}:{full_name}...") line.decode("utf-8")
module_name, _, class_name = full_name.rpartition(".") .rstrip("\r\n")
module = importlib.import_module(module_name) .strip()
cls_: Type[mongoengine.Document] = getattr(module, class_name) .lstrip("[")
.rstrip(",]")
.strip()
)
if not clean:
continue
yield clean
with reader.open(file_info) as f: @classmethod
for item in tqdm( def _import(cls, reader: ZipFile, company_id: str = "", user_id: str = None):
f.readlines(), """
desc=f"Writing {cls_.__name__.lower()}s into database", Import entities and events from the zip file
unit="doc", Start from entities since event import will require the tasks already in DB
): """
item = ( event_file_ending = cls.events_file_suffix + ".json"
item.decode("utf-8") entity_files = (
.strip() fi
.lstrip("[") for fi in reader.filelist
.rstrip("]") if not fi.orig_filename.endswith(event_file_ending)
.rstrip(",") )
.strip() event_files = (
) fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
if not item: )
continue for files, reader_func in (
doc = cls_.from_json(item) (entity_files, cls._import_entity),
if user_id is not None and hasattr(doc, "user"): (event_files, cls._import_events),
doc.user = user_id ):
doc.save(force_insert=True) for file_info in files:
with reader.open(file_info) as f:
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
reader_func(f, full_name, company_id, user_id)
@classmethod
def _import_entity(cls, f: BinaryIO, full_name: str, company_id: str, user_id: str):
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
print(f"Writing {cls_.__name__.lower()}s into database")
for item in cls.json_lines(f):
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):
doc.user = user_id
if hasattr(doc, "company"):
doc.company = company_id
if isinstance(doc, Project):
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
)
doc.save()
if isinstance(doc, Task):
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
@classmethod
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked_tasks=True
)

View File

@ -1,5 +1,6 @@
import atexit import atexit
from argparse import ArgumentParser from argparse import ArgumentParser
from hashlib import md5
from flask import Flask, request, Response from flask import Flask, request, Response
from flask_compress import Compress from flask_compress import Compress
@ -11,10 +12,11 @@ from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter from bll.statistics.stats_reporter import StatisticsReporter
from config import config from config import config
from elastic.initialize import init_es_data from elastic.initialize import init_es_data
from mongo.initialize import init_mongo_data from mongo.initialize import init_mongo_data, pre_populate_data
from service_repo import ServiceRepo, APICall from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType from service_repo.auth import AuthType
from service_repo.errors import PathParsingError from service_repo.errors import PathParsingError
from sync import distributed_lock
from timing_context import TimingContext from timing_context import TimingContext
from updates import check_updates_thread from updates import check_updates_thread
from utilities import json from utilities import json
@ -33,8 +35,16 @@ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
database.initialize() database.initialize()
init_es_data() # build a key that uniquely identifies specific mongo instance
init_mongo_data() hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)):
print(key)
init_es_data()
empty_db = init_mongo_data()
if empty_db and config.get("apiserver.pre_populate.enabled", False):
pre_populate_data()
ServiceRepo.load("services") ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}") log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")

View File

@ -185,6 +185,7 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None
def get_all_ex(call: APICall): def get_all_ex(call: APICall):
include_stats = call.data.get("include_stats") include_stats = call.data.get("include_stats")
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value) stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
allow_public = not call.data.get("non_public", False)
if stats_for_state: if stats_for_state:
try: try:
@ -200,7 +201,7 @@ def get_all_ex(call: APICall):
company=call.identity.company, company=call.identity.company,
query_dict=call.data, query_dict=call.data,
query_options=get_all_query_options, query_options=get_all_query_options,
allow_public=True, allow_public=allow_public,
) )
conform_output_tags(call, projects) conform_output_tags(call, projects)

View File

@ -56,7 +56,7 @@ from database.model.task.task import (
) )
from database.utils import get_fields, parse_from_call from database.utils import get_fields, parse_from_call
from service_repo import APICall, endpoint from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags from services.utils import conform_tag_fields, conform_output_tags, validate_tags
from timing_context import TimingContext from timing_context import TimingContext
from utilities import safe_get from utilities import safe_get
@ -377,6 +377,7 @@ def create(call: APICall, company_id, req_model: CreateRequest):
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse "tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
) )
def clone_task(call: APICall, company_id, request: CloneRequest): def clone_task(call: APICall, company_id, request: CloneRequest):
validate_tags(request.new_task_tags, request.new_task_system_tags)
task = task_bll.clone_task( task = task_bll.clone_task(
company_id=company_id, company_id=company_id,
user_id=call.identity.user, user_id=call.identity.user,

28
server/sync.py Normal file
View File

@ -0,0 +1,28 @@
import time
from contextlib import contextmanager
from time import sleep
from redis_manager import redman
_redis = redman.connection("apiserver")
@contextmanager
def distributed_lock(name: str, timeout: int, max_wait: int = 0):
"""
Context manager that acquires a distributed lock on enter
and releases it on exit. The has a ttl equal to timeout seconds
If the lock can not be acquired for wait seconds (defaults to timeout * 2)
then the exception is thrown
"""
lock_name = f"dist_lock_{name}"
start = time.time()
max_wait = max_wait or timeout * 2
while not _redis.set(lock_name, value="", ex=timeout, nx=True):
sleep(1)
if time.time() - start > max_wait:
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds")
try:
yield
finally:
_redis.delete(lock_name)