mirror of
https://github.com/clearml/clearml-server
synced 2025-04-08 06:54:08 +00:00
Add worker runtime properties support
Refactor login and add guest mode Support artifacts in prepopulate
This commit is contained in:
parent
e12fd8f3df
commit
1c7de3a86e
@ -214,3 +214,21 @@ class JsonSerializableMixin:
|
||||
@classmethod
|
||||
def from_json(cls: Type[ModelBase], s):
|
||||
return cls(**loads(s))
|
||||
|
||||
|
||||
def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]:
|
||||
class _Wrapped(cls):
|
||||
_callable_default = None
|
||||
|
||||
def get_default_value(self):
|
||||
if self._callable_default:
|
||||
return self._callable_default()
|
||||
return super(_Wrapped, self).get_default_value()
|
||||
|
||||
def __init__(self, *args, default=None, **kwargs):
|
||||
if default and callable(default):
|
||||
self._callable_default = default
|
||||
default = default()
|
||||
super(_Wrapped, self).__init__(*args, default=default, **kwargs)
|
||||
|
||||
return _Wrapped
|
||||
|
32
apiserver/apimodels/login.py
Normal file
32
apiserver/apimodels/login.py
Normal file
@ -0,0 +1,32 @@
|
||||
from jsonmodels.fields import StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import DictField, callable_default
|
||||
|
||||
|
||||
class GetSupportedModesRequest(Base):
|
||||
state = StringField(help_text="ASCII base64 encoded application state")
|
||||
callback_url_prefix = StringField()
|
||||
|
||||
|
||||
class BasicGuestMode(Base):
|
||||
enabled = BoolField(default=False)
|
||||
name = StringField()
|
||||
username = StringField()
|
||||
password = StringField()
|
||||
|
||||
|
||||
class BasicMode(Base):
|
||||
enabled = BoolField(default=False)
|
||||
guest = callable_default(EmbeddedField)(BasicGuestMode, default=BasicGuestMode)
|
||||
|
||||
|
||||
class ServerErrors(Base):
|
||||
missed_es_upgrade = BoolField(default=False)
|
||||
es_connection_error = BoolField(default=False)
|
||||
|
||||
|
||||
class GetSupportedModesResponse(Base):
|
||||
basic = EmbeddedField(BasicMode)
|
||||
server_errors = EmbeddedField(ServerErrors)
|
||||
sso = DictField([str, type(None)])
|
@ -176,3 +176,28 @@ class ActivityReportSeries(Base):
|
||||
class GetActivityReportResponse(Base):
|
||||
total = EmbeddedField(ActivityReportSeries)
|
||||
active = EmbeddedField(ActivityReportSeries)
|
||||
|
||||
|
||||
class RuntimeProperty(Base):
|
||||
key = StringField()
|
||||
value = StringField()
|
||||
expiry = IntField(default=None)
|
||||
|
||||
|
||||
class GetRuntimePropertiesRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
|
||||
|
||||
class GetRuntimePropertiesResponse(Base):
|
||||
runtime_properties = ListField(RuntimeProperty)
|
||||
|
||||
|
||||
class SetRuntimePropertiesRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
runtime_properties = ListField(RuntimeProperty)
|
||||
|
||||
|
||||
class SetRuntimePropertiesResponse(Base):
|
||||
added = ListField(str)
|
||||
removed = ListField(str)
|
||||
errors = ListField(str)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Sequence, Set, Optional
|
||||
from typing import Sequence, Set, Optional, List
|
||||
|
||||
import attr
|
||||
import elasticsearch.helpers
|
||||
@ -16,6 +16,7 @@ from apiserver.apimodels.workers import (
|
||||
WorkerResponseEntry,
|
||||
QueueEntry,
|
||||
MachineStats,
|
||||
RuntimeProperty,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@ -416,6 +417,66 @@ class WorkerBLL:
|
||||
added, errors = es_res[:2]
|
||||
return (added == len(actions)) and not errors
|
||||
|
||||
def set_runtime_properties(
|
||||
self,
|
||||
company: str,
|
||||
user: str,
|
||||
worker_id: str,
|
||||
runtime_properties: List[RuntimeProperty],
|
||||
) -> dict:
|
||||
"""Save worker entry in Redis"""
|
||||
res = {
|
||||
"added": [],
|
||||
"removed": [],
|
||||
"errors": [],
|
||||
}
|
||||
for prop in runtime_properties:
|
||||
try:
|
||||
key = self._get_runtime_property_key(company, user, worker_id, prop.key)
|
||||
if prop.expiry == 0:
|
||||
self.redis.delete(key)
|
||||
res["removed"].append(key)
|
||||
else:
|
||||
self.redis.set(
|
||||
key,
|
||||
prop.value,
|
||||
ex=prop.expiry
|
||||
)
|
||||
res["added"].append(key)
|
||||
except Exception as ex:
|
||||
msg = f"Exception: {ex}\nFailed saving property '{prop.key}: {prop.value}', skipping"
|
||||
log.exception(msg)
|
||||
res["errors"].append(ex)
|
||||
return res
|
||||
|
||||
def get_runtime_properties(
|
||||
self,
|
||||
company: str,
|
||||
user: str,
|
||||
worker_id: str,
|
||||
) -> List[RuntimeProperty]:
|
||||
match = self._get_runtime_property_key(company, user, worker_id, "*")
|
||||
with TimingContext("redis", "get_runtime_properties"):
|
||||
res = self.redis.scan_iter(match=match)
|
||||
runtime_properties = []
|
||||
for r in res:
|
||||
ttl = self.redis.ttl(r)
|
||||
runtime_properties.append(
|
||||
RuntimeProperty(
|
||||
key=r.decode()[len(match) - 1:],
|
||||
value=self.redis.get(r).decode(),
|
||||
expiry=ttl if ttl >= 0 else None
|
||||
)
|
||||
)
|
||||
return runtime_properties
|
||||
|
||||
def _get_runtime_property_key(
|
||||
self, company: str, user: str, worker_id: str, prop_id: str
|
||||
) -> str:
|
||||
"""Build redis key from company, user, worker_id and prop_id"""
|
||||
prefix = self._get_worker_key(company, user, worker_id)
|
||||
return f"{prefix}_prop_{prop_id}"
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class WorkerConversionHelper:
|
||||
|
33
apiserver/login/__init__.py
Normal file
33
apiserver/login/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
from apiserver.apimodels.login import (
|
||||
GetSupportedModesRequest,
|
||||
GetSupportedModesResponse,
|
||||
BasicMode,
|
||||
BasicGuestMode,
|
||||
ServerErrors,
|
||||
)
|
||||
from apiserver.config import info
|
||||
from apiserver.service_repo import endpoint
|
||||
from apiserver.service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
|
||||
@endpoint("login.supported_modes", response_data_model=GetSupportedModesResponse)
|
||||
def supported_modes(_, __, ___: GetSupportedModesRequest):
|
||||
guest_user = FixedUser.get_guest_user()
|
||||
if guest_user:
|
||||
guest = BasicGuestMode(
|
||||
enabled=True,
|
||||
name=guest_user.name,
|
||||
username=guest_user.username,
|
||||
password=guest_user.password,
|
||||
)
|
||||
else:
|
||||
guest = BasicGuestMode()
|
||||
|
||||
return GetSupportedModesResponse(
|
||||
basic=BasicMode(enabled=FixedUser.enabled(), guest=guest),
|
||||
sso={},
|
||||
server_errors=ServerErrors(
|
||||
missed_es_upgrade=info.missed_es_upgrade,
|
||||
es_connection_error=info.es_connection_error,
|
||||
),
|
||||
)
|
@ -4,7 +4,6 @@ import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
@ -21,17 +20,19 @@ from typing import (
|
||||
BinaryIO,
|
||||
Union,
|
||||
Mapping,
|
||||
IO,
|
||||
)
|
||||
from urllib.parse import unquote, urlparse
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import dpath
|
||||
import mongoengine
|
||||
from boltons.iterutils import chunked_iter
|
||||
from boltons.iterutils import chunked_iter, first
|
||||
from furl import furl
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.task.artifacts import get_artifact_id
|
||||
from apiserver.bll.task.param_utils import (
|
||||
split_param_name,
|
||||
hyperparams_default_section,
|
||||
@ -46,6 +47,7 @@ from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from .user import _ensure_backend_user
|
||||
|
||||
|
||||
@ -344,7 +346,7 @@ class PrePopulate:
|
||||
dpath.delete(task_data, old_param_field)
|
||||
|
||||
@classmethod
|
||||
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
|
||||
def _upgrade_tasks(cls, f: IO[bytes]) -> bytes:
|
||||
"""
|
||||
Build content array that contains fixed tasks from the passed file
|
||||
For each task the old execution.parameters and model.design are
|
||||
@ -564,13 +566,13 @@ class PrePopulate:
|
||||
if not task.execution.artifacts:
|
||||
return []
|
||||
|
||||
for a in task.execution.artifacts:
|
||||
for a in task.execution.artifacts.values():
|
||||
if a.mode == ArtifactModes.output:
|
||||
a.uri = cls._get_fixed_url(a.uri)
|
||||
|
||||
return [
|
||||
a.uri
|
||||
for a in task.execution.artifacts
|
||||
for a in task.execution.artifacts.values()
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
]
|
||||
|
||||
@ -630,7 +632,7 @@ class PrePopulate:
|
||||
return artifacts
|
||||
|
||||
@staticmethod
|
||||
def json_lines(file: BinaryIO):
|
||||
def json_lines(file: IO[bytes]):
|
||||
for line in file:
|
||||
clean = (
|
||||
line.decode("utf-8")
|
||||
@ -651,6 +653,7 @@ class PrePopulate:
|
||||
company_id: str = "",
|
||||
user_id: str = None,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
sort_tasks_by_last_updated: bool = True,
|
||||
):
|
||||
"""
|
||||
Import entities and events from the zip file
|
||||
@ -663,35 +666,60 @@ class PrePopulate:
|
||||
if not fi.orig_filename.endswith(event_file_ending)
|
||||
and fi.orig_filename != cls.metadata_filename
|
||||
)
|
||||
event_files = (
|
||||
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
|
||||
)
|
||||
for files, reader_func in (
|
||||
(entity_files, partial(cls._import_entity, metadata=metadata or {})),
|
||||
(event_files, cls._import_events),
|
||||
):
|
||||
for file_info in files:
|
||||
with reader.open(file_info) as f:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
reader_func(f, full_name, company_id, user_id)
|
||||
metadata = metadata or {}
|
||||
tasks = []
|
||||
for entity_file in entity_files:
|
||||
with reader.open(entity_file) as f:
|
||||
full_name = splitext(entity_file.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
res = cls._import_entity(f, full_name, company_id, user_id, metadata)
|
||||
if res:
|
||||
tasks = res
|
||||
|
||||
if sort_tasks_by_last_updated:
|
||||
tasks = sorted(tasks, key=attrgetter("last_update"))
|
||||
|
||||
for task in tasks:
|
||||
events_file = first(
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if fi.orig_filename.endswith(task.id + event_file_ending)
|
||||
)
|
||||
if not events_file:
|
||||
continue
|
||||
with reader.open(events_file) as f:
|
||||
full_name = splitext(events_file.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
cls._import_events(f, full_name, company_id, user_id)
|
||||
|
||||
@classmethod
|
||||
def _import_entity(
|
||||
cls,
|
||||
f: BinaryIO,
|
||||
f: IO[bytes],
|
||||
full_name: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
metadata: Mapping[str, Any],
|
||||
):
|
||||
) -> Optional[Sequence[Task]]:
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
|
||||
tasks = []
|
||||
override_project_count = 0
|
||||
for item in cls.json_lines(f):
|
||||
if cls_ == Task:
|
||||
task_data = json.loads(item)
|
||||
artifacts_path = ("execution", "artifacts")
|
||||
artifacts = nested_get(task_data, artifacts_path)
|
||||
if isinstance(artifacts, list):
|
||||
nested_set(
|
||||
task_data,
|
||||
artifacts_path,
|
||||
value={get_artifact_id(a): a for a in artifacts},
|
||||
)
|
||||
item = json.dumps(task_data)
|
||||
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
@ -717,10 +745,14 @@ class PrePopulate:
|
||||
doc.save()
|
||||
|
||||
if isinstance(doc, Task):
|
||||
tasks.append(doc)
|
||||
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
|
||||
|
||||
if tasks:
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
|
||||
def _import_events(cls, f: IO[bytes], full_name: str, company_id: str, _):
|
||||
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
|
||||
print(f"Writing events for task {task_id} into database")
|
||||
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
||||
|
83
apiserver/schema/services/login.conf
Normal file
83
apiserver/schema/services/login.conf
Normal file
@ -0,0 +1,83 @@
|
||||
_description: """This service provides an administrator management interface to the company's users login information."""
|
||||
|
||||
_default {
|
||||
internal: false
|
||||
allow_roles: ["system", "root", "admin"]
|
||||
}
|
||||
|
||||
supported_modes {
|
||||
authorize: false
|
||||
"0.17" {
|
||||
description: """ Return supported login modes."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
state {
|
||||
description: "ASCII base64 encoded application state"
|
||||
type: string
|
||||
}
|
||||
callback_url_prefix {
|
||||
description: "URL prefix used to generate the callback URL for each supported SSO provider"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
basic {
|
||||
type: object
|
||||
properties {
|
||||
enabled {
|
||||
description: "Basic aothentication (fixed users mode) mode enabled"
|
||||
type: boolean
|
||||
}
|
||||
guest {
|
||||
type: object
|
||||
properties {
|
||||
enabled {
|
||||
description: "Basic aothentication guest mode enabled"
|
||||
type: boolean
|
||||
}
|
||||
name {
|
||||
description: "Guest name"
|
||||
type: string
|
||||
}
|
||||
username {
|
||||
description: "Guest username"
|
||||
type: string
|
||||
}
|
||||
password {
|
||||
description: "Guest password"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
sso {
|
||||
description: "SSO authentication providers"
|
||||
type: object
|
||||
additionalProperties {
|
||||
desctiprion: "Provider redirect URL"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
server_errors {
|
||||
description: "Server initialization errors"
|
||||
type: object
|
||||
properties {
|
||||
missed_es_upgrade {
|
||||
description: "Indicate that Elasticsearch database was not upgraded from version 5"
|
||||
type: boolean
|
||||
}
|
||||
es_connection_error {
|
||||
description: "Indicate an error communicating to Elasticsearch"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -499,4 +499,92 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
get_runtime_properties {
|
||||
"2.10" {
|
||||
description: "Get runtime properties for a worker"
|
||||
request {
|
||||
required: [
|
||||
worker
|
||||
]
|
||||
type: object
|
||||
properties {
|
||||
worker {
|
||||
description: "Worker ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
runtime_properties {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
key { type: string }
|
||||
value { type: string }
|
||||
expiry {
|
||||
description: "Expiry (in seconds) for a runtime property"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
set_runtime_properties {
|
||||
"2.10" {
|
||||
description: "Set runtime properties for a worker"
|
||||
request {
|
||||
required: [
|
||||
worker
|
||||
runtime_properties
|
||||
]
|
||||
type: object
|
||||
properties {
|
||||
worker {
|
||||
description: "Worker ID"
|
||||
type: string
|
||||
}
|
||||
runtime_properties {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
key { type: string }
|
||||
value { type: string }
|
||||
expiry {
|
||||
description: "Expiry (in seconds) for a runtime property. When set to null no expiry is set, when set to 0 the specified key is removed"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
added {
|
||||
type: array
|
||||
description: "keys of runtime properties added to redis"
|
||||
items: { type: string }
|
||||
}
|
||||
removed {
|
||||
type: array
|
||||
description: "keys of runtime properties removed from redis"
|
||||
items: { type: string }
|
||||
}
|
||||
errors {
|
||||
type: array
|
||||
description: "errors for keys failed to be added to redis"
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -22,6 +22,10 @@ from apiserver.apimodels.workers import (
|
||||
GetActivityReportRequest,
|
||||
GetActivityReportResponse,
|
||||
ActivityReportSeries,
|
||||
SetRuntimePropertiesRequest,
|
||||
GetRuntimePropertiesRequest,
|
||||
GetRuntimePropertiesResponse,
|
||||
SetRuntimePropertiesResponse
|
||||
)
|
||||
from apiserver.bll.util import extract_properties_to_lists
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
@ -202,3 +206,35 @@ def get_stats(call: APICall, company_id, request: GetStatsRequest):
|
||||
for worker, stats in ret.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"workers.set_runtime_properties",
|
||||
min_version="2.10",
|
||||
request_data_model=SetRuntimePropertiesRequest,
|
||||
response_data_model=SetRuntimePropertiesResponse,
|
||||
)
|
||||
def set_runtime_properties(call: APICall, company_id, request: SetRuntimePropertiesRequest):
|
||||
res = worker_bll.set_runtime_properties(
|
||||
company=company_id,
|
||||
user=call.identity.user,
|
||||
worker_id=request.worker,
|
||||
runtime_properties=request.runtime_properties,
|
||||
)
|
||||
return SetRuntimePropertiesResponse(added=res["added"], removed=res["removed"], errors=res["errors"])
|
||||
|
||||
|
||||
@endpoint(
|
||||
"workers.get_runtime_properties",
|
||||
min_version="2.10",
|
||||
request_data_model=GetRuntimePropertiesRequest,
|
||||
response_data_model=GetRuntimePropertiesResponse,
|
||||
)
|
||||
def get_runtime_properties(call: APICall, company_id, request: GetRuntimePropertiesRequest):
|
||||
return GetRuntimePropertiesResponse(
|
||||
runtime_properties=worker_bll.get_runtime_properties(
|
||||
company=company_id,
|
||||
user=call.identity.user,
|
||||
worker_id=request.worker,
|
||||
)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user