Rename server to apiserver

This commit is contained in:
allegroai
2021-01-05 16:22:34 +02:00
parent 01115c1223
commit df65e1c7ad
195 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1,89 @@
from pathlib import Path
from typing import Sequence, Union
from config import config
from config.info import get_default_company
from database.model.auth import Role
from service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
from .pre_populate import PrePopulate
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
log = config.logger(__package__)
def _pre_populate(company_id: str, zip_file: str):
if not zip_file or not Path(zip_file).is_file():
msg = f"Invalid pre-populate zip file: {zip_file}"
if config.get("apiserver.pre_populate.fail_on_error", False):
log.error(msg)
raise ValueError(msg)
else:
log.warning(msg)
else:
log.info(f"Pre-populating using {zip_file}")
PrePopulate.import_from_zip(
zip_file,
artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None),
)
def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]:
if isinstance(zip_files, str):
zip_files = [zip_files]
for p in map(Path, zip_files):
if p.is_file():
yield p
if p.is_dir():
yield from p.glob("*.zip")
log.warning(f"Invalid pre-populate entry {str(p)}, skipping")
def pre_populate_data():
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
PrePopulate.update_featured_projects_order()
def init_mongo_data():
try:
_apply_migrations(log)
_ensure_uuid()
company_id = _ensure_company(get_default_company(), "trains", log)
_ensure_default_queue(company_id)
fixed_mode = FixedUser.enabled()
for user, credentials in config.get("secure.credentials", {}).items():
user_data = {
"name": user,
"role": credentials.role,
"email": f"{user}@example.com",
"key": credentials.user_key,
"secret": credentials.user_secret,
}
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
if credentials.role == Role.user:
_ensure_backend_user(user_id, company_id, credentials.display_name)
if fixed_mode:
log.info("Fixed users mode is enabled")
FixedUser.validate()
if FixedUser.guest_enabled():
_ensure_company(FixedUser.get_guest_user().company, "guests", log)
for user in FixedUser.from_config():
try:
ensure_fixed_user(user, log=log)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
except Exception as ex:
log.exception("Failed initializing mongodb")

View File

@@ -0,0 +1,91 @@
import importlib.util
from datetime import datetime
from logging import Logger
from pathlib import Path
from mongoengine.connection import get_db
from semantic_version import Version
import database.utils
from database import Database
from database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
def get_last_server_version() -> Version:
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
return previous_versions[0] if previous_versions else Version("0.0.0")
def _apply_migrations(log: Logger):
"""
Apply migrations as found in the migration dir.
Returns a boolean indicating whether the database was empty prior to migration.
"""
log = log.getChild(Path(__file__).stem)
log.info(f"Started mongodb migrations")
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
empty_dbs = check_mongo_empty()
last_version = get_last_server_version()
try:
new_scripts = {
ver: path
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
if ver > last_version
}
except ValueError as ex:
raise ValueError(f"Failed parsing migration version from file: {ex}")
dbs = {Database.auth: "migrate_auth", Database.backend: "migrate_backend"}
for script_version in sorted(new_scripts):
script = new_scripts[script_version]
if empty_dbs:
log.info(f"Skipping migration {script.name} (empty databases)")
else:
spec = importlib.util.spec_from_file_location(script.stem, str(script))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
for alias, func_name in dbs.items():
func = getattr(module, func_name, None)
if not func:
continue
try:
log.info(f"Applying {script.stem}/{func_name}()")
func(get_db(alias))
except Exception:
log.exception(f"Failed applying {script}:{func_name}()")
raise ValueError(
"Migration failed, aborting. Please restore backup."
)
DatabaseVersion(
id=database.utils.id(),
num=script.stem,
created=datetime.utcnow(),
desc="Applied on server startup",
).save()
log.info("Finished mongodb migrations")

View File

