Add worker runtime properties support

Refactor login and add guest mode
Support artifacts in prepopulate
This commit is contained in:
allegroai 2021-01-05 16:56:08 +02:00
parent e12fd8f3df
commit 1c7de3a86e
9 changed files with 432 additions and 24 deletions

View File

@ -214,3 +214,21 @@ class JsonSerializableMixin:
@classmethod @classmethod
def from_json(cls: Type[ModelBase], s): def from_json(cls: Type[ModelBase], s):
return cls(**loads(s)) return cls(**loads(s))
def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]:
class _Wrapped(cls):
_callable_default = None
def get_default_value(self):
if self._callable_default:
return self._callable_default()
return super(_Wrapped, self).get_default_value()
def __init__(self, *args, default=None, **kwargs):
if default and callable(default):
self._callable_default = default
default = default()
super(_Wrapped, self).__init__(*args, default=default, **kwargs)
return _Wrapped

View File

@ -0,0 +1,32 @@
from jsonmodels.fields import StringField, BoolField, EmbeddedField
from jsonmodels.models import Base
from apiserver.apimodels import DictField, callable_default
class GetSupportedModesRequest(Base):
state = StringField(help_text="ASCII base64 encoded application state")
callback_url_prefix = StringField()
class BasicGuestMode(Base):
enabled = BoolField(default=False)
name = StringField()
username = StringField()
password = StringField()
class BasicMode(Base):
enabled = BoolField(default=False)
guest = callable_default(EmbeddedField)(BasicGuestMode, default=BasicGuestMode)
class ServerErrors(Base):
missed_es_upgrade = BoolField(default=False)
es_connection_error = BoolField(default=False)
class GetSupportedModesResponse(Base):
basic = EmbeddedField(BasicMode)
server_errors = EmbeddedField(ServerErrors)
sso = DictField([str, type(None)])

View File

@ -176,3 +176,28 @@ class ActivityReportSeries(Base):
class GetActivityReportResponse(Base): class GetActivityReportResponse(Base):
total = EmbeddedField(ActivityReportSeries) total = EmbeddedField(ActivityReportSeries)
active = EmbeddedField(ActivityReportSeries) active = EmbeddedField(ActivityReportSeries)
class RuntimeProperty(Base):
key = StringField()
value = StringField()
expiry = IntField(default=None)
class GetRuntimePropertiesRequest(Base):
worker = StringField(required=True)
class GetRuntimePropertiesResponse(Base):
runtime_properties = ListField(RuntimeProperty)
class SetRuntimePropertiesRequest(Base):
worker = StringField(required=True)
runtime_properties = ListField(RuntimeProperty)
class SetRuntimePropertiesResponse(Base):
added = ListField(str)
removed = ListField(str)
errors = ListField(str)

View File

@ -1,6 +1,6 @@
import itertools import itertools
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Sequence, Set, Optional from typing import Sequence, Set, Optional, List
import attr import attr
import elasticsearch.helpers import elasticsearch.helpers
@ -16,6 +16,7 @@ from apiserver.apimodels.workers import (
WorkerResponseEntry, WorkerResponseEntry,
QueueEntry, QueueEntry,
MachineStats, MachineStats,
RuntimeProperty,
) )
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
@ -416,6 +417,66 @@ class WorkerBLL:
added, errors = es_res[:2] added, errors = es_res[:2]
return (added == len(actions)) and not errors return (added == len(actions)) and not errors
def set_runtime_properties(
self,
company: str,
user: str,
worker_id: str,
runtime_properties: List[RuntimeProperty],
) -> dict:
"""Save worker entry in Redis"""
res = {
"added": [],
"removed": [],
"errors": [],
}
for prop in runtime_properties:
try:
key = self._get_runtime_property_key(company, user, worker_id, prop.key)
if prop.expiry == 0:
self.redis.delete(key)
res["removed"].append(key)
else:
self.redis.set(
key,
prop.value,
ex=prop.expiry
)
res["added"].append(key)
except Exception as ex:
msg = f"Exception: {ex}\nFailed saving property '{prop.key}: {prop.value}', skipping"
log.exception(msg)
res["errors"].append(ex)
return res
def get_runtime_properties(
self,
company: str,
user: str,
worker_id: str,
) -> List[RuntimeProperty]:
match = self._get_runtime_property_key(company, user, worker_id, "*")
with TimingContext("redis", "get_runtime_properties"):
res = self.redis.scan_iter(match=match)
runtime_properties = []
for r in res:
ttl = self.redis.ttl(r)
runtime_properties.append(
RuntimeProperty(
key=r.decode()[len(match) - 1:],
value=self.redis.get(r).decode(),
expiry=ttl if ttl >= 0 else None
)
)
return runtime_properties
def _get_runtime_property_key(
self, company: str, user: str, worker_id: str, prop_id: str
) -> str:
"""Build redis key from company, user, worker_id and prop_id"""
prefix = self._get_worker_key(company, user, worker_id)
return f"{prefix}_prop_{prop_id}"
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class WorkerConversionHelper: class WorkerConversionHelper:

