From 1c7de3a86e4417f5682597556f28ed7a8f0626c7 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 5 Jan 2021 16:56:08 +0200 Subject: [PATCH] Add worker runtime properties support Refactor login and add guest mode Support artifacts in prepopulate --- apiserver/apimodels/__init__.py | 18 +++++ apiserver/apimodels/login.py | 32 ++++++++ apiserver/apimodels/workers.py | 25 ++++++ apiserver/bll/workers/__init__.py | 63 ++++++++++++++- apiserver/login/__init__.py | 33 ++++++++ apiserver/mongo/initialize/pre_populate.py | 76 ++++++++++++------ apiserver/schema/services/login.conf | 83 ++++++++++++++++++++ apiserver/schema/services/workers.conf | 90 +++++++++++++++++++++- apiserver/services/workers.py | 36 +++++++++ 9 files changed, 432 insertions(+), 24 deletions(-) create mode 100644 apiserver/apimodels/login.py create mode 100644 apiserver/login/__init__.py create mode 100644 apiserver/schema/services/login.conf diff --git a/apiserver/apimodels/__init__.py b/apiserver/apimodels/__init__.py index dd1a090..590a00c 100644 --- a/apiserver/apimodels/__init__.py +++ b/apiserver/apimodels/__init__.py @@ -214,3 +214,21 @@ class JsonSerializableMixin: @classmethod def from_json(cls: Type[ModelBase], s): return cls(**loads(s)) + + +def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]: + class _Wrapped(cls): + _callable_default = None + + def get_default_value(self): + if self._callable_default: + return self._callable_default() + return super(_Wrapped, self).get_default_value() + + def __init__(self, *args, default=None, **kwargs): + if default and callable(default): + self._callable_default = default + default = default() + super(_Wrapped, self).__init__(*args, default=default, **kwargs) + + return _Wrapped diff --git a/apiserver/apimodels/login.py b/apiserver/apimodels/login.py new file mode 100644 index 0000000..82e9450 --- /dev/null +++ b/apiserver/apimodels/login.py @@ -0,0 +1,32 @@ +from jsonmodels.fields import StringField, BoolField, EmbeddedField +from jsonmodels.models import Base + +from apiserver.apimodels import DictField, callable_default + + +class GetSupportedModesRequest(Base): + state = StringField(help_text="ASCII base64 encoded application state") + callback_url_prefix = StringField() + + +class BasicGuestMode(Base): + enabled = BoolField(default=False) + name = StringField() + username = StringField() + password = StringField() + + +class BasicMode(Base): + enabled = BoolField(default=False) + guest = callable_default(EmbeddedField)(BasicGuestMode, default=BasicGuestMode) + + +class ServerErrors(Base): + missed_es_upgrade = BoolField(default=False) + es_connection_error = BoolField(default=False) + + +class GetSupportedModesResponse(Base): + basic = EmbeddedField(BasicMode) + server_errors = EmbeddedField(ServerErrors) + sso = DictField([str, type(None)]) diff --git a/apiserver/apimodels/workers.py b/apiserver/apimodels/workers.py index 85b593d..2db1407 100644 --- a/apiserver/apimodels/workers.py +++ b/apiserver/apimodels/workers.py @@ -176,3 +176,28 @@ class ActivityReportSeries(Base): class GetActivityReportResponse(Base): total = EmbeddedField(ActivityReportSeries) active = EmbeddedField(ActivityReportSeries) + + +class RuntimeProperty(Base): + key = StringField() + value = StringField() + expiry = IntField(default=None) + + +class GetRuntimePropertiesRequest(Base): + worker = StringField(required=True) + + +class GetRuntimePropertiesResponse(Base): + runtime_properties = ListField(RuntimeProperty) + + +class SetRuntimePropertiesRequest(Base): + worker = StringField(required=True) + runtime_properties = ListField(RuntimeProperty) + + +class SetRuntimePropertiesResponse(Base): + added = ListField(str) + removed = ListField(str) + errors = ListField(str) diff --git a/apiserver/bll/workers/__init__.py b/apiserver/bll/workers/__init__.py index bd8a5c1..e384026 100644 --- a/apiserver/bll/workers/__init__.py +++ b/apiserver/bll/workers/__init__.py @@ -1,6 +1,6 @@ import itertools from datetime import datetime, timedelta -from typing import Sequence, Set, Optional +from typing import Sequence, Set, Optional, List import attr import elasticsearch.helpers @@ -16,6 +16,7 @@ from apiserver.apimodels.workers import ( WorkerResponseEntry, QueueEntry, MachineStats, + RuntimeProperty, ) from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context @@ -416,6 +417,66 @@ class WorkerBLL: added, errors = es_res[:2] return (added == len(actions)) and not errors + def set_runtime_properties( + self, + company: str, + user: str, + worker_id: str, + runtime_properties: List[RuntimeProperty], + ) -> dict: + """Save worker entry in Redis""" + res = { + "added": [], + "removed": [], + "errors": [], + } + for prop in runtime_properties: + try: + key = self._get_runtime_property_key(company, user, worker_id, prop.key) + if prop.expiry == 0: + self.redis.delete(key) + res["removed"].append(key) + else: + self.redis.set( + key, + prop.value, + ex=prop.expiry + ) + res["added"].append(key) + except Exception as ex: + msg = f"Exception: {ex}\nFailed saving property '{prop.key}: {prop.value}', skipping" + log.exception(msg) + res["errors"].append(ex) + return res + + def get_runtime_properties( + self, + company: str, + user: str, + worker_id: str, + ) -> List[RuntimeProperty]: + match = self._get_runtime_property_key(company, user, worker_id, "*") + with TimingContext("redis", "get_runtime_properties"): + res = self.redis.scan_iter(match=match) + runtime_properties = [] + for r in res: + ttl = self.redis.ttl(r) + runtime_properties.append( + RuntimeProperty( + key=r.decode()[len(match) - 1:], + value=self.redis.get(r).decode(), + expiry=ttl if ttl >= 0 else None + ) + ) + return runtime_properties + + def _get_runtime_property_key( + self, company: str, user: str, worker_id: str, prop_id: str + ) -> str: + """Build redis key from company, user, worker_id and prop_id""" + prefix = self._get_worker_key(company, user, worker_id) + return f"{prefix}_prop_{prop_id}" + @attr.s(auto_attribs=True) class WorkerConversionHelper: diff --git a/apiserver/login/__init__.py b/apiserver/login/__init__.py new file mode 100644 index 0000000..0034d8f --- /dev/null +++ b/apiserver/login/__init__.py @@ -0,0 +1,33 @@ +from apiserver.apimodels.login import ( + GetSupportedModesRequest, + GetSupportedModesResponse, + BasicMode, + BasicGuestMode, + ServerErrors, +) +from apiserver.config import info +from apiserver.service_repo import endpoint +from apiserver.service_repo.auth.fixed_user import FixedUser + + +@endpoint("login.supported_modes", response_data_model=GetSupportedModesResponse) +def supported_modes(_, __, ___: GetSupportedModesRequest): + guest_user = FixedUser.get_guest_user() + if guest_user: + guest = BasicGuestMode( + enabled=True, + name=guest_user.name, + username=guest_user.username, + password=guest_user.password, + ) + else: + guest = BasicGuestMode() + + return GetSupportedModesResponse( + basic=BasicMode(enabled=FixedUser.enabled(), guest=guest), + sso={}, + server_errors=ServerErrors( + missed_es_upgrade=info.missed_es_upgrade, + es_connection_error=info.es_connection_error, + ), + ) diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index c0dfcf4..4e0d0ca 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -4,7 +4,6 @@ import os import re from collections import defaultdict from datetime import datetime, timezone -from functools import partial from io import BytesIO from itertools import chain from operator import attrgetter @@ -21,17 +20,19 @@ from typing import ( BinaryIO, Union, Mapping, + IO, ) from urllib.parse import unquote, urlparse from zipfile import ZipFile, ZIP_BZIP2 import dpath import mongoengine -from boltons.iterutils import chunked_iter +from boltons.iterutils import chunked_iter, first from furl import furl from mongoengine import Q from apiserver.bll.event import EventBLL +from apiserver.bll.task.artifacts import get_artifact_id from apiserver.bll.task.param_utils import ( split_param_name, hyperparams_default_section, @@ -46,6 +47,7 @@ from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus from apiserver.database.utils import get_options from apiserver.tools import safe_get from apiserver.utilities import json +from apiserver.utilities.dicts import nested_get, nested_set from .user import _ensure_backend_user @@ -344,7 +346,7 @@ class PrePopulate: dpath.delete(task_data, old_param_field) @classmethod - def _upgrade_tasks(cls, f: BinaryIO) -> bytes: + def _upgrade_tasks(cls, f: IO[bytes]) -> bytes: """ Build content array that contains fixed tasks from the passed file For each task the old execution.parameters and model.design are @@ -564,13 +566,13 @@ class PrePopulate: if not task.execution.artifacts: return [] - for a in task.execution.artifacts: + for a in task.execution.artifacts.values(): if a.mode == ArtifactModes.output: a.uri = cls._get_fixed_url(a.uri) return [ a.uri - for a in task.execution.artifacts + for a in task.execution.artifacts.values() if a.mode == ArtifactModes.output and a.uri ] @@ -630,7 +632,7 @@ class PrePopulate: return artifacts @staticmethod - def json_lines(file: BinaryIO): + def json_lines(file: IO[bytes]): for line in file: clean = ( line.decode("utf-8") @@ -651,6 +653,7 @@ class PrePopulate: company_id: str = "", user_id: str = None, metadata: Mapping[str, Any] = None, + sort_tasks_by_last_updated: bool = True, ): """ Import entities and events from the zip file @@ -663,35 +666,60 @@ class PrePopulate: if not fi.orig_filename.endswith(event_file_ending) and fi.orig_filename != cls.metadata_filename ) - event_files = ( - fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending) - ) - for files, reader_func in ( - (entity_files, partial(cls._import_entity, metadata=metadata or {})), - (event_files, cls._import_events), - ): - for file_info in files: - with reader.open(file_info) as f: - full_name = splitext(file_info.orig_filename)[0] - print(f"Reading {reader.filename}:{full_name}...") - reader_func(f, full_name, company_id, user_id) + metadata = metadata or {} + tasks = [] + for entity_file in entity_files: + with reader.open(entity_file) as f: + full_name = splitext(entity_file.orig_filename)[0] + print(f"Reading {reader.filename}:{full_name}...") + res = cls._import_entity(f, full_name, company_id, user_id, metadata) + if res: + tasks = res + + if sort_tasks_by_last_updated: + tasks = sorted(tasks, key=attrgetter("last_update")) + + for task in tasks: + events_file = first( + fi + for fi in reader.filelist + if fi.orig_filename.endswith(task.id + event_file_ending) + ) + if not events_file: + continue + with reader.open(events_file) as f: + full_name = splitext(events_file.orig_filename)[0] + print(f"Reading {reader.filename}:{full_name}...") + cls._import_events(f, full_name, company_id, user_id) @classmethod def _import_entity( cls, - f: BinaryIO, + f: IO[bytes], full_name: str, company_id: str, user_id: str, metadata: Mapping[str, Any], - ): + ) -> Optional[Sequence[Task]]: module_name, _, class_name = full_name.rpartition(".") module = importlib.import_module(module_name) cls_: Type[mongoengine.Document] = getattr(module, class_name) print(f"Writing {cls_.__name__.lower()}s into database") - + tasks = [] override_project_count = 0 for item in cls.json_lines(f): + if cls_ == Task: + task_data = json.loads(item) + artifacts_path = ("execution", "artifacts") + artifacts = nested_get(task_data, artifacts_path) + if isinstance(artifacts, list): + nested_set( + task_data, + artifacts_path, + value={get_artifact_id(a): a for a in artifacts}, + ) + item = json.dumps(task_data) + doc = cls_.from_json(item, created=True) if hasattr(doc, "user"): doc.user = user_id @@ -717,10 +745,14 @@ class PrePopulate: doc.save() if isinstance(doc, Task): + tasks.append(doc) cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True) + if tasks: + return tasks + @classmethod - def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _): + def _import_events(cls, f: IO[bytes], full_name: str, company_id: str, _): _, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_") print(f"Writing events for task {task_id} into database") for events_chunk in chunked_iter(cls.json_lines(f), 1000): diff --git a/apiserver/schema/services/login.conf b/apiserver/schema/services/login.conf new file mode 100644 index 0000000..d2a30b9 --- /dev/null +++ b/apiserver/schema/services/login.conf @@ -0,0 +1,83 @@ +_description: """This service provides an administrator management interface to the company's users login information.""" + +_default { + internal: false + allow_roles: ["system", "root", "admin"] +} + +supported_modes { + authorize: false + "0.17" { + description: """ Return supported login modes.""" + request { + type: object + properties { + state { + description: "ASCII base64 encoded application state" + type: string + } + callback_url_prefix { + description: "URL prefix used to generate the callback URL for each supported SSO provider" + type: string + } + } + } + response { + type: object + properties { + basic { + type: object + properties { + enabled { + description: "Basic aothentication (fixed users mode) mode enabled" + type: boolean + } + guest { + type: object + properties { + enabled { + description: "Basic aothentication guest mode enabled" + type: boolean + } + name { + description: "Guest name" + type: string + } + username { + description: "Guest username" + type: string + } + password { + description: "Guest password" + type: string + } + } + } + } + } + sso { + description: "SSO authentication providers" + type: object + additionalProperties { + desctiprion: "Provider redirect URL" + type: string + } + } + server_errors { + description: "Server initialization errors" + type: object + properties { + missed_es_upgrade { + description: "Indicate that Elasticsearch database was not upgraded from version 5" + type: boolean + } + es_connection_error { + description: "Indicate an error communicating to Elasticsearch" + type: boolean + } + } + } + } + } + } +} diff --git a/apiserver/schema/services/workers.conf b/apiserver/schema/services/workers.conf index 8b10b41..8131a90 100644 --- a/apiserver/schema/services/workers.conf +++ b/apiserver/schema/services/workers.conf @@ -499,4 +499,92 @@ } } } -} + get_runtime_properties { + "2.10" { + description: "Get runtime properties for a worker" + request { + required: [ + worker + ] + type: object + properties { + worker { + description: "Worker ID" + type: string + } + } + } + response { + type: object + properties { + runtime_properties { + type: array + items { + type: object + properties { + key { type: string } + value { type: string } + expiry { + description: "Expiry (in seconds) for a runtime property" + type: integer + } + } + } + } + } + } + } + } + set_runtime_properties { + "2.10" { + description: "Set runtime properties for a worker" + request { + required: [ + worker + runtime_properties + ] + type: object + properties { + worker { + description: "Worker ID" + type: string + } + runtime_properties { + type: array + items { + type: object + properties { + key { type: string } + value { type: string } + expiry { + description: "Expiry (in seconds) for a runtime property. When set to null no expiry is set, when set to 0 the specified key is removed" + type: integer + } + } + } + } + } + } + response { + type: object + properties { + added { + type: array + description: "keys of runtime properties added to redis" + items: { type: string } + } + removed { + type: array + description: "keys of runtime properties removed from redis" + items: { type: string } + } + errors { + type: array + description: "errors for keys failed to be added to redis" + items: { type: string } + } + } + } + } + } +} \ No newline at end of file diff --git a/apiserver/services/workers.py b/apiserver/services/workers.py index 45fff0c..9e256f4 100644 --- a/apiserver/services/workers.py +++ b/apiserver/services/workers.py @@ -22,6 +22,10 @@ from apiserver.apimodels.workers import ( GetActivityReportRequest, GetActivityReportResponse, ActivityReportSeries, + SetRuntimePropertiesRequest, + GetRuntimePropertiesRequest, + GetRuntimePropertiesResponse, + SetRuntimePropertiesResponse ) from apiserver.bll.util import extract_properties_to_lists from apiserver.bll.workers import WorkerBLL @@ -202,3 +206,35 @@ def get_stats(call: APICall, company_id, request: GetStatsRequest): for worker, stats in ret.items() ] ) + + +@endpoint( + "workers.set_runtime_properties", + min_version="2.10", + request_data_model=SetRuntimePropertiesRequest, + response_data_model=SetRuntimePropertiesResponse, +) +def set_runtime_properties(call: APICall, company_id, request: SetRuntimePropertiesRequest): + res = worker_bll.set_runtime_properties( + company=company_id, + user=call.identity.user, + worker_id=request.worker, + runtime_properties=request.runtime_properties, + ) + return SetRuntimePropertiesResponse(added=res["added"], removed=res["removed"], errors=res["errors"]) + + +@endpoint( + "workers.get_runtime_properties", + min_version="2.10", + request_data_model=GetRuntimePropertiesRequest, + response_data_model=GetRuntimePropertiesResponse, +) +def get_runtime_properties(call: APICall, company_id, request: GetRuntimePropertiesRequest): + return GetRuntimePropertiesResponse( + runtime_properties=worker_bll.get_runtime_properties( + company=company_id, + user=call.identity.user, + worker_id=request.worker, + ) + )