@@ -0,0 +1,728 @@
import hashlib
import importlib
import os
import re
from collections import defaultdict
from datetime import datetime, timezone
from functools import partial
from io import BytesIO
from itertools import chain
from operator import attrgetter
from os.path import splitext
from pathlib import Path
from typing import (
Optional,
Any,
Type,
Set,
Dict,
Sequence,
Tuple,
BinaryIO,
Union,
Mapping,
)
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
import dpath
import mongoengine
from boltons.iterutils import chunked_iter
from furl import furl
from mongoengine import Q
from bll.event import EventBLL
from bll.task.param_utils import (
split_param_name,
hyperparams_default_section,
hyperparams_legacy_type,
)
from config import config
from config.info import get_default_company
from database.model import EntityVisibility
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, ArtifactModes, TaskStatus
from database.utils import get_options
from tools import safe_get
from utilities import json
from .user import _ensure_backend_user
class PrePopulate:
event_bll = EventBLL()
events_file_suffix = "_events"
export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts"
img_source_regex = re.compile(
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
flags=re.IGNORECASE,
)
class JsonLinesWriter:
def __init__(self, file: BinaryIO):
self.file = file
self.empty = True
def __enter__(self):
self._write("[")
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self._write("\n]")
def _write(self, data: str):
self.file.write(data.encode("utf-8"))
def write(self, line: str):
if not self.empty:
self._write(",")
self._write("\n" + line)
self.empty = False
@staticmethod
def _get_last_update_time(entity) -> datetime:
return getattr(entity, "last_update", None) or getattr(entity, "created")
@classmethod
def _check_for_update(
cls, map_file: Path, entities: dict, metadata_hash: str
) -> Tuple[bool, Sequence[str]]:
if not map_file.is_file():
return True, []
files = []
try:
map_data = json.loads(map_file.read_text())
files = map_data.get("files", [])
for file in files:
if not Path(file).is_file():
return True, files
new_times = {
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
for item in chain.from_iterable(entities.values())
}
old_times = map_data.get("entities", {})
if set(new_times.keys()) != set(old_times.keys()):
return True, files
for id_, new_timestamp in new_times.items():
if new_timestamp != old_times[id_]:
return True, files
if metadata_hash != map_data.get("metadata_hash", ""):
return True, files
except Exception as ex:
print("Error reading map file. " + str(ex))
return True, files
return False, files
@classmethod
def _write_update_file(
cls,
map_file: Path,
entities: dict,
created_files: Sequence[str],
metadata_hash: str,
):
map_file.write_text(
json.dumps(
dict(
files=created_files,
entities={
entity.id: cls._get_last_update_time(entity)
for entity in chain.from_iterable(entities.values())
},
metadata_hash=metadata_hash,
)
)
)
@staticmethod
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
def is_fileserver_link(a: str) -> bool:
a = a.lower()
if a.startswith("https://files."):
return True
if a.startswith("http"):
parsed = urlparse(a)
if parsed.scheme in {"http", "https"} and parsed.netloc.endswith(
"8081"
):
return True
return False
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
print(
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
)
return fileserver_links
@classmethod
def export_to_zip(
cls,
filename: str,
experiments: Sequence[str] = None,
projects: Sequence[str] = None,
artifacts_path: str = None,
task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
) -> Sequence[str]:
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
raise ValueError("Invalid task statuses")
file = Path(filename)
entities = cls._resolve_entities(
experiments=experiments, projects=projects, task_statuses=task_statuses
)
hash_ = hashlib.md5()
if metadata:
meta_str = json.dumps(metadata)
hash_.update(meta_str.encode())
metadata_hash = hash_.hexdigest()
else:
meta_str, metadata_hash = "", ""
map_file = file.with_suffix(".map")
updated, old_files = cls._check_for_update(
map_file, entities=entities, metadata_hash=metadata_hash
)
if not updated:
print(f"There are no updates from the last export")
return old_files
for old in old_files:
old_path = Path(old)
if old_path.is_file():
old_path.unlink()
with ZipFile(file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
)
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
file.replace(file_with_hash)
created_files = [str(file_with_hash)]
artifacts = cls._filter_artifacts(artifacts)
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
cls._export_artifacts(zfile, artifacts, artifacts_path)
created_files.append(str(artifacts_file))
cls._write_update_file(
map_file,
entities=entities,
created_files=created_files,
metadata_hash=metadata_hash,
)
return created_files
@classmethod
def import_from_zip(
cls,
filename: str,
artifacts_path: str,
company_id: Optional[str] = None,
user_id: str = "",
user_name: str = "",
):
metadata = None
with ZipFile(filename) as zfile:
try:
with zfile.open(cls.metadata_filename) as f:
metadata = json.loads(f.read())
meta_public = metadata.get("public")
if company_id is None and meta_public is not None:
company_id = "" if meta_public else get_default_company()
if not user_id:
meta_user_id = metadata.get("user_id", "")
meta_user_name = metadata.get("user_name", "")
user_id, user_name = meta_user_id, meta_user_name
except Exception:
pass
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
# Make sure we won't end up with an invalid company ID
if company_id is None:
company_id = ""
# Always use a public user for pre-populated data
user_id = _ensure_backend_user(
user_id=user_id, user_name=user_name, company_id="",
)
cls._import(zfile, company_id, user_id, metadata)
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
print(f"Unzipping artifacts into {artifacts_path}")
with ZipFile(artifacts_file) as zfile:
zfile.extractall(artifacts_path)
@classmethod
def upgrade_zip(cls, filename) -> Sequence:
hash_ = hashlib.md5()
task_file = cls._get_base_filename(Task) + ".json"
temp_file = Path("temp.zip")
file = Path(filename)
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
for file_info in reader.filelist:
if file_info.orig_filename == task_file:
with reader.open(file_info) as f:
content = cls._upgrade_tasks(f)
else:
content = reader.read(file_info)
writer.writestr(file_info, content)
hash_.update(content)
base_file_name, _, old_hash = file.stem.rpartition("_")
new_hash = hash_.hexdigest()
if old_hash == new_hash:
print(f"The file {filename} was not updated")
temp_file.unlink()
return []
new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}")
temp_file.replace(new_file)
upadated = [str(new_file)]
artifacts_file = file.with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
new_artifacts = new_file.with_suffix(cls.artifacts_ext)
artifacts_file.replace(new_artifacts)
upadated.append(str(new_artifacts))
print(f"File {str(file)} replaced with {str(new_file)}")
file.unlink()
return upadated
@staticmethod
def _upgrade_task_data(task_data: dict):
for old_param_field, new_param_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy = safe_get(task_data, old_param_field)
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not safe_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
dpath.new(task_data, new_path, new_param)
dpath.delete(task_data, old_param_field)
@classmethod
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
"""
Build content array that contains fixed tasks from the passed file
For each task the old execution.parameters and model.design are
converted to the new structure.
The fix is done on Task objects (not the dictionary) so that
the fields are serialized back in the same order as they were in the original file
"""
with BytesIO() as temp:
with cls.JsonLinesWriter(temp) as w:
for line in cls.json_lines(f):
task_data = Task.from_json(line).to_proper_dict()
cls._upgrade_task_data(task_data)
new_task = Task(**task_data)
w.write(new_task.to_json())
return temp.getvalue()
@classmethod
def update_featured_projects_order(cls):
featured_order = config.get("services.projects.featured_order", [])
if not featured_order:
return
def get_index(p: Project):
for index, entry in enumerate(featured_order):
if (
entry.get("id", None) == p.id
or entry.get("name", None) == p.name
or ("name_regex" in entry and re.match(entry["name_regex"], p.name))
):
return index
return 999
for project in Project.get_many_public(projection=["id", "name"]):
featured_index = get_index(project)
Project.objects(id=project.id).update(featured=featured_index)
@staticmethod
def _resolve_type(
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
) -> Sequence[Any]:
ids = set(ids)
items = list(cls.objects(id__in=list(ids)))
resolved = {i.id for i in items}
missing = ids - resolved
for name_candidate in missing:
results = list(cls.objects(name=name_candidate))
if not results:
print(f"ERROR: no match for `{name_candidate}`")
exit(1)
elif len(results) > 1:
print(f"ERROR: more than one match for `{name_candidate}`")
exit(1)
items.append(results[0])
return items
@classmethod
def _resolve_entities(
cls,
experiments: Sequence[str] = None,
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
entities = defaultdict(set)
if projects:
print("Reading projects...")
entities[Project].update(cls._resolve_type(Project, projects))
print("--> Reading project experiments...")
query = Q(
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
system_tags__nin=[EntityVisibility.archived.value],
)
if task_statuses:
query &= Q(status__in=list(set(task_statuses)))
objs = Task.objects(query)
entities[Task].update(o for o in objs if o.id not in (experiments or []))
if experiments:
print("Reading experiments...")
entities[Task].update(cls._resolve_type(Task, experiments))
print("--> Reading experiments projects...")
objs = Project.objects(
id__in=list(set(filter(None, (p.project for p in entities[Task]))))
)
project_ids = {p.id for p in entities[Project]}
entities[Project].update(o for o in objs if o.id not in project_ids)
model_ids = {
model_id
for task in entities[Task]
for model_id in (task.output.model, task.execution.model)
if model_id
}
if model_ids:
print("Reading models...")
entities[Model] = set(Model.objects(id__in=list(model_ids)))
return entities
@classmethod
def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
if not tags:
return tags
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
@classmethod
def _cleanup_model(cls, model: Model):
model.company = ""
model.user = ""
model.tags = cls._filter_out_export_tags(model.tags)
@classmethod
def _cleanup_task(cls, task: Task):
task.comment = "Auto generated by Allegro.ai"
task.status_message = ""
task.status_reason = ""
task.user = ""
task.company = ""
task.tags = cls._filter_out_export_tags(task.tags)
if task.output:
task.output.destination = None
@classmethod
def _cleanup_project(cls, project: Project):
project.user = ""
project.company = ""
project.tags = cls._filter_out_export_tags(project.tags)
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
if entity_cls == Task:
cls._cleanup_task(entity)
elif entity_cls == Model:
cls._cleanup_model(entity)
elif entity == Project:
cls._cleanup_project(entity)
@classmethod
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
try:
for item in items:
item.update(upsert=False, tags=sorted(item.tags + [tag]))
except AttributeError:
pass
@classmethod
def _export_task_events(
cls, task: Task, base_filename: str, writer: ZipFile, hash_
) -> Sequence[str]:
artifacts = []
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
print(f"Writing task events into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
scroll_id = None
while True:
res = cls.event_bll.get_task_events(
task.company, task.id, scroll_id=scroll_id
)
if not res.events:
break
scroll_id = res.next_scroll_id
for event in res.events:
event_type = event.get("type")
if event_type == "training_debug_image":
url = cls._get_fixed_url(event.get("url"))
if url:
event["url"] = url
artifacts.append(url)
elif event_type == "plot":
plot_str: str = event.get("plot_str", "")
for match in cls.img_source_regex.findall(plot_str):
url = cls._get_fixed_url(match)
if match != url:
plot_str = plot_str.replace(match, url)
artifacts.append(url)
w.write(json.dumps(event))
data = f.getvalue()
hash_.update(data)
writer.writestr(filename, data)
return artifacts
@staticmethod
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
if not (url and url.lower().startswith("s3://")):
return url
try:
fixed = furl(url)
fixed.scheme = "https"
fixed.host += ".s3.amazonaws.com"
return fixed.url
except Exception as ex:
print(f"Failed processing link {url}. " + str(ex))
return url
@classmethod
def _export_entity_related_data(
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
):
if entity_cls == Task:
return [
*cls._get_task_output_artifacts(entity),
*cls._export_task_events(entity, base_filename, writer, hash_),
]
if entity_cls == Model:
entity.uri = cls._get_fixed_url(entity.uri)
return [entity.uri] if entity.uri else []
return []
@classmethod
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
if not task.execution.artifacts:
return []
for a in task.execution.artifacts:
if a.mode == ArtifactModes.output:
a.uri = cls._get_fixed_url(a.uri)
return [
a.uri
for a in task.execution.artifacts
if a.mode == ArtifactModes.output and a.uri
]
@classmethod
def _export_artifacts(
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
):
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
for path in unique_paths:
path = path.lstrip("/")
full_path = os.path.join(artifacts_path, path)
if os.path.isfile(full_path):
writer.write(full_path, path)
else:
print(f"Artifact {full_path} not found")
@staticmethod
def _get_base_filename(cls_: type):
return f"{cls_.__module__}.{cls_.__name__}"
@classmethod
def _export(
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
) -> Sequence[str]:
"""
Export the requested experiments, projects and models and return the list of artifact files
Always do the export on sorted items since the order of items influence hash
"""
artifacts = []
now = datetime.utcnow()
for cls_ in sorted(entities, key=attrgetter("__name__")):
items = sorted(entities[cls_], key=attrgetter("id"))
if not items:
continue
base_filename = cls._get_base_filename(cls_)
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
)
)
filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for item in items:
cls._cleanup_entity(cls_, item)
w.write(item.to_json())
data = f.getvalue()
hash_.update(data)
writer.writestr(filename, data)
if tag_entities:
cls._add_tag(items, now.strftime(cls.export_tag))
return artifacts
@staticmethod
def json_lines(file: BinaryIO):
for line in file:
clean = (
line.decode("utf-8")
.rstrip("\r\n")
.strip()
.lstrip("[")
.rstrip(",]")
.strip()
)
if not clean:
continue
yield clean
@classmethod
def _import(
cls,
reader: ZipFile,
company_id: str = "",
user_id: str = None,
metadata: Mapping[str, Any] = None,
):
"""
Import entities and events from the zip file
Start from entities since event import will require the tasks already in DB
"""
event_file_ending = cls.events_file_suffix + ".json"
entity_files = (
fi
for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
)
event_files = (
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
)
for files, reader_func in (
(entity_files, partial(cls._import_entity, metadata=metadata or {})),
(event_files, cls._import_events),
):
for file_info in files:
with reader.open(file_info) as f:
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
reader_func(f, full_name, company_id, user_id)
@classmethod
def _import_entity(
cls,
f: BinaryIO,
full_name: str,
company_id: str,
user_id: str,
metadata: Mapping[str, Any],
):
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
print(f"Writing {cls_.__name__.lower()}s into database")
override_project_count = 0
for item in cls.json_lines(f):
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):
doc.user = user_id
if hasattr(doc, "company"):
doc.company = company_id
if isinstance(doc, Project):
override_project_name = metadata.get("project_name", None)
if override_project_name:
if override_project_count:
override_project_name = (
f"{override_project_name} {override_project_count + 1}"
)
override_project_count += 1
doc.name = override_project_name
doc.logo_url = metadata.get("logo_url", None)
doc.logo_blob = metadata.get("logo_blob", None)
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
)
doc.save()
if isinstance(doc, Task):
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
@classmethod
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked_tasks=True
)