View File

@ -0,0 +1,33 @@
from apiserver.apimodels.login import (
GetSupportedModesRequest,
GetSupportedModesResponse,
BasicMode,
BasicGuestMode,
ServerErrors,
)
from apiserver.config import info
from apiserver.service_repo import endpoint
from apiserver.service_repo.auth.fixed_user import FixedUser
@endpoint("login.supported_modes", response_data_model=GetSupportedModesResponse)
def supported_modes(_, __, ___: GetSupportedModesRequest):
guest_user = FixedUser.get_guest_user()
if guest_user:
guest = BasicGuestMode(
enabled=True,
name=guest_user.name,
username=guest_user.username,
password=guest_user.password,
)
else:
guest = BasicGuestMode()
return GetSupportedModesResponse(
basic=BasicMode(enabled=FixedUser.enabled(), guest=guest),
sso={},
server_errors=ServerErrors(
missed_es_upgrade=info.missed_es_upgrade,
es_connection_error=info.es_connection_error,
),
)

View File

@ -4,7 +4,6 @@ import os
import re import re
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial
from io import BytesIO from io import BytesIO
from itertools import chain from itertools import chain
from operator import attrgetter from operator import attrgetter
@ -21,17 +20,19 @@ from typing import (
BinaryIO, BinaryIO,
Union, Union,
Mapping, Mapping,
IO,
) )
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2 from zipfile import ZipFile, ZIP_BZIP2
import dpath import dpath
import mongoengine import mongoengine
from boltons.iterutils import chunked_iter from boltons.iterutils import chunked_iter, first
from furl import furl from furl import furl
from mongoengine import Q from mongoengine import Q
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.task.artifacts import get_artifact_id
from apiserver.bll.task.param_utils import ( from apiserver.bll.task.param_utils import (
split_param_name, split_param_name,
hyperparams_default_section, hyperparams_default_section,
@ -46,6 +47,7 @@ from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
from apiserver.database.utils import get_options from apiserver.database.utils import get_options
from apiserver.tools import safe_get from apiserver.tools import safe_get
from apiserver.utilities import json from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set
from .user import _ensure_backend_user from .user import _ensure_backend_user
@ -344,7 +346,7 @@ class PrePopulate:
dpath.delete(task_data, old_param_field) dpath.delete(task_data, old_param_field)
@classmethod @classmethod
def _upgrade_tasks(cls, f: BinaryIO) -> bytes: def _upgrade_tasks(cls, f: IO[bytes]) -> bytes:
""" """
Build content array that contains fixed tasks from the passed file Build content array that contains fixed tasks from the passed file
For each task the old execution.parameters and model.design are For each task the old execution.parameters and model.design are
@ -564,13 +566,13 @@ class PrePopulate:
if not task.execution.artifacts: if not task.execution.artifacts:
return [] return []
for a in task.execution.artifacts: for a in task.execution.artifacts.values():
if a.mode == ArtifactModes.output: if a.mode == ArtifactModes.output:
a.uri = cls._get_fixed_url(a.uri) a.uri = cls._get_fixed_url(a.uri)
return [ return [
a.uri a.uri
for a in task.execution.artifacts for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri if a.mode == ArtifactModes.output and a.uri
] ]
@ -630,7 +632,7 @@ class PrePopulate:
return artifacts return artifacts
@staticmethod @staticmethod
def json_lines(file: BinaryIO): def json_lines(file: IO[bytes]):
for line in file: for line in file:
clean = ( clean = (
line.decode("utf-8") line.decode("utf-8")
@ -651,6 +653,7 @@ class PrePopulate:
company_id: str = "", company_id: str = "",
user_id: str = None, user_id: str = None,
metadata: Mapping[str, Any] = None, metadata: Mapping[str, Any] = None,
sort_tasks_by_last_updated: bool = True,
): ):
""" """
Import entities and events from the zip file Import entities and events from the zip file
@ -663,35 +666,60 @@ class PrePopulate:
if not fi.orig_filename.endswith(event_file_ending) if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename and fi.orig_filename != cls.metadata_filename
) )
event_files = ( metadata = metadata or {}
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending) tasks = []
) for entity_file in entity_files:
for files, reader_func in ( with reader.open(entity_file) as f:
(entity_files, partial(cls._import_entity, metadata=metadata or {})), full_name = splitext(entity_file.orig_filename)[0]
(event_files, cls._import_events), print(f"Reading {reader.filename}:{full_name}...")
): res = cls._import_entity(f, full_name, company_id, user_id, metadata)
for file_info in files: if res:
with reader.open(file_info) as f: tasks = res
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...") if sort_tasks_by_last_updated:
reader_func(f, full_name, company_id, user_id) tasks = sorted(tasks, key=attrgetter("last_update"))
for task in tasks:
events_file = first(
fi
for fi in reader.filelist
if fi.orig_filename.endswith(task.id + event_file_ending)
)
if not events_file:
continue
with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, full_name, company_id, user_id)
@classmethod @classmethod
def _import_entity( def _import_entity(
cls, cls,
f: BinaryIO, f: IO[bytes],
full_name: str, full_name: str,
company_id: str, company_id: str,
user_id: str, user_id: str,
metadata: Mapping[str, Any], metadata: Mapping[str, Any],
): ) -> Optional[Sequence[Task]]:
module_name, _, class_name = full_name.rpartition(".") module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name) cls_: Type[mongoengine.Document] = getattr(module, class_name)
print(f"Writing {cls_.__name__.lower()}s into database") print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
override_project_count = 0 override_project_count = 0
for item in cls.json_lines(f): for item in cls.json_lines(f):
if cls_ == Task:
task_data = json.loads(item)
artifacts_path = ("execution", "artifacts")
artifacts = nested_get(task_data, artifacts_path)
if isinstance(artifacts, list):
nested_set(
task_data,
artifacts_path,
value={get_artifact_id(a): a for a in artifacts},
)
item = json.dumps(task_data)
doc = cls_.from_json(item, created=True) doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"): if hasattr(doc, "user"):
doc.user = user_id doc.user = user_id
@ -717,10 +745,14 @@ class PrePopulate:
doc.save() doc.save()
if isinstance(doc, Task): if isinstance(doc, Task):
tasks.append(doc)
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True) cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
if tasks:
return tasks
@classmethod @classmethod
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _): def _import_events(cls, f: IO[bytes], full_name: str, company_id: str, _):
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_") _, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
print(f"Writing events for task {task_id} into database") print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000): for events_chunk in chunked_iter(cls.json_lines(f), 1000):

