Add worker runtime properties support

Refactor login and add guest mode
Support artifacts in prepopulate
This commit is contained in:
allegroai 2021-01-05 16:56:08 +02:00
parent e12fd8f3df
commit 1c7de3a86e
9 changed files with 432 additions and 24 deletions

View File

@ -214,3 +214,21 @@ class JsonSerializableMixin:
@classmethod
def from_json(cls: Type[ModelBase], s):
return cls(**loads(s))
def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]:
class _Wrapped(cls):
_callable_default = None
def get_default_value(self):
if self._callable_default:
return self._callable_default()
return super(_Wrapped, self).get_default_value()
def __init__(self, *args, default=None, **kwargs):
if default and callable(default):
self._callable_default = default
default = default()
super(_Wrapped, self).__init__(*args, default=default, **kwargs)
return _Wrapped

View File

@ -0,0 +1,32 @@
from jsonmodels.fields import StringField, BoolField, EmbeddedField
from jsonmodels.models import Base
from apiserver.apimodels import DictField, callable_default
class GetSupportedModesRequest(Base):
state = StringField(help_text="ASCII base64 encoded application state")
callback_url_prefix = StringField()
class BasicGuestMode(Base):
enabled = BoolField(default=False)
name = StringField()
username = StringField()
password = StringField()
class BasicMode(Base):
enabled = BoolField(default=False)
guest = callable_default(EmbeddedField)(BasicGuestMode, default=BasicGuestMode)
class ServerErrors(Base):
missed_es_upgrade = BoolField(default=False)
es_connection_error = BoolField(default=False)
class GetSupportedModesResponse(Base):
basic = EmbeddedField(BasicMode)
server_errors = EmbeddedField(ServerErrors)
sso = DictField([str, type(None)])

View File

@ -176,3 +176,28 @@ class ActivityReportSeries(Base):
class GetActivityReportResponse(Base):
total = EmbeddedField(ActivityReportSeries)
active = EmbeddedField(ActivityReportSeries)
class RuntimeProperty(Base):
key = StringField()
value = StringField()
expiry = IntField(default=None)
class GetRuntimePropertiesRequest(Base):
worker = StringField(required=True)
class GetRuntimePropertiesResponse(Base):
runtime_properties = ListField(RuntimeProperty)
class SetRuntimePropertiesRequest(Base):
worker = StringField(required=True)
runtime_properties = ListField(RuntimeProperty)
class SetRuntimePropertiesResponse(Base):
added = ListField(str)
removed = ListField(str)
errors = ListField(str)

View File

@ -1,6 +1,6 @@
import itertools
from datetime import datetime, timedelta
from typing import Sequence, Set, Optional
from typing import Sequence, Set, Optional, List
import attr
import elasticsearch.helpers
@ -16,6 +16,7 @@ from apiserver.apimodels.workers import (
WorkerResponseEntry,
QueueEntry,
MachineStats,
RuntimeProperty,
)
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@ -416,6 +417,66 @@ class WorkerBLL:
added, errors = es_res[:2]
return (added == len(actions)) and not errors
def set_runtime_properties(
self,
company: str,
user: str,
worker_id: str,
runtime_properties: List[RuntimeProperty],
) -> dict:
"""Save worker entry in Redis"""
res = {
"added": [],
"removed": [],
"errors": [],
}
for prop in runtime_properties:
try:
key = self._get_runtime_property_key(company, user, worker_id, prop.key)
if prop.expiry == 0:
self.redis.delete(key)
res["removed"].append(key)
else:
self.redis.set(
key,
prop.value,
ex=prop.expiry
)
res["added"].append(key)
except Exception as ex:
msg = f"Exception: {ex}\nFailed saving property '{prop.key}: {prop.value}', skipping"
log.exception(msg)
res["errors"].append(ex)
return res
def get_runtime_properties(
self,
company: str,
user: str,
worker_id: str,
) -> List[RuntimeProperty]:
match = self._get_runtime_property_key(company, user, worker_id, "*")
with TimingContext("redis", "get_runtime_properties"):
res = self.redis.scan_iter(match=match)
runtime_properties = []
for r in res:
ttl = self.redis.ttl(r)
runtime_properties.append(
RuntimeProperty(
key=r.decode()[len(match) - 1:],
value=self.redis.get(r).decode(),
expiry=ttl if ttl >= 0 else None
)
)
return runtime_properties
def _get_runtime_property_key(
self, company: str, user: str, worker_id: str, prop_id: str
) -> str:
"""Build redis key from company, user, worker_id and prop_id"""
prefix = self._get_worker_key(company, user, worker_id)
return f"{prefix}_prop_{prop_id}"
@attr.s(auto_attribs=True)
class WorkerConversionHelper:

View File

@ -0,0 +1,33 @@
from apiserver.apimodels.login import (
GetSupportedModesRequest,
GetSupportedModesResponse,
BasicMode,
BasicGuestMode,
ServerErrors,
)
from apiserver.config import info
from apiserver.service_repo import endpoint
from apiserver.service_repo.auth.fixed_user import FixedUser
@endpoint("login.supported_modes", response_data_model=GetSupportedModesResponse)
def supported_modes(_, __, ___: GetSupportedModesRequest):
guest_user = FixedUser.get_guest_user()
if guest_user:
guest = BasicGuestMode(
enabled=True,
name=guest_user.name,
username=guest_user.username,
password=guest_user.password,
)
else:
guest = BasicGuestMode()
return GetSupportedModesResponse(
basic=BasicMode(enabled=FixedUser.enabled(), guest=guest),
sso={},
server_errors=ServerErrors(
missed_es_upgrade=info.missed_es_upgrade,
es_connection_error=info.es_connection_error,
),
)

View File

@ -4,7 +4,6 @@ import os
import re
from collections import defaultdict
from datetime import datetime, timezone
from functools import partial
from io import BytesIO
from itertools import chain
from operator import attrgetter
@ -21,17 +20,19 @@ from typing import (
BinaryIO,
Union,
Mapping,
IO,
)
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
import dpath
import mongoengine
from boltons.iterutils import chunked_iter
from boltons.iterutils import chunked_iter, first
from furl import furl
from mongoengine import Q
from apiserver.bll.event import EventBLL
from apiserver.bll.task.artifacts import get_artifact_id
from apiserver.bll.task.param_utils import (
split_param_name,
hyperparams_default_section,
@ -46,6 +47,7 @@ from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
from apiserver.database.utils import get_options
from apiserver.tools import safe_get
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set
from .user import _ensure_backend_user
@ -344,7 +346,7 @@ class PrePopulate:
dpath.delete(task_data, old_param_field)
@classmethod
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
def _upgrade_tasks(cls, f: IO[bytes]) -> bytes:
"""
Build content array that contains fixed tasks from the passed file
For each task the old execution.parameters and model.design are
@ -564,13 +566,13 @@ class PrePopulate:
if not task.execution.artifacts:
return []
for a in task.execution.artifacts:
for a in task.execution.artifacts.values():
if a.mode == ArtifactModes.output:
a.uri = cls._get_fixed_url(a.uri)
return [
a.uri
for a in task.execution.artifacts
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
]
@ -630,7 +632,7 @@ class PrePopulate:
return artifacts
@staticmethod
def json_lines(file: BinaryIO):
def json_lines(file: IO[bytes]):
for line in file:
clean = (
line.decode("utf-8")
@ -651,6 +653,7 @@ class PrePopulate:
company_id: str = "",
user_id: str = None,
metadata: Mapping[str, Any] = None,
sort_tasks_by_last_updated: bool = True,
):
"""
Import entities and events from the zip file
@ -663,35 +666,60 @@ class PrePopulate:
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
)
event_files = (
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
)
for files, reader_func in (
(entity_files, partial(cls._import_entity, metadata=metadata or {})),
(event_files, cls._import_events),
):
for file_info in files:
with reader.open(file_info) as f:
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
reader_func(f, full_name, company_id, user_id)
metadata = metadata or {}
tasks = []
for entity_file in entity_files:
with reader.open(entity_file) as f:
full_name = splitext(entity_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
res = cls._import_entity(f, full_name, company_id, user_id, metadata)
if res:
tasks = res
if sort_tasks_by_last_updated:
tasks = sorted(tasks, key=attrgetter("last_update"))
for task in tasks:
events_file = first(
fi
for fi in reader.filelist
if fi.orig_filename.endswith(task.id + event_file_ending)
)
if not events_file:
continue
with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, full_name, company_id, user_id)
@classmethod
def _import_entity(
cls,
f: BinaryIO,
f: IO[bytes],
full_name: str,
company_id: str,
user_id: str,
metadata: Mapping[str, Any],
):
) -> Optional[Sequence[Task]]:
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
override_project_count = 0
for item in cls.json_lines(f):
if cls_ == Task:
task_data = json.loads(item)
artifacts_path = ("execution", "artifacts")
artifacts = nested_get(task_data, artifacts_path)
if isinstance(artifacts, list):
nested_set(
task_data,
artifacts_path,
value={get_artifact_id(a): a for a in artifacts},
)
item = json.dumps(task_data)
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):
doc.user = user_id
@ -717,10 +745,14 @@ class PrePopulate:
doc.save()
if isinstance(doc, Task):
tasks.append(doc)
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
if tasks:
return tasks
@classmethod
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
def _import_events(cls, f: IO[bytes], full_name: str, company_id: str, _):
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):