View File

@@ -0,0 +1,80 @@
from datetime import datetime
from logging import Logger
import attr
from database.model.auth import Role
from database.model.auth import User as AuthUser, Credentials
from database.model.user import User
from service_repo.auth.fixed_user import FixedUser
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
key, secret = user_data.get("key"), user_data.get("secret")
if not (key and secret):
credentials = None
else:
creds = Credentials(key=key, secret=secret)
user = AuthUser.objects(credentials__match=creds).first()
if user:
if revoke:
user.credentials = []
user.save()
return user.id
credentials = [] if revoke else [creds]
user_id = user_data.get("id", f"__{user_data['name']}__")
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_id,
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
created=datetime.utcnow(),
credentials=credentials,
)
user.save()
return user.id
def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
given_name, _, family_name = user_name.partition(" ")
User(
id=user_id,
company=company_id,
name=user_name,
given_name=given_name,
family_name=family_name,
).save()
return user_id
def ensure_fixed_user(user: FixedUser, log: Logger):
db_user = User.objects(company=user.company, id=user.user_id).first()
if db_user:
# noinspection PyBroadException
try:
log.info(f"Updating user name: {user.name}")
given_name, _, family_name = user.name.partition(" ")
db_user.update(name=user.name, given_name=given_name, family_name=family_name)
except Exception:
pass
return
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.guest if user.is_guest else Role.user
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
return _ensure_backend_user(user.user_id, user.company, user.name)

View File

@@ -0,0 +1,37 @@
from logging import Logger
from uuid import uuid4
from bll.queue import QueueBLL
from config import config
from database.model.company import Company
from database.model.queue import Queue
from database.model.settings import Settings, SettingKeys
log = config.logger(__file__)
def _ensure_company(company_id, company_name, log: Logger):
company = Company.objects(id=company_id).only("id").first()
if company:
return company_id
log.info(f"Creating company: {company_name}")
company = Company(id=company_id, name=company_name)
company.save()
return company_id
def _ensure_default_queue(company):
"""
If no queue is present for the company then
create a new one and mark it as a default
"""
queue = Queue.objects(company=company).only("id").first()
if queue:
return
QueueBLL.create(company, name="default", system_tags=["default"])
def _ensure_uuid():
Settings.add_value(SettingKeys.server__uuid, str(uuid4()))