View File

@ -0,0 +1,83 @@
_description: """This service provides an administrator management interface to the company's users login information."""
_default {
internal: false
allow_roles: ["system", "root", "admin"]
}
supported_modes {
authorize: false
"0.17" {
description: """ Return supported login modes."""
request {
type: object
properties {
state {
description: "ASCII base64 encoded application state"
type: string
}
callback_url_prefix {
description: "URL prefix used to generate the callback URL for each supported SSO provider"
type: string
}
}
}
response {
type: object
properties {
basic {
type: object
properties {
enabled {
description: "Basic aothentication (fixed users mode) mode enabled"
type: boolean
}
guest {
type: object
properties {
enabled {
description: "Basic aothentication guest mode enabled"
type: boolean
}
name {
description: "Guest name"
type: string
}
username {
description: "Guest username"
type: string
}
password {
description: "Guest password"
type: string
}
}
}
}
}
sso {
description: "SSO authentication providers"
type: object
additionalProperties {
desctiprion: "Provider redirect URL"
type: string
}
}
server_errors {
description: "Server initialization errors"
type: object
properties {
missed_es_upgrade {
description: "Indicate that Elasticsearch database was not upgraded from version 5"
type: boolean
}
es_connection_error {
description: "Indicate an error communicating to Elasticsearch"
type: boolean
}
}
}
}
}
}
}