View File

@ -0,0 +1,83 @@
_description: """This service provides an administrator management interface to the company's users login information."""
_default {
internal: false
allow_roles: ["system", "root", "admin"]
}
supported_modes {
authorize: false
"0.17" {
description: """ Return supported login modes."""
request {
type: object
properties {
state {
description: "ASCII base64 encoded application state"
type: string
}
callback_url_prefix {
description: "URL prefix used to generate the callback URL for each supported SSO provider"
type: string
}
}
}
response {
type: object
properties {
basic {
type: object
properties {
enabled {
description: "Basic aothentication (fixed users mode) mode enabled"
type: boolean
}
guest {
type: object
properties {
enabled {
description: "Basic aothentication guest mode enabled"
type: boolean
}
name {
description: "Guest name"
type: string
}
username {
description: "Guest username"
type: string
}
password {
description: "Guest password"
type: string
}
}
}
}
}
sso {
description: "SSO authentication providers"
type: object
additionalProperties {
desctiprion: "Provider redirect URL"
type: string
}
}
server_errors {
description: "Server initialization errors"
type: object
properties {
missed_es_upgrade {
description: "Indicate that Elasticsearch database was not upgraded from version 5"
type: boolean
}
es_connection_error {
description: "Indicate an error communicating to Elasticsearch"
type: boolean
}
}
}
}
}
}
}

View File

@ -499,4 +499,92 @@
}
}
}
get_runtime_properties {
"2.10" {
description: "Get runtime properties for a worker"
request {
required: [
worker
]
type: object
properties {
worker {
description: "Worker ID"
type: string
}
}
}
response {
type: object
properties {
runtime_properties {
type: array
items {
type: object
properties {
key { type: string }
value { type: string }
expiry {
description: "Expiry (in seconds) for a runtime property"
type: integer
}
}
}
}
}
}
}
}
set_runtime_properties {
"2.10" {
description: "Set runtime properties for a worker"
request {
required: [
worker
runtime_properties
]
type: object
properties {
worker {
description: "Worker ID"
type: string
}
runtime_properties {
type: array
items {
type: object
properties {
key { type: string }
value { type: string }
expiry {
description: "Expiry (in seconds) for a runtime property. When set to null no expiry is set, when set to 0 the specified key is removed"
type: integer
}
}
}
}
}
}
response {
type: object
properties {
added {
type: array
description: "keys of runtime properties added to redis"
items: { type: string }
}
removed {
type: array
description: "keys of runtime properties removed from redis"
items: { type: string }
}
errors {
type: array
description: "errors for keys failed to be added to redis"
items: { type: string }
}
}
}
}
}
}

View File

@ -22,6 +22,10 @@ from apiserver.apimodels.workers import (
GetActivityReportRequest,
GetActivityReportResponse,
ActivityReportSeries,
SetRuntimePropertiesRequest,
GetRuntimePropertiesRequest,
GetRuntimePropertiesResponse,
SetRuntimePropertiesResponse
)
from apiserver.bll.util import extract_properties_to_lists
from apiserver.bll.workers import WorkerBLL
@ -202,3 +206,35 @@ def get_stats(call: APICall, company_id, request: GetStatsRequest):
for worker, stats in ret.items()
]
)
@endpoint(
"workers.set_runtime_properties",
min_version="2.10",
request_data_model=SetRuntimePropertiesRequest,
response_data_model=SetRuntimePropertiesResponse,
)
def set_runtime_properties(call: APICall, company_id, request: SetRuntimePropertiesRequest):
res = worker_bll.set_runtime_properties(
company=company_id,
user=call.identity.user,
worker_id=request.worker,
runtime_properties=request.runtime_properties,
)
return SetRuntimePropertiesResponse(added=res["added"], removed=res["removed"], errors=res["errors"])
@endpoint(
"workers.get_runtime_properties",
min_version="2.10",
request_data_model=GetRuntimePropertiesRequest,
response_data_model=GetRuntimePropertiesResponse,
)
def get_runtime_properties(call: APICall, company_id, request: GetRuntimePropertiesRequest):
return GetRuntimePropertiesResponse(
runtime_properties=worker_bll.get_runtime_properties(
company=company_id,
user=call.identity.user,
worker_id=request.worker,
)
)