mirror of
https://github.com/clearml/clearml-agent
synced 2025-01-31 17:16:51 +00:00
5ed47d2d2c
Add support for FORCE_CLEARML_AGENT_REPO env var to allow installing agent from a repo url when executing a task Implement skip venv installation on execute and allow custom binary Fix services mode limit implementation in docker mode
410 lines
15 KiB
Python
410 lines
15 KiB
Python
from __future__ import unicode_literals, print_function
|
|
|
|
import copy
|
|
import re
|
|
import sys
|
|
from abc import abstractmethod
|
|
from functools import wraps
|
|
from operator import attrgetter
|
|
from traceback import print_exc
|
|
from typing import Text
|
|
|
|
from clearml_agent.helper.console import ListFormatter, print_text
|
|
from clearml_agent.helper.dicts import filter_keys
|
|
|
|
import six
|
|
from clearml_agent.backend_api import services
|
|
|
|
from clearml_agent.errors import APIError, CommandFailedError
|
|
from clearml_agent.helper.base import Singleton, return_list, print_parameters, dump_yaml, load_yaml, error, warning
|
|
from clearml_agent.interface.base import ObjectID
|
|
from clearml_agent.session import Session
|
|
|
|
|
|
class NameResolutionError(CommandFailedError):
|
|
|
|
def __init__(self, message, suggestions=''):
|
|
super(NameResolutionError, self).__init__(message)
|
|
self.message = message
|
|
self.suggestions = suggestions
|
|
|
|
def __str__(self):
|
|
return self.message + self.suggestions
|
|
|
|
|
|
def resolve_names(func):
|
|
def safe_resolve(command, arg):
|
|
try:
|
|
result = command._resolve_name(arg.name, arg.service)
|
|
return result, None
|
|
except NameResolutionError:
|
|
return arg.name, sys.exc_info()
|
|
|
|
def _resolve_single(command, arg):
|
|
if isinstance(arg, ObjectID):
|
|
return command._resolve_name(arg.name, arg.service)
|
|
elif isinstance(arg, (list, tuple)) and all(isinstance(x, ObjectID) for x in arg):
|
|
result = [safe_resolve(command, x) for x in arg]
|
|
if len(result) == 1:
|
|
name, ex = result[0]
|
|
if ex:
|
|
six.reraise(*ex)
|
|
return [name]
|
|
for _, ex in result:
|
|
if ex:
|
|
command.warning(ex[1].message)
|
|
return [name for (name, _) in result]
|
|
return arg
|
|
|
|
@wraps(func)
|
|
def newfunc(self, *args, **kwargs):
|
|
args = [_resolve_single(self, arg) for arg in args]
|
|
kwargs = {key: _resolve_single(self, value) for key, value in kwargs.items()}
|
|
return func(self, *args, **kwargs)
|
|
return newfunc
|
|
|
|
|
|
class BaseCommandSection(object):
|
|
"""
|
|
Base class for command sections which do not interact with the allegro API.
|
|
Has basic utilities for user interaction.
|
|
"""
|
|
warning = staticmethod(warning)
|
|
error = staticmethod(error)
|
|
|
|
@staticmethod
|
|
def log(message, *args):
|
|
print("clearml-agent: {}".format(message % args))
|
|
|
|
@classmethod
|
|
def exit(cls, message, code=1): # type: (Text, int) -> ()
|
|
cls.error(message)
|
|
sys.exit(code)
|
|
|
|
|
|
@six.add_metaclass(Singleton)
|
|
class ServiceCommandSection(BaseCommandSection):
|
|
"""
|
|
Base class for command sections which interact with the allegro API.
|
|
Contains API functionality which is common across services.
|
|
"""
|
|
|
|
_worker_name = None
|
|
MAX_SUGGESTIONS = 10
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(ServiceCommandSection, self).__init__()
|
|
kwargs = self._verify_command_states(kwargs)
|
|
self._session = self._get_session(*args, **kwargs)
|
|
self._list_formatter = ListFormatter(self.service)
|
|
|
|
@classmethod
|
|
def _verify_command_states(cls, kwargs):
|
|
"""
|
|
Conform and enforce command argument
|
|
This is where you can automatically turn on/off switches based on different states.
|
|
:param kwargs:
|
|
:return: kwargs
|
|
"""
|
|
return kwargs
|
|
|
|
@staticmethod
|
|
def _get_session(*args, **kwargs):
|
|
return Session(*args, **kwargs)
|
|
|
|
@property
|
|
@abstractmethod
|
|
def service(self):
|
|
""" The name of the REST service used by this command """
|
|
pass
|
|
|
|
def get(self, endpoint, *args, session=None, **kwargs):
|
|
session = session or self._session
|
|
return session.get(service=self.service, action=endpoint, *args, **kwargs)
|
|
|
|
def post(self, endpoint, *args, session=None, **kwargs):
|
|
session = session or self._session
|
|
return session.post(service=self.service, action=endpoint, *args, **kwargs)
|
|
|
|
def get_with_act_as(self, endpoint, *args, **kwargs):
|
|
return self._session.get_with_act_as(service=self.service, action=endpoint, *args, **kwargs)
|
|
|
|
@property
|
|
def name(self):
|
|
return self.service.title()
|
|
|
|
@property
|
|
def name_single(self):
|
|
return self.name.rstrip('s')
|
|
|
|
@property
|
|
def service_single(self):
|
|
return self.service.rstrip('s')
|
|
|
|
@resolve_names
|
|
def __info(self, id=None, yaml=None, **kwargs):
|
|
ids = return_list(id)
|
|
if not ids:
|
|
return
|
|
|
|
yaml_dump = {}
|
|
|
|
for i in ids:
|
|
get_fields = {self.service_single: i}
|
|
try:
|
|
info = self.get('get_by_id', **get_fields)
|
|
yaml_dump[i] = info[self.service_single]
|
|
except APIError:
|
|
self.error('Failed retrieving info for {} {}'.format(self.service_single, i))
|
|
|
|
self.output_info(yaml_dump, yaml_path=yaml, **kwargs)
|
|
return yaml_dump
|
|
|
|
@resolve_names
|
|
def _info(self, *args, **kwargs):
|
|
self.__info(*args, **kwargs)
|
|
|
|
@staticmethod
|
|
def output_info(entries, quiet=False, yaml_path=None, **_):
|
|
if not quiet and entries:
|
|
print_parameters(entries, indent=4)
|
|
|
|
if yaml_path:
|
|
print('Storing entries to [{}]'.format(yaml_path))
|
|
dump_yaml(entries, yaml_path)
|
|
|
|
@staticmethod
|
|
def _make_query(json, table, sort=None, projection_from_table=False, extra_fields=None):
|
|
json = json.copy()
|
|
if isinstance(table, six.string_types):
|
|
table = table.split('#')
|
|
|
|
if extra_fields:
|
|
table.extend(extra_fields)
|
|
|
|
if projection_from_table:
|
|
json['only_fields'] = table
|
|
|
|
if sort:
|
|
# does nothing if 'order_by' is not in get_fields
|
|
json['order_by'] = sort.split('#')[0]
|
|
return json, table
|
|
|
|
def _get_all(self, endpoint, json, retpoint=None):
|
|
return self.get(endpoint, **json).get(retpoint or self.service, [])
|
|
|
|
@resolve_names
|
|
def _update(self,
|
|
endpoint='update',
|
|
send_diff=False,
|
|
quiet=False,
|
|
primary_key='id',
|
|
override=None,
|
|
model_desc=None,
|
|
yaml=None,
|
|
**kwargs):
|
|
|
|
if not yaml and primary_key not in kwargs:
|
|
raise ValueError('Update must supply either yaml file or %s-id' % self.service_single)
|
|
|
|
data_entries = {}
|
|
original_data_entries = {}
|
|
if yaml:
|
|
data_entries = load_yaml(yaml)
|
|
|
|
if send_diff or (not yaml and primary_key in kwargs):
|
|
i = kwargs.get(primary_key) or next(iter(data_entries))
|
|
original_info = self.__info(id=i, quiet=True)[i]
|
|
|
|
if send_diff:
|
|
original_data_entries[i] = copy.deepcopy(original_info)
|
|
if yaml and i not in data_entries:
|
|
if len(data_entries) > 1:
|
|
raise ValueError(
|
|
'Error: yaml file [%s] contains more than one task id' % yaml)
|
|
first_key = next(iter(data_entries))
|
|
if first_key != i:
|
|
if kwargs.get('force'):
|
|
if not quiet:
|
|
print('Warning: overriding yaml task id [%s] with id=%s' % (first_key, i))
|
|
else:
|
|
raise ValueError(
|
|
'Error: yaml task id [%s] != id [%s], use --force to override' % (first_key, i))
|
|
data_entries = {i: data_entries[first_key]}
|
|
data_entries[i][primary_key] = i
|
|
elif not yaml:
|
|
data_entries[i] = kwargs
|
|
|
|
if model_desc:
|
|
first_key = next(iter(data_entries))
|
|
with open(model_desc) as f:
|
|
proto_data = f.read()
|
|
info = data_entries[first_key]
|
|
info['execution']['model_desc']['prototxt'] = proto_data
|
|
|
|
if override:
|
|
first_key = next(iter(data_entries))
|
|
info = data_entries[first_key]
|
|
for p in override:
|
|
key, val = p.split('=') if isinstance(p, six.string_types) else p
|
|
info_key = info
|
|
keys = key.split('.')
|
|
for k in keys[:-1]:
|
|
if not info_key.get(k):
|
|
info_key[k] = dict()
|
|
info_key = info_key[k]
|
|
info_key[keys[-1]] = val
|
|
|
|
# always make sure tags is a list of strings
|
|
# split string to tokens ':'
|
|
# examples tags='auto_generated:draft'
|
|
if info.get('tags') is not None:
|
|
# remove empty strings from list
|
|
info['tags'] = [t for t in info['tags'].split(':') if t]
|
|
|
|
# send only change set
|
|
if send_diff:
|
|
# only send the values that changed
|
|
for i, info in data_entries.items():
|
|
org_info = original_data_entries[i]
|
|
out_info = {}
|
|
recursive_diff(org_info, info, out_info)
|
|
data_entries[i] = out_info
|
|
|
|
for i, info in data_entries.items():
|
|
if not info or len(info) == 0 or list(info) == [primary_key]:
|
|
if not quiet:
|
|
print('Skipping: nothing to update for %s id [%s]' % (self.service_single, i))
|
|
continue
|
|
if not quiet:
|
|
print('Updating %s id [%s]' % (self.service_single, i))
|
|
info[self.service_single] = i
|
|
result = self.get(endpoint, **info)
|
|
if not result['updated']:
|
|
raise ValueError('Failed updating %s id [%s]' % (self.service_single, i))
|
|
if not quiet:
|
|
print('%s [%s] updated fields: %s' % (self.name_single, i, result.get('fields', '')))
|
|
|
|
@resolve_names
|
|
def remove(self, ids, **kwargs):
|
|
return self._apply_command(
|
|
request_cls=getattr(services, self.service).DeleteRequest,
|
|
object_ids=ids,
|
|
response_validation_field='deleted',
|
|
**kwargs
|
|
)
|
|
|
|
def _apply_command(self, request_cls, object_ids, response_validation_field=None, **kwargs):
|
|
object_ids = return_list(object_ids)
|
|
|
|
def call_one(object_id):
|
|
error_message = '[{object_id}]: failed'.format(**locals())
|
|
try:
|
|
response = self._session.send_api(request_cls(object_id, **kwargs))
|
|
except APIError as e:
|
|
if not self._session.debug_mode:
|
|
self.error('{}: {}'.format(error_message, e))
|
|
else:
|
|
traceback = e.format_traceback()
|
|
if traceback:
|
|
print(traceback)
|
|
print('Own traceback:')
|
|
print_exc()
|
|
return False
|
|
|
|
if not response_validation_field or getattr(response, response_validation_field) == 1:
|
|
return True
|
|
else:
|
|
self.error(error_message)
|
|
return False
|
|
|
|
succeeded = [call_one(object_id) for object_id in object_ids].count(True)
|
|
message = '{}/{} succeeded'.format(succeeded, len(object_ids))
|
|
(self.log if succeeded == len(object_ids) else self.exit)(message)
|
|
|
|
def get_service(self, service_class):
|
|
return service_class(config=self._session.config)
|
|
|
|
def _resolve_name(self, name, service=None):
|
|
"""
|
|
Resolve an object name to an object ID.
|
|
Operation:
|
|
- If the argument "looks like" an ID, return it.
|
|
- Else, get all object with names containing the argument
|
|
- if an object with the argument as its name exists, return the object's ID
|
|
- Else, print a list of suggestions and exit
|
|
:param str name: ID (returned unmodified) or Name to resolve
|
|
:param str service: Service to resolve from (type of object). Defaults to service represented by the class
|
|
:return: ID of object
|
|
:rtype: str
|
|
"""
|
|
service = service or self.service
|
|
if re.match(r'^[-a-f0-9]{30,}$', name):
|
|
return name
|
|
|
|
try:
|
|
request_cls = getattr(services, service).GetAllRequest
|
|
except AttributeError:
|
|
raise NameResolutionError('Name resolution unavailable for {}'.format(service))
|
|
|
|
request = request_cls.from_dict(dict(name=name, only_fields=['name', 'id']))
|
|
# from_dict will ignore unrecognised keyword arguments - not all GetAll's have only_fields
|
|
response = getattr(self._session.send_api(request), service)
|
|
matches = [db_object for db_object in response if name.lower() == db_object.name.lower()]
|
|
|
|
def truncated_bullet_list(format_string, elements, callback, **kwargs):
|
|
if len(elements) > self.MAX_SUGGESTIONS:
|
|
kwargs.update(
|
|
dict(details=' (showing {}/{})'.format(self.MAX_SUGGESTIONS, len(elements)), suffix='\n...'))
|
|
else:
|
|
kwargs.update(dict(details='', suffix=''))
|
|
bullet_list = '\n'.join('* {}'.format(callback(item)) for item in elements[:self.MAX_SUGGESTIONS])
|
|
return format_string.format(bullet_list, **kwargs)
|
|
|
|
if len(matches) == 1:
|
|
return matches.pop().id
|
|
elif len(matches) > 1:
|
|
message = truncated_bullet_list(
|
|
'Found multiple {service} with name "{name}"{details}:\n{}{suffix}',
|
|
matches,
|
|
callback=attrgetter('id'),
|
|
**locals())
|
|
self.exit(message)
|
|
|
|
message = 'Could not find {} with name/id "{}"'.format(service.rstrip('s'), name)
|
|
|
|
if not response:
|
|
raise NameResolutionError(message)
|
|
|
|
suggestions = truncated_bullet_list(
|
|
'. Did you mean this?{details}\n{}{suffix}',
|
|
sorted(response, key=attrgetter('name')),
|
|
lambda db_object: '({}) {}'.format(db_object.id, db_object.name)
|
|
)
|
|
raise NameResolutionError(message, suggestions)
|
|
|
|
|
|
def recursive_diff(org, upd, out):
|
|
if isinstance(upd, dict) and isinstance(org, dict):
|
|
diff_keys = [
|
|
k for k in upd
|
|
if k not in org or upd[k] != org[k]
|
|
]
|
|
if diff_keys:
|
|
has_nested_dict = False
|
|
for k in diff_keys:
|
|
if isinstance(upd[k], dict):
|
|
out[k] = {}
|
|
has_nested_dict = True
|
|
k_has_nested = recursive_diff(
|
|
org.get(k, {}), upd[k], out[k])
|
|
if not k_has_nested:
|
|
out[k] = upd[k]
|
|
elif upd[k] is not None:
|
|
out[k] = upd[k]
|
|
return has_nested_dict
|
|
elif isinstance(upd, list) and isinstance(org, list):
|
|
diff_list = [k for k in upd if k not in org]
|
|
out.extend(diff_list)
|
|
return False
|