View File

@ -499,4 +499,92 @@
} }
} }
} }
} get_runtime_properties {
"2.10" {
description: "Get runtime properties for a worker"
request {
required: [
worker
]
type: object
properties {
worker {
description: "Worker ID"
type: string
}
}
}
response {
type: object
properties {
runtime_properties {
type: array
items {
type: object
properties {
key { type: string }
value { type: string }
expiry {
description: "Expiry (in seconds) for a runtime property"
type: integer
}
}
}
}
}
}
}
}
set_runtime_properties {
"2.10" {
description: "Set runtime properties for a worker"
request {
required: [
worker
runtime_properties
]
type: object
properties {
worker {
description: "Worker ID"
type: string
}
runtime_properties {
type: array
items {
type: object
properties {
key { type: string }
value { type: string }
expiry {
description: "Expiry (in seconds) for a runtime property. When set to null no expiry is set, when set to 0 the specified key is removed"
type: integer
}
}
}
}
}
}
response {
type: object
properties {
added {
type: array
description: "keys of runtime properties added to redis"
items: { type: string }
}
removed {
type: array
description: "keys of runtime properties removed from redis"
items: { type: string }
}
errors {
type: array
description: "errors for keys failed to be added to redis"
items: { type: string }
}
}
}
}
}
}

View File

@ -22,6 +22,10 @@ from apiserver.apimodels.workers import (
GetActivityReportRequest, GetActivityReportRequest,
GetActivityReportResponse, GetActivityReportResponse,
ActivityReportSeries, ActivityReportSeries,
SetRuntimePropertiesRequest,
GetRuntimePropertiesRequest,
GetRuntimePropertiesResponse,
SetRuntimePropertiesResponse
) )
from apiserver.bll.util import extract_properties_to_lists from apiserver.bll.util import extract_properties_to_lists
from apiserver.bll.workers import WorkerBLL from apiserver.bll.workers import WorkerBLL
@ -202,3 +206,35 @@ def get_stats(call: APICall, company_id, request: GetStatsRequest):
for worker, stats in ret.items() for worker, stats in ret.items()
] ]
) )
@endpoint(
"workers.set_runtime_properties",
min_version="2.10",
request_data_model=SetRuntimePropertiesRequest,
response_data_model=SetRuntimePropertiesResponse,
)
def set_runtime_properties(call: APICall, company_id, request: SetRuntimePropertiesRequest):
res = worker_bll.set_runtime_properties(
company=company_id,
user=call.identity.user,
worker_id=request.worker,
runtime_properties=request.runtime_properties,
)
return SetRuntimePropertiesResponse(added=res["added"], removed=res["removed"], errors=res["errors"])
@endpoint(
"workers.get_runtime_properties",
min_version="2.10",
request_data_model=GetRuntimePropertiesRequest,
response_data_model=GetRuntimePropertiesResponse,
)
def get_runtime_properties(call: APICall, company_id, request: GetRuntimePropertiesRequest):
return GetRuntimePropertiesResponse(
runtime_properties=worker_bll.get_runtime_properties(
company=company_id,
user=call.identity.user,
worker_id=request.worker,
)
)